diffusers 0.33.1__py3-none-any.whl → 0.34.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (478) hide show
  1. diffusers/__init__.py +48 -1
  2. diffusers/commands/__init__.py +1 -1
  3. diffusers/commands/diffusers_cli.py +1 -1
  4. diffusers/commands/env.py +1 -1
  5. diffusers/commands/fp16_safetensors.py +1 -1
  6. diffusers/dependency_versions_check.py +1 -1
  7. diffusers/dependency_versions_table.py +1 -1
  8. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  9. diffusers/hooks/faster_cache.py +2 -2
  10. diffusers/hooks/group_offloading.py +128 -29
  11. diffusers/hooks/hooks.py +2 -2
  12. diffusers/hooks/layerwise_casting.py +3 -3
  13. diffusers/hooks/pyramid_attention_broadcast.py +1 -1
  14. diffusers/image_processor.py +7 -2
  15. diffusers/loaders/__init__.py +4 -0
  16. diffusers/loaders/ip_adapter.py +5 -14
  17. diffusers/loaders/lora_base.py +212 -111
  18. diffusers/loaders/lora_conversion_utils.py +275 -34
  19. diffusers/loaders/lora_pipeline.py +1554 -819
  20. diffusers/loaders/peft.py +52 -109
  21. diffusers/loaders/single_file.py +2 -2
  22. diffusers/loaders/single_file_model.py +20 -4
  23. diffusers/loaders/single_file_utils.py +225 -5
  24. diffusers/loaders/textual_inversion.py +3 -2
  25. diffusers/loaders/transformer_flux.py +1 -1
  26. diffusers/loaders/transformer_sd3.py +2 -2
  27. diffusers/loaders/unet.py +2 -16
  28. diffusers/loaders/unet_loader_utils.py +1 -1
  29. diffusers/loaders/utils.py +1 -1
  30. diffusers/models/__init__.py +15 -1
  31. diffusers/models/activations.py +5 -5
  32. diffusers/models/adapter.py +2 -3
  33. diffusers/models/attention.py +4 -4
  34. diffusers/models/attention_flax.py +10 -10
  35. diffusers/models/attention_processor.py +14 -10
  36. diffusers/models/auto_model.py +47 -10
  37. diffusers/models/autoencoders/__init__.py +1 -0
  38. diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
  39. diffusers/models/autoencoders/autoencoder_dc.py +3 -3
  40. diffusers/models/autoencoders/autoencoder_kl.py +4 -4
  41. diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
  42. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
  43. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1108 -0
  44. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
  45. diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
  46. diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
  47. diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
  48. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
  49. diffusers/models/autoencoders/autoencoder_kl_wan.py +256 -22
  50. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
  51. diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
  52. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  53. diffusers/models/autoencoders/vae.py +13 -2
  54. diffusers/models/autoencoders/vq_model.py +2 -2
  55. diffusers/models/cache_utils.py +1 -1
  56. diffusers/models/controlnet.py +1 -1
  57. diffusers/models/controlnet_flux.py +1 -1
  58. diffusers/models/controlnet_sd3.py +1 -1
  59. diffusers/models/controlnet_sparsectrl.py +1 -1
  60. diffusers/models/controlnets/__init__.py +1 -0
  61. diffusers/models/controlnets/controlnet.py +3 -3
  62. diffusers/models/controlnets/controlnet_flax.py +1 -1
  63. diffusers/models/controlnets/controlnet_flux.py +16 -15
  64. diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
  65. diffusers/models/controlnets/controlnet_sana.py +290 -0
  66. diffusers/models/controlnets/controlnet_sd3.py +1 -1
  67. diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
  68. diffusers/models/controlnets/controlnet_union.py +1 -1
  69. diffusers/models/controlnets/controlnet_xs.py +7 -7
  70. diffusers/models/controlnets/multicontrolnet.py +4 -5
  71. diffusers/models/controlnets/multicontrolnet_union.py +5 -6
  72. diffusers/models/downsampling.py +2 -2
  73. diffusers/models/embeddings.py +10 -12
  74. diffusers/models/embeddings_flax.py +2 -2
  75. diffusers/models/lora.py +3 -3
  76. diffusers/models/modeling_utils.py +44 -14
  77. diffusers/models/normalization.py +4 -4
  78. diffusers/models/resnet.py +2 -2
  79. diffusers/models/resnet_flax.py +1 -1
  80. diffusers/models/transformers/__init__.py +5 -0
  81. diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
  82. diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
  83. diffusers/models/transformers/consisid_transformer_3d.py +1 -1
  84. diffusers/models/transformers/dit_transformer_2d.py +2 -2
  85. diffusers/models/transformers/dual_transformer_2d.py +1 -1
  86. diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
  87. diffusers/models/transformers/latte_transformer_3d.py +4 -5
  88. diffusers/models/transformers/lumina_nextdit2d.py +2 -2
  89. diffusers/models/transformers/pixart_transformer_2d.py +3 -3
  90. diffusers/models/transformers/prior_transformer.py +1 -1
  91. diffusers/models/transformers/sana_transformer.py +8 -3
  92. diffusers/models/transformers/stable_audio_transformer.py +5 -9
  93. diffusers/models/transformers/t5_film_transformer.py +3 -3
  94. diffusers/models/transformers/transformer_2d.py +1 -1
  95. diffusers/models/transformers/transformer_allegro.py +1 -1
  96. diffusers/models/transformers/transformer_chroma.py +742 -0
  97. diffusers/models/transformers/transformer_cogview3plus.py +5 -10
  98. diffusers/models/transformers/transformer_cogview4.py +317 -25
  99. diffusers/models/transformers/transformer_cosmos.py +579 -0
  100. diffusers/models/transformers/transformer_flux.py +9 -11
  101. diffusers/models/transformers/transformer_hidream_image.py +942 -0
  102. diffusers/models/transformers/transformer_hunyuan_video.py +6 -8
  103. diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
  104. diffusers/models/transformers/transformer_ltx.py +2 -2
  105. diffusers/models/transformers/transformer_lumina2.py +1 -1
  106. diffusers/models/transformers/transformer_mochi.py +1 -1
  107. diffusers/models/transformers/transformer_omnigen.py +2 -2
  108. diffusers/models/transformers/transformer_sd3.py +7 -7
  109. diffusers/models/transformers/transformer_temporal.py +1 -1
  110. diffusers/models/transformers/transformer_wan.py +24 -8
  111. diffusers/models/transformers/transformer_wan_vace.py +393 -0
  112. diffusers/models/unets/unet_1d.py +1 -1
  113. diffusers/models/unets/unet_1d_blocks.py +1 -1
  114. diffusers/models/unets/unet_2d.py +1 -1
  115. diffusers/models/unets/unet_2d_blocks.py +1 -1
  116. diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
  117. diffusers/models/unets/unet_2d_condition.py +2 -2
  118. diffusers/models/unets/unet_2d_condition_flax.py +2 -2
  119. diffusers/models/unets/unet_3d_blocks.py +1 -1
  120. diffusers/models/unets/unet_3d_condition.py +3 -3
  121. diffusers/models/unets/unet_i2vgen_xl.py +3 -3
  122. diffusers/models/unets/unet_kandinsky3.py +1 -1
  123. diffusers/models/unets/unet_motion_model.py +2 -2
  124. diffusers/models/unets/unet_stable_cascade.py +1 -1
  125. diffusers/models/upsampling.py +2 -2
  126. diffusers/models/vae_flax.py +2 -2
  127. diffusers/models/vq_model.py +1 -1
  128. diffusers/pipelines/__init__.py +37 -6
  129. diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
  130. diffusers/pipelines/amused/pipeline_amused.py +7 -6
  131. diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
  132. diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
  133. diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
  134. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
  135. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
  136. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
  137. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
  138. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
  139. diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
  140. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  141. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +23 -13
  142. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
  143. diffusers/pipelines/auto_pipeline.py +6 -7
  144. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  145. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
  146. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
  147. diffusers/pipelines/chroma/__init__.py +49 -0
  148. diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
  149. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
  150. diffusers/pipelines/chroma/pipeline_output.py +21 -0
  151. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +8 -8
  152. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +8 -8
  153. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +8 -8
  154. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +8 -8
  155. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
  156. diffusers/pipelines/cogview4/pipeline_cogview4.py +7 -7
  157. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
  158. diffusers/pipelines/consisid/consisid_utils.py +2 -2
  159. diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
  160. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  161. diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
  162. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +8 -8
  163. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
  164. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
  165. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
  166. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
  167. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
  168. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +14 -14
  169. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +5 -5
  170. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +13 -13
  171. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  172. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
  173. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
  174. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  175. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
  176. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
  177. diffusers/pipelines/cosmos/__init__.py +54 -0
  178. diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
  179. diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
  180. diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
  181. diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
  182. diffusers/pipelines/cosmos/pipeline_output.py +40 -0
  183. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
  184. diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
  185. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
  186. diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
  187. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
  188. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
  189. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
  190. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
  191. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
  192. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
  193. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
  194. diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
  195. diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
  196. diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
  197. diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
  198. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
  199. diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
  200. diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
  201. diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
  202. diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
  203. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
  204. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +7 -7
  205. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
  206. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
  207. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
  208. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
  209. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
  210. diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
  211. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
  212. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
  213. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
  214. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
  215. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
  216. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
  217. diffusers/pipelines/dit/pipeline_dit.py +1 -1
  218. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
  219. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
  220. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
  221. diffusers/pipelines/flux/modeling_flux.py +1 -1
  222. diffusers/pipelines/flux/pipeline_flux.py +10 -17
  223. diffusers/pipelines/flux/pipeline_flux_control.py +6 -6
  224. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -6
  225. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +6 -6
  226. diffusers/pipelines/flux/pipeline_flux_controlnet.py +6 -6
  227. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +30 -22
  228. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +2 -1
  229. diffusers/pipelines/flux/pipeline_flux_fill.py +6 -6
  230. diffusers/pipelines/flux/pipeline_flux_img2img.py +39 -6
  231. diffusers/pipelines/flux/pipeline_flux_inpaint.py +11 -6
  232. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
  233. diffusers/pipelines/free_init_utils.py +2 -2
  234. diffusers/pipelines/free_noise_utils.py +3 -3
  235. diffusers/pipelines/hidream_image/__init__.py +47 -0
  236. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
  237. diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
  238. diffusers/pipelines/hunyuan_video/__init__.py +2 -0
  239. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
  240. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +8 -8
  241. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
  242. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
  243. diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
  244. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
  245. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
  246. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
  247. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
  248. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
  249. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
  250. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
  251. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  252. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
  253. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
  254. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
  255. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
  256. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
  257. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
  258. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
  259. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
  260. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
  261. diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
  262. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
  263. diffusers/pipelines/kolors/text_encoder.py +3 -3
  264. diffusers/pipelines/kolors/tokenizer.py +1 -1
  265. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
  266. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
  267. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  268. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
  269. diffusers/pipelines/latte/pipeline_latte.py +12 -12
  270. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
  271. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
  272. diffusers/pipelines/ltx/__init__.py +4 -0
  273. diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
  274. diffusers/pipelines/ltx/pipeline_ltx.py +51 -6
  275. diffusers/pipelines/ltx/pipeline_ltx_condition.py +107 -29
  276. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +50 -6
  277. diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
  278. diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
  279. diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
  280. diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
  281. diffusers/pipelines/mochi/pipeline_mochi.py +6 -6
  282. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
  283. diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
  284. diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
  285. diffusers/pipelines/onnx_utils.py +15 -2
  286. diffusers/pipelines/pag/pag_utils.py +2 -2
  287. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
  288. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
  289. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
  290. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
  291. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
  292. diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
  293. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
  294. diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
  295. diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
  296. diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
  297. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
  298. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
  299. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
  300. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
  301. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
  302. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
  303. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
  304. diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
  305. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
  306. diffusers/pipelines/pia/pipeline_pia.py +8 -6
  307. diffusers/pipelines/pipeline_flax_utils.py +3 -4
  308. diffusers/pipelines/pipeline_loading_utils.py +89 -13
  309. diffusers/pipelines/pipeline_utils.py +105 -33
  310. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +11 -11
  311. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +11 -11
  312. diffusers/pipelines/sana/__init__.py +4 -0
  313. diffusers/pipelines/sana/pipeline_sana.py +23 -21
  314. diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
  315. diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
  316. diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
  317. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
  318. diffusers/pipelines/shap_e/camera.py +1 -1
  319. diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
  320. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
  321. diffusers/pipelines/shap_e/renderer.py +3 -3
  322. diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
  323. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
  324. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
  325. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
  326. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
  327. diffusers/pipelines/stable_diffusion/__init__.py +0 -7
  328. diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
  329. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
  330. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  331. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
  332. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
  333. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +10 -10
  334. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
  335. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +10 -10
  336. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +9 -9
  337. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +8 -8
  338. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
  339. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
  340. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
  341. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
  342. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
  343. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
  344. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
  345. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
  346. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
  347. diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
  348. diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
  349. diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
  350. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -7
  351. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
  352. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
  353. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
  354. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
  355. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
  356. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
  357. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
  358. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
  359. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
  360. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
  361. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
  362. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  363. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
  364. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  365. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
  366. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
  367. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
  368. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
  369. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
  370. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
  371. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
  372. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
  373. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
  374. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
  375. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
  376. diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
  377. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
  378. diffusers/pipelines/unclip/text_proj.py +2 -2
  379. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
  380. diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
  381. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
  382. diffusers/pipelines/visualcloze/__init__.py +52 -0
  383. diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
  384. diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
  385. diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
  386. diffusers/pipelines/wan/__init__.py +2 -0
  387. diffusers/pipelines/wan/pipeline_wan.py +13 -10
  388. diffusers/pipelines/wan/pipeline_wan_i2v.py +38 -18
  389. diffusers/pipelines/wan/pipeline_wan_vace.py +976 -0
  390. diffusers/pipelines/wan/pipeline_wan_video2video.py +14 -16
  391. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  392. diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
  393. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  394. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
  395. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
  396. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
  397. diffusers/quantizers/__init__.py +179 -1
  398. diffusers/quantizers/base.py +6 -1
  399. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
  400. diffusers/quantizers/bitsandbytes/utils.py +10 -7
  401. diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
  402. diffusers/quantizers/gguf/utils.py +16 -13
  403. diffusers/quantizers/quantization_config.py +18 -16
  404. diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
  405. diffusers/quantizers/torchao/torchao_quantizer.py +5 -1
  406. diffusers/schedulers/__init__.py +3 -1
  407. diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
  408. diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
  409. diffusers/schedulers/scheduling_consistency_models.py +1 -1
  410. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
  411. diffusers/schedulers/scheduling_ddim.py +8 -8
  412. diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
  413. diffusers/schedulers/scheduling_ddim_flax.py +6 -6
  414. diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
  415. diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
  416. diffusers/schedulers/scheduling_ddpm.py +9 -9
  417. diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
  418. diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
  419. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
  420. diffusers/schedulers/scheduling_deis_multistep.py +8 -8
  421. diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
  422. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -12
  423. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
  424. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
  425. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
  426. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +13 -13
  427. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
  428. diffusers/schedulers/scheduling_edm_euler.py +20 -11
  429. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
  430. diffusers/schedulers/scheduling_euler_discrete.py +3 -3
  431. diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
  432. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
  433. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
  434. diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
  435. diffusers/schedulers/scheduling_heun_discrete.py +2 -2
  436. diffusers/schedulers/scheduling_ipndm.py +2 -2
  437. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
  438. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
  439. diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
  440. diffusers/schedulers/scheduling_lcm.py +3 -3
  441. diffusers/schedulers/scheduling_lms_discrete.py +2 -2
  442. diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
  443. diffusers/schedulers/scheduling_pndm.py +4 -4
  444. diffusers/schedulers/scheduling_pndm_flax.py +4 -4
  445. diffusers/schedulers/scheduling_repaint.py +9 -9
  446. diffusers/schedulers/scheduling_sasolver.py +15 -15
  447. diffusers/schedulers/scheduling_scm.py +1 -1
  448. diffusers/schedulers/scheduling_sde_ve.py +1 -1
  449. diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
  450. diffusers/schedulers/scheduling_tcd.py +3 -3
  451. diffusers/schedulers/scheduling_unclip.py +5 -5
  452. diffusers/schedulers/scheduling_unipc_multistep.py +11 -11
  453. diffusers/schedulers/scheduling_utils.py +1 -1
  454. diffusers/schedulers/scheduling_utils_flax.py +1 -1
  455. diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
  456. diffusers/training_utils.py +13 -5
  457. diffusers/utils/__init__.py +5 -0
  458. diffusers/utils/accelerate_utils.py +1 -1
  459. diffusers/utils/doc_utils.py +1 -1
  460. diffusers/utils/dummy_pt_objects.py +120 -0
  461. diffusers/utils/dummy_torch_and_transformers_objects.py +225 -0
  462. diffusers/utils/dynamic_modules_utils.py +21 -3
  463. diffusers/utils/export_utils.py +1 -1
  464. diffusers/utils/import_utils.py +81 -18
  465. diffusers/utils/logging.py +1 -1
  466. diffusers/utils/outputs.py +2 -1
  467. diffusers/utils/peft_utils.py +91 -8
  468. diffusers/utils/state_dict_utils.py +20 -3
  469. diffusers/utils/testing_utils.py +59 -7
  470. diffusers/utils/torch_utils.py +25 -5
  471. diffusers/video_processor.py +2 -2
  472. {diffusers-0.33.1.dist-info → diffusers-0.34.0.dist-info}/METADATA +70 -55
  473. diffusers-0.34.0.dist-info/RECORD +639 -0
  474. {diffusers-0.33.1.dist-info → diffusers-0.34.0.dist-info}/WHEEL +1 -1
  475. diffusers-0.33.1.dist-info/RECORD +0 -608
  476. {diffusers-0.33.1.dist-info → diffusers-0.34.0.dist-info}/LICENSE +0 -0
  477. {diffusers-0.33.1.dist-info → diffusers-0.34.0.dist-info}/entry_points.txt +0 -0
  478. {diffusers-0.33.1.dist-info → diffusers-0.34.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -33,6 +33,24 @@ def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", b
33
33
  # 1. get all state_dict_keys
34
34
  all_keys = list(state_dict.keys())
35
35
  sgm_patterns = ["input_blocks", "middle_block", "output_blocks"]
36
+ not_sgm_patterns = ["down_blocks", "mid_block", "up_blocks"]
37
+
38
+ # check if state_dict contains both patterns
39
+ contains_sgm_patterns = False
40
+ contains_not_sgm_patterns = False
41
+ for key in all_keys:
42
+ if any(p in key for p in sgm_patterns):
43
+ contains_sgm_patterns = True
44
+ elif any(p in key for p in not_sgm_patterns):
45
+ contains_not_sgm_patterns = True
46
+
47
+ # if state_dict contains both patterns, remove sgm
48
+ # we can then return state_dict immediately
49
+ if contains_sgm_patterns and contains_not_sgm_patterns:
50
+ for key in all_keys:
51
+ if any(p in key for p in sgm_patterns):
52
+ state_dict.pop(key)
53
+ return state_dict
36
54
 
37
55
  # 2. check if needs remapping, if not return original dict
38
56
  is_in_sgm_format = False
@@ -126,7 +144,7 @@ def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", b
126
144
  )
