diffusers 0.32.1__py3-none-any.whl → 0.33.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (389) hide show
  1. diffusers/__init__.py +186 -3
  2. diffusers/configuration_utils.py +40 -12
  3. diffusers/dependency_versions_table.py +9 -2
  4. diffusers/hooks/__init__.py +9 -0
  5. diffusers/hooks/faster_cache.py +653 -0
  6. diffusers/hooks/group_offloading.py +793 -0
  7. diffusers/hooks/hooks.py +236 -0
  8. diffusers/hooks/layerwise_casting.py +245 -0
  9. diffusers/hooks/pyramid_attention_broadcast.py +311 -0
  10. diffusers/loaders/__init__.py +6 -0
  11. diffusers/loaders/ip_adapter.py +38 -30
  12. diffusers/loaders/lora_base.py +198 -28
  13. diffusers/loaders/lora_conversion_utils.py +679 -44
  14. diffusers/loaders/lora_pipeline.py +1963 -801
  15. diffusers/loaders/peft.py +169 -84
  16. diffusers/loaders/single_file.py +17 -2
  17. diffusers/loaders/single_file_model.py +53 -5
  18. diffusers/loaders/single_file_utils.py +653 -75
  19. diffusers/loaders/textual_inversion.py +9 -9
  20. diffusers/loaders/transformer_flux.py +8 -9
  21. diffusers/loaders/transformer_sd3.py +120 -39
  22. diffusers/loaders/unet.py +22 -32
  23. diffusers/models/__init__.py +22 -0
  24. diffusers/models/activations.py +9 -9
  25. diffusers/models/attention.py +0 -1
  26. diffusers/models/attention_processor.py +163 -25
  27. diffusers/models/auto_model.py +169 -0
  28. diffusers/models/autoencoders/__init__.py +2 -0
  29. diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
  30. diffusers/models/autoencoders/autoencoder_dc.py +106 -4
  31. diffusers/models/autoencoders/autoencoder_kl.py +0 -4
  32. diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
  33. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
  34. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
  35. diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
  36. diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
  37. diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
  38. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
  39. diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
  40. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
  41. diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
  42. diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
  43. diffusers/models/autoencoders/vae.py +31 -141
  44. diffusers/models/autoencoders/vq_model.py +3 -0
  45. diffusers/models/cache_utils.py +108 -0
  46. diffusers/models/controlnets/__init__.py +1 -0
  47. diffusers/models/controlnets/controlnet.py +3 -8
  48. diffusers/models/controlnets/controlnet_flux.py +14 -42
  49. diffusers/models/controlnets/controlnet_sd3.py +58 -34
  50. diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
  51. diffusers/models/controlnets/controlnet_union.py +27 -18
  52. diffusers/models/controlnets/controlnet_xs.py +7 -46
  53. diffusers/models/controlnets/multicontrolnet_union.py +196 -0
  54. diffusers/models/embeddings.py +18 -7
  55. diffusers/models/model_loading_utils.py +122 -80
  56. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  57. diffusers/models/modeling_flax_utils.py +1 -1
  58. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  59. diffusers/models/modeling_utils.py +617 -272
  60. diffusers/models/normalization.py +67 -14
  61. diffusers/models/resnet.py +1 -1
  62. diffusers/models/transformers/__init__.py +6 -0
  63. diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
  64. diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
  65. diffusers/models/transformers/consisid_transformer_3d.py +789 -0
  66. diffusers/models/transformers/dit_transformer_2d.py +5 -19
  67. diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
  68. diffusers/models/transformers/latte_transformer_3d.py +20 -15
  69. diffusers/models/transformers/lumina_nextdit2d.py +3 -1
  70. diffusers/models/transformers/pixart_transformer_2d.py +4 -19
  71. diffusers/models/transformers/prior_transformer.py +5 -1
  72. diffusers/models/transformers/sana_transformer.py +144 -40
  73. diffusers/models/transformers/stable_audio_transformer.py +5 -20
  74. diffusers/models/transformers/transformer_2d.py +7 -22
  75. diffusers/models/transformers/transformer_allegro.py +9 -17
  76. diffusers/models/transformers/transformer_cogview3plus.py +6 -17
  77. diffusers/models/transformers/transformer_cogview4.py +462 -0
  78. diffusers/models/transformers/transformer_easyanimate.py +527 -0
  79. diffusers/models/transformers/transformer_flux.py +68 -110
  80. diffusers/models/transformers/transformer_hunyuan_video.py +409 -49
  81. diffusers/models/transformers/transformer_ltx.py +53 -35
  82. diffusers/models/transformers/transformer_lumina2.py +548 -0
  83. diffusers/models/transformers/transformer_mochi.py +6 -17
  84. diffusers/models/transformers/transformer_omnigen.py +469 -0
  85. diffusers/models/transformers/transformer_sd3.py +56 -86
  86. diffusers/models/transformers/transformer_temporal.py +5 -11
  87. diffusers/models/transformers/transformer_wan.py +469 -0
  88. diffusers/models/unets/unet_1d.py +3 -1
  89. diffusers/models/unets/unet_2d.py +21 -20
  90. diffusers/models/unets/unet_2d_blocks.py +19 -243
  91. diffusers/models/unets/unet_2d_condition.py +4 -6
  92. diffusers/models/unets/unet_3d_blocks.py +14 -127
  93. diffusers/models/unets/unet_3d_condition.py +8 -12
  94. diffusers/models/unets/unet_i2vgen_xl.py +5 -13
  95. diffusers/models/unets/unet_kandinsky3.py +0 -4
  96. diffusers/models/unets/unet_motion_model.py +20 -114
  97. diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
  98. diffusers/models/unets/unet_stable_cascade.py +8 -35
  99. diffusers/models/unets/uvit_2d.py +1 -4
  100. diffusers/optimization.py +2 -2
  101. diffusers/pipelines/__init__.py +57 -8
  102. diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
  103. diffusers/pipelines/amused/pipeline_amused.py +15 -2
  104. diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
  105. diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
  106. diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
  107. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
  108. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
  109. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
  110. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
  111. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
  112. diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
  113. diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
  114. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
  115. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
  116. diffusers/pipelines/auto_pipeline.py +35 -14
  117. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  118. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
  119. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
  120. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
  121. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
  122. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
  123. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
  124. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
  125. diffusers/pipelines/cogview4/__init__.py +49 -0
  126. diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
  127. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
  128. diffusers/pipelines/cogview4/pipeline_output.py +21 -0
  129. diffusers/pipelines/consisid/__init__.py +49 -0
  130. diffusers/pipelines/consisid/consisid_utils.py +357 -0
  131. diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
  132. diffusers/pipelines/consisid/pipeline_output.py +20 -0
  133. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
  134. diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
  135. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
  136. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
  137. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
  138. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
  139. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
  140. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
  141. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
  142. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
  143. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
  144. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
  145. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
  146. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
  147. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
  148. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
  149. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
  150. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
  151. diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
  152. diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
  153. diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
  154. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
  155. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
  156. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
  157. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
  158. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
  159. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
  160. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
  161. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  162. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
  163. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
  164. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
  165. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
  166. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
  167. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
  168. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
  169. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
  170. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  171. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  172. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
  173. diffusers/pipelines/dit/pipeline_dit.py +15 -2
  174. diffusers/pipelines/easyanimate/__init__.py +52 -0
  175. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
  176. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
  177. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
  178. diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
  179. diffusers/pipelines/flux/pipeline_flux.py +53 -21
  180. diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
  181. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
  182. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
  183. diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
  184. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
  185. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
  186. diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
  187. diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
  188. diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
  189. diffusers/pipelines/free_noise_utils.py +3 -3
  190. diffusers/pipelines/hunyuan_video/__init__.py +4 -0
  191. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
  192. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
  193. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
  194. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
  195. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
  196. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
  197. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  198. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
  199. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
  200. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
  201. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
  202. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
  203. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
  204. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
  205. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
  206. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
  207. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
  208. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
  209. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
  210. diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
  211. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
  212. diffusers/pipelines/kolors/text_encoder.py +7 -34
  213. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
  214. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
  215. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
  216. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
  217. diffusers/pipelines/latte/pipeline_latte.py +36 -7
  218. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
  219. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
  220. diffusers/pipelines/ltx/__init__.py +2 -0
  221. diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
  222. diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
  223. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
  224. diffusers/pipelines/lumina/__init__.py +2 -2
  225. diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
  226. diffusers/pipelines/lumina2/__init__.py +48 -0
  227. diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
  228. diffusers/pipelines/marigold/__init__.py +2 -0
  229. diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
  230. diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
  231. diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
  232. diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
  233. diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
  234. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
  235. diffusers/pipelines/omnigen/__init__.py +50 -0
  236. diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
  237. diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
  238. diffusers/pipelines/onnx_utils.py +5 -3
  239. diffusers/pipelines/pag/pag_utils.py +1 -1
  240. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
  241. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
  242. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
  243. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
  244. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
  245. diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
  246. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
  247. diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
  248. diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
  249. diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
  250. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
  251. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
  252. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
  253. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
  254. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
  255. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
  256. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
  257. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
  258. diffusers/pipelines/pia/pipeline_pia.py +13 -1
  259. diffusers/pipelines/pipeline_flax_utils.py +7 -7
  260. diffusers/pipelines/pipeline_loading_utils.py +193 -83
  261. diffusers/pipelines/pipeline_utils.py +221 -106
  262. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
  263. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
  264. diffusers/pipelines/sana/__init__.py +2 -0
  265. diffusers/pipelines/sana/pipeline_sana.py +183 -58
  266. diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
  267. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
  268. diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
  269. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
  270. diffusers/pipelines/shap_e/renderer.py +6 -6
  271. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
  272. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
  273. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
  274. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
  275. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
  276. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
  277. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
  278. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
  279. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  280. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
  281. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
  282. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
  283. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
  284. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
  285. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
  286. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
  287. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
  288. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
  289. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
  290. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
  291. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
  292. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
  293. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
  294. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
  295. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
  296. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
  297. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
  298. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
  299. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
  300. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
  301. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
  302. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
  303. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
  304. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
  305. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
  306. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  307. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
  308. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
  309. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
  310. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
  311. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
  312. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
  313. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
  314. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
  315. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
  316. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
  317. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
  318. diffusers/pipelines/transformers_loading_utils.py +121 -0
  319. diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
  320. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
  321. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
  322. diffusers/pipelines/wan/__init__.py +51 -0
  323. diffusers/pipelines/wan/pipeline_output.py +20 -0
  324. diffusers/pipelines/wan/pipeline_wan.py +593 -0
  325. diffusers/pipelines/wan/pipeline_wan_i2v.py +722 -0
  326. diffusers/pipelines/wan/pipeline_wan_video2video.py +725 -0
  327. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
  328. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
  329. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
  330. diffusers/quantizers/auto.py +5 -1
  331. diffusers/quantizers/base.py +5 -9
  332. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
  333. diffusers/quantizers/bitsandbytes/utils.py +30 -20
  334. diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
  335. diffusers/quantizers/gguf/utils.py +4 -2
  336. diffusers/quantizers/quantization_config.py +59 -4
  337. diffusers/quantizers/quanto/__init__.py +1 -0
  338. diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
  339. diffusers/quantizers/quanto/utils.py +60 -0
  340. diffusers/quantizers/torchao/__init__.py +1 -1
  341. diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
  342. diffusers/schedulers/__init__.py +2 -1
  343. diffusers/schedulers/scheduling_consistency_models.py +1 -2
  344. diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
  345. diffusers/schedulers/scheduling_ddpm.py +2 -3
  346. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
  347. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
  348. diffusers/schedulers/scheduling_edm_euler.py +45 -10
  349. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
  350. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
  351. diffusers/schedulers/scheduling_heun_discrete.py +1 -1
  352. diffusers/schedulers/scheduling_lcm.py +1 -2
  353. diffusers/schedulers/scheduling_lms_discrete.py +1 -1
  354. diffusers/schedulers/scheduling_repaint.py +5 -1
  355. diffusers/schedulers/scheduling_scm.py +265 -0
  356. diffusers/schedulers/scheduling_tcd.py +1 -2
  357. diffusers/schedulers/scheduling_utils.py +2 -1
  358. diffusers/training_utils.py +14 -7
  359. diffusers/utils/__init__.py +10 -2
  360. diffusers/utils/constants.py +13 -1
  361. diffusers/utils/deprecation_utils.py +1 -1
  362. diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
  363. diffusers/utils/dummy_gguf_objects.py +17 -0
  364. diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
  365. diffusers/utils/dummy_pt_objects.py +233 -0
  366. diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
  367. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  368. diffusers/utils/dummy_torchao_objects.py +17 -0
  369. diffusers/utils/dynamic_modules_utils.py +1 -1
  370. diffusers/utils/export_utils.py +28 -3
  371. diffusers/utils/hub_utils.py +52 -102
  372. diffusers/utils/import_utils.py +121 -221
  373. diffusers/utils/loading_utils.py +14 -1
  374. diffusers/utils/logging.py +1 -2
  375. diffusers/utils/peft_utils.py +6 -14
  376. diffusers/utils/remote_utils.py +425 -0
  377. diffusers/utils/source_code_parsing_utils.py +52 -0
  378. diffusers/utils/state_dict_utils.py +15 -1
  379. diffusers/utils/testing_utils.py +243 -13
  380. diffusers/utils/torch_utils.py +10 -0
  381. diffusers/utils/typing_utils.py +91 -0
  382. diffusers/video_processor.py +1 -1
  383. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/METADATA +76 -44
  384. diffusers-0.33.0.dist-info/RECORD +608 -0
  385. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/WHEEL +1 -1
  386. diffusers-0.32.1.dist-info/RECORD +0 -550
  387. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/LICENSE +0 -0
  388. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/entry_points.txt +0 -0
  389. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/top_level.txt +0 -0
