diffusers 0.27.1__py3-none-any.whl → 0.32.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (445) hide show
  1. diffusers/__init__.py +233 -6
  2. diffusers/callbacks.py +209 -0
  3. diffusers/commands/env.py +102 -6
  4. diffusers/configuration_utils.py +45 -16
  5. diffusers/dependency_versions_table.py +4 -3
  6. diffusers/image_processor.py +434 -110
  7. diffusers/loaders/__init__.py +42 -9
  8. diffusers/loaders/ip_adapter.py +626 -36
  9. diffusers/loaders/lora_base.py +900 -0
  10. diffusers/loaders/lora_conversion_utils.py +991 -125
  11. diffusers/loaders/lora_pipeline.py +3812 -0
  12. diffusers/loaders/peft.py +571 -7
  13. diffusers/loaders/single_file.py +405 -173
  14. diffusers/loaders/single_file_model.py +385 -0
  15. diffusers/loaders/single_file_utils.py +1783 -713
  16. diffusers/loaders/textual_inversion.py +41 -23
  17. diffusers/loaders/transformer_flux.py +181 -0
  18. diffusers/loaders/transformer_sd3.py +89 -0
  19. diffusers/loaders/unet.py +464 -540
  20. diffusers/loaders/unet_loader_utils.py +163 -0
  21. diffusers/models/__init__.py +76 -7
  22. diffusers/models/activations.py +65 -10
  23. diffusers/models/adapter.py +53 -53
  24. diffusers/models/attention.py +605 -18
  25. diffusers/models/attention_flax.py +1 -1
  26. diffusers/models/attention_processor.py +4304 -687
  27. diffusers/models/autoencoders/__init__.py +8 -0
  28. diffusers/models/autoencoders/autoencoder_asym_kl.py +15 -17
  29. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  30. diffusers/models/autoencoders/autoencoder_kl.py +110 -28
  31. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  32. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1482 -0
  33. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  34. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  35. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  36. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +19 -24
  37. diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
  38. diffusers/models/autoencoders/autoencoder_tiny.py +21 -18
  39. diffusers/models/autoencoders/consistency_decoder_vae.py +45 -20
  40. diffusers/models/autoencoders/vae.py +41 -29
  41. diffusers/models/autoencoders/vq_model.py +182 -0
  42. diffusers/models/controlnet.py +47 -800
  43. diffusers/models/controlnet_flux.py +70 -0
  44. diffusers/models/controlnet_sd3.py +68 -0
  45. diffusers/models/controlnet_sparsectrl.py +116 -0
  46. diffusers/models/controlnets/__init__.py +23 -0
  47. diffusers/models/controlnets/controlnet.py +872 -0
  48. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +9 -9
  49. diffusers/models/controlnets/controlnet_flux.py +536 -0
  50. diffusers/models/controlnets/controlnet_hunyuan.py +401 -0
  51. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  52. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  53. diffusers/models/controlnets/controlnet_union.py +832 -0
  54. diffusers/models/controlnets/controlnet_xs.py +1946 -0
  55. diffusers/models/controlnets/multicontrolnet.py +183 -0
  56. diffusers/models/downsampling.py +85 -18
  57. diffusers/models/embeddings.py +1856 -158
  58. diffusers/models/embeddings_flax.py +23 -9
  59. diffusers/models/model_loading_utils.py +480 -0
  60. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  61. diffusers/models/modeling_flax_utils.py +2 -7
  62. diffusers/models/modeling_outputs.py +14 -0
  63. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  64. diffusers/models/modeling_utils.py +611 -146
  65. diffusers/models/normalization.py +361 -20
  66. diffusers/models/resnet.py +18 -23
  67. diffusers/models/transformers/__init__.py +16 -0
  68. diffusers/models/transformers/auraflow_transformer_2d.py +544 -0
  69. diffusers/models/transformers/cogvideox_transformer_3d.py +542 -0
  70. diffusers/models/transformers/dit_transformer_2d.py +240 -0
  71. diffusers/models/transformers/dual_transformer_2d.py +9 -8
  72. diffusers/models/transformers/hunyuan_transformer_2d.py +578 -0
  73. diffusers/models/transformers/latte_transformer_3d.py +327 -0
  74. diffusers/models/transformers/lumina_nextdit2d.py +340 -0
  75. diffusers/models/transformers/pixart_transformer_2d.py +445 -0
  76. diffusers/models/transformers/prior_transformer.py +13 -13
  77. diffusers/models/transformers/sana_transformer.py +488 -0
  78. diffusers/models/transformers/stable_audio_transformer.py +458 -0
  79. diffusers/models/transformers/t5_film_transformer.py +17 -19
  80. diffusers/models/transformers/transformer_2d.py +297 -187
  81. diffusers/models/transformers/transformer_allegro.py +422 -0
  82. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  83. diffusers/models/transformers/transformer_flux.py +593 -0
  84. diffusers/models/transformers/transformer_hunyuan_video.py +791 -0
  85. diffusers/models/transformers/transformer_ltx.py +469 -0
  86. diffusers/models/transformers/transformer_mochi.py +499 -0
  87. diffusers/models/transformers/transformer_sd3.py +461 -0
  88. diffusers/models/transformers/transformer_temporal.py +21 -19
  89. diffusers/models/unets/unet_1d.py +8 -8
  90. diffusers/models/unets/unet_1d_blocks.py +31 -31
  91. diffusers/models/unets/unet_2d.py +17 -10
  92. diffusers/models/unets/unet_2d_blocks.py +225 -149
  93. diffusers/models/unets/unet_2d_condition.py +41 -40
  94. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  95. diffusers/models/unets/unet_3d_blocks.py +192 -1057
  96. diffusers/models/unets/unet_3d_condition.py +22 -27
  97. diffusers/models/unets/unet_i2vgen_xl.py +22 -18
  98. diffusers/models/unets/unet_kandinsky3.py +2 -2
  99. diffusers/models/unets/unet_motion_model.py +1413 -89
  100. diffusers/models/unets/unet_spatio_temporal_condition.py +40 -16
  101. diffusers/models/unets/unet_stable_cascade.py +19 -18
  102. diffusers/models/unets/uvit_2d.py +2 -2
  103. diffusers/models/upsampling.py +95 -26
  104. diffusers/models/vq_model.py +12 -164
  105. diffusers/optimization.py +1 -1
  106. diffusers/pipelines/__init__.py +202 -3
  107. diffusers/pipelines/allegro/__init__.py +48 -0
  108. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  109. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  110. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  111. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  112. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  113. diffusers/pipelines/animatediff/__init__.py +8 -0
  114. diffusers/pipelines/animatediff/pipeline_animatediff.py +122 -109
  115. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1106 -0
  116. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1288 -0
  117. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1010 -0
  118. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +236 -180
  119. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  120. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  121. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  122. diffusers/pipelines/audioldm2/modeling_audioldm2.py +58 -39
  123. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +121 -36
  124. diffusers/pipelines/aura_flow/__init__.py +48 -0
  125. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +584 -0
  126. diffusers/pipelines/auto_pipeline.py +196 -28
  127. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  128. diffusers/pipelines/blip_diffusion/modeling_blip2.py +6 -6
  129. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  130. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  131. diffusers/pipelines/cogvideo/__init__.py +54 -0
  132. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +772 -0
  133. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  134. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +885 -0
  135. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +851 -0
  136. diffusers/pipelines/cogvideo/pipeline_output.py +20 -0
  137. diffusers/pipelines/cogview3/__init__.py +47 -0
  138. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  139. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  140. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +6 -6
  141. diffusers/pipelines/controlnet/__init__.py +86 -80
  142. diffusers/pipelines/controlnet/multicontrolnet.py +7 -182
  143. diffusers/pipelines/controlnet/pipeline_controlnet.py +134 -87
  144. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  145. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +93 -77
  146. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +88 -197
  147. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +136 -90
  148. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +176 -80
  149. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +125 -89
  150. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  151. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  152. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  153. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  154. diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
  155. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1060 -0
  156. diffusers/pipelines/controlnet_sd3/__init__.py +57 -0
  157. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +1133 -0
  158. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  159. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  160. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +916 -0
  161. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1111 -0
  162. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  163. diffusers/pipelines/deepfloyd_if/pipeline_if.py +16 -30
  164. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +20 -35
  165. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +23 -41
  166. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +22 -38
  167. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +25 -41
  168. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +19 -34
  169. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  170. diffusers/pipelines/deepfloyd_if/watermark.py +1 -1
  171. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  172. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +70 -30
  173. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +48 -25
  174. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  175. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  176. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +21 -20
  177. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +27 -29
  178. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +33 -27
  179. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +33 -23
  180. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +36 -30
  181. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +102 -69
  182. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  183. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  184. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  185. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  186. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  187. diffusers/pipelines/dit/pipeline_dit.py +7 -4
  188. diffusers/pipelines/flux/__init__.py +69 -0
  189. diffusers/pipelines/flux/modeling_flux.py +47 -0
  190. diffusers/pipelines/flux/pipeline_flux.py +957 -0
  191. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  192. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  193. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  194. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  195. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  196. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  197. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  198. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  199. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  200. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  201. diffusers/pipelines/flux/pipeline_output.py +37 -0
  202. diffusers/pipelines/free_init_utils.py +41 -38
  203. diffusers/pipelines/free_noise_utils.py +596 -0
  204. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  205. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  206. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  207. diffusers/pipelines/hunyuandit/__init__.py +48 -0
  208. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +916 -0
  209. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  210. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  211. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +32 -29
  212. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  213. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  214. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  215. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  216. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +34 -31
  217. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  218. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  219. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  220. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  221. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  222. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  223. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  224. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +22 -35
  225. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +26 -37
  226. diffusers/pipelines/kolors/__init__.py +54 -0
  227. diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
  228. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1250 -0
  229. diffusers/pipelines/kolors/pipeline_output.py +21 -0
  230. diffusers/pipelines/kolors/text_encoder.py +889 -0
  231. diffusers/pipelines/kolors/tokenizer.py +338 -0
  232. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +82 -62
  233. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +77 -60
  234. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +12 -12
  235. diffusers/pipelines/latte/__init__.py +48 -0
  236. diffusers/pipelines/latte/pipeline_latte.py +881 -0
  237. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +80 -74
  238. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +85 -76
  239. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  240. diffusers/pipelines/ltx/__init__.py +50 -0
  241. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  242. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  243. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  244. diffusers/pipelines/lumina/__init__.py +48 -0
  245. diffusers/pipelines/lumina/pipeline_lumina.py +890 -0
  246. diffusers/pipelines/marigold/__init__.py +50 -0
  247. diffusers/pipelines/marigold/marigold_image_processing.py +576 -0
  248. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  249. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  250. diffusers/pipelines/mochi/__init__.py +48 -0
  251. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  252. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  253. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  254. diffusers/pipelines/pag/__init__.py +80 -0
  255. diffusers/pipelines/pag/pag_utils.py +243 -0
  256. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1328 -0
  257. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  258. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1610 -0
  259. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  260. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +969 -0
  261. diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
  262. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +865 -0
  263. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  264. diffusers/pipelines/pag/pipeline_pag_sd.py +1062 -0
  265. diffusers/pipelines/pag/pipeline_pag_sd_3.py +994 -0
  266. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  267. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +866 -0
  268. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  269. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  270. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1345 -0
  271. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1544 -0
  272. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1776 -0
  273. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  274. diffusers/pipelines/pia/pipeline_pia.py +74 -164
  275. diffusers/pipelines/pipeline_flax_utils.py +5 -10
  276. diffusers/pipelines/pipeline_loading_utils.py +515 -53
  277. diffusers/pipelines/pipeline_utils.py +411 -222
  278. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  279. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +76 -93
  280. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +873 -0
  281. diffusers/pipelines/sana/__init__.py +47 -0
  282. diffusers/pipelines/sana/pipeline_output.py +21 -0
  283. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  284. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +27 -23
  285. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  286. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  287. diffusers/pipelines/shap_e/renderer.py +1 -1
  288. diffusers/pipelines/stable_audio/__init__.py +50 -0
  289. diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
  290. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +756 -0
  291. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +71 -25
  292. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  293. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +35 -34
  294. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  295. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +20 -11
  296. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  297. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  298. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  299. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +145 -79
  300. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +43 -28
  301. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  302. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +100 -68
  303. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +109 -201
  304. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +131 -32
  305. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +247 -87
  306. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +30 -29
  307. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +35 -27
  308. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +49 -42
  309. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  310. diffusers/pipelines/stable_diffusion_3/__init__.py +54 -0
  311. diffusers/pipelines/stable_diffusion_3/pipeline_output.py +21 -0
  312. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +1140 -0
  313. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +1036 -0
  314. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1250 -0
  315. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +29 -20
  316. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +59 -58
  317. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +31 -25
  318. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +38 -22
  319. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -24
  320. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -23
  321. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +107 -67
  322. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +316 -69
  323. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  324. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  325. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +98 -30
  326. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +121 -83
  327. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +161 -105
  328. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +142 -218
  329. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -29
  330. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  331. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  332. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +69 -39
  333. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +105 -74
  334. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  335. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +29 -49
  336. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +32 -93
  337. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +37 -25
  338. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +54 -40
  339. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  340. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  341. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  342. diffusers/pipelines/unidiffuser/modeling_uvit.py +12 -12
  343. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +29 -28
  344. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  345. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  346. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +6 -8
  347. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  348. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  349. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +15 -14
  350. diffusers/{models/dual_transformer_2d.py → quantizers/__init__.py} +2 -6
  351. diffusers/quantizers/auto.py +139 -0
  352. diffusers/quantizers/base.py +233 -0
  353. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  354. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  355. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  356. diffusers/quantizers/gguf/__init__.py +1 -0
  357. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  358. diffusers/quantizers/gguf/utils.py +456 -0
  359. diffusers/quantizers/quantization_config.py +669 -0
  360. diffusers/quantizers/torchao/__init__.py +15 -0
  361. diffusers/quantizers/torchao/torchao_quantizer.py +292 -0
  362. diffusers/schedulers/__init__.py +12 -2
  363. diffusers/schedulers/deprecated/__init__.py +1 -1
  364. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  365. diffusers/schedulers/scheduling_amused.py +5 -5
  366. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  367. diffusers/schedulers/scheduling_consistency_models.py +23 -25
  368. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
  369. diffusers/schedulers/scheduling_ddim.py +27 -26
  370. diffusers/schedulers/scheduling_ddim_cogvideox.py +452 -0
  371. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  372. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  373. diffusers/schedulers/scheduling_ddim_parallel.py +32 -31
  374. diffusers/schedulers/scheduling_ddpm.py +27 -30
  375. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  376. diffusers/schedulers/scheduling_ddpm_parallel.py +33 -36
  377. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  378. diffusers/schedulers/scheduling_deis_multistep.py +150 -50
  379. diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
  380. diffusers/schedulers/scheduling_dpmsolver_multistep.py +221 -84
  381. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  382. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +158 -52
  383. diffusers/schedulers/scheduling_dpmsolver_sde.py +153 -34
  384. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +275 -86
  385. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +81 -57
  386. diffusers/schedulers/scheduling_edm_euler.py +62 -39
  387. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +30 -29
  388. diffusers/schedulers/scheduling_euler_discrete.py +255 -74
  389. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +458 -0
  390. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +320 -0
  391. diffusers/schedulers/scheduling_heun_discrete.py +174 -46
  392. diffusers/schedulers/scheduling_ipndm.py +9 -9
  393. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +138 -29
  394. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +132 -26
  395. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  396. diffusers/schedulers/scheduling_lcm.py +23 -29
  397. diffusers/schedulers/scheduling_lms_discrete.py +105 -28
  398. diffusers/schedulers/scheduling_pndm.py +20 -20
  399. diffusers/schedulers/scheduling_repaint.py +21 -21
  400. diffusers/schedulers/scheduling_sasolver.py +157 -60
  401. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  402. diffusers/schedulers/scheduling_tcd.py +41 -36
  403. diffusers/schedulers/scheduling_unclip.py +19 -16
  404. diffusers/schedulers/scheduling_unipc_multistep.py +243 -47
  405. diffusers/schedulers/scheduling_utils.py +12 -5
  406. diffusers/schedulers/scheduling_utils_flax.py +1 -3
  407. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  408. diffusers/training_utils.py +214 -30
  409. diffusers/utils/__init__.py +17 -1
  410. diffusers/utils/constants.py +3 -0
  411. diffusers/utils/doc_utils.py +1 -0
  412. diffusers/utils/dummy_pt_objects.py +592 -7
  413. diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
  414. diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
  415. diffusers/utils/dummy_torch_and_transformers_objects.py +1001 -71
  416. diffusers/utils/dynamic_modules_utils.py +34 -29
  417. diffusers/utils/export_utils.py +50 -6
  418. diffusers/utils/hub_utils.py +131 -17
  419. diffusers/utils/import_utils.py +210 -8
  420. diffusers/utils/loading_utils.py +118 -5
  421. diffusers/utils/logging.py +4 -2
  422. diffusers/utils/peft_utils.py +37 -7
  423. diffusers/utils/state_dict_utils.py +13 -2
  424. diffusers/utils/testing_utils.py +193 -11
  425. diffusers/utils/torch_utils.py +4 -0
  426. diffusers/video_processor.py +113 -0
  427. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/METADATA +82 -91
  428. diffusers-0.32.2.dist-info/RECORD +550 -0
  429. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/WHEEL +1 -1
  430. diffusers/loaders/autoencoder.py +0 -146
  431. diffusers/loaders/controlnet.py +0 -136
  432. diffusers/loaders/lora.py +0 -1349
  433. diffusers/models/prior_transformer.py +0 -12
  434. diffusers/models/t5_film_transformer.py +0 -70
  435. diffusers/models/transformer_2d.py +0 -25
  436. diffusers/models/transformer_temporal.py +0 -34
  437. diffusers/models/unet_1d.py +0 -26
  438. diffusers/models/unet_1d_blocks.py +0 -203
  439. diffusers/models/unet_2d.py +0 -27
  440. diffusers/models/unet_2d_blocks.py +0 -375
  441. diffusers/models/unet_2d_condition.py +0 -25
  442. diffusers-0.27.1.dist-info/RECORD +0 -399
  443. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/LICENSE +0 -0
  444. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/entry_points.txt +0 -0
  445. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/top_level.txt +0 -0
