diffusers 0.32.2__py3-none-any.whl → 0.33.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (389) hide show
  1. diffusers/__init__.py +186 -3
  2. diffusers/configuration_utils.py +40 -12
  3. diffusers/dependency_versions_table.py +9 -2
  4. diffusers/hooks/__init__.py +9 -0
  5. diffusers/hooks/faster_cache.py +653 -0
  6. diffusers/hooks/group_offloading.py +793 -0
  7. diffusers/hooks/hooks.py +236 -0
  8. diffusers/hooks/layerwise_casting.py +245 -0
  9. diffusers/hooks/pyramid_attention_broadcast.py +311 -0
  10. diffusers/loaders/__init__.py +6 -0
  11. diffusers/loaders/ip_adapter.py +38 -30
  12. diffusers/loaders/lora_base.py +121 -86
  13. diffusers/loaders/lora_conversion_utils.py +504 -44
  14. diffusers/loaders/lora_pipeline.py +1769 -181
  15. diffusers/loaders/peft.py +167 -57
  16. diffusers/loaders/single_file.py +17 -2
  17. diffusers/loaders/single_file_model.py +53 -5
  18. diffusers/loaders/single_file_utils.py +646 -72
  19. diffusers/loaders/textual_inversion.py +9 -9
  20. diffusers/loaders/transformer_flux.py +8 -9
  21. diffusers/loaders/transformer_sd3.py +120 -39
  22. diffusers/loaders/unet.py +20 -7
  23. diffusers/models/__init__.py +22 -0
  24. diffusers/models/activations.py +9 -9
  25. diffusers/models/attention.py +0 -1
  26. diffusers/models/attention_processor.py +163 -25
  27. diffusers/models/auto_model.py +169 -0
  28. diffusers/models/autoencoders/__init__.py +2 -0
  29. diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
  30. diffusers/models/autoencoders/autoencoder_dc.py +106 -4
  31. diffusers/models/autoencoders/autoencoder_kl.py +0 -4
  32. diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
  33. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
  34. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
  35. diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
  36. diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
  37. diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
  38. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
  39. diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
  40. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
  41. diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
  42. diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
  43. diffusers/models/autoencoders/vae.py +31 -141
  44. diffusers/models/autoencoders/vq_model.py +3 -0
  45. diffusers/models/cache_utils.py +108 -0
  46. diffusers/models/controlnets/__init__.py +1 -0
  47. diffusers/models/controlnets/controlnet.py +3 -8
  48. diffusers/models/controlnets/controlnet_flux.py +14 -42
  49. diffusers/models/controlnets/controlnet_sd3.py +58 -34
  50. diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
  51. diffusers/models/controlnets/controlnet_union.py +27 -18
  52. diffusers/models/controlnets/controlnet_xs.py +7 -46
  53. diffusers/models/controlnets/multicontrolnet_union.py +196 -0
  54. diffusers/models/embeddings.py +18 -7
  55. diffusers/models/model_loading_utils.py +122 -80
  56. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  57. diffusers/models/modeling_flax_utils.py +1 -1
  58. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  59. diffusers/models/modeling_utils.py +617 -272
  60. diffusers/models/normalization.py +67 -14
  61. diffusers/models/resnet.py +1 -1
  62. diffusers/models/transformers/__init__.py +6 -0
  63. diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
  64. diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
  65. diffusers/models/transformers/consisid_transformer_3d.py +789 -0
  66. diffusers/models/transformers/dit_transformer_2d.py +5 -19
  67. diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
  68. diffusers/models/transformers/latte_transformer_3d.py +20 -15
  69. diffusers/models/transformers/lumina_nextdit2d.py +3 -1
  70. diffusers/models/transformers/pixart_transformer_2d.py +4 -19
  71. diffusers/models/transformers/prior_transformer.py +5 -1
  72. diffusers/models/transformers/sana_transformer.py +144 -40
  73. diffusers/models/transformers/stable_audio_transformer.py +5 -20
  74. diffusers/models/transformers/transformer_2d.py +7 -22
  75. diffusers/models/transformers/transformer_allegro.py +9 -17
  76. diffusers/models/transformers/transformer_cogview3plus.py +6 -17
  77. diffusers/models/transformers/transformer_cogview4.py +462 -0
  78. diffusers/models/transformers/transformer_easyanimate.py +527 -0
  79. diffusers/models/transformers/transformer_flux.py +68 -110
  80. diffusers/models/transformers/transformer_hunyuan_video.py +404 -46
  81. diffusers/models/transformers/transformer_ltx.py +53 -35
  82. diffusers/models/transformers/transformer_lumina2.py +548 -0
  83. diffusers/models/transformers/transformer_mochi.py +6 -17
  84. diffusers/models/transformers/transformer_omnigen.py +469 -0
  85. diffusers/models/transformers/transformer_sd3.py +56 -86
  86. diffusers/models/transformers/transformer_temporal.py +5 -11
  87. diffusers/models/transformers/transformer_wan.py +469 -0
  88. diffusers/models/unets/unet_1d.py +3 -1
  89. diffusers/models/unets/unet_2d.py +21 -20
  90. diffusers/models/unets/unet_2d_blocks.py +19 -243
  91. diffusers/models/unets/unet_2d_condition.py +4 -6
  92. diffusers/models/unets/unet_3d_blocks.py +14 -127
  93. diffusers/models/unets/unet_3d_condition.py +8 -12
  94. diffusers/models/unets/unet_i2vgen_xl.py +5 -13
  95. diffusers/models/unets/unet_kandinsky3.py +0 -4
  96. diffusers/models/unets/unet_motion_model.py +20 -114
  97. diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
  98. diffusers/models/unets/unet_stable_cascade.py +8 -35
  99. diffusers/models/unets/uvit_2d.py +1 -4
  100. diffusers/optimization.py +2 -2
  101. diffusers/pipelines/__init__.py +57 -8
  102. diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
  103. diffusers/pipelines/amused/pipeline_amused.py +15 -2
  104. diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
  105. diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
  106. diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
  107. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
  108. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
  109. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
  110. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
  111. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
  112. diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
  113. diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
  114. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
  115. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
  116. diffusers/pipelines/auto_pipeline.py +35 -14
  117. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  118. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
  119. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
  120. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
  121. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
  122. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
  123. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
  124. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
  125. diffusers/pipelines/cogview4/__init__.py +49 -0
  126. diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
  127. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
  128. diffusers/pipelines/cogview4/pipeline_output.py +21 -0
  129. diffusers/pipelines/consisid/__init__.py +49 -0
  130. diffusers/pipelines/consisid/consisid_utils.py +357 -0
  131. diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
  132. diffusers/pipelines/consisid/pipeline_output.py +20 -0
  133. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
  134. diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
  135. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
  136. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
  137. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
  138. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
  139. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
  140. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
  141. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
  142. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
  143. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
  144. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
  145. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
  146. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
  147. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
  148. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
  149. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
  150. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
  151. diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
  152. diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
  153. diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
  154. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
  155. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
  156. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
  157. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
  158. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
  159. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
  160. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
  161. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  162. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
  163. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
  164. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
  165. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
  166. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
  167. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
  168. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
  169. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
  170. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  171. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  172. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
  173. diffusers/pipelines/dit/pipeline_dit.py +15 -2
  174. diffusers/pipelines/easyanimate/__init__.py +52 -0
  175. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
  176. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
  177. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
  178. diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
  179. diffusers/pipelines/flux/pipeline_flux.py +53 -21
  180. diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
  181. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
  182. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
  183. diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
  184. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
  185. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
  186. diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
  187. diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
  188. diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
  189. diffusers/pipelines/free_noise_utils.py +3 -3
  190. diffusers/pipelines/hunyuan_video/__init__.py +4 -0
  191. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
  192. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
  193. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
  194. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
  195. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
  196. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
  197. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  198. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
  199. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
  200. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
  201. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
  202. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
  203. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
  204. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
  205. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
  206. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
  207. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
  208. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
  209. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
  210. diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
  211. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
  212. diffusers/pipelines/kolors/text_encoder.py +7 -34
  213. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
  214. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
  215. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
  216. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
  217. diffusers/pipelines/latte/pipeline_latte.py +36 -7
  218. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
  219. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
  220. diffusers/pipelines/ltx/__init__.py +2 -0
  221. diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
  222. diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
  223. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
  224. diffusers/pipelines/lumina/__init__.py +2 -2
  225. diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
  226. diffusers/pipelines/lumina2/__init__.py +48 -0
  227. diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
  228. diffusers/pipelines/marigold/__init__.py +2 -0
  229. diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
  230. diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
  231. diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
  232. diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
  233. diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
  234. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
  235. diffusers/pipelines/omnigen/__init__.py +50 -0
  236. diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
  237. diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
  238. diffusers/pipelines/onnx_utils.py +5 -3
  239. diffusers/pipelines/pag/pag_utils.py +1 -1
  240. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
  241. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
  242. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
  243. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
  244. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
  245. diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
  246. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
  247. diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
  248. diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
  249. diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
  250. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
  251. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
  252. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
  253. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
  254. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
  255. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
  256. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
  257. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
  258. diffusers/pipelines/pia/pipeline_pia.py +13 -1
  259. diffusers/pipelines/pipeline_flax_utils.py +7 -7
  260. diffusers/pipelines/pipeline_loading_utils.py +193 -83
  261. diffusers/pipelines/pipeline_utils.py +221 -106
  262. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
  263. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
  264. diffusers/pipelines/sana/__init__.py +2 -0
  265. diffusers/pipelines/sana/pipeline_sana.py +183 -58
  266. diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
  267. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
  268. diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
  269. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
  270. diffusers/pipelines/shap_e/renderer.py +6 -6
  271. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
  272. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
  273. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
  274. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
  275. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
  276. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
  277. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
  278. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
  279. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  280. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
  281. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
  282. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
  283. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
  284. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
  285. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
  286. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
  287. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
  288. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
  289. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
  290. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
  291. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
  292. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
  293. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
  294. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
  295. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
  296. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
  297. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
  298. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
  299. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
  300. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
  301. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
  302. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
  303. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
  304. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
  305. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
  306. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  307. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
  308. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
  309. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
  310. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
  311. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
  312. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
  313. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
  314. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
  315. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
  316. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
  317. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
  318. diffusers/pipelines/transformers_loading_utils.py +121 -0
  319. diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
  320. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
  321. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
  322. diffusers/pipelines/wan/__init__.py +51 -0
  323. diffusers/pipelines/wan/pipeline_output.py +20 -0
  324. diffusers/pipelines/wan/pipeline_wan.py +595 -0
  325. diffusers/pipelines/wan/pipeline_wan_i2v.py +724 -0
  326. diffusers/pipelines/wan/pipeline_wan_video2video.py +727 -0
  327. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
  328. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
  329. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
  330. diffusers/quantizers/auto.py +5 -1
  331. diffusers/quantizers/base.py +5 -9
  332. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
  333. diffusers/quantizers/bitsandbytes/utils.py +30 -20
  334. diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
  335. diffusers/quantizers/gguf/utils.py +4 -2
  336. diffusers/quantizers/quantization_config.py +59 -4
  337. diffusers/quantizers/quanto/__init__.py +1 -0
  338. diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
  339. diffusers/quantizers/quanto/utils.py +60 -0
  340. diffusers/quantizers/torchao/__init__.py +1 -1
  341. diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
  342. diffusers/schedulers/__init__.py +2 -1
  343. diffusers/schedulers/scheduling_consistency_models.py +1 -2
  344. diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
  345. diffusers/schedulers/scheduling_ddpm.py +2 -3
  346. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
  347. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
  348. diffusers/schedulers/scheduling_edm_euler.py +45 -10
  349. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
  350. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
  351. diffusers/schedulers/scheduling_heun_discrete.py +1 -1
  352. diffusers/schedulers/scheduling_lcm.py +1 -2
  353. diffusers/schedulers/scheduling_lms_discrete.py +1 -1
  354. diffusers/schedulers/scheduling_repaint.py +5 -1
  355. diffusers/schedulers/scheduling_scm.py +265 -0
  356. diffusers/schedulers/scheduling_tcd.py +1 -2
  357. diffusers/schedulers/scheduling_utils.py +2 -1
  358. diffusers/training_utils.py +14 -7
  359. diffusers/utils/__init__.py +9 -1
  360. diffusers/utils/constants.py +13 -1
  361. diffusers/utils/deprecation_utils.py +1 -1
  362. diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
  363. diffusers/utils/dummy_gguf_objects.py +17 -0
  364. diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
  365. diffusers/utils/dummy_pt_objects.py +233 -0
  366. diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
  367. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  368. diffusers/utils/dummy_torchao_objects.py +17 -0
  369. diffusers/utils/dynamic_modules_utils.py +1 -1
  370. diffusers/utils/export_utils.py +28 -3
  371. diffusers/utils/hub_utils.py +52 -102
  372. diffusers/utils/import_utils.py +121 -221
  373. diffusers/utils/loading_utils.py +2 -1
  374. diffusers/utils/logging.py +1 -2
  375. diffusers/utils/peft_utils.py +6 -14
  376. diffusers/utils/remote_utils.py +425 -0
  377. diffusers/utils/source_code_parsing_utils.py +52 -0
  378. diffusers/utils/state_dict_utils.py +15 -1
  379. diffusers/utils/testing_utils.py +243 -13
  380. diffusers/utils/torch_utils.py +10 -0
  381. diffusers/utils/typing_utils.py +91 -0
  382. diffusers/video_processor.py +1 -1
  383. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/METADATA +21 -4
  384. diffusers-0.33.1.dist-info/RECORD +608 -0
  385. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/WHEEL +1 -1
  386. diffusers-0.32.2.dist-info/RECORD +0 -550
  387. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/LICENSE +0 -0
  388. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/entry_points.txt +0 -0
  389. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/top_level.txt +0 -0