127
145
  new_state_dict[new_key] = state_dict.pop(key)
128
146
 
129
- if len(state_dict) > 0:
147
+ if state_dict:
130
148
  raise ValueError("At this point all state dict entries have to be converted.")
131
149
 
132
150
  return new_state_dict
@@ -415,7 +433,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
415
433
  ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
416
434
  if not is_sparse:
417
435
  # down_weight is copied to each split
418
- ait_sd.update({k: down_weight for k in ait_down_keys})
436
+ ait_sd.update(dict.fromkeys(ait_down_keys, down_weight))
419
437
 
420
438
  # up_weight is split to each split
421
439
  ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
@@ -709,8 +727,25 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
709
727
  elif k.startswith("lora_te1_"):
710
728
  has_te_keys = True
711
729
  continue
730
+ elif k.startswith("lora_transformer_context_embedder"):
731
+ diffusers_key = "context_embedder"
732
+ elif k.startswith("lora_transformer_norm_out_linear"):
733
+ diffusers_key = "norm_out.linear"
734
+ elif k.startswith("lora_transformer_proj_out"):
735
+ diffusers_key = "proj_out"
736
+ elif k.startswith("lora_transformer_x_embedder"):
737
+ diffusers_key = "x_embedder"
738
+ elif k.startswith("lora_transformer_time_text_embed_guidance_embedder_linear_"):
739
+ i = int(k.split("lora_transformer_time_text_embed_guidance_embedder_linear_")[-1])
740
+ diffusers_key = f"time_text_embed.guidance_embedder.linear_{i}"
741
+ elif k.startswith("lora_transformer_time_text_embed_text_embedder_linear_"):
742
+ i = int(k.split("lora_transformer_time_text_embed_text_embedder_linear_")[-1])
743
+ diffusers_key = f"time_text_embed.text_embedder.linear_{i}"
744
+ elif k.startswith("lora_transformer_time_text_embed_timestep_embedder_linear_"):
745
+ i = int(k.split("lora_transformer_time_text_embed_timestep_embedder_linear_")[-1])
746
+ diffusers_key = f"time_text_embed.timestep_embedder.linear_{i}"
712
747
  else:
713
- raise NotImplementedError
748
+ raise NotImplementedError(f"Handling for key ({k}) is not implemented.")
714
749
 
715
750
  if "attn_" in k:
716
751
  if "_to_out_0" in k:
@@ -801,7 +836,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
801
836
  if zero_status_pe:
802
837
  logger.info(
803
838
  "The `position_embedding` LoRA params are all zeros which make them ineffective. "
804
- "So, we will purge them out of the curret state dict to make loading possible."
839
+ "So, we will purge them out of the current state dict to make loading possible."
805
840
  )
806
841
 
807
842
  else:
@@ -817,7 +852,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
817
852
  if zero_status_t5:
818
853
  logger.info(
819
854
  "The `t5xxl` LoRA params are all zeros which make them ineffective. "
820
- "So, we will purge them out of the curret state dict to make loading possible."
855
+ "So, we will purge them out of the current state dict to make loading possible."
821
856
  )
822
857
  else:
823
858
  logger.info(
@@ -832,7 +867,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
832
867
  if zero_status_diff_b:
833
868
  logger.info(
834
869
  "The `diff_b` LoRA params are all zeros which make them ineffective. "
835
- "So, we will purge them out of the curret state dict to make loading possible."
870
+ "So, we will purge them out of the current state dict to make loading possible."
836
871
  )
837
872
  else:
838
873
  logger.info(
@@ -848,7 +883,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
848
883
  if zero_status_diff:
849
884
  logger.info(
850
885
  "The `diff` LoRA params are all zeros which make them ineffective. "
851
- "So, we will purge them out of the curret state dict to make loading possible."
886
+ "So, we will purge them out of the current state dict to make loading possible."
852
887
  )
853
888
  else:
854
889
  logger.info(
@@ -905,7 +940,7 @@ def _convert_xlabs_flux_lora_to_diffusers(old_state_dict):
905
940
  ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
906
941
 
907
942
  # down_weight is copied to each split
908
- ait_sd.update({k: down_weight for k in ait_down_keys})
943
+ ait_sd.update(dict.fromkeys(ait_down_keys, down_weight))
909
944
 
910
945
  # up_weight is split to each split
911
946
  ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
@@ -1219,7 +1254,7 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
1219
1254
  f"double_blocks.{i}.txt_attn.norm.key_norm.scale"
1220
1255
  )
1221
1256
 
1222
- # single transfomer blocks
1257
+ # single transformer blocks
1223
1258
  for i in range(num_single_layers):
1224
1259
  block_prefix = f"single_transformer_blocks.{i}."
1225
1260
 
@@ -1561,45 +1596,235 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
1561
1596
  converted_state_dict = {}
1562
1597
  original_state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()}
1563
1598
 
1564
- num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in original_state_dict})
1599
+ block_numbers = {int(k.split(".")[1]) for k in original_state_dict if k.startswith("blocks.")}
1600
+ min_block = min(block_numbers)
1601
+ max_block = max(block_numbers)
1602
+
1565
1603
  is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict)