@@ -13,15 +13,22 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import re
16
+ from typing import List
16
17
 
17
18
  import torch
18
19
 
19
- from ..utils import is_peft_version, logging
20
+ from ..utils import is_peft_version, logging, state_dict_all_zero
20
21
 
21
22
 
22
23
  logger = logging.get_logger(__name__)
23
24
 
24
25
 
26
+ def swap_scale_shift(weight):
27
+ shift, scale = weight.chunk(2, dim=0)
28
+ new_weight = torch.cat([scale, shift], dim=0)
29
+ return new_weight
30
+
31
+
25
32
  def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", block_slice_pos=5):
26
33
  # 1. get all state_dict_keys
27
34
  all_keys = list(state_dict.keys())
@@ -177,9 +184,9 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
177
184
  # Store DoRA scale if present.
178
185
  if dora_present_in_unet:
179
186
  dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down."
180
- unet_state_dict[
181
- diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")
182
- ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
187
+ unet_state_dict[diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")] = (
188
+ state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
189
+ )
183
190
 
184
191
  # Handle text encoder LoRAs.
185
192
  elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
@@ -199,13 +206,13 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
199
206
  "_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer."
200
207
  )
201
208
  if lora_name.startswith(("lora_te_", "lora_te1_")):
202
- te_state_dict[
203
- diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
204
- ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
209
+ te_state_dict[diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")] = (
210
+ state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
211
+ )
205
212
  elif lora_name.startswith("lora_te2_"):
206
- te2_state_dict[
207
- diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
208
- ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
213
+ te2_state_dict[diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")] = (
214
+ state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
215
+ )
209
216
 
210
217
  # Store alpha if present.
211
218
  if lora_name_alpha in state_dict:
@@ -313,6 +320,7 @@ def _convert_text_encoder_lora_key(key, lora_name):
313
320
  # Be aware that this is the new diffusers convention and the rest of the code might
314
321
  # not utilize it yet.
315
322
  diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
323
+
316
324
  return diffusers_name
317
325
 
318
326
 
@@ -331,8 +339,7 @@ def _get_alpha_name(lora_name_alpha, diffusers_name, alpha):
331
339
 
332
340
 
333
341
  # The utilities under `_convert_kohya_flux_lora_to_diffusers()`
334
- # are taken from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
335
- # All credits go to `kohya-ss`.
342
+ # are adapted from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
336
343
  def _convert_kohya_flux_lora_to_diffusers(state_dict):
337
344
  def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
338
345
  if sds_key + ".lora_down.weight" not in sds_sd:
@@ -341,7 +348,8 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
341
348
 
342
349
  # scale weight by alpha and dim
343
350
  rank = down_weight.shape[0]
344
- alpha = sds_sd.pop(sds_key + ".alpha").item() # alpha is scalar
351
+ default_alpha = torch.tensor(rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False)
352
+ alpha = sds_sd.pop(sds_key + ".alpha", default_alpha).item() # alpha is scalar
345
353
  scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
346
354
 
347
355
  # calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2
@@ -362,7 +370,10 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
362
370
  sd_lora_rank = down_weight.shape[0]
363
371
 
364
372
  # scale weight by alpha and dim
365
- alpha = sds_sd.pop(sds_key + ".alpha")
373
+ default_alpha = torch.tensor(
374
+ sd_lora_rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False
375
+ )
376
+ alpha = sds_sd.pop(sds_key + ".alpha", default_alpha)
366
377
  scale = alpha / sd_lora_rank
367
378
 
368
379
  # calculate scale_down and scale_up
@@ -516,10 +527,103 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
516
527
  f"transformer.single_transformer_blocks.{i}.norm.linear",
517
528
  )
518
529
 
530
+ # TODO: alphas.
531
+ def assign_remaining_weights(assignments, source):
532
+ for lora_key in ["lora_A", "lora_B"]:
533
+ orig_lora_key = "lora_down" if lora_key == "lora_A" else "lora_up"
534
+ for target_fmt, source_fmt, transform in assignments:
535
+ target_key = target_fmt.format(lora_key=lora_key)
536
+ source_key = source_fmt.format(orig_lora_key=orig_lora_key)
537
+ value = source.pop(source_key)
538
+ if transform:
539
+ value = transform(value)
540
+ ait_sd[target_key] = value
541
+
542
+ if any("guidance_in" in k for k in sds_sd):
543
+ assign_remaining_weights(
544
+ [
545
+ (
546
+ "time_text_embed.guidance_embedder.linear_1.{lora_key}.weight",
547
+ "lora_unet_guidance_in_in_layer.{orig_lora_key}.weight",
548
+ None,
549
+ ),
550
+ (
551
+ "time_text_embed.guidance_embedder.linear_2.{lora_key}.weight",
552
+ "lora_unet_guidance_in_out_layer.{orig_lora_key}.weight",
553
+ None,
554
+ ),
555
+ ],
556
+ sds_sd,
557
+ )
558
+
559
+ if any("img_in" in k for k in sds_sd):
560
+ assign_remaining_weights(
561
+ [
562
+ ("x_embedder.{lora_key}.weight", "lora_unet_img_in.{orig_lora_key}.weight", None),
563
+ ],
564
+ sds_sd,
565
+ )
566
+
567
+ if any("txt_in" in k for k in sds_sd):
568
+ assign_remaining_weights(
569
+ [
570
+ ("context_embedder.{lora_key}.weight", "lora_unet_txt_in.{orig_lora_key}.weight", None),
571
+ ],
572
+ sds_sd,
573
+ )
574
+
575
+ if any("time_in" in k for k in sds_sd):
576
+ assign_remaining_weights(
577
+ [
578
+ (
579
+ "time_text_embed.timestep_embedder.linear_1.{lora_key}.weight",
580
+ "lora_unet_time_in_in_layer.{orig_lora_key}.weight",
581
+ None,
582
+ ),
583
+ (
584
+ "time_text_embed.timestep_embedder.linear_2.{lora_key}.weight",
585
+ "lora_unet_time_in_out_layer.{orig_lora_key}.weight",
586
+ None,
587
+ ),
588
+ ],
589
+ sds_sd,
590
+ )
591
+
592
+ if any("vector_in" in k for k in sds_sd):
593
+ assign_remaining_weights(
594
+ [
595
+ (
596
+ "time_text_embed.text_embedder.linear_1.{lora_key}.weight",
597
+ "lora_unet_vector_in_in_layer.{orig_lora_key}.weight",
598
+ None,
599
+ ),
600
+ (
601
+ "time_text_embed.text_embedder.linear_2.{lora_key}.weight",
602
+ "lora_unet_vector_in_out_layer.{orig_lora_key}.weight",
603
+ None,
604
+ ),
605
+ ],
606
+ sds_sd,
607
+ )
608
+
609
+ if any("final_layer" in k for k in sds_sd):
610
+ # Notice the swap in processing for "final_layer".
611
+ assign_remaining_weights(
612
+ [
613
+ (
614
+ "norm_out.linear.{lora_key}.weight",
615
+ "lora_unet_final_layer_adaLN_modulation_1.{orig_lora_key}.weight",
616
+ swap_scale_shift,
617
+ ),
618
+ ("proj_out.{lora_key}.weight", "lora_unet_final_layer_linear.{orig_lora_key}.weight", None),
619
+ ],
620
+ sds_sd,
621
+ )
622
+
519
623
  remaining_keys = list(sds_sd.keys())
520
624
  te_state_dict = {}
521
625
  if remaining_keys:
522
- if not all(k.startswith("lora_te1") for k in remaining_keys):
626
+ if not all(k.startswith(("lora_te", "lora_te1")) for k in remaining_keys):
523
627
  raise ValueError(f"Incompatible keys detected: \n\n {', '.join(remaining_keys)}")
524
628
  for key in remaining_keys:
525
629
  if not key.endswith("lora_down.weight"):
@@ -558,6 +662,223 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
558
662
  new_state_dict = {**ait_sd, **te_state_dict}
559
663
  return new_state_dict
560
664
 
665
+ def _convert_mixture_state_dict_to_diffusers(state_dict):
666
+ new_state_dict = {}
667
+
668
+ def _convert(original_key, diffusers_key, state_dict, new_state_dict):
669
+ down_key = f"{original_key}.lora_down.weight"
670
+ down_weight = state_dict.pop(down_key)
671
+ lora_rank = down_weight.shape[0]
672
+
673
+ up_weight_key = f"{original_key}.lora_up.weight"
674
+ up_weight = state_dict.pop(up_weight_key)
675
+
676
+ alpha_key = f"{original_key}.alpha"
677
+ alpha = state_dict.pop(alpha_key)
678
+
679
+ # scale weight by alpha and dim
680
+ scale = alpha / lora_rank
681
+ # calculate scale_down and scale_up
682
+ scale_down = scale
683
+ scale_up = 1.0
684
+ while scale_down * 2 < scale_up:
685
+ scale_down *= 2
686
+ scale_up /= 2
687
+ down_weight = down_weight * scale_down
688
+ up_weight = up_weight * scale_up
689
+
690
+ diffusers_down_key = f"{diffusers_key}.lora_A.weight"
691
+ new_state_dict[diffusers_down_key] = down_weight
692
+ new_state_dict[diffusers_down_key.replace(".lora_A.", ".lora_B.")] = up_weight
693
+
694
+ all_unique_keys = {
695
+ k.replace(".lora_down.weight", "").replace(".lora_up.weight", "").replace(".alpha", "")
696
+ for k in state_dict
697
+ if not k.startswith(("lora_unet_"))
698
+ }
699
+ assert all(k.startswith(("lora_transformer_", "lora_te1_")) for k in all_unique_keys), f"{all_unique_keys=}"
700
+
701
+ has_te_keys = False
702
+ for k in all_unique_keys:
703
+ if k.startswith("lora_transformer_single_transformer_blocks_"):
704
+ i = int(k.split("lora_transformer_single_transformer_blocks_")[-1].split("_")[0])
705
+ diffusers_key = f"single_transformer_blocks.{i}"
706
+ elif k.startswith("lora_transformer_transformer_blocks_"):
707
+ i = int(k.split("lora_transformer_transformer_blocks_")[-1].split("_")[0])
708
+ diffusers_key = f"transformer_blocks.{i}"
709
+ elif k.startswith("lora_te1_"):
710
+ has_te_keys = True
711
+ continue
712
+ else:
713
+ raise NotImplementedError
714
+
715
+ if "attn_" in k:
716
+ if "_to_out_0" in k:
717
+ diffusers_key += ".attn.to_out.0"
718
+ elif "_to_add_out" in k:
719
+ diffusers_key += ".attn.to_add_out"
720
+ elif any(qkv in k for qkv in ["to_q", "to_k", "to_v"]):
721
+ remaining = k.split("attn_")[-1]
722
+ diffusers_key += f".attn.{remaining}"
723
+ elif any(add_qkv in k for add_qkv in ["add_q_proj", "add_k_proj", "add_v_proj"]):
724
+ remaining = k.split("attn_")[-1]
725
+ diffusers_key += f".attn.{remaining}"
726
+
727
+ _convert(k, diffusers_key, state_dict, new_state_dict)
728
+
729
+ if has_te_keys:
730
+ layer_pattern = re.compile(r"lora_te1_text_model_encoder_layers_(\d+)")
731
+ attn_mapping = {
732
+ "q_proj": ".self_attn.q_proj",
733
+ "k_proj": ".self_attn.k_proj",
734
+ "v_proj": ".self_attn.v_proj",
735
+ "out_proj": ".self_attn.out_proj",
736
+ }
737
+ mlp_mapping = {"fc1": ".mlp.fc1", "fc2": ".mlp.fc2"}
738
+ for k in all_unique_keys:
739
+ if not k.startswith("lora_te1_"):
740
+ continue
741
+
742
+ match = layer_pattern.search(k)
743
+ if not match:
744
+ continue
745
+ i = int(match.group(1))
746
+ diffusers_key = f"text_model.encoder.layers.{i}"
747
+
748
+ if "attn" in k:
749
+ for key_fragment, suffix in attn_mapping.items():
750
+ if key_fragment in k:
751
+ diffusers_key += suffix
752
+ break
753
+ elif "mlp" in k:
754
+ for key_fragment, suffix in mlp_mapping.items():
755
+ if key_fragment in k:
756
+ diffusers_key += suffix
757
+ break
758
+
759
+ _convert(k, diffusers_key, state_dict, new_state_dict)
760
+
761
+ remaining_all_unet = False
762
+ if state_dict:
763
+ remaining_all_unet = all(k.startswith("lora_unet_") for k in state_dict)
764
+ if remaining_all_unet:
765
+ keys = list(state_dict.keys())
766
+ for k in keys:
767
+ state_dict.pop(k)
768
+
769
+ if len(state_dict) > 0:
770
+ raise ValueError(
771
+ f"Expected an empty state dict at this point but its has these keys which couldn't be parsed: {list(state_dict.keys())}."
772
+ )
773
+
774
+ transformer_state_dict = {
775
+ f"transformer.{k}": v for k, v in new_state_dict.items() if not k.startswith("text_model.")
776
+ }
777
+ te_state_dict = {f"text_encoder.{k}": v for k, v in new_state_dict.items() if k.startswith("text_model.")}
778
+ return {**transformer_state_dict, **te_state_dict}
779
+
780
+ # This is weird.
781
+ # https://huggingface.co/sayakpaul/different-lora-from-civitai/tree/main?show_file_info=sharp_detailed_foot.safetensors
782
+ # has both `peft` and non-peft state dict.
783
+ has_peft_state_dict = any(k.startswith("transformer.") for k in state_dict)
784
+ if has_peft_state_dict:
785
+ state_dict = {k: v for k, v in state_dict.items() if k.startswith("transformer.")}
786
+ return state_dict
787
+
788
+ # Another weird one.
789
+ has_mixture = any(
790
+ k.startswith("lora_transformer_") and ("lora_down" in k or "lora_up" in k or "alpha" in k) for k in state_dict
791
+ )
792
+
793
+ # ComfyUI.
794
+ if not has_mixture:
795
+ state_dict = {k.replace("diffusion_model.", "lora_unet_"): v for k, v in state_dict.items()}
796
+ state_dict = {k.replace("text_encoders.clip_l.transformer.", "lora_te_"): v for k, v in state_dict.items()}
797
+
798
+ has_position_embedding = any("position_embedding" in k for k in state_dict)
799
+ if has_position_embedding:
800
+ zero_status_pe = state_dict_all_zero(state_dict, "position_embedding")
801
+ if zero_status_pe:
802
+ logger.info(
803
+ "The `position_embedding` LoRA params are all zeros which make them ineffective. "
804
+ "So, we will purge them out of the curret state dict to make loading possible."
805
+ )
806
+
807
+ else:
808
+ logger.info(
809
+ "The state_dict has position_embedding LoRA params and we currently do not support them. "
810
+ "Open an issue if you need this supported - https://github.com/huggingface/diffusers/issues/new."
811
+ )
812
+ state_dict = {k: v for k, v in state_dict.items() if "position_embedding" not in k}
813
+
814
+ has_t5xxl = any(k.startswith("text_encoders.t5xxl.transformer.") for k in state_dict)
815
+ if has_t5xxl:
816
+ zero_status_t5 = state_dict_all_zero(state_dict, "text_encoders.t5xxl")
817
+ if zero_status_t5:
818
+ logger.info(
819
+ "The `t5xxl` LoRA params are all zeros which make them ineffective. "
820
+ "So, we will purge them out of the curret state dict to make loading possible."
821
+ )
822
+ else:
823
+ logger.info(
824
+ "T5-xxl keys found in the state dict, which are currently unsupported. We will filter them out."
825
+ "Open an issue if this is a problem - https://github.com/huggingface/diffusers/issues/new."
826
+ )
827
+ state_dict = {k: v for k, v in state_dict.items() if not k.startswith("text_encoders.t5xxl.transformer.")}
828
+
829
+ has_diffb = any("diff_b" in k and k.startswith(("lora_unet_", "lora_te_")) for k in state_dict)
830
+ if has_diffb:
831
+ zero_status_diff_b = state_dict_all_zero(state_dict, ".diff_b")
832
+ if zero_status_diff_b:
833
+ logger.info(
834
+ "The `diff_b` LoRA params are all zeros which make them ineffective. "
835
+ "So, we will purge them out of the curret state dict to make loading possible."
836
+ )
837
+ else:
838
+ logger.info(
839
+ "`diff_b` keys found in the state dict which are currently unsupported. "
840
+ "So, we will filter out those keys. Open an issue if this is a problem - "
841
+ "https://github.com/huggingface/diffusers/issues/new."
842
+ )
843
+ state_dict = {k: v for k, v in state_dict.items() if ".diff_b" not in k}
844
+
845
+ has_norm_diff = any(".norm" in k and ".diff" in k for k in state_dict)
846
+ if has_norm_diff:
847
+ zero_status_diff = state_dict_all_zero(state_dict, ".diff")
848
+ if zero_status_diff:
849
+ logger.info(
850
+ "The `diff` LoRA params are all zeros which make them ineffective. "
851
+ "So, we will purge them out of the curret state dict to make loading possible."
852
+ )
853
+ else:
854
+ logger.info(
855
+ "Normalization diff keys found in the state dict which are currently unsupported. "
856
+ "So, we will filter out those keys. Open an issue if this is a problem - "
857
+ "https://github.com/huggingface/diffusers/issues/new."
858
+ )
859
+ state_dict = {k: v for k, v in state_dict.items() if ".norm" not in k and ".diff" not in k}
860
+
861
+ limit_substrings = ["lora_down", "lora_up"]
862
+ if any("alpha" in k for k in state_dict):
863
+ limit_substrings.append("alpha")
864
+
865
+ state_dict = {
866
+ _custom_replace(k, limit_substrings): v
867
+ for k, v in state_dict.items()
868
+ if k.startswith(("lora_unet_", "lora_te_"))
869
+ }
870
+
871
+ if any("text_projection" in k for k in state_dict):
872
+ logger.info(
873
+ "`text_projection` keys found in the `state_dict` which are unexpected. "
874
+ "So, we will filter out those keys. Open an issue if this is a problem - "
875
+ "https://github.com/huggingface/diffusers/issues/new."
876
+ )
877
+ state_dict = {k: v for k, v in state_dict.items() if "text_projection" not in k}
878
+
879
+ if has_mixture:
880
+ return _convert_mixture_state_dict_to_diffusers(state_dict)
881
+
561
882
  return _convert_sd_scripts_to_ai_toolkit(state_dict)
562
883
 
563
884
 
@@ -669,6 +990,26 @@ def _convert_xlabs_flux_lora_to_diffusers(old_state_dict):
669
990
  return new_state_dict
670
991
 
671
992
 
993
+ def _custom_replace(key: str, substrings: List[str]) -> str:
994
+ # Replaces the "."s with "_"s upto the `substrings`.
995
+ # Example:
996
+ # lora_unet.foo.bar.lora_A.weight -> lora_unet_foo_bar.lora_A.weight
997
+ pattern = "(" + "|".join(re.escape(sub) for sub in substrings) + ")"
998
+
999
+ match = re.search(pattern, key)
1000
+ if match:
1001
+ start_sub = match.start()
1002
+ if start_sub > 0 and key[start_sub - 1] == ".":
1003
+ boundary = start_sub - 1
1004
+ else:
1005
+ boundary = start_sub
1006
+ left = key[:boundary].replace(".", "_")
1007
+ right = key[boundary:]
1008
+ return left + right
1009
+ else:
1010
+ return key.replace(".", "_")
1011
+
1012
+
672
1013
  def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
673
1014
  converted_state_dict = {}
674
1015
  original_state_dict_keys = list(original_state_dict.keys())
@@ -677,28 +1018,23 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
677
1018
  inner_dim = 3072
678
1019
  mlp_ratio = 4.0
679
1020
 
680
- def swap_scale_shift(weight):
681
- shift, scale = weight.chunk(2, dim=0)
682
- new_weight = torch.cat([scale, shift], dim=0)
683
- return new_weight
684
-
685
1021
  for lora_key in ["lora_A", "lora_B"]:
686
1022
  ## time_text_embed.timestep_embedder <- time_in
687
- converted_state_dict[
688
- f"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight"
689
- ] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.weight")
1023
+ converted_state_dict[f"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight"] = (
1024
+ original_state_dict.pop(f"time_in.in_layer.{lora_key}.weight")
1025
+ )
690
1026
  if f"time_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
691
- converted_state_dict[
692
- f"time_text_embed.timestep_embedder.linear_1.{lora_key}.bias"
693
- ] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.bias")
1027
+ converted_state_dict[f"time_text_embed.timestep_embedder.linear_1.{lora_key}.bias"] = (
1028
+ original_state_dict.pop(f"time_in.in_layer.{lora_key}.bias")
1029
+ )
694
1030
 
695
- converted_state_dict[
696
- f"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight"
697
- ] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.weight")
1031
+ converted_state_dict[f"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight"] = (
1032
+ original_state_dict.pop(f"time_in.out_layer.{lora_key}.weight")
1033
+ )
698
1034
  if f"time_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
699
- converted_state_dict[
700
- f"time_text_embed.timestep_embedder.linear_2.{lora_key}.bias"
701
- ] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.bias")
1035
+ converted_state_dict[f"time_text_embed.timestep_embedder.linear_2.{lora_key}.bias"] = (
1036
+ original_state_dict.pop(f"time_in.out_layer.{lora_key}.bias")
1037
+ )
702
1038
 
703
1039
  ## time_text_embed.text_embedder <- vector_in
704
1040
  converted_state_dict[f"time_text_embed.text_embedder.linear_1.{lora_key}.weight"] = original_state_dict.pop(
@@ -720,21 +1056,21 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
720
1056
  # guidance
721
1057
  has_guidance = any("guidance" in k for k in original_state_dict)
722
1058
  if has_guidance:
723
- converted_state_dict[
724
- f"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight"
725
- ] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.weight")
1059
+ converted_state_dict[f"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight"] = (
1060
+ original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.weight")
1061
+ )
726
1062
  if f"guidance_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
727
- converted_state_dict[
728
- f"time_text_embed.guidance_embedder.linear_1.{lora_key}.bias"
729
- ] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.bias")
1063
+ converted_state_dict[f"time_text_embed.guidance_embedder.linear_1.{lora_key}.bias"] = (
1064
+ original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.bias")
1065
+ )
730
1066
 
731
- converted_state_dict[
732
- f"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight"
733
- ] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.weight")
1067
+ converted_state_dict[f"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight"] = (
1068
+ original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.weight")
1069
+ )
734
1070
  if f"guidance_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
735
- converted_state_dict[
736
- f"time_text_embed.guidance_embedder.linear_2.{lora_key}.bias"
737
- ] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.bias")
1071
+ converted_state_dict[f"time_text_embed.guidance_embedder.linear_2.{lora_key}.bias"] = (
1072
+ original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.bias")
1073
+ )
738
1074
 
739
1075
  # context_embedder
740
1076
  converted_state_dict[f"context_embedder.{lora_key}.weight"] = original_state_dict.pop(
@@ -973,3 +1309,302 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
973
1309
  converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
974
1310
 
975
1311
  return converted_state_dict
1312
+
1313
+
1314
+ def _convert_hunyuan_video_lora_to_diffusers(original_state_dict):
1315
+ converted_state_dict = {k: original_state_dict.pop(k) for k in list(original_state_dict.keys())}
1316
+
1317
+ def remap_norm_scale_shift_(key, state_dict):
1318
+ weight = state_dict.pop(key)
1319
+ shift, scale = weight.chunk(2, dim=0)
1320
+ new_weight = torch.cat([scale, shift], dim=0)
1321
+ state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight
1322
+
1323
+ def remap_txt_in_(key, state_dict):
1324
+ def rename_key(key):
1325
+ new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks")
1326
+ new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear")
1327
+ new_key = new_key.replace("txt_in", "context_embedder")
1328
+ new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1")
1329
+ new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2")
1330
+ new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder")
1331
+ new_key = new_key.replace("mlp", "ff")
1332
+ return new_key
1333
+
1334
+ if "self_attn_qkv" in key:
1335
+ weight = state_dict.pop(key)
1336
+ to_q, to_k, to_v = weight.chunk(3, dim=0)
1337
+ state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q
1338
+ state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k
1339
+ state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v
1340
+ else:
1341
+ state_dict[rename_key(key)] = state_dict.pop(key)
1342
+
1343
+ def remap_img_attn_qkv_(key, state_dict):
1344
+ weight = state_dict.pop(key)
1345
+ if "lora_A" in key:
1346
+ state_dict[key.replace("img_attn_qkv", "attn.to_q")] = weight
1347
+ state_dict[key.replace("img_attn_qkv", "attn.to_k")] = weight
1348
+ state_dict[key.replace("img_attn_qkv", "attn.to_v")] = weight
1349
+ else:
1350
+ to_q, to_k, to_v = weight.chunk(3, dim=0)
1351
+ state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q
1352
+ state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k
1353
+ state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v
1354
+
1355
+ def remap_txt_attn_qkv_(key, state_dict):
1356
+ weight = state_dict.pop(key)
1357
+ if "lora_A" in key:
1358
+ state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = weight
1359
+ state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = weight
1360
+ state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = weight
1361
+ else:
1362
+ to_q, to_k, to_v = weight.chunk(3, dim=0)
1363
+ state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q
1364
+ state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k
1365
+ state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v
1366
+
1367
+ def remap_single_transformer_blocks_(key, state_dict):
1368
+ hidden_size = 3072
1369
+
1370
+ if "linear1.lora_A.weight" in key or "linear1.lora_B.weight" in key:
1371
+ linear1_weight = state_dict.pop(key)
1372
+ if "lora_A" in key:
1373
+ new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
1374
+ ".linear1.lora_A.weight"
1375
+ )
1376
+ state_dict[f"{new_key}.attn.to_q.lora_A.weight"] = linear1_weight
1377
+ state_dict[f"{new_key}.attn.to_k.lora_A.weight"] = linear1_weight
1378
+ state_dict[f"{new_key}.attn.to_v.lora_A.weight"] = linear1_weight
1379
+ state_dict[f"{new_key}.proj_mlp.lora_A.weight"] = linear1_weight
1380
+ else:
1381
+ split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size)
1382
+ q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0)
1383
+ new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
1384
+ ".linear1.lora_B.weight"
1385
+ )
1386
+ state_dict[f"{new_key}.attn.to_q.lora_B.weight"] = q
1387
+ state_dict[f"{new_key}.attn.to_k.lora_B.weight"] = k
1388
+ state_dict[f"{new_key}.attn.to_v.lora_B.weight"] = v
1389
+ state_dict[f"{new_key}.proj_mlp.lora_B.weight"] = mlp
1390
+
1391
+ elif "linear1.lora_A.bias" in key or "linear1.lora_B.bias" in key:
1392
+ linear1_bias = state_dict.pop(key)
1393
+ if "lora_A" in key:
1394
+ new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
1395
+ ".linear1.lora_A.bias"
1396
+ )
1397
+ state_dict[f"{new_key}.attn.to_q.lora_A.bias"] = linear1_bias
1398
+ state_dict[f"{new_key}.attn.to_k.lora_A.bias"] = linear1_bias
1399
+ state_dict[f"{new_key}.attn.to_v.lora_A.bias"] = linear1_bias
1400
+ state_dict[f"{new_key}.proj_mlp.lora_A.bias"] = linear1_bias
1401
+ else:
1402
+ split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size)
1403
+ q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0)
1404
+ new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
1405
+ ".linear1.lora_B.bias"
1406
+ )
1407
+ state_dict[f"{new_key}.attn.to_q.lora_B.bias"] = q_bias
1408
+ state_dict[f"{new_key}.attn.to_k.lora_B.bias"] = k_bias
1409
+ state_dict[f"{new_key}.attn.to_v.lora_B.bias"] = v_bias
1410
+ state_dict[f"{new_key}.proj_mlp.lora_B.bias"] = mlp_bias
1411
+
1412
+ else:
1413
+ new_key = key.replace("single_blocks", "single_transformer_blocks")
1414
+ new_key = new_key.replace("linear2", "proj_out")
1415
+ new_key = new_key.replace("q_norm", "attn.norm_q")
1416
+ new_key = new_key.replace("k_norm", "attn.norm_k")
1417
+ state_dict[new_key] = state_dict.pop(key)
1418
+
1419
+ TRANSFORMER_KEYS_RENAME_DICT = {
1420
+ "img_in": "x_embedder",
1421
+ "time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1",
1422
+ "time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2",
1423
+ "guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1",
1424
+ "guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2",
1425
+ "vector_in.in_layer": "time_text_embed.text_embedder.linear_1",
1426
+ "vector_in.out_layer": "time_text_embed.text_embedder.linear_2",
1427
+ "double_blocks": "transformer_blocks",
1428
+ "img_attn_q_norm": "attn.norm_q",
1429
+ "img_attn_k_norm": "attn.norm_k",
1430
+ "img_attn_proj": "attn.to_out.0",
1431
+ "txt_attn_q_norm": "attn.norm_added_q",
1432
+ "txt_attn_k_norm": "attn.norm_added_k",
1433
+ "txt_attn_proj": "attn.to_add_out",
1434
+ "img_mod.linear": "norm1.linear",
1435
+ "img_norm1": "norm1.norm",
1436
+ "img_norm2": "norm2",
1437
+ "img_mlp": "ff",
1438
+ "txt_mod.linear": "norm1_context.linear",
1439
+ "txt_norm1": "norm1.norm",
1440
+ "txt_norm2": "norm2_context",
1441
+ "txt_mlp": "ff_context",
1442
+ "self_attn_proj": "attn.to_out.0",
1443
+ "modulation.linear": "norm.linear",
1444
+ "pre_norm": "norm.norm",
1445
+ "final_layer.norm_final": "norm_out.norm",
1446
+ "final_layer.linear": "proj_out",
1447
+ "fc1": "net.0.proj",
1448
+ "fc2": "net.2",
1449
+ "input_embedder": "proj_in",
1450
+ }
1451
+
1452
+ TRANSFORMER_SPECIAL_KEYS_REMAP = {
1453
+ "txt_in": remap_txt_in_,
1454
+ "img_attn_qkv": remap_img_attn_qkv_,
1455
+ "txt_attn_qkv": remap_txt_attn_qkv_,
1456
+ "single_blocks": remap_single_transformer_blocks_,
1457
+ "final_layer.adaLN_modulation.1": remap_norm_scale_shift_,
1458
+ }
1459
+
1460
+ # Some folks attempt to make their state dict compatible with diffusers by adding "transformer." prefix to all keys
1461
+ # and use their custom code. To make sure both "original" and "attempted diffusers" loras work as expected, we make
1462
+ # sure that both follow the same initial format by stripping off the "transformer." prefix.
1463
+ for key in list(converted_state_dict.keys()):
1464
+ if key.startswith("transformer."):
1465
+ converted_state_dict[key[len("transformer.") :]] = converted_state_dict.pop(key)
1466
+ if key.startswith("diffusion_model."):
1467
+ converted_state_dict[key[len("diffusion_model.") :]] = converted_state_dict.pop(key)
1468
+
1469
+ # Rename and remap the state dict keys
1470
+ for key in list(converted_state_dict.keys()):
1471
+ new_key = key[:]
1472
+ for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
1473
+ new_key = new_key.replace(replace_key, rename_key)
1474
+ converted_state_dict[new_key] = converted_state_dict.pop(key)
1475
+
1476
+ for key in list(converted_state_dict.keys()):
1477
+ for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
1478
+ if special_key not in key:
1479
+ continue
1480
+ handler_fn_inplace(key, converted_state_dict)
1481
+
1482
+ # Add back the "transformer." prefix
1483
+ for key in list(converted_state_dict.keys()):
1484
+ converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
1485
+
1486
+ return converted_state_dict
1487
+
1488
+
1489
+ def _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict):
1490
+ # Remove "diffusion_model." prefix from keys.
1491
+ state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()}
1492
+ converted_state_dict = {}
1493
+
1494
+ def get_num_layers(keys, pattern):
1495
+ layers = set()
1496
+ for key in keys:
1497
+ match = re.search(pattern, key)
1498
+ if match:
1499
+ layers.add(int(match.group(1)))
1500
+ return len(layers)
1501
+
1502
+ def process_block(prefix, index, convert_norm):
1503
+ # Process attention qkv: pop lora_A and lora_B weights.
1504
+ lora_down = state_dict.pop(f"{prefix}.{index}.attention.qkv.lora_A.weight")
1505
+ lora_up = state_dict.pop(f"{prefix}.{index}.attention.qkv.lora_B.weight")
1506
+ for attn_key in ["to_q", "to_k", "to_v"]:
1507
+ converted_state_dict[f"{prefix}.{index}.attn.{attn_key}.lora_A.weight"] = lora_down
1508
+ for attn_key, weight in zip(["to_q", "to_k", "to_v"], torch.split(lora_up, [2304, 768, 768], dim=0)):
1509
+ converted_state_dict[f"{prefix}.{index}.attn.{attn_key}.lora_B.weight"] = weight
1510
+
1511
+ # Process attention out weights.
1512
+ converted_state_dict[f"{prefix}.{index}.attn.to_out.0.lora_A.weight"] = state_dict.pop(
1513
+ f"{prefix}.{index}.attention.out.lora_A.weight"
1514
+ )
1515
+ converted_state_dict[f"{prefix}.{index}.attn.to_out.0.lora_B.weight"] = state_dict.pop(
1516
+ f"{prefix}.{index}.attention.out.lora_B.weight"
1517
+ )
1518
+
1519
+ # Process feed-forward weights for layers 1, 2, and 3.
1520
+ for layer in range(1, 4):
1521
+ converted_state_dict[f"{prefix}.{index}.feed_forward.linear_{layer}.lora_A.weight"] = state_dict.pop(
1522
+ f"{prefix}.{index}.feed_forward.w{layer}.lora_A.weight"
1523
+ )
1524
+ converted_state_dict[f"{prefix}.{index}.feed_forward.linear_{layer}.lora_B.weight"] = state_dict.pop(
1525
+ f"{prefix}.{index}.feed_forward.w{layer}.lora_B.weight"
1526
+ )
1527
+
1528
+ if convert_norm:
1529
+ converted_state_dict[f"{prefix}.{index}.norm1.linear.lora_A.weight"] = state_dict.pop(
1530
+ f"{prefix}.{index}.adaLN_modulation.1.lora_A.weight"
1531
+ )
1532
+ converted_state_dict[f"{prefix}.{index}.norm1.linear.lora_B.weight"] = state_dict.pop(
1533
+ f"{prefix}.{index}.adaLN_modulation.1.lora_B.weight"
1534
+ )
1535
+
1536
+ noise_refiner_pattern = r"noise_refiner\.(\d+)\."
1537
+ num_noise_refiner_layers = get_num_layers(state_dict.keys(), noise_refiner_pattern)
1538
+ for i in range(num_noise_refiner_layers):
1539
+ process_block("noise_refiner", i, convert_norm=True)
1540
+
1541
+ context_refiner_pattern = r"context_refiner\.(\d+)\."
1542
+ num_context_refiner_layers = get_num_layers(state_dict.keys(), context_refiner_pattern)
1543
+ for i in range(num_context_refiner_layers):
1544
+ process_block("context_refiner", i, convert_norm=False)
1545
+
1546
+ core_transformer_pattern = r"layers\.(\d+)\."
1547
+ num_core_transformer_layers = get_num_layers(state_dict.keys(), core_transformer_pattern)
1548
+ for i in range(num_core_transformer_layers):
1549
+ process_block("layers", i, convert_norm=True)
1550
+
1551
+ if len(state_dict) > 0:
1552
+ raise ValueError(f"`state_dict` should be empty at this point but has {state_dict.keys()=}")
1553
+
1554
+ for key in list(converted_state_dict.keys()):
1555
+ converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
1556
+
1557
+ return converted_state_dict
1558
+
1559
+
1560
+ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
1561
+ converted_state_dict = {}
1562
+ original_state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()}
1563
+
1564
+ num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in original_state_dict})
1565
+ is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict)
1566
+
1567
+ for i in range(num_blocks):
1568
+ # Self-attention
1569
+ for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
1570
+ converted_state_dict[f"blocks.{i}.attn1.{c}.lora_A.weight"] = original_state_dict.pop(
1571
+ f"blocks.{i}.self_attn.{o}.lora_A.weight"
1572
+ )
1573
+ converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.weight"] = original_state_dict.pop(
1574
+ f"blocks.{i}.self_attn.{o}.lora_B.weight"
1575
+ )
1576
+
1577
+ # Cross-attention
1578
+ for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
1579
+ converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop(
1580
+ f"blocks.{i}.cross_attn.{o}.lora_A.weight"
1581
+ )
1582
+ converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
1583
+ f"blocks.{i}.cross_attn.{o}.lora_B.weight"
1584
+ )
1585
+
1586
+ if is_i2v_lora:
1587
+ for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
1588
+ converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop(
1589
+ f"blocks.{i}.cross_attn.{o}.lora_A.weight"
1590
+ )
1591
+ converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
1592
+ f"blocks.{i}.cross_attn.{o}.lora_B.weight"
1593
+ )
1594
+
1595
+ # FFN
1596
+ for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]):
1597
+ converted_state_dict[f"blocks.{i}.ffn.{c}.lora_A.weight"] = original_state_dict.pop(
1598
+ f"blocks.{i}.{o}.lora_A.weight"
1599
+ )
1600
+ converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.weight"] = original_state_dict.pop(
1601
+ f"blocks.{i}.{o}.lora_B.weight"
1602
+ )
1603
+
1604
+ if len(original_state_dict) > 0:
1605
+ raise ValueError(f"`state_dict` should be empty at this point but has {original_state_dict.keys()=}")
1606
+
1607
+ for key in list(converted_state_dict.keys()):
1608
+ converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
1609
+
1610
+ return converted_state_dict