@@ -13,15 +13,22 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import re
16
+ from typing import List
16
17
 
17
18
  import torch
18
19
 
19
- from ..utils import is_peft_version, logging
20
+ from ..utils import is_peft_version, logging, state_dict_all_zero
20
21
 
21
22
 
22
23
  logger = logging.get_logger(__name__)
23
24
 
24
25
 
26
+ def swap_scale_shift(weight):
27
+ shift, scale = weight.chunk(2, dim=0)
28
+ new_weight = torch.cat([scale, shift], dim=0)
29
+ return new_weight
30
+
31
+
25
32
  def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", block_slice_pos=5):
26
33
  # 1. get all state_dict_keys
27
34
  all_keys = list(state_dict.keys())
@@ -177,9 +184,9 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
177
184
  # Store DoRA scale if present.
178
185
  if dora_present_in_unet:
179
186
  dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down."
180
- unet_state_dict[
181
- diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")
182
- ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
187
+ unet_state_dict[diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")] = (
188
+ state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
189
+ )
183
190
 
184
191
  # Handle text encoder LoRAs.
185
192
  elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
@@ -199,13 +206,13 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
199
206
  "_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer."
200
207
  )
201
208
  if lora_name.startswith(("lora_te_", "lora_te1_")):
202
- te_state_dict[
203
- diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
204
- ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
209
+ te_state_dict[diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")] = (
210
+ state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
211
+ )
205
212
  elif lora_name.startswith("lora_te2_"):
206
- te2_state_dict[
207
- diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
208
- ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
213
+ te2_state_dict[diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")] = (
214
+ state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
215
+ )
209
216
 
210
217
  # Store alpha if present.
211
218
  if lora_name_alpha in state_dict:
@@ -313,6 +320,7 @@ def _convert_text_encoder_lora_key(key, lora_name):
313
320
  # Be aware that this is the new diffusers convention and the rest of the code might
314
321
  # not utilize it yet.
315
322
  diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
323
+
316
324
  return diffusers_name
317
325
 
318
326
 
@@ -331,8 +339,7 @@ def _get_alpha_name(lora_name_alpha, diffusers_name, alpha):
331
339
 
332
340
 
333
341
  # The utilities under `_convert_kohya_flux_lora_to_diffusers()`
334
- # are taken from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
335
- # All credits go to `kohya-ss`.
342
+ # are adapted from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
336
343
  def _convert_kohya_flux_lora_to_diffusers(state_dict):
337
344
  def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
338
345
  if sds_key + ".lora_down.weight" not in sds_sd:
@@ -341,7 +348,8 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
341
348
 
342
349
  # scale weight by alpha and dim
343
350
  rank = down_weight.shape[0]
344
- alpha = sds_sd.pop(sds_key + ".alpha").item() # alpha is scalar
351
+ default_alpha = torch.tensor(rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False)
352
+ alpha = sds_sd.pop(sds_key + ".alpha", default_alpha).item() # alpha is scalar
345
353
  scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
346
354
 
347
355
  # calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2
@@ -362,7 +370,10 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
362
370
  sd_lora_rank = down_weight.shape[0]
363
371
 
364
372
  # scale weight by alpha and dim
365
- alpha = sds_sd.pop(sds_key + ".alpha")
373
+ default_alpha = torch.tensor(
374
+ sd_lora_rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False
375
+ )
376
+ alpha = sds_sd.pop(sds_key + ".alpha", default_alpha)
366
377
  scale = alpha / sd_lora_rank
367
378
 
368
379
  # calculate scale_down and scale_up
@@ -516,10 +527,103 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
516
527
  f"transformer.single_transformer_blocks.{i}.norm.linear",
517
528
  )
518
529
 
530
+ # TODO: alphas.
531
+ def assign_remaining_weights(assignments, source):
532
+ for lora_key in ["lora_A", "lora_B"]:
533
+ orig_lora_key = "lora_down" if lora_key == "lora_A" else "lora_up"
534
+ for target_fmt, source_fmt, transform in assignments:
535
+ target_key = target_fmt.format(lora_key=lora_key)
536
+ source_key = source_fmt.format(orig_lora_key=orig_lora_key)
537
+ value = source.pop(source_key)
538
+ if transform:
539
+ value = transform(value)
540
+ ait_sd[target_key] = value
541
+
542
+ if any("guidance_in" in k for k in sds_sd):
543
+ assign_remaining_weights(
544
+ [
545
+ (
546
+ "time_text_embed.guidance_embedder.linear_1.{lora_key}.weight",
547
+ "lora_unet_guidance_in_in_layer.{orig_lora_key}.weight",
548
+ None,
549
+ ),
550
+ (
551
+ "time_text_embed.guidance_embedder.linear_2.{lora_key}.weight",
552
+ "lora_unet_guidance_in_out_layer.{orig_lora_key}.weight",
553
+ None,
554
+ ),
555
+ ],
556
+ sds_sd,
557
+ )
558
+
559
+ if any("img_in" in k for k in sds_sd):
560
+ assign_remaining_weights(
561
+ [
562
+ ("x_embedder.{lora_key}.weight", "lora_unet_img_in.{orig_lora_key}.weight", None),
563
+ ],
564
+ sds_sd,
565
+ )
566
+
567
+ if any("txt_in" in k for k in sds_sd):
568
+ assign_remaining_weights(
569
+ [
570
+ ("context_embedder.{lora_key}.weight", "lora_unet_txt_in.{orig_lora_key}.weight", None),
571
+ ],
572
+ sds_sd,
573
+ )
574
+
575
+ if any("time_in" in k for k in sds_sd):
576
+ assign_remaining_weights(
577
+ [
578
+ (
579
+ "time_text_embed.timestep_embedder.linear_1.{lora_key}.weight",
580
+ "lora_unet_time_in_in_layer.{orig_lora_key}.weight",
581
+ None,
582
+ ),
583
+ (
584
+ "time_text_embed.timestep_embedder.linear_2.{lora_key}.weight",
585
+ "lora_unet_time_in_out_layer.{orig_lora_key}.weight",
586
+ None,
587
+ ),
588
+ ],
589
+ sds_sd,
590
+ )
591
+
592
+ if any("vector_in" in k for k in sds_sd):
593
+ assign_remaining_weights(
594
+ [
595
+ (
596
+ "time_text_embed.text_embedder.linear_1.{lora_key}.weight",
597
+ "lora_unet_vector_in_in_layer.{orig_lora_key}.weight",
598
+ None,
599
+ ),
600
+ (
601
+ "time_text_embed.text_embedder.linear_2.{lora_key}.weight",
602
+ "lora_unet_vector_in_out_layer.{orig_lora_key}.weight",
603
+ None,
604
+ ),
605
+ ],
606
+ sds_sd,
607
+ )
608
+
609
+ if any("final_layer" in k for k in sds_sd):
610
+ # Notice the swap in processing for "final_layer".
611
+ assign_remaining_weights(
612
+ [
613
+ (
614
+ "norm_out.linear.{lora_key}.weight",
615
+ "lora_unet_final_layer_adaLN_modulation_1.{orig_lora_key}.weight",
616
+ swap_scale_shift,
617
+ ),
618
+ ("proj_out.{lora_key}.weight", "lora_unet_final_layer_linear.{orig_lora_key}.weight", None),
619
+ ],
620
+ sds_sd,
621
+ )
622
+
519
623
  remaining_keys = list(sds_sd.keys())
520
624
  te_state_dict = {}
521
625
  if remaining_keys:
522
- if not all(k.startswith("lora_te1") for k in remaining_keys):
626
+ if not all(k.startswith(("lora_te", "lora_te1")) for k in remaining_keys):
523
627
  raise ValueError(f"Incompatible keys detected: \n\n {', '.join(remaining_keys)}")
524
628
  for key in remaining_keys:
525
629
  if not key.endswith("lora_down.weight"):
@@ -558,6 +662,223 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
558
662
  new_state_dict = {**ait_sd, **te_state_dict}
559
663
  return new_state_dict
560
664
 
665
+ def _convert_mixture_state_dict_to_diffusers(state_dict):
666
+ new_state_dict = {}
667
+
668
+ def _convert(original_key, diffusers_key, state_dict, new_state_dict):
669
+ down_key = f"{original_key}.lora_down.weight"
670
+ down_weight = state_dict.pop(down_key)
671
+ lora_rank = down_weight.shape[0]
672
+
673
+ up_weight_key = f"{original_key}.lora_up.weight"
674
+ up_weight = state_dict.pop(up_weight_key)
675
+
676
+ alpha_key = f"{original_key}.alpha"
677
+ alpha = state_dict.pop(alpha_key)
678
+
679
+ # scale weight by alpha and dim
680
+ scale = alpha / lora_rank
681
+ # calculate scale_down and scale_up
682
+ scale_down = scale
683
+ scale_up = 1.0
684
+ while scale_down * 2 < scale_up:
685
+ scale_down *= 2
686
+ scale_up /= 2
687
+ down_weight = down_weight * scale_down
688
+ up_weight = up_weight * scale_up
689
+
690
+ diffusers_down_key = f"{diffusers_key}.lora_A.weight"
691
+ new_state_dict[diffusers_down_key] = down_weight
692
+ new_state_dict[diffusers_down_key.replace(".lora_A.", ".lora_B.")] = up_weight
693
+
694
+ all_unique_keys = {
695
+ k.replace(".lora_down.weight", "").replace(".lora_up.weight", "").replace(".alpha", "")
696
+ for k in state_dict
697
+ if not k.startswith(("lora_unet_"))
698
+ }
699
+ assert all(k.startswith(("lora_transformer_", "lora_te1_")) for k in all_unique_keys), f"{all_unique_keys=}"
700
+
701
+ has_te_keys = False
702
+ for k in all_unique_keys:
703
+ if k.startswith("lora_transformer_single_transformer_blocks_"):
704
+ i = int(k.split("lora_transformer_single_transformer_blocks_")[-1].split("_")[0])
705
+ diffusers_key = f"single_transformer_blocks.{i}"
706
+ elif k.startswith("lora_transformer_transformer_blocks_"):
707
+ i = int(k.split("lora_transformer_transformer_blocks_")[-1].split("_")[0])
708
+ diffusers_key = f"transformer_blocks.{i}"
709
+ elif k.startswith("lora_te1_"):
710
+ has_te_keys = True
711
+ continue
712
+ else:
713
+ raise NotImplementedError
714
+
715
+ if "attn_" in k:
716
+ if "_to_out_0" in k:
717
+ diffusers_key += ".attn.to_out.0"
718
+ elif "_to_add_out" in k:
719
+ diffusers_key += ".attn.to_add_out"
720
+ elif any(qkv in k for qkv in ["to_q", "to_k", "to_v"]):
721
+ remaining = k.split("attn_")[-1]
722
+ diffusers_key += f".attn.{remaining}"
723
+ elif any(add_qkv in k for add_qkv in ["add_q_proj", "add_k_proj", "add_v_proj"]):
724
+ remaining = k.split("attn_")[-1]
725
+ diffusers_key += f".attn.{remaining}"
726
+
727
+ _convert(k, diffusers_key, state_dict, new_state_dict)
728
+
729
+ if has_te_keys:
730
+ layer_pattern = re.compile(r"lora_te1_text_model_encoder_layers_(\d+)")
731
+ attn_mapping = {
732
+ "q_proj": ".self_attn.q_proj",
733
+ "k_proj": ".self_attn.k_proj",
734
+ "v_proj": ".self_attn.v_proj",
735
+ "out_proj": ".self_attn.out_proj",
736
+ }
737
+ mlp_mapping = {"fc1": ".mlp.fc1", "fc2": ".mlp.fc2"}
738
+ for k in all_unique_keys:
739
+ if not k.startswith("lora_te1_"):
740
+ continue
741
+
742
+ match = layer_pattern.search(k)
743
+ if not match:
744
+ continue
745
+ i = int(match.group(1))
746
+ diffusers_key = f"text_model.encoder.layers.{i}"
747
+
748
+ if "attn" in k:
749
+ for key_fragment, suffix in attn_mapping.items():
750
+ if key_fragment in k:
751
+ diffusers_key += suffix
752
+ break
753
+ elif "mlp" in k:
754
+ for key_fragment, suffix in mlp_mapping.items():
755
+ if key_fragment in k:
756
+ diffusers_key += suffix
757
+ break
758
+
759
+ _convert(k, diffusers_key, state_dict, new_state_dict)
760
+
761
+ remaining_all_unet = False
762
+ if state_dict:
763
+ remaining_all_unet = all(k.startswith("lora_unet_") for k in state_dict)
764
+ if remaining_all_unet:
765
+ keys = list(state_dict.keys())
766
+ for k in keys:
767
+ state_dict.pop(k)
768
+
769
+ if len(state_dict) > 0:
770
+ raise ValueError(
771
+ f"Expected an empty state dict at this point but its has these keys which couldn't be parsed: {list(state_dict.keys())}."
772
+ )
773
+
774
+ transformer_state_dict = {
775
+ f"transformer.{k}": v for k, v in new_state_dict.items() if not k.startswith("text_model.")
776
+ }
777
+ te_state_dict = {f"text_encoder.{k}": v for k, v in new_state_dict.items() if k.startswith("text_model.")}
778
+ return {**transformer_state_dict, **te_state_dict}
779
+
780
+ # This is weird.
781
+ # https://huggingface.co/sayakpaul/different-lora-from-civitai/tree/main?show_file_info=sharp_detailed_foot.safetensors
782
+ # has both `peft` and non-peft state dict.
783
+ has_peft_state_dict = any(k.startswith("transformer.") for k in state_dict)
784
+ if has_peft_state_dict:
785
+ state_dict = {k: v for k, v in state_dict.items() if k.startswith("transformer.")}
786
+ return state_dict
787
+
788
+ # Another weird one.
789
+ has_mixture = any(
790
+ k.startswith("lora_transformer_") and ("lora_down" in k or "lora_up" in k or "alpha" in k) for k in state_dict
791
+ )
792
+
793
+ # ComfyUI.
794
+ if not has_mixture:
795
+ state_dict = {k.replace("diffusion_model.", "lora_unet_"): v for k, v in state_dict.items()}
796
+ state_dict = {k.replace("text_encoders.clip_l.transformer.", "lora_te_"): v for k, v in state_dict.items()}
797
+
798
+ has_position_embedding = any("position_embedding" in k for k in state_dict)
799
+ if has_position_embedding:
800
+ zero_status_pe = state_dict_all_zero(state_dict, "position_embedding")
801
+ if zero_status_pe:
802
+ logger.info(
803
+ "The `position_embedding` LoRA params are all zeros which make them ineffective. "
804
+ "So, we will purge them out of the curret state dict to make loading possible."
805
+ )
806
+
807
+ else:
808
+ logger.info(
809
+ "The state_dict has position_embedding LoRA params and we currently do not support them. "
810
+ "Open an issue if you need this supported - https://github.com/huggingface/diffusers/issues/new."
811
+ )
812
+ state_dict = {k: v for k, v in state_dict.items() if "position_embedding" not in k}
813
+
814
+ has_t5xxl = any(k.startswith("text_encoders.t5xxl.transformer.") for k in state_dict)
815
+ if has_t5xxl:
816
+ zero_status_t5 = state_dict_all_zero(state_dict, "text_encoders.t5xxl")
817
+ if zero_status_t5:
818
+ logger.info(
819
+ "The `t5xxl` LoRA params are all zeros which make them ineffective. "
820
+ "So, we will purge them out of the curret state dict to make loading possible."
821
+ )
822
+ else:
823
+ logger.info(
824
+ "T5-xxl keys found in the state dict, which are currently unsupported. We will filter them out."
825
+ "Open an issue if this is a problem - https://github.com/huggingface/diffusers/issues/new."
826
+ )
827
+ state_dict = {k: v for k, v in state_dict.items() if not k.startswith("text_encoders.t5xxl.transformer.")}
828
+
829
+ has_diffb = any("diff_b" in k and k.startswith(("lora_unet_", "lora_te_")) for k in state_dict)
830
+ if has_diffb:
831
+ zero_status_diff_b = state_dict_all_zero(state_dict, ".diff_b")
832
+ if zero_status_diff_b:
833
+ logger.info(
834
+ "The `diff_b` LoRA params are all zeros which make them ineffective. "
835
+ "So, we will purge them out of the curret state dict to make loading possible."
836
+ )
837
+ else:
838
+ logger.info(
839
+ "`diff_b` keys found in the state dict which are currently unsupported. "
840
+ "So, we will filter out those keys. Open an issue if this is a problem - "
841
+ "https://github.com/huggingface/diffusers/issues/new."
842
+ )
843
+ state_dict = {k: v for k, v in state_dict.items() if ".diff_b" not in k}
844
+
845
+ has_norm_diff = any(".norm" in k and ".diff" in k for k in state_dict)
846
+ if has_norm_diff:
847
+ zero_status_diff = state_dict_all_zero(state_dict, ".diff")
848
+ if zero_status_diff:
849
+ logger.info(
850
+ "The `diff` LoRA params are all zeros which make them ineffective. "
851
+ "So, we will purge them out of the curret state dict to make loading possible."
852
+ )
853
+ else:
854
+ logger.info(
855
+ "Normalization diff keys found in the state dict which are currently unsupported. "
856
+ "So, we will filter out those keys. Open an issue if this is a problem - "
857
+ "https://github.com/huggingface/diffusers/issues/new."
858
+ )
859
+ state_dict = {k: v for k, v in state_dict.items() if ".norm" not in k and ".diff" not in k}
860
+
861
+ limit_substrings = ["lora_down", "lora_up"]
862
+ if any("alpha" in k for k in state_dict):
863
+ limit_substrings.append("alpha")
864
+
865
+ state_dict = {
866
+ _custom_replace(k, limit_substrings): v
867
+ for k, v in state_dict.items()
868
+ if k.startswith(("lora_unet_", "lora_te_"))
869
+ }
870
+
871
+ if any("text_projection" in k for k in state_dict):
872
+ logger.info(
873
+ "`text_projection` keys found in the `state_dict` which are unexpected. "
874
+ "So, we will filter out those keys. Open an issue if this is a problem - "
875
+ "https://github.com/huggingface/diffusers/issues/new."
876
+ )
877
+ state_dict = {k: v for k, v in state_dict.items() if "text_projection" not in k}
878
+
879
+ if has_mixture:
880
+ return _convert_mixture_state_dict_to_diffusers(state_dict)
881
+
561
882
  return _convert_sd_scripts_to_ai_toolkit(state_dict)
562
883
 
563
884
 
@@ -669,6 +990,26 @@ def _convert_xlabs_flux_lora_to_diffusers(old_state_dict):
669
990
  return new_state_dict
670
991
 
671
992
 
993
+ def _custom_replace(key: str, substrings: List[str]) -> str:
994
+ # Replaces the "."s with "_"s upto the `substrings`.
995
+ # Example:
996
+ # lora_unet.foo.bar.lora_A.weight -> lora_unet_foo_bar.lora_A.weight
997
+ pattern = "(" + "|".join(re.escape(sub) for sub in substrings) + ")"
998
+
999
+ match = re.search(pattern, key)
1000
+ if match:
1001
+ start_sub = match.start()
1002
+ if start_sub > 0 and key[start_sub - 1] == ".":
1003
+ boundary = start_sub - 1
1004
+ else:
1005
+ boundary = start_sub
1006
+ left = key[:boundary].replace(".", "_")
1007
+ right = key[boundary:]
1008
+ return left + right
1009
+ else:
1010
+ return key.replace(".", "_")
1011
+
1012
+
672
1013
  def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
673
1014
  converted_state_dict = {}
674
1015
  original_state_dict_keys = list(original_state_dict.keys())
@@ -677,28 +1018,23 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
677
1018
  inner_dim = 3072
678
1019
  mlp_ratio = 4.0
679
1020
 
680
- def swap_scale_shift(weight):
681
- shift, scale = weight.chunk(2, dim=0)
682
- new_weight = torch.cat([scale, shift], dim=0)
683
- return new_weight
684
-
685
1021
  for lora_key in ["lora_A", "lora_B"]:
686
1022
  ## time_text_embed.timestep_embedder <- time_in
687
- converted_state_dict[
688
- f"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight"
689
- ] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.weight")
1023
+ converted_state_dict[f"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight"] = (
1024
+ original_state_dict.pop(f"time_in.in_layer.{lora_key}.weight")
1025
+ )
690
1026
  if f"time_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
691
- converted_state_dict[
692
- f"time_text_embed.timestep_embedder.linear_1.{lora_key}.bias"
693
- ] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.bias")
1027
+ converted_state_dict[f"time_text_embed.timestep_embedder.linear_1.{lora_key}.bias"] = (
1028
+ original_state_dict.pop(f"time_in.in_layer.{lora_key}.bias")
1029
+ )
694
1030
 
695
- converted_state_dict[
696
- f"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight"
697
- ] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.weight")
1031
+ converted_state_dict[f"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight"] = (
1032
+ original_state_dict.pop(f"time_in.out_layer.{lora_key}.weight")
1033
+ )
698
1034
  if f"time_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
699
- converted_state_dict[
700
- f"time_text_embed.timestep_embedder.linear_2.{lora_key}.bias"
701
- ] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.bias")
1035
+ converted_state_dict[f"time_text_embed.timestep_embedder.linear_2.{lora_key}.bias"] = (
1036
+ original_state_dict.pop(f"time_in.out_layer.{lora_key}.bias")
1037
+ )
702
1038
 
703
1039
  ## time_text_embed.text_embedder <- vector_in
704
1040
  converted_state_dict[f"time_text_embed.text_embedder.linear_1.{lora_key}.weight"] = original_state_dict.pop(
@@ -720,21 +1056,21 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
720
1056
  # guidance
721
1057
  has_guidance = any("guidance" in k for k in original_state_dict)
722
1058
  if has_guidance:
723
- converted_state_dict[
724
- f"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight"
725
- ] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.weight")
1059
+ converted_state_dict[f"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight"] = (
1060
+ original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.weight")
1061
+ )
726
1062
  if f"guidance_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
727
- converted_state_dict[
728
- f"time_text_embed.guidance_embedder.linear_1.{lora_key}.bias"
729
- ] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.bias")
1063
+ converted_state_dict[f"time_text_embed.guidance_embedder.linear_1.{lora_key}.bias"] = (
1064
+ original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.bias")
1065
+ )
730
1066
 
731
- converted_state_dict[
732
- f"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight"
733
- ] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.weight")
1067
+ converted_state_dict[f"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight"] = (
1068
+ original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.weight")
1069
+ )
734
1070
  if f"guidance_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
735
- converted_state_dict[
736
- f"time_text_embed.guidance_embedder.linear_2.{lora_key}.bias"
737
- ] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.bias")
1071
+ converted_state_dict[f"time_text_embed.guidance_embedder.linear_2.{lora_key}.bias"] = (
1072
+ original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.bias")
1073
+ )
738
1074
 
739
1075
  # context_embedder
740
1076
  converted_state_dict[f"context_embedder.{lora_key}.weight"] = original_state_dict.pop(
@@ -1148,3 +1484,127 @@ def _convert_hunyuan_video_lora_to_diffusers(original_state_dict):
1148
1484
  converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
1149
1485
 
1150
1486
  return converted_state_dict
1487
+
1488
+
1489
+ def _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict):
1490
+ # Remove "diffusion_model." prefix from keys.
1491
+ state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()}
1492
+ converted_state_dict = {}
1493
+
1494
+ def get_num_layers(keys, pattern):
1495
+ layers = set()
1496
+ for key in keys:
1497
+ match = re.search(pattern, key)
1498
+ if match:
1499
+ layers.add(int(match.group(1)))
1500
+ return len(layers)
1501
+
1502
+ def process_block(prefix, index, convert_norm):
1503
+ # Process attention qkv: pop lora_A and lora_B weights.
1504
+ lora_down = state_dict.pop(f"{prefix}.{index}.attention.qkv.lora_A.weight")
1505
+ lora_up = state_dict.pop(f"{prefix}.{index}.attention.qkv.lora_B.weight")
1506
+ for attn_key in ["to_q", "to_k", "to_v"]:
1507
+ converted_state_dict[f"{prefix}.{index}.attn.{attn_key}.lora_A.weight"] = lora_down
1508
+ for attn_key, weight in zip(["to_q", "to_k", "to_v"], torch.split(lora_up, [2304, 768, 768], dim=0)):
1509
+ converted_state_dict[f"{prefix}.{index}.attn.{attn_key}.lora_B.weight"] = weight
1510
+
1511
+ # Process attention out weights.
1512
+ converted_state_dict[f"{prefix}.{index}.attn.to_out.0.lora_A.weight"] = state_dict.pop(
1513
+ f"{prefix}.{index}.attention.out.lora_A.weight"
1514
+ )
1515
+ converted_state_dict[f"{prefix}.{index}.attn.to_out.0.lora_B.weight"] = state_dict.pop(
1516
+ f"{prefix}.{index}.attention.out.lora_B.weight"
1517
+ )
1518
+
1519
+ # Process feed-forward weights for layers 1, 2, and 3.
1520
+ for layer in range(1, 4):
1521
+ converted_state_dict[f"{prefix}.{index}.feed_forward.linear_{layer}.lora_A.weight"] = state_dict.pop(
1522
+ f"{prefix}.{index}.feed_forward.w{layer}.lora_A.weight"
1523
+ )
1524
+ converted_state_dict[f"{prefix}.{index}.feed_forward.linear_{layer}.lora_B.weight"] = state_dict.pop(
1525
+ f"{prefix}.{index}.feed_forward.w{layer}.lora_B.weight"
1526
+ )
1527
+
1528
+ if convert_norm:
1529
+ converted_state_dict[f"{prefix}.{index}.norm1.linear.lora_A.weight"] = state_dict.pop(
1530
+ f"{prefix}.{index}.adaLN_modulation.1.lora_A.weight"
1531
+ )
1532
+ converted_state_dict[f"{prefix}.{index}.norm1.linear.lora_B.weight"] = state_dict.pop(
1533
+ f"{prefix}.{index}.adaLN_modulation.1.lora_B.weight"
1534
+ )
1535
+
1536
+ noise_refiner_pattern = r"noise_refiner\.(\d+)\."
1537
+ num_noise_refiner_layers = get_num_layers(state_dict.keys(), noise_refiner_pattern)
1538
+ for i in range(num_noise_refiner_layers):
1539
+ process_block("noise_refiner", i, convert_norm=True)
1540
+
1541
+ context_refiner_pattern = r"context_refiner\.(\d+)\."
1542
+ num_context_refiner_layers = get_num_layers(state_dict.keys(), context_refiner_pattern)
1543
+ for i in range(num_context_refiner_layers):
1544
+ process_block("context_refiner", i, convert_norm=False)
1545
+
1546
+ core_transformer_pattern = r"layers\.(\d+)\."
1547
+ num_core_transformer_layers = get_num_layers(state_dict.keys(), core_transformer_pattern)
1548
+ for i in range(num_core_transformer_layers):
1549
+ process_block("layers", i, convert_norm=True)
1550
+
1551
+ if len(state_dict) > 0:
1552
+ raise ValueError(f"`state_dict` should be empty at this point but has {state_dict.keys()=}")
1553
+
1554
+ for key in list(converted_state_dict.keys()):
1555
+ converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
1556
+
1557
+ return converted_state_dict
1558
+
1559
+
1560
+ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
1561
+ converted_state_dict = {}
1562
+ original_state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()}
1563
+
1564
+ num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in original_state_dict})
1565
+ is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict)
1566
+
1567
+ for i in range(num_blocks):
1568
+ # Self-attention
1569
+ for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
1570
+ converted_state_dict[f"blocks.{i}.attn1.{c}.lora_A.weight"] = original_state_dict.pop(
1571
+ f"blocks.{i}.self_attn.{o}.lora_A.weight"
1572
+ )
1573
+ converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.weight"] = original_state_dict.pop(
1574
+ f"blocks.{i}.self_attn.{o}.lora_B.weight"
1575
+ )
1576
+
1577
+ # Cross-attention
1578
+ for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
1579
+ converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop(
1580
+ f"blocks.{i}.cross_attn.{o}.lora_A.weight"
1581
+ )
1582
+ converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
1583
+ f"blocks.{i}.cross_attn.{o}.lora_B.weight"
1584
+ )
1585
+
1586
+ if is_i2v_lora:
1587
+ for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
1588
+ converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop(
1589
+ f"blocks.{i}.cross_attn.{o}.lora_A.weight"
1590
+ )
1591
+ converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
1592
+ f"blocks.{i}.cross_attn.{o}.lora_B.weight"
1593
+ )
1594
+
1595
+ # FFN
1596
+ for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]):
1597
+ converted_state_dict[f"blocks.{i}.ffn.{c}.lora_A.weight"] = original_state_dict.pop(
1598
+ f"blocks.{i}.{o}.lora_A.weight"
1599
+ )
1600
+ converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.weight"] = original_state_dict.pop(
1601
+ f"blocks.{i}.{o}.lora_B.weight"
1602
+ )
1603
+
1604
+ if len(original_state_dict) > 0:
1605
+ raise ValueError(f"`state_dict` should be empty at this point but has {original_state_dict.keys()=}")
1606
+
1607
+ for key in list(converted_state_dict.keys()):
1608
+ converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
1609
+
1610
+ return converted_state_dict