1604
+ lora_down_key = "lora_A" if any("lora_A" in k for k in original_state_dict) else "lora_down"
1605
+ lora_up_key = "lora_B" if any("lora_B" in k for k in original_state_dict) else "lora_up"
1606
+
1607
+ diff_keys = [k for k in original_state_dict if k.endswith((".diff_b", ".diff"))]
1608
+ if diff_keys:
1609
+ for diff_k in diff_keys:
1610
+ param = original_state_dict[diff_k]
1611
+ # The magnitudes of the .diff-ending weights are very low (most are below 1e-4, some are upto 1e-3,
1612
+ # and 2 of them are about 1.6e-2 [the case with AccVideo lora]). The low magnitudes mostly correspond
1613
+ # to norm layers. Ignoring them is the best option at the moment until a better solution is found. It
1614
+ # is okay to ignore because they do not affect the model output in a significant manner.
1615
+ threshold = 1.6e-2
1616
+ absdiff = param.abs().max() - param.abs().min()
1617
+ all_zero = torch.all(param == 0).item()
1618
+ all_absdiff_lower_than_threshold = absdiff < threshold
1619
+ if all_zero or all_absdiff_lower_than_threshold:
1620
+ logger.debug(
1621
+ f"Removed {diff_k} key from the state dict as it's all zeros, or values lower than hardcoded threshold."
1622
+ )
1623
+ original_state_dict.pop(diff_k)
1566
1624
 