@@ -14,7 +14,9 @@
14
14
 
15
15
  import re
16
16
 
17
- from ..utils import logging
17
+ import torch
18
+
19
+ from ..utils import is_peft_version, logging
18
20
 
19
21
 
20
22
  logger = logging.get_logger(__name__)
@@ -123,153 +125,100 @@ def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", b
123
125
  return new_state_dict
124
126
 
125
127
 
126
- def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_name="text_encoder"):
128
+ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_name="text_encoder"):
129
+ """
130
+ Converts a non-Diffusers LoRA state dict to a Diffusers compatible state dict.
131
+
132
+ Args:
133
+ state_dict (`dict`): The state dict to convert.
134
+ unet_name (`str`, optional): The name of the U-Net module in the Diffusers model. Defaults to "unet".
135
+ text_encoder_name (`str`, optional): The name of the text encoder module in the Diffusers model. Defaults to
136
+ "text_encoder".
137
+
138
+ Returns:
139
+ `tuple`: A tuple containing the converted state dict and a dictionary of alphas.
140
+ """
127
141
  unet_state_dict = {}
128
142
  te_state_dict = {}
129
143
  te2_state_dict = {}
130
144
  network_alphas = {}
131
145
 
132
- # every down weight has a corresponding up weight and potentially an alpha weight
133
- lora_keys = [k for k in state_dict.keys() if k.endswith("lora_down.weight")]
134
- for key in lora_keys:
146
+ # Check for DoRA-enabled LoRAs.
147
+ dora_present_in_unet = any("dora_scale" in k and "lora_unet_" in k for k in state_dict)
148
+ dora_present_in_te = any("dora_scale" in k and ("lora_te_" in k or "lora_te1_" in k) for k in state_dict)
149
+ dora_present_in_te2 = any("dora_scale" in k and "lora_te2_" in k for k in state_dict)
150
+ if dora_present_in_unet or dora_present_in_te or dora_present_in_te2:
151
+ if is_peft_version("<", "0.9.0"):
152
+ raise ValueError(
153
+ "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
154
+ )
155
+
156
+ # Iterate over all LoRA weights.
157
+ all_lora_keys = list(state_dict.keys())
158
+ for key in all_lora_keys:
159
+ if not key.endswith("lora_down.weight"):
160
+ continue
161
+
162
+ # Extract LoRA name.
135
163
  lora_name = key.split(".")[0]
