diffusers 0.33.1__py3-none-any.whl → 0.34.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (478) hide show
  1. diffusers/__init__.py +48 -1
  2. diffusers/commands/__init__.py +1 -1
  3. diffusers/commands/diffusers_cli.py +1 -1
  4. diffusers/commands/env.py +1 -1
  5. diffusers/commands/fp16_safetensors.py +1 -1
  6. diffusers/dependency_versions_check.py +1 -1
  7. diffusers/dependency_versions_table.py +1 -1
  8. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  9. diffusers/hooks/faster_cache.py +2 -2
  10. diffusers/hooks/group_offloading.py +128 -29
  11. diffusers/hooks/hooks.py +2 -2
  12. diffusers/hooks/layerwise_casting.py +3 -3
  13. diffusers/hooks/pyramid_attention_broadcast.py +1 -1
  14. diffusers/image_processor.py +7 -2
  15. diffusers/loaders/__init__.py +4 -0
  16. diffusers/loaders/ip_adapter.py +5 -14
  17. diffusers/loaders/lora_base.py +212 -111
  18. diffusers/loaders/lora_conversion_utils.py +275 -34
  19. diffusers/loaders/lora_pipeline.py +1554 -819
  20. diffusers/loaders/peft.py +52 -109
  21. diffusers/loaders/single_file.py +2 -2
  22. diffusers/loaders/single_file_model.py +20 -4
  23. diffusers/loaders/single_file_utils.py +225 -5
  24. diffusers/loaders/textual_inversion.py +3 -2
  25. diffusers/loaders/transformer_flux.py +1 -1
  26. diffusers/loaders/transformer_sd3.py +2 -2
  27. diffusers/loaders/unet.py +2 -16
  28. diffusers/loaders/unet_loader_utils.py +1 -1
  29. diffusers/loaders/utils.py +1 -1
  30. diffusers/models/__init__.py +15 -1
  31. diffusers/models/activations.py +5 -5
  32. diffusers/models/adapter.py +2 -3
  33. diffusers/models/attention.py +4 -4
  34. diffusers/models/attention_flax.py +10 -10
  35. diffusers/models/attention_processor.py +14 -10
  36. diffusers/models/auto_model.py +47 -10
  37. diffusers/models/autoencoders/__init__.py +1 -0
  38. diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
  39. diffusers/models/autoencoders/autoencoder_dc.py +3 -3
  40. diffusers/models/autoencoders/autoencoder_kl.py +4 -4
  41. diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
  42. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
  43. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1108 -0
  44. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
  45. diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
  46. diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
  47. diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
  48. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
  49. diffusers/models/autoencoders/autoencoder_kl_wan.py +256 -22
  50. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
  51. diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
  52. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  53. diffusers/models/autoencoders/vae.py +13 -2
  54. diffusers/models/autoencoders/vq_model.py +2 -2
  55. diffusers/models/cache_utils.py +1 -1
  56. diffusers/models/controlnet.py +1 -1
  57. diffusers/models/controlnet_flux.py +1 -1
  58. diffusers/models/controlnet_sd3.py +1 -1
  59. diffusers/models/controlnet_sparsectrl.py +1 -1
  60. diffusers/models/controlnets/__init__.py +1 -0
  61. diffusers/models/controlnets/controlnet.py +3 -3
  62. diffusers/models/controlnets/controlnet_flax.py +1 -1
  63. diffusers/models/controlnets/controlnet_flux.py +16 -15
  64. diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
  65. diffusers/models/controlnets/controlnet_sana.py +290 -0
  66. diffusers/models/controlnets/controlnet_sd3.py +1 -1
  67. diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
  68. diffusers/models/controlnets/controlnet_union.py +1 -1
  69. diffusers/models/controlnets/controlnet_xs.py +7 -7
  70. diffusers/models/controlnets/multicontrolnet.py +4 -5
  71. diffusers/models/controlnets/multicontrolnet_union.py +5 -6
  72. diffusers/models/downsampling.py +2 -2
  73. diffusers/models/embeddings.py +10 -12
  74. diffusers/models/embeddings_flax.py +2 -2
  75. diffusers/models/lora.py +3 -3
  76. diffusers/models/modeling_utils.py +44 -14
  77. diffusers/models/normalization.py +4 -4
  78. diffusers/models/resnet.py +2 -2
  79. diffusers/models/resnet_flax.py +1 -1
  80. diffusers/models/transformers/__init__.py +5 -0
  81. diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
  82. diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
  83. diffusers/models/transformers/consisid_transformer_3d.py +1 -1
  84. diffusers/models/transformers/dit_transformer_2d.py +2 -2
  85. diffusers/models/transformers/dual_transformer_2d.py +1 -1
  86. diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
  87. diffusers/models/transformers/latte_transformer_3d.py +4 -5
  88. diffusers/models/transformers/lumina_nextdit2d.py +2 -2
  89. diffusers/models/transformers/pixart_transformer_2d.py +3 -3
  90. diffusers/models/transformers/prior_transformer.py +1 -1
  91. diffusers/models/transformers/sana_transformer.py +8 -3
  92. diffusers/models/transformers/stable_audio_transformer.py +5 -9
  93. diffusers/models/transformers/t5_film_transformer.py +3 -3
  94. diffusers/models/transformers/transformer_2d.py +1 -1
  95. diffusers/models/transformers/transformer_allegro.py +1 -1
  96. diffusers/models/transformers/transformer_chroma.py +742 -0
  97. diffusers/models/transformers/transformer_cogview3plus.py +5 -10
  98. diffusers/models/transformers/transformer_cogview4.py +317 -25
  99. diffusers/models/transformers/transformer_cosmos.py +579 -0
  100. diffusers/models/transformers/transformer_flux.py +9 -11
  101. diffusers/models/transformers/transformer_hidream_image.py +942 -0
  102. diffusers/models/transformers/transformer_hunyuan_video.py +6 -8
  103. diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
  104. diffusers/models/transformers/transformer_ltx.py +2 -2
  105. diffusers/models/transformers/transformer_lumina2.py +1 -1
  106. diffusers/models/transformers/transformer_mochi.py +1 -1
  107. diffusers/models/transformers/transformer_omnigen.py +2 -2
  108. diffusers/models/transformers/transformer_sd3.py +7 -7
  109. diffusers/models/transformers/transformer_temporal.py +1 -1
  110. diffusers/models/transformers/transformer_wan.py +24 -8
  111. diffusers/models/transformers/transformer_wan_vace.py +393 -0
  112. diffusers/models/unets/unet_1d.py +1 -1
  113. diffusers/models/unets/unet_1d_blocks.py +1 -1
  114. diffusers/models/unets/unet_2d.py +1 -1
  115. diffusers/models/unets/unet_2d_blocks.py +1 -1
  116. diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
  117. diffusers/models/unets/unet_2d_condition.py +2 -2
  118. diffusers/models/unets/unet_2d_condition_flax.py +2 -2
  119. diffusers/models/unets/unet_3d_blocks.py +1 -1
  120. diffusers/models/unets/unet_3d_condition.py +3 -3
  121. diffusers/models/unets/unet_i2vgen_xl.py +3 -3
  122. diffusers/models/unets/unet_kandinsky3.py +1 -1
  123. diffusers/models/unets/unet_motion_model.py +2 -2
  124. diffusers/models/unets/unet_stable_cascade.py +1 -1
  125. diffusers/models/upsampling.py +2 -2
  126. diffusers/models/vae_flax.py +2 -2
  127. diffusers/models/vq_model.py +1 -1
  128. diffusers/pipelines/__init__.py +37 -6
  129. diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
  130. diffusers/pipelines/amused/pipeline_amused.py +7 -6
  131. diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
  132. diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
  133. diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
  134. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
  135. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
  136. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
  137. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
  138. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
  139. diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
  140. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  141. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +23 -13
  142. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
  143. diffusers/pipelines/auto_pipeline.py +6 -7
  144. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  145. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
  146. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
  147. diffusers/pipelines/chroma/__init__.py +49 -0
  148. diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
  149. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
  150. diffusers/pipelines/chroma/pipeline_output.py +21 -0
  151. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +8 -8
  152. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +8 -8
  153. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +8 -8
  154. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +8 -8
  155. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
  156. diffusers/pipelines/cogview4/pipeline_cogview4.py +7 -7
  157. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
  158. diffusers/pipelines/consisid/consisid_utils.py +2 -2
  159. diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
  160. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  161. diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
  162. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +8 -8
  163. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
  164. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
  165. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
  166. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
  167. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
  168. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +14 -14
  169. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +5 -5
  170. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +13 -13
  171. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  172. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
  173. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
  174. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  175. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
  176. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
  177. diffusers/pipelines/cosmos/__init__.py +54 -0
  178. diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
  179. diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
  180. diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
  181. diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
  182. diffusers/pipelines/cosmos/pipeline_output.py +40 -0
  183. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
  184. diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
  185. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
  186. diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
  187. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
  188. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
  189. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
  190. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
  191. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
  192. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
  193. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
  194. diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
  195. diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
  196. diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
  197. diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
  198. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
  199. diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
  200. diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
  201. diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
  202. diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
  203. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
  204. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +7 -7
  205. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
  206. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
  207. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
  208. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
  209. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
  210. diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
  211. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
  212. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
  213. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
  214. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
  215. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
  216. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
  217. diffusers/pipelines/dit/pipeline_dit.py +1 -1
  218. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
  219. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
  220. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
  221. diffusers/pipelines/flux/modeling_flux.py +1 -1
  222. diffusers/pipelines/flux/pipeline_flux.py +10 -17
  223. diffusers/pipelines/flux/pipeline_flux_control.py +6 -6
  224. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -6
  225. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +6 -6
  226. diffusers/pipelines/flux/pipeline_flux_controlnet.py +6 -6
  227. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +30 -22
  228. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +2 -1
  229. diffusers/pipelines/flux/pipeline_flux_fill.py +6 -6
  230. diffusers/pipelines/flux/pipeline_flux_img2img.py +39 -6
  231. diffusers/pipelines/flux/pipeline_flux_inpaint.py +11 -6
  232. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
  233. diffusers/pipelines/free_init_utils.py +2 -2
  234. diffusers/pipelines/free_noise_utils.py +3 -3
  235. diffusers/pipelines/hidream_image/__init__.py +47 -0
  236. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
  237. diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
  238. diffusers/pipelines/hunyuan_video/__init__.py +2 -0
  239. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
  240. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +8 -8
  241. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
  242. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
  243. diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
  244. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
  245. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
  246. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
  247. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
  248. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
  249. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
  250. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
  251. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  252. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
  253. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
  254. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
  255. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
  256. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
  257. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
  258. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
  259. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
  260. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
  261. diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
  262. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
  263. diffusers/pipelines/kolors/text_encoder.py +3 -3
  264. diffusers/pipelines/kolors/tokenizer.py +1 -1
  265. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
  266. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
  267. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  268. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
  269. diffusers/pipelines/latte/pipeline_latte.py +12 -12
  270. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
  271. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
  272. diffusers/pipelines/ltx/__init__.py +4 -0
  273. diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
  274. diffusers/pipelines/ltx/pipeline_ltx.py +51 -6
  275. diffusers/pipelines/ltx/pipeline_ltx_condition.py +107 -29
  276. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +50 -6
  277. diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
  278. diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
  279. diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
  280. diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
  281. diffusers/pipelines/mochi/pipeline_mochi.py +6 -6
  282. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
  283. diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
  284. diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
  285. diffusers/pipelines/onnx_utils.py +15 -2
  286. diffusers/pipelines/pag/pag_utils.py +2 -2
  287. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
  288. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
  289. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
  290. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
  291. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
  292. diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
  293. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
  294. diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
  295. diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
  296. diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
  297. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
  298. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
  299. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
  300. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
  301. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
  302. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
  303. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
  304. diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
  305. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
  306. diffusers/pipelines/pia/pipeline_pia.py +8 -6
  307. diffusers/pipelines/pipeline_flax_utils.py +3 -4
  308. diffusers/pipelines/pipeline_loading_utils.py +89 -13
  309. diffusers/pipelines/pipeline_utils.py +105 -33
  310. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +11 -11
  311. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +11 -11
  312. diffusers/pipelines/sana/__init__.py +4 -0
  313. diffusers/pipelines/sana/pipeline_sana.py +23 -21
  314. diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
  315. diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
  316. diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
  317. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
  318. diffusers/pipelines/shap_e/camera.py +1 -1
  319. diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
  320. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
  321. diffusers/pipelines/shap_e/renderer.py +3 -3
  322. diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
  323. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
  324. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
  325. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
  326. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
  327. diffusers/pipelines/stable_diffusion/__init__.py +0 -7
  328. diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
  329. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
  330. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  331. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
  332. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
  333. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +10 -10
  334. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
  335. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +10 -10
  336. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +9 -9
  337. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +8 -8
  338. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
  339. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
  340. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
  341. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
  342. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
  343. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
  344. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
  345. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
  346. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
  347. diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
  348. diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
  349. diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
  350. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -7
  351. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
  352. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
  353. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
  354. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
  355. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
  356. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
  357. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
  358. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
  359. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
  360. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
  361. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
  362. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  363. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
  364. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  365. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
  366. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
  367. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
  368. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
  369. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
  370. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
  371. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
  372. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
  373. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
  374. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
  375. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
  376. diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
  377. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
  378. diffusers/pipelines/unclip/text_proj.py +2 -2
  379. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
  380. diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
  381. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
  382. diffusers/pipelines/visualcloze/__init__.py +52 -0
  383. diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
  384. diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
  385. diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
  386. diffusers/pipelines/wan/__init__.py +2 -0
  387. diffusers/pipelines/wan/pipeline_wan.py +13 -10
  388. diffusers/pipelines/wan/pipeline_wan_i2v.py +38 -18
  389. diffusers/pipelines/wan/pipeline_wan_vace.py +976 -0
  390. diffusers/pipelines/wan/pipeline_wan_video2video.py +14 -16
  391. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  392. diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
  393. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  394. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
  395. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
  396. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
  397. diffusers/quantizers/__init__.py +179 -1
  398. diffusers/quantizers/base.py +6 -1
  399. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
  400. diffusers/quantizers/bitsandbytes/utils.py +10 -7
  401. diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
  402. diffusers/quantizers/gguf/utils.py +16 -13
  403. diffusers/quantizers/quantization_config.py +18 -16
  404. diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
  405. diffusers/quantizers/torchao/torchao_quantizer.py +5 -1
  406. diffusers/schedulers/__init__.py +3 -1
  407. diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
  408. diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
  409. diffusers/schedulers/scheduling_consistency_models.py +1 -1
  410. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
  411. diffusers/schedulers/scheduling_ddim.py +8 -8
  412. diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
  413. diffusers/schedulers/scheduling_ddim_flax.py +6 -6
  414. diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
  415. diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
  416. diffusers/schedulers/scheduling_ddpm.py +9 -9
  417. diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
  418. diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
  419. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
  420. diffusers/schedulers/scheduling_deis_multistep.py +8 -8
  421. diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
  422. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -12
  423. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
  424. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
  425. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
  426. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +13 -13
  427. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
  428. diffusers/schedulers/scheduling_edm_euler.py +20 -11
  429. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
  430. diffusers/schedulers/scheduling_euler_discrete.py +3 -3
  431. diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
  432. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
  433. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
  434. diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
  435. diffusers/schedulers/scheduling_heun_discrete.py +2 -2
  436. diffusers/schedulers/scheduling_ipndm.py +2 -2
  437. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
  438. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
  439. diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
  440. diffusers/schedulers/scheduling_lcm.py +3 -3
  441. diffusers/schedulers/scheduling_lms_discrete.py +2 -2
  442. diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
  443. diffusers/schedulers/scheduling_pndm.py +4 -4
  444. diffusers/schedulers/scheduling_pndm_flax.py +4 -4
  445. diffusers/schedulers/scheduling_repaint.py +9 -9
  446. diffusers/schedulers/scheduling_sasolver.py +15 -15
  447. diffusers/schedulers/scheduling_scm.py +1 -1
  448. diffusers/schedulers/scheduling_sde_ve.py +1 -1
  449. diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
  450. diffusers/schedulers/scheduling_tcd.py +3 -3
  451. diffusers/schedulers/scheduling_unclip.py +5 -5
  452. diffusers/schedulers/scheduling_unipc_multistep.py +11 -11
  453. diffusers/schedulers/scheduling_utils.py +1 -1
  454. diffusers/schedulers/scheduling_utils_flax.py +1 -1
  455. diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
  456. diffusers/training_utils.py +13 -5
  457. diffusers/utils/__init__.py +5 -0
  458. diffusers/utils/accelerate_utils.py +1 -1
  459. diffusers/utils/doc_utils.py +1 -1
  460. diffusers/utils/dummy_pt_objects.py +120 -0
  461. diffusers/utils/dummy_torch_and_transformers_objects.py +225 -0
  462. diffusers/utils/dynamic_modules_utils.py +21 -3
  463. diffusers/utils/export_utils.py +1 -1
  464. diffusers/utils/import_utils.py +81 -18
  465. diffusers/utils/logging.py +1 -1
  466. diffusers/utils/outputs.py +2 -1
  467. diffusers/utils/peft_utils.py +91 -8
  468. diffusers/utils/state_dict_utils.py +20 -3
  469. diffusers/utils/testing_utils.py +59 -7
  470. diffusers/utils/torch_utils.py +25 -5
  471. diffusers/video_processor.py +2 -2
  472. {diffusers-0.33.1.dist-info → diffusers-0.34.0.dist-info}/METADATA +70 -55
  473. diffusers-0.34.0.dist-info/RECORD +639 -0
  474. {diffusers-0.33.1.dist-info → diffusers-0.34.0.dist-info}/WHEEL +1 -1
  475. diffusers-0.33.1.dist-info/RECORD +0 -608
  476. {diffusers-0.33.1.dist-info → diffusers-0.34.0.dist-info}/LICENSE +0 -0
  477. {diffusers-0.33.1.dist-info → diffusers-0.34.0.dist-info}/entry_points.txt +0 -0
  478. {diffusers-0.33.1.dist-info → diffusers-0.34.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -14,6 +14,7 @@
14
14
 
15
15
  import copy
16
16
  import inspect
17
+ import json
17
18
  import os
18
19
  from pathlib import Path
19
20
  from typing import Callable, Dict, List, Optional, Union
@@ -33,7 +34,6 @@ from ..utils import (
33
34
  delete_adapter_layers,
34
35
  deprecate,
35
36
  get_adapter_name,
36
- get_peft_kwargs,
37
37
  is_accelerate_available,
38
38
  is_peft_available,
39
39
  is_peft_version,
@@ -45,13 +45,13 @@ from ..utils import (
45
45
  set_adapter_layers,
46
46
  set_weights_and_activate_adapters,
47
47
  )
48
+ from ..utils.peft_utils import _create_lora_config
49
+ from ..utils.state_dict_utils import _load_sft_state_dict_metadata
48
50
 
49
51
 
50
52
  if is_transformers_available():
51
53
  from transformers import PreTrainedModel
52
54
 
53
- from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
54
-
55
55
  if is_peft_available():
56
56
  from peft.tuners.tuners_utils import BaseTunerLayer
57
57
 
@@ -62,6 +62,7 @@ logger = logging.get_logger(__name__)
62
62
 
63
63
  LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
64
64
  LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
65
+ LORA_ADAPTER_METADATA_KEY = "lora_adapter_metadata"
65
66
 
66
67
 
67
68
  def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
@@ -206,6 +207,7 @@ def _fetch_state_dict(
206
207
  subfolder,
207
208
  user_agent,
208
209
  allow_pickle,
210
+ metadata=None,
209
211
  ):
210
212
  model_file = None
211
213
  if not isinstance(pretrained_model_name_or_path_or_dict, dict):
@@ -236,11 +238,14 @@ def _fetch_state_dict(
236
238
  user_agent=user_agent,
237
239
  )
238
240
  state_dict = safetensors.torch.load_file(model_file, device="cpu")
241
+ metadata = _load_sft_state_dict_metadata(model_file)
242
+
239
243
  except (IOError, safetensors.SafetensorError) as e:
240
244
  if not allow_pickle:
241
245
  raise e
242
246
  # try loading non-safetensors weights
243
247
  model_file = None
248
+ metadata = None
244
249
  pass
245
250
 
246
251
  if model_file is None:
@@ -261,10 +266,11 @@ def _fetch_state_dict(
261
266
  user_agent=user_agent,
262
267
  )
263
268
  state_dict = load_state_dict(model_file)
269
+ metadata = None
264
270
  else:
265
271
  state_dict = pretrained_model_name_or_path_or_dict
266
272
 
267
- return state_dict
273
+ return state_dict, metadata
268
274
 
269
275
 
270
276
  def _best_guess_weight_name(
@@ -299,13 +305,18 @@ def _best_guess_weight_name(
299
305
  targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files))
300
306
 
301
307
  if len(targeted_files) > 1:
302
- raise ValueError(
303
- f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}."
308
+ logger.warning(
309
+ f"Provided path contains more than one weights file in the {file_extension} format. `{targeted_files[0]}` is going to be loaded, for precise control, specify a `weight_name` in `load_lora_weights`."
304
310
  )
305
311
  weight_name = targeted_files[0]
306
312
  return weight_name
307
313
 
308
314
 
315
+ def _pack_dict_with_prefix(state_dict, prefix):
316
+ sd_with_prefix = {f"{prefix}.{key}": value for key, value in state_dict.items()}
317
+ return sd_with_prefix
318
+
319
+
309
320
  def _load_lora_into_text_encoder(
310
321
  state_dict,
311
322
  network_alphas,
@@ -317,10 +328,14 @@ def _load_lora_into_text_encoder(
317
328
  _pipeline=None,
318
329
  low_cpu_mem_usage=False,
319
330
  hotswap: bool = False,
331
+ metadata=None,
320
332
  ):
321
333
  if not USE_PEFT_BACKEND:
322
334
  raise ValueError("PEFT backend is required for this method.")
323
335
 
336
+ if network_alphas and metadata:
337
+ raise ValueError("`network_alphas` and `metadata` cannot be specified both at the same time.")
338
+
324
339
  peft_kwargs = {}
325
340
  if low_cpu_mem_usage:
326
341
  if not is_peft_version(">=", "0.13.1"):
@@ -335,8 +350,6 @@ def _load_lora_into_text_encoder(
335
350
  )
336
351
  peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
337
352
 
338
- from peft import LoraConfig
339
-
340
353
  # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
341
354
  # then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as
342
355
  # their prefixes.
@@ -348,7 +361,9 @@ def _load_lora_into_text_encoder(
348
361
 
349
362
  # Load the layers corresponding to text encoder and make necessary adjustments.
350
363
  if prefix is not None:
351
- state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
364
+ state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
365
+ if metadata is not None:
366
+ metadata = {k.removeprefix(f"{prefix}."): v for k, v in metadata.items() if k.startswith(f"{prefix}.")}
352
367
 
353
368
  if len(state_dict) > 0:
354
369
  logger.info(f"Loading {prefix}.")
@@ -358,54 +373,25 @@ def _load_lora_into_text_encoder(
358
373
  # convert state dict
359
374
  state_dict = convert_state_dict_to_peft(state_dict)
360
375
 
361
- for name, _ in text_encoder_attn_modules(text_encoder):
362
- for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
363
- rank_key = f"{name}.{module}.lora_B.weight"
364
- if rank_key not in state_dict:
365
- continue
366
- rank[rank_key] = state_dict[rank_key].shape[1]
367
-
368
- for name, _ in text_encoder_mlp_modules(text_encoder):
369
- for module in ("fc1", "fc2"):
370
- rank_key = f"{name}.{module}.lora_B.weight"
371
- if rank_key not in state_dict:
372
- continue
373
- rank[rank_key] = state_dict[rank_key].shape[1]
376
+ for name, _ in text_encoder.named_modules():
377
+ if name.endswith((".q_proj", ".k_proj", ".v_proj", ".out_proj", ".fc1", ".fc2")):
378
+ rank_key = f"{name}.lora_B.weight"
379
+ if rank_key in state_dict:
380
+ rank[rank_key] = state_dict[rank_key].shape[1]
374
381
 
375
382
  if network_alphas is not None:
376
383
  alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
377
- network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
378
-
379
- lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)
384
+ network_alphas = {k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys}
380
385
 
381
- if "use_dora" in lora_config_kwargs:
382
- if lora_config_kwargs["use_dora"]:
383
- if is_peft_version("<", "0.9.0"):
384
- raise ValueError(
385
- "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
386
- )
387
- else:
388
- if is_peft_version("<", "0.9.0"):
389
- lora_config_kwargs.pop("use_dora")
390
-
391
- if "lora_bias" in lora_config_kwargs:
392
- if lora_config_kwargs["lora_bias"]:
393
- if is_peft_version("<=", "0.13.2"):
394
- raise ValueError(
395
- "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
396
- )
397
- else:
398
- if is_peft_version("<=", "0.13.2"):
399
- lora_config_kwargs.pop("lora_bias")
400
-
401
- lora_config = LoraConfig(**lora_config_kwargs)
386
+ # create `LoraConfig`
387
+ lora_config = _create_lora_config(state_dict, network_alphas, metadata, rank, is_unet=False)
402
388
 
403
389
  # adapter_name
404
390
  if adapter_name is None:
405
391
  adapter_name = get_adapter_name(text_encoder)
406
392
 
393
+ # <Unsafe code
407
394
  is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
408
-
409
395
  # inject LoRA layers and load the state dict
410
396
  # in transformers we automatically check whether the adapter name is already in use or not
411
397
  text_encoder.load_adapter(
@@ -417,7 +403,6 @@ def _load_lora_into_text_encoder(
417
403
 
418
404
  # scale LoRA layers with `lora_scale`
419
405
  scale_lora_layers(text_encoder, weight=lora_scale)
420
-
421
406
  text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
422
407
 
423
408
  # Offload back.
@@ -428,16 +413,28 @@ def _load_lora_into_text_encoder(
428
413
  # Unsafe code />
429
414
 
430
415
  if prefix is not None and not state_dict:
416
+ model_class_name = text_encoder.__class__.__name__
431
417
  logger.warning(
432
- f"No LoRA keys associated to {text_encoder.__class__.__name__} found with the {prefix=}. "
418
+ f"No LoRA keys associated to {model_class_name} found with the {prefix=}. "
433
419
  "This is safe to ignore if LoRA state dict didn't originally have any "
434
- f"{text_encoder.__class__.__name__} related params. You can also try specifying `prefix=None` "
420
+ f"{model_class_name} related params. You can also try specifying `prefix=None` "
435
421
  "to resolve the warning. Otherwise, open an issue if you think it's unexpected: "
436
422
  "https://github.com/huggingface/diffusers/issues/new"
437
423
  )
438
424
 
439
425
 
440
426
  def _func_optionally_disable_offloading(_pipeline):
427
+ """
428
+ Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
429
+
430
+ Args:
431
+ _pipeline (`DiffusionPipeline`):
432
+ The pipeline to disable offloading for.
433
+
434
+ Returns:
435
+ tuple:
436
+ A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
437
+ """
441
438
  is_model_cpu_offload = False
442
439
  is_sequential_cpu_offload = False
443
440
 
@@ -456,7 +453,8 @@ def _func_optionally_disable_offloading(_pipeline):
456
453
  logger.info(
457
454
  "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
458
455
  )
459
- remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
456
+ if is_sequential_cpu_offload or is_model_cpu_offload:
457
+ remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
460
458
 
461
459
  return (is_model_cpu_offload, is_sequential_cpu_offload)
462
460
 
@@ -465,7 +463,25 @@ class LoraBaseMixin:
465
463
  """Utility class for handling LoRAs."""
466
464
 
467
465
  _lora_loadable_modules = []
468
- num_fused_loras = 0
466
+ _merged_adapters = set()
467
+
468
+ @property
469
+ def lora_scale(self) -> float:
470
+ """
471
+ Returns the lora scale which can be set at run time by the pipeline. # if `_lora_scale` has not been set,
472
+ return 1.
473
+ """
474
+ return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
475
+
476
+ @property
477
+ def num_fused_loras(self):
478
+ """Returns the number of LoRAs that have been fused."""
479
+ return len(self._merged_adapters)
480
+
481
+ @property
482
+ def fused_loras(self):
483
+ """Returns names of the LoRAs that have been fused."""
484
+ return self._merged_adapters
469
485
 
470
486
  def load_lora_weights(self, **kwargs):
471
487
  raise NotImplementedError("`load_lora_weights()` is not implemented.")
@@ -478,33 +494,6 @@ class LoraBaseMixin:
478
494
  def lora_state_dict(cls, **kwargs):
479
495
  raise NotImplementedError("`lora_state_dict()` is not implemented.")
480
496
 
481
- @classmethod
482
- def _optionally_disable_offloading(cls, _pipeline):
483
- """
484
- Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
485
-
486
- Args:
487
- _pipeline (`DiffusionPipeline`):
488
- The pipeline to disable offloading for.
489
-
490
- Returns:
491
- tuple:
492
- A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
493
- """
494
- return _func_optionally_disable_offloading(_pipeline=_pipeline)
495
-
496
- @classmethod
497
- def _fetch_state_dict(cls, *args, **kwargs):
498
- deprecation_message = f"Using the `_fetch_state_dict()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`."
499
- deprecate("_fetch_state_dict", "0.35.0", deprecation_message)
500
- return _fetch_state_dict(*args, **kwargs)
501
-
502
- @classmethod
503
- def _best_guess_weight_name(cls, *args, **kwargs):
504
- deprecation_message = f"Using the `_best_guess_weight_name()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`."
505
- deprecate("_best_guess_weight_name", "0.35.0", deprecation_message)
506
- return _best_guess_weight_name(*args, **kwargs)
507
-
508
497
  def unload_lora_weights(self):
509
498
  """
510
499
  Unloads the LoRA parameters.
@@ -592,6 +581,9 @@ class LoraBaseMixin:
592
581
  if len(components) == 0:
593
582
  raise ValueError("`components` cannot be an empty list.")
594
583
 
584
+ # Need to retrieve the names as `adapter_names` can be None. So we cannot directly use it
585
+ # in `self._merged_adapters = self._merged_adapters | merged_adapter_names`.
586
+ merged_adapter_names = set()
595
587
  for fuse_component in components:
596
588
  if fuse_component not in self._lora_loadable_modules:
597
589
  raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.")
@@ -601,13 +593,19 @@ class LoraBaseMixin:
601
593
  # check if diffusers model
602
594
  if issubclass(model.__class__, ModelMixin):
603
595
  model.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
596
+ for module in model.modules():
597
+ if isinstance(module, BaseTunerLayer):
598
+ merged_adapter_names.update(set(module.merged_adapters))
604
599
  # handle transformers models.
605
600
  if issubclass(model.__class__, PreTrainedModel):
606
601
  fuse_text_encoder_lora(
607
602
  model, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
608
603
  )
604
+ for module in model.modules():
605
+ if isinstance(module, BaseTunerLayer):
606
+ merged_adapter_names.update(set(module.merged_adapters))
609
607
 
610
- self.num_fused_loras += 1
608
+ self._merged_adapters = self._merged_adapters | merged_adapter_names
611
609
 
612
610
  def unfuse_lora(self, components: List[str] = [], **kwargs):
613
611
  r"""
@@ -661,15 +659,42 @@ class LoraBaseMixin:
661
659
  if issubclass(model.__class__, (ModelMixin, PreTrainedModel)):
662
660
  for module in model.modules():
663
661
  if isinstance(module, BaseTunerLayer):
662
+ for adapter in set(module.merged_adapters):
663
+ if adapter and adapter in self._merged_adapters:
664
+ self._merged_adapters = self._merged_adapters - {adapter}
664
665
  module.unmerge()
665
666
 
666
- self.num_fused_loras -= 1
667
-
668
667
  def set_adapters(
669
668
  self,
670
669
  adapter_names: Union[List[str], str],
671
670
  adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None,
672
671
  ):
672
+ """
673
+ Set the currently active adapters for use in the pipeline.
674
+
675
+ Args:
676
+ adapter_names (`List[str]` or `str`):
677
+ The names of the adapters to use.
678
+ adapter_weights (`Union[List[float], float]`, *optional*):
679
+ The adapter(s) weights to use with the UNet. If `None`, the weights are set to `1.0` for all the
680
+ adapters.
681
+
682
+ Example:
683
+
684
+ ```py
685
+ from diffusers import AutoPipelineForText2Image
686
+ import torch
687
+
688
+ pipeline = AutoPipelineForText2Image.from_pretrained(
689
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
690
+ ).to("cuda")
691
+ pipeline.load_lora_weights(
692
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
693
+ )
694
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
695
+ pipeline.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5])
696
+ ```
697
+ """
673
698
  if isinstance(adapter_weights, dict):
674
699
  components_passed = set(adapter_weights.keys())
675
700
  lora_components = set(self._lora_loadable_modules)
@@ -739,6 +764,24 @@ class LoraBaseMixin:
739
764
  set_adapters_for_text_encoder(adapter_names, model, _component_adapter_weights[component])
740
765
 
741
766
  def disable_lora(self):
767
+ """
768
+ Disables the active LoRA layers of the pipeline.
769
+
770
+ Example:
771
+
772
+ ```py
773
+ from diffusers import AutoPipelineForText2Image
774
+ import torch
775
+
776
+ pipeline = AutoPipelineForText2Image.from_pretrained(
777
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
778
+ ).to("cuda")
779
+ pipeline.load_lora_weights(
780
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
781
+ )
782
+ pipeline.disable_lora()
783
+ ```
784
+ """
742
785
  if not USE_PEFT_BACKEND:
743
786
  raise ValueError("PEFT backend is required for this method.")
744
787
 
@@ -751,6 +794,24 @@ class LoraBaseMixin:
751
794
  disable_lora_for_text_encoder(model)
752
795
 
753
796
  def enable_lora(self):
797
+ """
798
+ Enables the active LoRA layers of the pipeline.
799
+
800
+ Example:
801
+
802
+ ```py
803
+ from diffusers import AutoPipelineForText2Image
804
+ import torch
805
+
806
+ pipeline = AutoPipelineForText2Image.from_pretrained(
807
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
808
+ ).to("cuda")
809
+ pipeline.load_lora_weights(
810
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
811
+ )
812
+ pipeline.enable_lora()
813
+ ```
814
+ """
754
815
  if not USE_PEFT_BACKEND:
755
816
  raise ValueError("PEFT backend is required for this method.")
756
817
 
@@ -764,10 +825,26 @@ class LoraBaseMixin:
764
825
 
765
826
  def delete_adapters(self, adapter_names: Union[List[str], str]):
766
827
  """
828
+ Delete an adapter's LoRA layers from the pipeline.
829
+
767
830
  Args:
768
- Deletes the LoRA layers of `adapter_name` for the unet and text-encoder(s).
769
831
  adapter_names (`Union[List[str], str]`):
770
- The names of the adapter to delete. Can be a single string or a list of strings
832
+ The names of the adapters to delete.
833
+
834
+ Example:
835
+
836
+ ```py
837
+ from diffusers import AutoPipelineForText2Image
838
+ import torch
839
+
840
+ pipeline = AutoPipelineForText2Image.from_pretrained(
841
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
842
+ ).to("cuda")
843
+ pipeline.load_lora_weights(
844
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_names="cinematic"
845
+ )
846
+ pipeline.delete_adapters("cinematic")
847
+ ```
771
848
  """
772
849
  if not USE_PEFT_BACKEND:
773
850
  raise ValueError("PEFT backend is required for this method.")
@@ -868,11 +945,28 @@ class LoraBaseMixin:
868
945
  adapter_name
869
946
  ].to(device)
870
947
 
948
+ def enable_lora_hotswap(self, **kwargs) -> None:
949
+ """
950
+ Hotswap adapters without triggering recompilation of a model or if the ranks of the loaded adapters are
951
+ different.
952
+
953
+ Args:
954
+ target_rank (`int`):
955
+ The highest rank among all the adapters that will be loaded.
956
+ check_compiled (`str`, *optional*, defaults to `"error"`):
957
+ How to handle a model that is already compiled. The check can return the following messages:
958
+ - "error" (default): raise an error
959
+ - "warn": issue a warning
960
+ - "ignore": do nothing
961
+ """
962
+ for key, component in self.components.items():
963
+ if hasattr(component, "enable_lora_hotswap") and (key in self._lora_loadable_modules):
964
+ component.enable_lora_hotswap(**kwargs)
965
+
871
966
  @staticmethod
872
967
  def pack_weights(layers, prefix):
873
968
  layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
874
- layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
875
- return layers_state_dict
969
+ return _pack_dict_with_prefix(layers_weights, prefix)
876
970
 
877
971
  @staticmethod
878
972
  def write_lora_layers(
@@ -882,16 +976,33 @@ class LoraBaseMixin:
882
976
  weight_name: str,
883
977
  save_function: Callable,
884
978
  safe_serialization: bool,
979
+ lora_adapter_metadata: Optional[dict] = None,
885
980
  ):
981
+ """Writes the state dict of the LoRA layers (optionally with metadata) to disk."""
886
982
  if os.path.isfile(save_directory):
887
983
  logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
888
984
  return
889
985
 
986
+ if lora_adapter_metadata and not safe_serialization:
987
+ raise ValueError("`lora_adapter_metadata` cannot be specified when not using `safe_serialization`.")
988
+ if lora_adapter_metadata and not isinstance(lora_adapter_metadata, dict):
989
+ raise TypeError("`lora_adapter_metadata` must be of type `dict`.")
990
+
890
991
  if save_function is None:
891
992
  if safe_serialization:
892
993
 
893
994
  def save_function(weights, filename):
894
- return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
995
+ # Inject framework format.
996
+ metadata = {"format": "pt"}
997
+ if lora_adapter_metadata:
998
+ for key, value in lora_adapter_metadata.items():
999
+ if isinstance(value, set):
1000
+ lora_adapter_metadata[key] = list(value)
1001
+ metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(
1002
+ lora_adapter_metadata, indent=2, sort_keys=True
1003
+ )
1004
+
1005
+ return safetensors.torch.save_file(weights, filename, metadata=metadata)
895
1006
 
896
1007
  else:
897
1008
  save_function = torch.save
@@ -908,28 +1019,18 @@ class LoraBaseMixin:
908
1019
  save_function(state_dict, save_path)
909
1020
  logger.info(f"Model weights saved in {save_path}")
910
1021
 
911
- @property
912
- def lora_scale(self) -> float:
913
- # property function that returns the lora scale which can be set at run time by the pipeline.
914
- # if _lora_scale has not been set, return 1
915
- return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
916
-
917
- def enable_lora_hotswap(self, **kwargs) -> None:
918
- """Enables the possibility to hotswap LoRA adapters.
1022
+ @classmethod
1023
+ def _optionally_disable_offloading(cls, _pipeline):
1024
+ return _func_optionally_disable_offloading(_pipeline=_pipeline)
919
1025
 
920
- Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of
921
- the loaded adapters differ.
1026
+ @classmethod
1027
+ def _fetch_state_dict(cls, *args, **kwargs):
1028
+ deprecation_message = f"Using the `_fetch_state_dict()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`."
1029
+ deprecate("_fetch_state_dict", "0.35.0", deprecation_message)
1030
+ return _fetch_state_dict(*args, **kwargs)
922
1031
 
923
- Args:
924
- target_rank (`int`):
925
- The highest rank among all the adapters that will be loaded.
926
- check_compiled (`str`, *optional*, defaults to `"error"`):
927
- How to handle the case when the model is already compiled, which should generally be avoided. The
928
- options are:
929
- - "error" (default): raise an error
930
- - "warn": issue a warning
931
- - "ignore": do nothing
932
- """
933
- for key, component in self.components.items():
934
- if hasattr(component, "enable_lora_hotswap") and (key in self._lora_loadable_modules):
935
- component.enable_lora_hotswap(**kwargs)
1032
+ @classmethod
1033
+ def _best_guess_weight_name(cls, *args, **kwargs):
1034
+ deprecation_message = f"Using the `_best_guess_weight_name()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`."
1035
+ deprecate("_best_guess_weight_name", "0.35.0", deprecation_message)
1036
+ return _best_guess_weight_name(*args, **kwargs)