1567
- for i in range(num_blocks):
1625
+ # For the `diff_b` keys, we treat them as lora_bias.
1626
+ # https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.lora_bias
1627
+
1628
+ for i in range(min_block, max_block + 1):
1568
1629
  # Self-attention
1569
1630
  for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
1570
- converted_state_dict[f"blocks.{i}.attn1.{c}.lora_A.weight"] = original_state_dict.pop(
1571
- f"blocks.{i}.self_attn.{o}.lora_A.weight"
1572
- )
1573
- converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.weight"] = original_state_dict.pop(
1574
- f"blocks.{i}.self_attn.{o}.lora_B.weight"
1575
- )
1631
+ original_key = f"blocks.{i}.self_attn.{o}.{lora_down_key}.weight"
1632
+ converted_key = f"blocks.{i}.attn1.{c}.lora_A.weight"
1633
+ if original_key in original_state_dict:
1634
+ converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1635
+
1636
+ original_key = f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight"
1637
+ converted_key = f"blocks.{i}.attn1.{c}.lora_B.weight"
1638
+ if original_key in original_state_dict:
1639
+ converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1640
+
1641
+ original_key = f"blocks.{i}.self_attn.{o}.diff_b"
1642
+ converted_key = f"blocks.{i}.attn1.{c}.lora_B.bias"
1643
+ if original_key in original_state_dict:
1644
+ converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1576
1645
 