164
+
165
+ # Find corresponding up weight and alpha.
136
166
  lora_name_up = lora_name + ".lora_up.weight"
137
167
  lora_name_alpha = lora_name + ".alpha"
138
168
 
169
+ # Handle U-Net LoRAs.
139
170
  if lora_name.startswith("lora_unet_"):
140
- diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
171
+ diffusers_name = _convert_unet_lora_key(key)
141
172
 
142
- if "input.blocks" in diffusers_name:
143
- diffusers_name = diffusers_name.replace("input.blocks", "down_blocks")
144
- else:
145
- diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
173
+ # Store down and up weights.
174
+ unet_state_dict[diffusers_name] = state_dict.pop(key)
175
+ unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
146
176
 
147
- if "middle.block" in diffusers_name:
148
- diffusers_name = diffusers_name.replace("middle.block", "mid_block")
149
- else:
150
- diffusers_name = diffusers_name.replace("mid.block", "mid_block")
151
- if "output.blocks" in diffusers_name:
152
- diffusers_name = diffusers_name.replace("output.blocks", "up_blocks")
153
- else:
154
- diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
155
-
156
- diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
157
- diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
158
- diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
159
- diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
160
- diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
161
- diffusers_name = diffusers_name.replace("proj.in", "proj_in")
162
- diffusers_name = diffusers_name.replace("proj.out", "proj_out")
163
- diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj")
164
-
165
- # SDXL specificity.
166
- if "emb" in diffusers_name and "time.emb.proj" not in diffusers_name:
167
- pattern = r"\.\d+(?=\D*$)"
168
- diffusers_name = re.sub(pattern, "", diffusers_name, count=1)
169
- if ".in." in diffusers_name:
170
- diffusers_name = diffusers_name.replace("in.layers.2", "conv1")
171
- if ".out." in diffusers_name:
172
- diffusers_name = diffusers_name.replace("out.layers.3", "conv2")
173
- if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name:
174
- diffusers_name = diffusers_name.replace("op", "conv")
175
- if "skip" in diffusers_name:
176
- diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut")
177
-
178
- # LyCORIS specificity.
179
- if "time.emb.proj" in diffusers_name:
180
- diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj")
181
- if "conv.shortcut" in diffusers_name:
182
- diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut")
183
-
184
- # General coverage.
185
- if "transformer_blocks" in diffusers_name:
186
- if "attn1" in diffusers_name or "attn2" in diffusers_name:
187
- diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
188
- diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
189
- unet_state_dict[diffusers_name] = state_dict.pop(key)
190
- unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
191
- elif "ff" in diffusers_name:
192
- unet_state_dict[diffusers_name] = state_dict.pop(key)
193
- unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
194
- elif any(key in diffusers_name for key in ("proj_in", "proj_out")):
195
- unet_state_dict[diffusers_name] = state_dict.pop(key)
196
- unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
197
- else:
198
- unet_state_dict[diffusers_name] = state_dict.pop(key)
199
- unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
200
-
201
- elif lora_name.startswith("lora_te_"):
202
- diffusers_name = key.replace("lora_te_", "").replace("_", ".")
203
- diffusers_name = diffusers_name.replace("text.model", "text_model")
204
- diffusers_name = diffusers_name.replace("self.attn", "self_attn")
205
- diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
206
- diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
207
- diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
208
- diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
209
- if "self_attn" in diffusers_name:
210
- te_state_dict[diffusers_name] = state_dict.pop(key)
211
- te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
212
- elif "mlp" in diffusers_name:
213
- # Be aware that this is the new diffusers convention and the rest of the code might
214
- # not utilize it yet.
215
- diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
216
- te_state_dict[diffusers_name] = state_dict.pop(key)
217
- te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
177
+ # Store DoRA scale if present.
178
+ if dora_present_in_unet:
179
+ dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down."
180
+ unet_state_dict[
181
+ diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")
182
+ ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
218
183
 
219
- # (sayakpaul): Duplicate code. Needs to be cleaned.
220
- elif lora_name.startswith("lora_te1_"):
221
- diffusers_name = key.replace("lora_te1_", "").replace("_", ".")
222
- diffusers_name = diffusers_name.replace("text.model", "text_model")
223
- diffusers_name = diffusers_name.replace("self.attn", "self_attn")
224
- diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
225
- diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
226
- diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
227
- diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
228
- if "self_attn" in diffusers_name:
229
- te_state_dict[diffusers_name] = state_dict.pop(key)
230
- te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
231
- elif "mlp" in diffusers_name:
232
- # Be aware that this is the new diffusers convention and the rest of the code might
233
- # not utilize it yet.
234
- diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
184
+ # Handle text encoder LoRAs.
185
+ elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
186
+ diffusers_name = _convert_text_encoder_lora_key(key, lora_name)
187
+
188
+ # Store down and up weights for te or te2.
189
+ if lora_name.startswith(("lora_te_", "lora_te1_")):
235
190
  te_state_dict[diffusers_name] = state_dict.pop(key)
236
191
  te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
237
-
238
- # (sayakpaul): Duplicate code. Needs to be cleaned.
239
- elif lora_name.startswith("lora_te2_"):
240
- diffusers_name = key.replace("lora_te2_", "").replace("_", ".")
241
- diffusers_name = diffusers_name.replace("text.model", "text_model")
242
- diffusers_name = diffusers_name.replace("self.attn", "self_attn")
243
- diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
244
- diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
245
- diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
246
- diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
247
- if "self_attn" in diffusers_name:
248
- te2_state_dict[diffusers_name] = state_dict.pop(key)
249
- te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
250
- elif "mlp" in diffusers_name:
251
- # Be aware that this is the new diffusers convention and the rest of the code might
252
- # not utilize it yet.
253
- diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
192
+ else:
254
193
  te2_state_dict[diffusers_name] = state_dict.pop(key)
255
194
  te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
256
195
 
257
- # Rename the alphas so that they can be mapped appropriately.
196
+ # Store DoRA scale if present.
197
+ if dora_present_in_te or dora_present_in_te2:
198
+ dora_scale_key_to_replace_te = (
199
+ "_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer."
200
+ )
201
+ if lora_name.startswith(("lora_te_", "lora_te1_")):
202
+ te_state_dict[
203
+ diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
204
+ ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
205
+ elif lora_name.startswith("lora_te2_"):
206
+ te2_state_dict[
207
+ diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
208
+ ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
209
+
210
+ # Store alpha if present.
258
211
  if lora_name_alpha in state_dict:
259
212
  alpha = state_dict.pop(lora_name_alpha).item()
260
- if lora_name_alpha.startswith("lora_unet_"):
261
- prefix = "unet."
262
- elif lora_name_alpha.startswith(("lora_te_", "lora_te1_")):
263
- prefix = "text_encoder."
264
- else:
265
- prefix = "text_encoder_2."
266
- new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha"
267
- network_alphas.update({new_name: alpha})
213
+ network_alphas.update(_get_alpha_name(lora_name_alpha, diffusers_name, alpha))
268
214
 
215
+ # Check if any keys remain.
269
216
  if len(state_dict) > 0:
270
- raise ValueError(f"The following keys have not been correctly be renamed: \n\n {', '.join(state_dict.keys())}")
217
+ raise ValueError(f"The following keys have not been correctly renamed: \n\n {', '.join(state_dict.keys())}")
218
+
219
+ logger.info("Non-diffusers checkpoint detected.")
271
220
 
272
- logger.info("Kohya-style checkpoint detected.")
221
+ # Construct final state dict.
273
222
  unet_state_dict = {f"{unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()}
274
223
  te_state_dict = {f"{text_encoder_name}.{module_name}": params for module_name, params in te_state_dict.items()}
275
224
  te2_state_dict = (
@@ -282,3 +231,920 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
282
231
 
283
232
  new_state_dict = {**unet_state_dict, **te_state_dict}
284
233
  return new_state_dict, network_alphas
234
+
235
+
236
+ def _convert_unet_lora_key(key):
237
+ """
238
+ Converts a U-Net LoRA key to a Diffusers compatible key.
239
+ """
240
+ diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
241
+
242
+ # Replace common U-Net naming patterns.
243
+ diffusers_name = diffusers_name.replace("input.blocks", "down_blocks")
244
+ diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
245
+ diffusers_name = diffusers_name.replace("middle.block", "mid_block")
246
+ diffusers_name = diffusers_name.replace("mid.block", "mid_block")
247
+ diffusers_name = diffusers_name.replace("output.blocks", "up_blocks")
248
+ diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
249
+ diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
250
+ diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
251
+ diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
252
+ diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
253
+ diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
254
+ diffusers_name = diffusers_name.replace("proj.in", "proj_in")
255
+ diffusers_name = diffusers_name.replace("proj.out", "proj_out")
256
+ diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj")
257
+
258
+ # SDXL specific conversions.
259
+ if "emb" in diffusers_name and "time.emb.proj" not in diffusers_name:
260
+ pattern = r"\.\d+(?=\D*$)"
261
+ diffusers_name = re.sub(pattern, "", diffusers_name, count=1)
262
+ if ".in." in diffusers_name:
263
+ diffusers_name = diffusers_name.replace("in.layers.2", "conv1")
264
+ if ".out." in diffusers_name:
265
+ diffusers_name = diffusers_name.replace("out.layers.3", "conv2")
266
+ if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name:
267
+ diffusers_name = diffusers_name.replace("op", "conv")
268
+ if "skip" in diffusers_name:
269
+ diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut")
270
+
271
+ # LyCORIS specific conversions.
272
+ if "time.emb.proj" in diffusers_name:
273
+ diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj")
274
+ if "conv.shortcut" in diffusers_name:
275
+ diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut")
276
+
277
+ # General conversions.
278
+ if "transformer_blocks" in diffusers_name:
279
+ if "attn1" in diffusers_name or "attn2" in diffusers_name:
280
+ diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
281
+ diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
282
+ elif "ff" in diffusers_name:
283
+ pass
284
+ elif any(key in diffusers_name for key in ("proj_in", "proj_out")):
285
+ pass
286
+ else:
287
+ pass
288
+
289
+ return diffusers_name
290
+
291
+
292
+ def _convert_text_encoder_lora_key(key, lora_name):
293
+ """
294
+ Converts a text encoder LoRA key to a Diffusers compatible key.
295
+ """
296
+ if lora_name.startswith(("lora_te_", "lora_te1_")):
297
+ key_to_replace = "lora_te_" if lora_name.startswith("lora_te_") else "lora_te1_"
298
+ else:
299
+ key_to_replace = "lora_te2_"
300
+
301
+ diffusers_name = key.replace(key_to_replace, "").replace("_", ".")
302
+ diffusers_name = diffusers_name.replace("text.model", "text_model")
303
+ diffusers_name = diffusers_name.replace("self.attn", "self_attn")
304
+ diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
305
+ diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
306
+ diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
307
+ diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
308
+ diffusers_name = diffusers_name.replace("text.projection", "text_projection")
309
+
310
+ if "self_attn" in diffusers_name or "text_projection" in diffusers_name:
311
+ pass
312
+ elif "mlp" in diffusers_name:
313
+ # Be aware that this is the new diffusers convention and the rest of the code might
314
+ # not utilize it yet.
315
+ diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
316
+ return diffusers_name
317
+
318
+
319
+ def _get_alpha_name(lora_name_alpha, diffusers_name, alpha):
320
+ """
321
+ Gets the correct alpha name for the Diffusers model.
322
+ """
323
+ if lora_name_alpha.startswith("lora_unet_"):
324
+ prefix = "unet."
325
+ elif lora_name_alpha.startswith(("lora_te_", "lora_te1_")):
326
+ prefix = "text_encoder."
327
+ else:
328
+ prefix = "text_encoder_2."
329
+ new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha"
330
+ return {new_name: alpha}
331
+
332
+
333
+ # The utilities under `_convert_kohya_flux_lora_to_diffusers()`
334
+ # are taken from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
335
+ # All credits go to `kohya-ss`.
336
+ def _convert_kohya_flux_lora_to_diffusers(state_dict):
337
+ def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
338
+ if sds_key + ".lora_down.weight" not in sds_sd:
339
+ return
340
+ down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
341
+
342
+ # scale weight by alpha and dim
343
+ rank = down_weight.shape[0]
344
+ alpha = sds_sd.pop(sds_key + ".alpha").item() # alpha is scalar
345
+ scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
346
+
347
+ # calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2
348
+ scale_down = scale
349
+ scale_up = 1.0
350
+ while scale_down * 2 < scale_up:
351
+ scale_down *= 2
352
+ scale_up /= 2
353
+
354
+ ait_sd[ait_key + ".lora_A.weight"] = down_weight * scale_down
355
+ ait_sd[ait_key + ".lora_B.weight"] = sds_sd.pop(sds_key + ".lora_up.weight") * scale_up
356
+
357
+ def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
358
+ if sds_key + ".lora_down.weight" not in sds_sd:
359
+ return
360
+ down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
361
+ up_weight = sds_sd.pop(sds_key + ".lora_up.weight")
362
+ sd_lora_rank = down_weight.shape[0]
363
+
364
+ # scale weight by alpha and dim
365
+ alpha = sds_sd.pop(sds_key + ".alpha")
366
+ scale = alpha / sd_lora_rank
367
+
368
+ # calculate scale_down and scale_up
369
+ scale_down = scale
370
+ scale_up = 1.0
371
+ while scale_down * 2 < scale_up:
372
+ scale_down *= 2
373
+ scale_up /= 2
374
+
375
+ down_weight = down_weight * scale_down
376
+ up_weight = up_weight * scale_up
377
+
378
+ # calculate dims if not provided
379
+ num_splits = len(ait_keys)
380
+ if dims is None:
381
+ dims = [up_weight.shape[0] // num_splits] * num_splits
382
+ else:
383
+ assert sum(dims) == up_weight.shape[0]
384
+
385
+ # check upweight is sparse or not
386
+ is_sparse = False
387
+ if sd_lora_rank % num_splits == 0:
388
+ ait_rank = sd_lora_rank // num_splits
389
+ is_sparse = True
390
+ i = 0
391
+ for j in range(len(dims)):
392
+ for k in range(len(dims)):
393
+ if j == k:
394
+ continue
395
+ is_sparse = is_sparse and torch.all(
396
+ up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0
397
+ )
398
+ i += dims[j]
399
+ if is_sparse:
400
+ logger.info(f"weight is sparse: {sds_key}")
401
+
402
+ # make ai-toolkit weight
403
+ ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
404
+ ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
405
+ if not is_sparse:
406
+ # down_weight is copied to each split
407
+ ait_sd.update({k: down_weight for k in ait_down_keys})
408
+
409
+ # up_weight is split to each split
410
+ ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
411
+ else:
412
+ # down_weight is chunked to each split
413
+ ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))}) # noqa: C416
414
+
415
+ # up_weight is sparse: only non-zero values are copied to each split
416
+ i = 0
417
+ for j in range(len(dims)):
418
+ ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous()
419
+ i += dims[j]
420
+
421
+ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
422
+ ait_sd = {}
423
+ for i in range(19):
424
+ _convert_to_ai_toolkit(
425
+ sds_sd,
426
+ ait_sd,
427
+ f"lora_unet_double_blocks_{i}_img_attn_proj",
428
+ f"transformer.transformer_blocks.{i}.attn.to_out.0",
429
+ )
430
+ _convert_to_ai_toolkit_cat(
431
+ sds_sd,
432
+ ait_sd,
433
+ f"lora_unet_double_blocks_{i}_img_attn_qkv",
434
+ [
435
+ f"transformer.transformer_blocks.{i}.attn.to_q",
436
+ f"transformer.transformer_blocks.{i}.attn.to_k",
437
+ f"transformer.transformer_blocks.{i}.attn.to_v",
438
+ ],
439
+ )
440
+ _convert_to_ai_toolkit(
441
+ sds_sd,
442
+ ait_sd,
443
+ f"lora_unet_double_blocks_{i}_img_mlp_0",
444
+ f"transformer.transformer_blocks.{i}.ff.net.0.proj",
445
+ )
446
+ _convert_to_ai_toolkit(
447
+ sds_sd,
448
+ ait_sd,
449
+ f"lora_unet_double_blocks_{i}_img_mlp_2",
450
+ f"transformer.transformer_blocks.{i}.ff.net.2",
451
+ )
452
+ _convert_to_ai_toolkit(
453
+ sds_sd,
454
+ ait_sd,
455
+ f"lora_unet_double_blocks_{i}_img_mod_lin",
456
+ f"transformer.transformer_blocks.{i}.norm1.linear",
457
+ )
458
+ _convert_to_ai_toolkit(
459
+ sds_sd,
460
+ ait_sd,
461
+ f"lora_unet_double_blocks_{i}_txt_attn_proj",
462
+ f"transformer.transformer_blocks.{i}.attn.to_add_out",
463
+ )
464
+ _convert_to_ai_toolkit_cat(
465
+ sds_sd,
466
+ ait_sd,
467
+ f"lora_unet_double_blocks_{i}_txt_attn_qkv",
468
+ [
469
+ f"transformer.transformer_blocks.{i}.attn.add_q_proj",
470
+ f"transformer.transformer_blocks.{i}.attn.add_k_proj",
471
+ f"transformer.transformer_blocks.{i}.attn.add_v_proj",
472
+ ],
473
+ )
474
+ _convert_to_ai_toolkit(
475
+ sds_sd,
476
+ ait_sd,
477
+ f"lora_unet_double_blocks_{i}_txt_mlp_0",
478
+ f"transformer.transformer_blocks.{i}.ff_context.net.0.proj",
479
+ )
480
+ _convert_to_ai_toolkit(
481
+ sds_sd,
482
+ ait_sd,
483
+ f"lora_unet_double_blocks_{i}_txt_mlp_2",
484
+ f"transformer.transformer_blocks.{i}.ff_context.net.2",
485
+ )
486
+ _convert_to_ai_toolkit(
487
+ sds_sd,
488
+ ait_sd,
489
+ f"lora_unet_double_blocks_{i}_txt_mod_lin",
490
+ f"transformer.transformer_blocks.{i}.norm1_context.linear",
491
+ )
492
+
493
+ for i in range(38):
494
+ _convert_to_ai_toolkit_cat(
495
+ sds_sd,
496
+ ait_sd,
497
+ f"lora_unet_single_blocks_{i}_linear1",
498
+ [
499
+ f"transformer.single_transformer_blocks.{i}.attn.to_q",
500
+ f"transformer.single_transformer_blocks.{i}.attn.to_k",
501
+ f"transformer.single_transformer_blocks.{i}.attn.to_v",
502
+ f"transformer.single_transformer_blocks.{i}.proj_mlp",
503
+ ],
504
+ dims=[3072, 3072, 3072, 12288],
505
+ )
506
+ _convert_to_ai_toolkit(
507
+ sds_sd,
508
+ ait_sd,
509
+ f"lora_unet_single_blocks_{i}_linear2",
510
+ f"transformer.single_transformer_blocks.{i}.proj_out",
511
+ )
512
+ _convert_to_ai_toolkit(
513
+ sds_sd,
514
+ ait_sd,
515
+ f"lora_unet_single_blocks_{i}_modulation_lin",
516
+ f"transformer.single_transformer_blocks.{i}.norm.linear",
517
+ )
518
+
519
+ remaining_keys = list(sds_sd.keys())
520
+ te_state_dict = {}
521
+ if remaining_keys:
522
+ if not all(k.startswith("lora_te1") for k in remaining_keys):
523
+ raise ValueError(f"Incompatible keys detected: \n\n {', '.join(remaining_keys)}")
524
+ for key in remaining_keys:
525
+ if not key.endswith("lora_down.weight"):
526
+ continue
527
+
528
+ lora_name = key.split(".")[0]
529
+ lora_name_up = f"{lora_name}.lora_up.weight"
530
+ lora_name_alpha = f"{lora_name}.alpha"
531
+ diffusers_name = _convert_text_encoder_lora_key(key, lora_name)
532
+
533
+ if lora_name.startswith(("lora_te_", "lora_te1_")):
534
+ down_weight = sds_sd.pop(key)
535
+ sd_lora_rank = down_weight.shape[0]
536
+ te_state_dict[diffusers_name] = down_weight
537
+ te_state_dict[diffusers_name.replace(".down.", ".up.")] = sds_sd.pop(lora_name_up)
538
+
539
+ if lora_name_alpha in sds_sd:
540
+ alpha = sds_sd.pop(lora_name_alpha).item()
541
+ scale = alpha / sd_lora_rank
542
+
543
+ scale_down = scale
544
+ scale_up = 1.0
545
+ while scale_down * 2 < scale_up:
546
+ scale_down *= 2
547
+ scale_up /= 2
548
+
549
+ te_state_dict[diffusers_name] *= scale_down
550
+ te_state_dict[diffusers_name.replace(".down.", ".up.")] *= scale_up
551
+
552
+ if len(sds_sd) > 0:
553
+ logger.warning(f"Unsupported keys for ai-toolkit: {sds_sd.keys()}")
554
+
555
+ if te_state_dict:
556
+ te_state_dict = {f"text_encoder.{module_name}": params for module_name, params in te_state_dict.items()}
557
+
558
+ new_state_dict = {**ait_sd, **te_state_dict}
559
+ return new_state_dict
560
+
561
+ return _convert_sd_scripts_to_ai_toolkit(state_dict)
562
+
563
+
564
+ # Adapted from https://gist.github.com/Leommm-byte/6b331a1e9bd53271210b26543a7065d6
565
+ # Some utilities were reused from
566
+ # https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
567
+ def _convert_xlabs_flux_lora_to_diffusers(old_state_dict):
568
+ new_state_dict = {}
569
+ orig_keys = list(old_state_dict.keys())
570
+
571
+ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
572
+ down_weight = sds_sd.pop(sds_key)
573
+ up_weight = sds_sd.pop(sds_key.replace(".down.weight", ".up.weight"))
574
+
575
+ # calculate dims if not provided
576
+ num_splits = len(ait_keys)
577
+ if dims is None:
578
+ dims = [up_weight.shape[0] // num_splits] * num_splits
579
+ else:
580
+ assert sum(dims) == up_weight.shape[0]
581
+
582
+ # make ai-toolkit weight
583
+ ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
584
+ ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
585
+
586
+ # down_weight is copied to each split
587
+ ait_sd.update({k: down_weight for k in ait_down_keys})
588
+
589
+ # up_weight is split to each split
590
+ ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
591
+
592
+ for old_key in orig_keys:
593
+ # Handle double_blocks
594
+ if old_key.startswith(("diffusion_model.double_blocks", "double_blocks")):
595
+ block_num = re.search(r"double_blocks\.(\d+)", old_key).group(1)
596
+ new_key = f"transformer.transformer_blocks.{block_num}"
597
+
598
+ if "processor.proj_lora1" in old_key:
599
+ new_key += ".attn.to_out.0"
600
+ elif "processor.proj_lora2" in old_key:
601
+ new_key += ".attn.to_add_out"
602
+ # Handle text latents.
603
+ elif "processor.qkv_lora2" in old_key and "up" not in old_key:
604
+ handle_qkv(
605
+ old_state_dict,
606
+ new_state_dict,
607
+ old_key,
608
+ [
609
+ f"transformer.transformer_blocks.{block_num}.attn.add_q_proj",
610
+ f"transformer.transformer_blocks.{block_num}.attn.add_k_proj",
611
+ f"transformer.transformer_blocks.{block_num}.attn.add_v_proj",
612
+ ],
613
+ )
614
+ # continue
615
+ # Handle image latents.
616
+ elif "processor.qkv_lora1" in old_key and "up" not in old_key:
617
+ handle_qkv(
618
+ old_state_dict,
619
+ new_state_dict,
620
+ old_key,
621
+ [
622
+ f"transformer.transformer_blocks.{block_num}.attn.to_q",
623
+ f"transformer.transformer_blocks.{block_num}.attn.to_k",
624
+ f"transformer.transformer_blocks.{block_num}.attn.to_v",
625
+ ],
626
+ )
627
+ # continue
628
+
629
+ if "down" in old_key:
630
+ new_key += ".lora_A.weight"
631
+ elif "up" in old_key:
632
+ new_key += ".lora_B.weight"
633
+
634
+ # Handle single_blocks
635
+ elif old_key.startswith(("diffusion_model.single_blocks", "single_blocks")):
636
+ block_num = re.search(r"single_blocks\.(\d+)", old_key).group(1)
637
+ new_key = f"transformer.single_transformer_blocks.{block_num}"
638
+
639
+ if "proj_lora" in old_key:
640
+ new_key += ".proj_out"
641
+ elif "qkv_lora" in old_key and "up" not in old_key:
642
+ handle_qkv(
643
+ old_state_dict,
644
+ new_state_dict,
645
+ old_key,
646
+ [
647
+ f"transformer.single_transformer_blocks.{block_num}.attn.to_q",
648
+ f"transformer.single_transformer_blocks.{block_num}.attn.to_k",
649
+ f"transformer.single_transformer_blocks.{block_num}.attn.to_v",
650
+ ],
651
+ )
652
+
653
+ if "down" in old_key:
654
+ new_key += ".lora_A.weight"
655
+ elif "up" in old_key:
656
+ new_key += ".lora_B.weight"
657
+
658
+ else:
659
+ # Handle other potential key patterns here
660
+ new_key = old_key
661
+
662
+ # Since we already handle qkv above.
663
+ if "qkv" not in old_key:
664
+ new_state_dict[new_key] = old_state_dict.pop(old_key)
665
+
666
+ if len(old_state_dict) > 0:
667
+ raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.")
668
+
669
+ return new_state_dict
670
+
671
+
672
+ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
673
+ converted_state_dict = {}
674
+ original_state_dict_keys = list(original_state_dict.keys())
675
+ num_layers = 19
676
+ num_single_layers = 38
677
+ inner_dim = 3072
678
+ mlp_ratio = 4.0
679
+
680
+ def swap_scale_shift(weight):
681
+ shift, scale = weight.chunk(2, dim=0)
682
+ new_weight = torch.cat([scale, shift], dim=0)
683
+ return new_weight
684
+
685
+ for lora_key in ["lora_A", "lora_B"]:
686
+ ## time_text_embed.timestep_embedder <- time_in
687
+ converted_state_dict[
688
+ f"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight"
689
+ ] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.weight")
690
+ if f"time_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
691
+ converted_state_dict[
692
+ f"time_text_embed.timestep_embedder.linear_1.{lora_key}.bias"
693
+ ] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.bias")
694
+
695
+ converted_state_dict[
696
+ f"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight"
697
+ ] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.weight")
698
+ if f"time_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
699
+ converted_state_dict[
700
+ f"time_text_embed.timestep_embedder.linear_2.{lora_key}.bias"
701
+ ] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.bias")
702
+
703
+ ## time_text_embed.text_embedder <- vector_in
704
+ converted_state_dict[f"time_text_embed.text_embedder.linear_1.{lora_key}.weight"] = original_state_dict.pop(
705
+ f"vector_in.in_layer.{lora_key}.weight"
706
+ )
707
+ if f"vector_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
708
+ converted_state_dict[f"time_text_embed.text_embedder.linear_1.{lora_key}.bias"] = original_state_dict.pop(
709
+ f"vector_in.in_layer.{lora_key}.bias"
710
+ )
711
+
712
+ converted_state_dict[f"time_text_embed.text_embedder.linear_2.{lora_key}.weight"] = original_state_dict.pop(
713
+ f"vector_in.out_layer.{lora_key}.weight"
714
+ )
715
+ if f"vector_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
716
+ converted_state_dict[f"time_text_embed.text_embedder.linear_2.{lora_key}.bias"] = original_state_dict.pop(
717
+ f"vector_in.out_layer.{lora_key}.bias"
718
+ )
719
+
720
+ # guidance
721
+ has_guidance = any("guidance" in k for k in original_state_dict)
722
+ if has_guidance:
723
+ converted_state_dict[
724
+ f"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight"
725
+ ] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.weight")
726
+ if f"guidance_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
727
+ converted_state_dict[
728
+ f"time_text_embed.guidance_embedder.linear_1.{lora_key}.bias"
729
+ ] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.bias")
730
+
731
+ converted_state_dict[
732
+ f"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight"
733
+ ] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.weight")
734
+ if f"guidance_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
735
+ converted_state_dict[
736
+ f"time_text_embed.guidance_embedder.linear_2.{lora_key}.bias"
737
+ ] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.bias")
738
+
739
+ # context_embedder
740
+ converted_state_dict[f"context_embedder.{lora_key}.weight"] = original_state_dict.pop(
741
+ f"txt_in.{lora_key}.weight"
742
+ )
743
+ if f"txt_in.{lora_key}.bias" in original_state_dict_keys:
744
+ converted_state_dict[f"context_embedder.{lora_key}.bias"] = original_state_dict.pop(
745
+ f"txt_in.{lora_key}.bias"
746
+ )
747
+
748
+ # x_embedder
749
+ converted_state_dict[f"x_embedder.{lora_key}.weight"] = original_state_dict.pop(f"img_in.{lora_key}.weight")
750
+ if f"img_in.{lora_key}.bias" in original_state_dict_keys:
751
+ converted_state_dict[f"x_embedder.{lora_key}.bias"] = original_state_dict.pop(f"img_in.{lora_key}.bias")
752
+
753
+ # double transformer blocks
754
+ for i in range(num_layers):
755
+ block_prefix = f"transformer_blocks.{i}."
756
+
757
+ for lora_key in ["lora_A", "lora_B"]:
758
+ # norms
759
+ converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.weight"] = original_state_dict.pop(
760
+ f"double_blocks.{i}.img_mod.lin.{lora_key}.weight"
761
+ )
762
+ if f"double_blocks.{i}.img_mod.lin.{lora_key}.bias" in original_state_dict_keys:
763
+ converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.bias"] = original_state_dict.pop(
764
+ f"double_blocks.{i}.img_mod.lin.{lora_key}.bias"
765
+ )
766
+
767
+ converted_state_dict[f"{block_prefix}norm1_context.linear.{lora_key}.weight"] = original_state_dict.pop(
768
+ f"double_blocks.{i}.txt_mod.lin.{lora_key}.weight"
769
+ )
770
+ if f"double_blocks.{i}.txt_mod.lin.{lora_key}.bias" in original_state_dict_keys:
771
+ converted_state_dict[f"{block_prefix}norm1_context.linear.{lora_key}.bias"] = original_state_dict.pop(
772
+ f"double_blocks.{i}.txt_mod.lin.{lora_key}.bias"
773
+ )
774
+
775
+ # Q, K, V
776
+ if lora_key == "lora_A":
777
+ sample_lora_weight = original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.weight")
778
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_lora_weight])
779
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_lora_weight])
780
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_lora_weight])
781
+
782
+ context_lora_weight = original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.weight")
783
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat(
784
+ [context_lora_weight]
785
+ )
786
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat(
787
+ [context_lora_weight]
788
+ )
789
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat(
790
+ [context_lora_weight]
791
+ )
792
+ else:
793
+ sample_q, sample_k, sample_v = torch.chunk(
794
+ original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.weight"), 3, dim=0
795
+ )
796
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_q])
797
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_k])
798
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_v])
799
+
800
+ context_q, context_k, context_v = torch.chunk(
801
+ original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.weight"), 3, dim=0
802
+ )
803
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat([context_q])
804
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat([context_k])
805
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat([context_v])
806
+
807
+ if f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias" in original_state_dict_keys:
808
+ sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
809
+ original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias"), 3, dim=0
810
+ )
811
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([sample_q_bias])
812
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([sample_k_bias])
813
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([sample_v_bias])
814
+
815
+ if f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias" in original_state_dict_keys:
816
+ context_q_bias, context_k_bias, context_v_bias = torch.chunk(
817
+ original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias"), 3, dim=0
818
+ )
819
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.bias"] = torch.cat([context_q_bias])
820
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.bias"] = torch.cat([context_k_bias])
821
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.bias"] = torch.cat([context_v_bias])
822
+
823
+ # ff img_mlp
824
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.weight"] = original_state_dict.pop(
825
+ f"double_blocks.{i}.img_mlp.0.{lora_key}.weight"
826
+ )
827
+ if f"double_blocks.{i}.img_mlp.0.{lora_key}.bias" in original_state_dict_keys:
828
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.bias"] = original_state_dict.pop(
829
+ f"double_blocks.{i}.img_mlp.0.{lora_key}.bias"
830
+ )
831
+
832
+ converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.weight"] = original_state_dict.pop(
833
+ f"double_blocks.{i}.img_mlp.2.{lora_key}.weight"
834
+ )
835
+ if f"double_blocks.{i}.img_mlp.2.{lora_key}.bias" in original_state_dict_keys:
836
+ converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.bias"] = original_state_dict.pop(
837
+ f"double_blocks.{i}.img_mlp.2.{lora_key}.bias"
838
+ )
839
+
840
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.weight"] = original_state_dict.pop(
841
+ f"double_blocks.{i}.txt_mlp.0.{lora_key}.weight"
842
+ )
843
+ if f"double_blocks.{i}.txt_mlp.0.{lora_key}.bias" in original_state_dict_keys:
844
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.bias"] = original_state_dict.pop(
845
+ f"double_blocks.{i}.txt_mlp.0.{lora_key}.bias"
846
+ )
847
+
848
+ converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.weight"] = original_state_dict.pop(
849
+ f"double_blocks.{i}.txt_mlp.2.{lora_key}.weight"
850
+ )
851
+ if f"double_blocks.{i}.txt_mlp.2.{lora_key}.bias" in original_state_dict_keys:
852
+ converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.bias"] = original_state_dict.pop(
853
+ f"double_blocks.{i}.txt_mlp.2.{lora_key}.bias"
854
+ )
855
+
856
+ # output projections.
857
+ converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.weight"] = original_state_dict.pop(
858
+ f"double_blocks.{i}.img_attn.proj.{lora_key}.weight"
859
+ )
860
+ if f"double_blocks.{i}.img_attn.proj.{lora_key}.bias" in original_state_dict_keys:
861
+ converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.bias"] = original_state_dict.pop(
862
+ f"double_blocks.{i}.img_attn.proj.{lora_key}.bias"
863
+ )
864
+ converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.weight"] = original_state_dict.pop(
865
+ f"double_blocks.{i}.txt_attn.proj.{lora_key}.weight"
866
+ )
867
+ if f"double_blocks.{i}.txt_attn.proj.{lora_key}.bias" in original_state_dict_keys:
868
+ converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.bias"] = original_state_dict.pop(
869
+ f"double_blocks.{i}.txt_attn.proj.{lora_key}.bias"
870
+ )
871
+
872
+ # qk_norm
873
+ converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
874
+ f"double_blocks.{i}.img_attn.norm.query_norm.scale"
875
+ )
876
+ converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
877
+ f"double_blocks.{i}.img_attn.norm.key_norm.scale"
878
+ )
879
+ converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = original_state_dict.pop(
880
+ f"double_blocks.{i}.txt_attn.norm.query_norm.scale"
881
+ )
882
+ converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = original_state_dict.pop(
883
+ f"double_blocks.{i}.txt_attn.norm.key_norm.scale"
884
+ )
885
+
886
+ # single transfomer blocks
887
+ for i in range(num_single_layers):
888
+ block_prefix = f"single_transformer_blocks.{i}."
889
+
890
+ for lora_key in ["lora_A", "lora_B"]:
891
+ # norm.linear <- single_blocks.0.modulation.lin
892
+ converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.weight"] = original_state_dict.pop(
893
+ f"single_blocks.{i}.modulation.lin.{lora_key}.weight"
894
+ )
895
+ if f"single_blocks.{i}.modulation.lin.{lora_key}.bias" in original_state_dict_keys:
896
+ converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.bias"] = original_state_dict.pop(
897
+ f"single_blocks.{i}.modulation.lin.{lora_key}.bias"
898
+ )
899
+
900
+ # Q, K, V, mlp
901
+ mlp_hidden_dim = int(inner_dim * mlp_ratio)
902
+ split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
903
+
904
+ if lora_key == "lora_A":
905
+ lora_weight = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.weight")
906
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([lora_weight])
907
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([lora_weight])
908
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([lora_weight])
909
+ converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([lora_weight])
910
+
911
+ if f"single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys:
912
+ lora_bias = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias")
913
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([lora_bias])
914
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([lora_bias])
915
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([lora_bias])
916
+ converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([lora_bias])
917
+ else:
918
+ q, k, v, mlp = torch.split(
919
+ original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.weight"), split_size, dim=0
920
+ )
921
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([q])
922
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([k])
923
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([v])
924
+ converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([mlp])
925
+
926
+ if f"single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys:
927
+ q_bias, k_bias, v_bias, mlp_bias = torch.split(
928
+ original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias"), split_size, dim=0
929
+ )
930
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([q_bias])
931
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([k_bias])
932
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([v_bias])
933
+ converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([mlp_bias])
934
+
935
+ # output projections.
936
+ converted_state_dict[f"{block_prefix}proj_out.{lora_key}.weight"] = original_state_dict.pop(
937
+ f"single_blocks.{i}.linear2.{lora_key}.weight"
938
+ )
939
+ if f"single_blocks.{i}.linear2.{lora_key}.bias" in original_state_dict_keys:
940
+ converted_state_dict[f"{block_prefix}proj_out.{lora_key}.bias"] = original_state_dict.pop(
941
+ f"single_blocks.{i}.linear2.{lora_key}.bias"
942
+ )
943
+
944
+ # qk norm
945
+ converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
946
+ f"single_blocks.{i}.norm.query_norm.scale"
947
+ )
948
+ converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
949
+ f"single_blocks.{i}.norm.key_norm.scale"
950
+ )
951
+
952
+ for lora_key in ["lora_A", "lora_B"]:
953
+ converted_state_dict[f"proj_out.{lora_key}.weight"] = original_state_dict.pop(
954
+ f"final_layer.linear.{lora_key}.weight"
955
+ )
956
+ if f"final_layer.linear.{lora_key}.bias" in original_state_dict_keys:
957
+ converted_state_dict[f"proj_out.{lora_key}.bias"] = original_state_dict.pop(
958
+ f"final_layer.linear.{lora_key}.bias"
959
+ )
960
+
961
+ converted_state_dict[f"norm_out.linear.{lora_key}.weight"] = swap_scale_shift(
962
+ original_state_dict.pop(f"final_layer.adaLN_modulation.1.{lora_key}.weight")
963
+ )
964
+ if f"final_layer.adaLN_modulation.1.{lora_key}.bias" in original_state_dict_keys:
965
+ converted_state_dict[f"norm_out.linear.{lora_key}.bias"] = swap_scale_shift(
966
+ original_state_dict.pop(f"final_layer.adaLN_modulation.1.{lora_key}.bias")
967
+ )
968
+
969
+ if len(original_state_dict) > 0:
970
+ raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.")
971
+
972
+ for key in list(converted_state_dict.keys()):
973
+ converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
974
+
975
+ return converted_state_dict
976
+
977
+
978
+ def _convert_hunyuan_video_lora_to_diffusers(original_state_dict):
979
+ converted_state_dict = {k: original_state_dict.pop(k) for k in list(original_state_dict.keys())}
980
+
981
+ def remap_norm_scale_shift_(key, state_dict):
982
+ weight = state_dict.pop(key)
983
+ shift, scale = weight.chunk(2, dim=0)
984
+ new_weight = torch.cat([scale, shift], dim=0)
985
+ state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight
986
+
987
+ def remap_txt_in_(key, state_dict):
988
+ def rename_key(key):
989
+ new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks")
990
+ new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear")
991
+ new_key = new_key.replace("txt_in", "context_embedder")
992
+ new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1")
993
+ new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2")
994
+ new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder")
995
+ new_key = new_key.replace("mlp", "ff")
996
+ return new_key
997
+
998
+ if "self_attn_qkv" in key:
999
+ weight = state_dict.pop(key)
1000
+ to_q, to_k, to_v = weight.chunk(3, dim=0)
1001
+ state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q
1002
+ state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k
1003
+ state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v
1004
+ else:
1005
+ state_dict[rename_key(key)] = state_dict.pop(key)
1006
+
1007
+ def remap_img_attn_qkv_(key, state_dict):
1008
+ weight = state_dict.pop(key)
1009
+ if "lora_A" in key:
1010
+ state_dict[key.replace("img_attn_qkv", "attn.to_q")] = weight
1011
+ state_dict[key.replace("img_attn_qkv", "attn.to_k")] = weight
1012
+ state_dict[key.replace("img_attn_qkv", "attn.to_v")] = weight
1013
+ else:
1014
+ to_q, to_k, to_v = weight.chunk(3, dim=0)
1015
+ state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q
1016
+ state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k
1017
+ state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v
1018
+
1019
+ def remap_txt_attn_qkv_(key, state_dict):
1020
+ weight = state_dict.pop(key)
1021
+ if "lora_A" in key:
1022
+ state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = weight
1023
+ state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = weight
1024
+ state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = weight
1025
+ else:
1026
+ to_q, to_k, to_v = weight.chunk(3, dim=0)
1027
+ state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q
1028
+ state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k
1029
+ state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v
1030
+
1031
+ def remap_single_transformer_blocks_(key, state_dict):
1032
+ hidden_size = 3072
1033
+
1034
+ if "linear1.lora_A.weight" in key or "linear1.lora_B.weight" in key:
1035
+ linear1_weight = state_dict.pop(key)
1036
+ if "lora_A" in key:
1037
+ new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
1038
+ ".linear1.lora_A.weight"
1039
+ )
1040
+ state_dict[f"{new_key}.attn.to_q.lora_A.weight"] = linear1_weight
1041
+ state_dict[f"{new_key}.attn.to_k.lora_A.weight"] = linear1_weight
1042
+ state_dict[f"{new_key}.attn.to_v.lora_A.weight"] = linear1_weight
1043
+ state_dict[f"{new_key}.proj_mlp.lora_A.weight"] = linear1_weight
1044
+ else:
1045
+ split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size)
1046
+ q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0)
1047
+ new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
1048
+ ".linear1.lora_B.weight"
1049
+ )
1050
+ state_dict[f"{new_key}.attn.to_q.lora_B.weight"] = q
1051
+ state_dict[f"{new_key}.attn.to_k.lora_B.weight"] = k
1052
+ state_dict[f"{new_key}.attn.to_v.lora_B.weight"] = v
1053
+ state_dict[f"{new_key}.proj_mlp.lora_B.weight"] = mlp
1054
+
1055
+ elif "linear1.lora_A.bias" in key or "linear1.lora_B.bias" in key:
1056
+ linear1_bias = state_dict.pop(key)
1057
+ if "lora_A" in key:
1058
+ new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
1059
+ ".linear1.lora_A.bias"
1060
+ )
1061
+ state_dict[f"{new_key}.attn.to_q.lora_A.bias"] = linear1_bias
1062
+ state_dict[f"{new_key}.attn.to_k.lora_A.bias"] = linear1_bias
1063
+ state_dict[f"{new_key}.attn.to_v.lora_A.bias"] = linear1_bias
1064
+ state_dict[f"{new_key}.proj_mlp.lora_A.bias"] = linear1_bias
1065
+ else:
1066
+ split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size)
1067
+ q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0)
1068
+ new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
1069
+ ".linear1.lora_B.bias"
1070
+ )
1071
+ state_dict[f"{new_key}.attn.to_q.lora_B.bias"] = q_bias
1072
+ state_dict[f"{new_key}.attn.to_k.lora_B.bias"] = k_bias
1073
+ state_dict[f"{new_key}.attn.to_v.lora_B.bias"] = v_bias
1074
+ state_dict[f"{new_key}.proj_mlp.lora_B.bias"] = mlp_bias
1075
+
1076
+ else:
1077
+ new_key = key.replace("single_blocks", "single_transformer_blocks")
1078
+ new_key = new_key.replace("linear2", "proj_out")
1079
+ new_key = new_key.replace("q_norm", "attn.norm_q")
1080
+ new_key = new_key.replace("k_norm", "attn.norm_k")
1081
+ state_dict[new_key] = state_dict.pop(key)
1082
+
1083
+ TRANSFORMER_KEYS_RENAME_DICT = {
1084
+ "img_in": "x_embedder",
1085
+ "time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1",
1086
+ "time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2",
1087
+ "guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1",
1088
+ "guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2",
1089
+ "vector_in.in_layer": "time_text_embed.text_embedder.linear_1",
1090
+ "vector_in.out_layer": "time_text_embed.text_embedder.linear_2",
1091
+ "double_blocks": "transformer_blocks",
1092
+ "img_attn_q_norm": "attn.norm_q",
1093
+ "img_attn_k_norm": "attn.norm_k",
1094
+ "img_attn_proj": "attn.to_out.0",
1095
+ "txt_attn_q_norm": "attn.norm_added_q",
1096
+ "txt_attn_k_norm": "attn.norm_added_k",
1097
+ "txt_attn_proj": "attn.to_add_out",
1098
+ "img_mod.linear": "norm1.linear",
1099
+ "img_norm1": "norm1.norm",
1100
+ "img_norm2": "norm2",
1101
+ "img_mlp": "ff",
1102
+ "txt_mod.linear": "norm1_context.linear",
1103
+ "txt_norm1": "norm1.norm",
1104
+ "txt_norm2": "norm2_context",
1105
+ "txt_mlp": "ff_context",
1106
+ "self_attn_proj": "attn.to_out.0",
1107
+ "modulation.linear": "norm.linear",
1108
+ "pre_norm": "norm.norm",
1109
+ "final_layer.norm_final": "norm_out.norm",
1110
+ "final_layer.linear": "proj_out",
1111
+ "fc1": "net.0.proj",
1112
+ "fc2": "net.2",
1113
+ "input_embedder": "proj_in",
1114
+ }
1115
+
1116
+ TRANSFORMER_SPECIAL_KEYS_REMAP = {
1117
+ "txt_in": remap_txt_in_,
1118
+ "img_attn_qkv": remap_img_attn_qkv_,
1119
+ "txt_attn_qkv": remap_txt_attn_qkv_,
1120
+ "single_blocks": remap_single_transformer_blocks_,
1121
+ "final_layer.adaLN_modulation.1": remap_norm_scale_shift_,
1122
+ }
1123
+
1124
+ # Some folks attempt to make their state dict compatible with diffusers by adding "transformer." prefix to all keys
1125
+ # and use their custom code. To make sure both "original" and "attempted diffusers" loras work as expected, we make
1126
+ # sure that both follow the same initial format by stripping off the "transformer." prefix.
1127
+ for key in list(converted_state_dict.keys()):
1128
+ if key.startswith("transformer."):
1129
+ converted_state_dict[key[len("transformer.") :]] = converted_state_dict.pop(key)
1130
+ if key.startswith("diffusion_model."):
1131
+ converted_state_dict[key[len("diffusion_model.") :]] = converted_state_dict.pop(key)
1132
+
1133
+ # Rename and remap the state dict keys
1134
+ for key in list(converted_state_dict.keys()):
1135
+ new_key = key[:]
1136
+ for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
1137
+ new_key = new_key.replace(replace_key, rename_key)
1138
+ converted_state_dict[new_key] = converted_state_dict.pop(key)
1139
+
1140
+ for key in list(converted_state_dict.keys()):
1141
+ for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
1142
+ if special_key not in key:
1143
+ continue
1144
+ handler_fn_inplace(key, converted_state_dict)
1145
+
1146
+ # Add back the "transformer." prefix
1147
+ for key in list(converted_state_dict.keys()):
1148
+ converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
1149
+
1150
+ return converted_state_dict