1577
1646
  # Cross-attention
1578
1647
  for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
1579
- converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop(
1580
- f"blocks.{i}.cross_attn.{o}.lora_A.weight"
1581
- )
1582
- converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
1583
- f"blocks.{i}.cross_attn.{o}.lora_B.weight"
1584
- )
1648
+ original_key = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
1649
+ converted_key = f"blocks.{i}.attn2.{c}.lora_A.weight"
1650
+ if original_key in original_state_dict:
1651
+ converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1652
+
1653
+ original_key = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
1654
+ converted_key = f"blocks.{i}.attn2.{c}.lora_B.weight"
1655
+ if original_key in original_state_dict:
1656
+ converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1657
+
1658
+ original_key = f"blocks.{i}.cross_attn.{o}.diff_b"
1659
+ converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias"
1660
+ if original_key in original_state_dict:
1661
+ converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1585
1662
 
1586
1663
  if is_i2v_lora:
1587
1664
  for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
1588
- converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop(
1589
- f"blocks.{i}.cross_attn.{o}.lora_A.weight"
1590
- )
1591
- converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
1592
- f"blocks.{i}.cross_attn.{o}.lora_B.weight"
1593
- )
1665
+ original_key = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
1666
+ converted_key = f"blocks.{i}.attn2.{c}.lora_A.weight"
1667
+ if original_key in original_state_dict:
1668
+ converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1669
+
1670
+ original_key = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
1671
+ converted_key = f"blocks.{i}.attn2.{c}.lora_B.weight"
1672
+ if original_key in original_state_dict:
1673
+ converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1674
+
1675
+ original_key = f"blocks.{i}.cross_attn.{o}.diff_b"
1676
+ converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias"
1677
+ if original_key in original_state_dict:
1678
+ converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1594
1679
 
1595
1680
  # FFN
1596
1681
  for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]):
1597
- converted_state_dict[f"blocks.{i}.ffn.{c}.lora_A.weight"] = original_state_dict.pop(
1598
- f"blocks.{i}.{o}.lora_A.weight"
1682
+ original_key = f"blocks.{i}.{o}.{lora_down_key}.weight"
1683
+ converted_key = f"blocks.{i}.ffn.{c}.lora_A.weight"
1684
+ if original_key in original_state_dict:
1685
+ converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1686
+
1687
+ original_key = f"blocks.{i}.{o}.{lora_up_key}.weight"
1688
+ converted_key = f"blocks.{i}.ffn.{c}.lora_B.weight"
1689
+ if original_key in original_state_dict:
1690
+ converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1691
+
1692
+ original_key = f"blocks.{i}.{o}.diff_b"
1693
+ converted_key = f"blocks.{i}.ffn.{c}.lora_B.bias"
1694
+ if original_key in original_state_dict:
1695
+ converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1696
+
1697
+ # Remaining.
1698
+ if original_state_dict:
1699
+ if any("time_projection" in k for k in original_state_dict):
1700
+ original_key = f"time_projection.1.{lora_down_key}.weight"
1701
+ converted_key = "condition_embedder.time_proj.lora_A.weight"
1702
+ if original_key in original_state_dict:
1703
+ converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1704
+
1705
+ original_key = f"time_projection.1.{lora_up_key}.weight"
1706
+ converted_key = "condition_embedder.time_proj.lora_B.weight"
1707
+ if original_key in original_state_dict:
1708
+ converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1709
+
1710
+ if "time_projection.1.diff_b" in original_state_dict:
1711
+ converted_state_dict["condition_embedder.time_proj.lora_B.bias"] = original_state_dict.pop(
1712
+ "time_projection.1.diff_b"
1713
+ )
1714
+
1715
+ if any("head.head" in k for k in state_dict):
1716
+ converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop(
1717
+ f"head.head.{lora_down_key}.weight"
1599
1718
  )
1600
- converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.weight"] = original_state_dict.pop(
1601
- f"blocks.{i}.{o}.lora_B.weight"
1719
+ converted_state_dict["proj_out.lora_B.weight"] = original_state_dict.pop(f"head.head.{lora_up_key}.weight")
1720
+ if "head.head.diff_b" in original_state_dict:
1721
+ converted_state_dict["proj_out.lora_B.bias"] = original_state_dict.pop("head.head.diff_b")
1722
+
1723
+ for text_time in ["text_embedding", "time_embedding"]:
1724
+ if any(text_time in k for k in original_state_dict):
1725
+ for b_n in [0, 2]:
1726
+ diffusers_b_n = 1 if b_n == 0 else 2
1727
+ diffusers_name = (
1728
+ "condition_embedder.text_embedder"
1729
+ if text_time == "text_embedding"
1730
+ else "condition_embedder.time_embedder"
1731
+ )
1732
+ if any(f"{text_time}.{b_n}" in k for k in original_state_dict):
1733
+ converted_state_dict[f"{diffusers_name}.linear_{diffusers_b_n}.lora_A.weight"] = (
1734
+ original_state_dict.pop(f"{text_time}.{b_n}.{lora_down_key}.weight")
1735
+ )
1736
+ converted_state_dict[f"{diffusers_name}.linear_{diffusers_b_n}.lora_B.weight"] = (
1737
+ original_state_dict.pop(f"{text_time}.{b_n}.{lora_up_key}.weight")
1738
+ )
1739
+ if f"{text_time}.{b_n}.diff_b" in original_state_dict:
1740
+ converted_state_dict[f"{diffusers_name}.linear_{diffusers_b_n}.lora_B.bias"] = (
1741
+ original_state_dict.pop(f"{text_time}.{b_n}.diff_b")
1742
+ )
1743
+
1744
+ for img_ours, img_theirs in [
1745
+ ("ff.net.0.proj", "img_emb.proj.1"),
1746
+ ("ff.net.2", "img_emb.proj.3"),
1747
+ ]:
1748
+ original_key = f"{img_theirs}.{lora_down_key}.weight"
1749
+ converted_key = f"condition_embedder.image_embedder.{img_ours}.lora_A.weight"
1750
+ if original_key in original_state_dict:
1751
+ converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1752
+
1753
+ original_key = f"{img_theirs}.{lora_up_key}.weight"
1754
+ converted_key = f"condition_embedder.image_embedder.{img_ours}.lora_B.weight"
1755
+ if original_key in original_state_dict:
1756
+ converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1757
+
1758
+ if len(original_state_dict) > 0:
1759
+ diff = all(".diff" in k for k in original_state_dict)
1760
+ if diff:
1761
+ diff_keys = {k for k in original_state_dict if k.endswith(".diff")}
1762
+ if not all("lora" not in k for k in diff_keys):
1763
+ raise ValueError
1764
+ logger.info(
1765
+ "The remaining `state_dict` contains `diff` keys which we do not handle yet. If you see performance issues, please file an issue: "
1766
+ "https://github.com/huggingface/diffusers//issues/new"
1602
1767
  )
1768
+ else:
1769
+ raise ValueError(f"`state_dict` should be empty at this point but has {original_state_dict.keys()=}")
1770
+
1771
+ for key in list(converted_state_dict.keys()):
1772
+ converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
1773
+
1774
+ return converted_state_dict
1775
+
1776
+
1777
+ def _convert_musubi_wan_lora_to_diffusers(state_dict):
1778
+ # https://github.com/kohya-ss/musubi-tuner
1779
+ converted_state_dict = {}
1780
+ original_state_dict = {k[len("lora_unet_") :]: v for k, v in state_dict.items()}
1781
+
1782
+ num_blocks = len({k.split("blocks_")[1].split("_")[0] for k in original_state_dict})
1783
+ is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict)
1784
+
1785
+ def get_alpha_scales(down_weight, key):
1786
+ rank = down_weight.shape[0]
1787
+ alpha = original_state_dict.pop(key + ".alpha").item()
1788
+ scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
1789
+ scale_down = scale
1790
+ scale_up = 1.0
1791
+ while scale_down * 2 < scale_up:
1792
+ scale_down *= 2
1793
+ scale_up /= 2
1794
+ return scale_down, scale_up
1795
+
1796
+ for i in range(num_blocks):
1797
+ # Self-attention
1798
+ for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
1799
+ down_weight = original_state_dict.pop(f"blocks_{i}_self_attn_{o}.lora_down.weight")
1800
+ up_weight = original_state_dict.pop(f"blocks_{i}_self_attn_{o}.lora_up.weight")
1801
+ scale_down, scale_up = get_alpha_scales(down_weight, f"blocks_{i}_self_attn_{o}")
1802
+ converted_state_dict[f"blocks.{i}.attn1.{c}.lora_A.weight"] = down_weight * scale_down
1803
+ converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.weight"] = up_weight * scale_up
1804
+
1805
+ # Cross-attention
1806
+ for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
1807
+ down_weight = original_state_dict.pop(f"blocks_{i}_cross_attn_{o}.lora_down.weight")
1808
+ up_weight = original_state_dict.pop(f"blocks_{i}_cross_attn_{o}.lora_up.weight")
1809
+ scale_down, scale_up = get_alpha_scales(down_weight, f"blocks_{i}_cross_attn_{o}")
1810
+ converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = down_weight * scale_down
1811
+ converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = up_weight * scale_up
1812
+
1813
+ if is_i2v_lora:
1814
+ for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
1815
+ down_weight = original_state_dict.pop(f"blocks_{i}_cross_attn_{o}.lora_down.weight")
1816
+ up_weight = original_state_dict.pop(f"blocks_{i}_cross_attn_{o}.lora_up.weight")
1817
+ scale_down, scale_up = get_alpha_scales(down_weight, f"blocks_{i}_cross_attn_{o}")
1818
+ converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = down_weight * scale_down
1819
+ converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = up_weight * scale_up
1820
+
1821
+ # FFN
1822
+ for o, c in zip(["ffn_0", "ffn_2"], ["net.0.proj", "net.2"]):
1823
+ down_weight = original_state_dict.pop(f"blocks_{i}_{o}.lora_down.weight")
1824
+ up_weight = original_state_dict.pop(f"blocks_{i}_{o}.lora_up.weight")
1825
+ scale_down, scale_up = get_alpha_scales(down_weight, f"blocks_{i}_{o}")
1826
+ converted_state_dict[f"blocks.{i}.ffn.{c}.lora_A.weight"] = down_weight * scale_down
1827
+ converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.weight"] = up_weight * scale_up
1603
1828
 
1604
1829
  if len(original_state_dict) > 0:
1605
1830
  raise ValueError(f"`state_dict` should be empty at this point but has {original_state_dict.keys()=}")
@@ -1608,3 +1833,19 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
1608
1833
  converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
1609
1834
 
1610
1835
  return converted_state_dict
1836
+
1837
+
1838
+ def _convert_non_diffusers_hidream_lora_to_diffusers(state_dict, non_diffusers_prefix="diffusion_model"):
1839
+ if not all(k.startswith(non_diffusers_prefix) for k in state_dict):
1840
+ raise ValueError("Invalid LoRA state dict for HiDream.")
1841
+ converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()}
1842
+ converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
1843
+ return converted_state_dict
1844
+
1845
+
1846
+ def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_prefix="diffusion_model"):
1847
+ if not all(k.startswith(f"{non_diffusers_prefix}.") for k in state_dict):
1848
+ raise ValueError("Invalid LoRA state dict for LTX-Video.")
1849
+ converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()}
1850
+ converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
1851
+ return converted_state_dict