diffusers 0.33.1__py3-none-any.whl → 0.35.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (551) hide show
  1. diffusers/__init__.py +145 -1
  2. diffusers/callbacks.py +35 -0
  3. diffusers/commands/__init__.py +1 -1
  4. diffusers/commands/custom_blocks.py +134 -0
  5. diffusers/commands/diffusers_cli.py +3 -1
  6. diffusers/commands/env.py +1 -1
  7. diffusers/commands/fp16_safetensors.py +2 -2
  8. diffusers/configuration_utils.py +11 -2
  9. diffusers/dependency_versions_check.py +1 -1
  10. diffusers/dependency_versions_table.py +3 -3
  11. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  12. diffusers/guiders/__init__.py +41 -0
  13. diffusers/guiders/adaptive_projected_guidance.py +188 -0
  14. diffusers/guiders/auto_guidance.py +190 -0
  15. diffusers/guiders/classifier_free_guidance.py +141 -0
  16. diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
  17. diffusers/guiders/frequency_decoupled_guidance.py +327 -0
  18. diffusers/guiders/guider_utils.py +309 -0
  19. diffusers/guiders/perturbed_attention_guidance.py +271 -0
  20. diffusers/guiders/skip_layer_guidance.py +262 -0
  21. diffusers/guiders/smoothed_energy_guidance.py +251 -0
  22. diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
  23. diffusers/hooks/__init__.py +17 -0
  24. diffusers/hooks/_common.py +56 -0
  25. diffusers/hooks/_helpers.py +293 -0
  26. diffusers/hooks/faster_cache.py +9 -8
  27. diffusers/hooks/first_block_cache.py +259 -0
  28. diffusers/hooks/group_offloading.py +332 -227
  29. diffusers/hooks/hooks.py +58 -3
  30. diffusers/hooks/layer_skip.py +263 -0
  31. diffusers/hooks/layerwise_casting.py +5 -10
  32. diffusers/hooks/pyramid_attention_broadcast.py +15 -12
  33. diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
  34. diffusers/hooks/utils.py +43 -0
  35. diffusers/image_processor.py +7 -2
  36. diffusers/loaders/__init__.py +10 -0
  37. diffusers/loaders/ip_adapter.py +260 -18
  38. diffusers/loaders/lora_base.py +261 -127
  39. diffusers/loaders/lora_conversion_utils.py +657 -35
  40. diffusers/loaders/lora_pipeline.py +2778 -1246
  41. diffusers/loaders/peft.py +78 -112
  42. diffusers/loaders/single_file.py +2 -2
  43. diffusers/loaders/single_file_model.py +64 -15
  44. diffusers/loaders/single_file_utils.py +395 -7
  45. diffusers/loaders/textual_inversion.py +3 -2
  46. diffusers/loaders/transformer_flux.py +10 -11
  47. diffusers/loaders/transformer_sd3.py +8 -3
  48. diffusers/loaders/unet.py +24 -21
  49. diffusers/loaders/unet_loader_utils.py +6 -3
  50. diffusers/loaders/utils.py +1 -1
  51. diffusers/models/__init__.py +23 -1
  52. diffusers/models/activations.py +5 -5
  53. diffusers/models/adapter.py +2 -3
  54. diffusers/models/attention.py +488 -7
  55. diffusers/models/attention_dispatch.py +1218 -0
  56. diffusers/models/attention_flax.py +10 -10
  57. diffusers/models/attention_processor.py +113 -667
  58. diffusers/models/auto_model.py +49 -12
  59. diffusers/models/autoencoders/__init__.py +2 -0
  60. diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
  61. diffusers/models/autoencoders/autoencoder_dc.py +17 -4
  62. diffusers/models/autoencoders/autoencoder_kl.py +5 -5
  63. diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
  64. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
  65. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1110 -0
  66. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
  67. diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
  68. diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
  69. diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
  70. diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
  71. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
  72. diffusers/models/autoencoders/autoencoder_kl_wan.py +626 -62
  73. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
  74. diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
  75. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  76. diffusers/models/autoencoders/vae.py +13 -2
  77. diffusers/models/autoencoders/vq_model.py +2 -2
  78. diffusers/models/cache_utils.py +32 -10
  79. diffusers/models/controlnet.py +1 -1
  80. diffusers/models/controlnet_flux.py +1 -1
  81. diffusers/models/controlnet_sd3.py +1 -1
  82. diffusers/models/controlnet_sparsectrl.py +1 -1
  83. diffusers/models/controlnets/__init__.py +1 -0
  84. diffusers/models/controlnets/controlnet.py +3 -3
  85. diffusers/models/controlnets/controlnet_flax.py +1 -1
  86. diffusers/models/controlnets/controlnet_flux.py +21 -20
  87. diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
  88. diffusers/models/controlnets/controlnet_sana.py +290 -0
  89. diffusers/models/controlnets/controlnet_sd3.py +1 -1
  90. diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
  91. diffusers/models/controlnets/controlnet_union.py +5 -5
  92. diffusers/models/controlnets/controlnet_xs.py +7 -7
  93. diffusers/models/controlnets/multicontrolnet.py +4 -5
  94. diffusers/models/controlnets/multicontrolnet_union.py +5 -6
  95. diffusers/models/downsampling.py +2 -2
  96. diffusers/models/embeddings.py +36 -46
  97. diffusers/models/embeddings_flax.py +2 -2
  98. diffusers/models/lora.py +3 -3
  99. diffusers/models/model_loading_utils.py +233 -1
  100. diffusers/models/modeling_flax_utils.py +1 -2
  101. diffusers/models/modeling_utils.py +203 -108
  102. diffusers/models/normalization.py +4 -4
  103. diffusers/models/resnet.py +2 -2
  104. diffusers/models/resnet_flax.py +1 -1
  105. diffusers/models/transformers/__init__.py +7 -0
  106. diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
  107. diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
  108. diffusers/models/transformers/consisid_transformer_3d.py +1 -1
  109. diffusers/models/transformers/dit_transformer_2d.py +2 -2
  110. diffusers/models/transformers/dual_transformer_2d.py +1 -1
  111. diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
  112. diffusers/models/transformers/latte_transformer_3d.py +4 -5
  113. diffusers/models/transformers/lumina_nextdit2d.py +2 -2
  114. diffusers/models/transformers/pixart_transformer_2d.py +3 -3
  115. diffusers/models/transformers/prior_transformer.py +1 -1
  116. diffusers/models/transformers/sana_transformer.py +8 -3
  117. diffusers/models/transformers/stable_audio_transformer.py +5 -9
  118. diffusers/models/transformers/t5_film_transformer.py +3 -3
  119. diffusers/models/transformers/transformer_2d.py +1 -1
  120. diffusers/models/transformers/transformer_allegro.py +1 -1
  121. diffusers/models/transformers/transformer_chroma.py +641 -0
  122. diffusers/models/transformers/transformer_cogview3plus.py +5 -10
  123. diffusers/models/transformers/transformer_cogview4.py +353 -27
  124. diffusers/models/transformers/transformer_cosmos.py +586 -0
  125. diffusers/models/transformers/transformer_flux.py +376 -138
  126. diffusers/models/transformers/transformer_hidream_image.py +942 -0
  127. diffusers/models/transformers/transformer_hunyuan_video.py +12 -8
  128. diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
  129. diffusers/models/transformers/transformer_ltx.py +105 -24
  130. diffusers/models/transformers/transformer_lumina2.py +1 -1
  131. diffusers/models/transformers/transformer_mochi.py +1 -1
  132. diffusers/models/transformers/transformer_omnigen.py +2 -2
  133. diffusers/models/transformers/transformer_qwenimage.py +645 -0
  134. diffusers/models/transformers/transformer_sd3.py +7 -7
  135. diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
  136. diffusers/models/transformers/transformer_temporal.py +1 -1
  137. diffusers/models/transformers/transformer_wan.py +316 -87
  138. diffusers/models/transformers/transformer_wan_vace.py +387 -0
  139. diffusers/models/unets/unet_1d.py +1 -1
  140. diffusers/models/unets/unet_1d_blocks.py +1 -1
  141. diffusers/models/unets/unet_2d.py +1 -1
  142. diffusers/models/unets/unet_2d_blocks.py +1 -1
  143. diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
  144. diffusers/models/unets/unet_2d_condition.py +4 -3
  145. diffusers/models/unets/unet_2d_condition_flax.py +2 -2
  146. diffusers/models/unets/unet_3d_blocks.py +1 -1
  147. diffusers/models/unets/unet_3d_condition.py +3 -3
  148. diffusers/models/unets/unet_i2vgen_xl.py +3 -3
  149. diffusers/models/unets/unet_kandinsky3.py +1 -1
  150. diffusers/models/unets/unet_motion_model.py +2 -2
  151. diffusers/models/unets/unet_stable_cascade.py +1 -1
  152. diffusers/models/upsampling.py +2 -2
  153. diffusers/models/vae_flax.py +2 -2
  154. diffusers/models/vq_model.py +1 -1
  155. diffusers/modular_pipelines/__init__.py +83 -0
  156. diffusers/modular_pipelines/components_manager.py +1068 -0
  157. diffusers/modular_pipelines/flux/__init__.py +66 -0
  158. diffusers/modular_pipelines/flux/before_denoise.py +689 -0
  159. diffusers/modular_pipelines/flux/decoders.py +109 -0
  160. diffusers/modular_pipelines/flux/denoise.py +227 -0
  161. diffusers/modular_pipelines/flux/encoders.py +412 -0
  162. diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
  163. diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
  164. diffusers/modular_pipelines/modular_pipeline.py +2446 -0
  165. diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
  166. diffusers/modular_pipelines/node_utils.py +665 -0
  167. diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
  168. diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
  169. diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
  170. diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
  171. diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
  172. diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
  173. diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
  174. diffusers/modular_pipelines/wan/__init__.py +66 -0
  175. diffusers/modular_pipelines/wan/before_denoise.py +365 -0
  176. diffusers/modular_pipelines/wan/decoders.py +105 -0
  177. diffusers/modular_pipelines/wan/denoise.py +261 -0
  178. diffusers/modular_pipelines/wan/encoders.py +242 -0
  179. diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
  180. diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
  181. diffusers/pipelines/__init__.py +68 -6
  182. diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
  183. diffusers/pipelines/amused/pipeline_amused.py +7 -6
  184. diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
  185. diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
  186. diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
  187. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
  188. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
  189. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
  190. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
  191. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
  192. diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
  193. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  194. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +22 -13
  195. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
  196. diffusers/pipelines/auto_pipeline.py +23 -20
  197. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  198. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
  199. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
  200. diffusers/pipelines/chroma/__init__.py +49 -0
  201. diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
  202. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
  203. diffusers/pipelines/chroma/pipeline_output.py +21 -0
  204. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +17 -16
  205. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +17 -16
  206. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +18 -17
  207. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +17 -16
  208. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
  209. diffusers/pipelines/cogview4/pipeline_cogview4.py +23 -22
  210. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
  211. diffusers/pipelines/consisid/consisid_utils.py +2 -2
  212. diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
  213. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  214. diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
  215. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +11 -10
  216. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
  217. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
  218. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
  219. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
  220. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
  221. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +226 -107
  222. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +12 -8
  223. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +207 -105
  224. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  225. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
  226. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
  227. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  228. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
  229. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
  230. diffusers/pipelines/cosmos/__init__.py +54 -0
  231. diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
  232. diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
  233. diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
  234. diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
  235. diffusers/pipelines/cosmos/pipeline_output.py +40 -0
  236. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
  237. diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
  238. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
  239. diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
  240. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
  241. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
  242. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
  243. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
  244. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
  245. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
  246. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
  247. diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
  248. diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
  249. diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
  250. diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
  251. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
  252. diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
  253. diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
  254. diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
  255. diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
  256. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
  257. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +8 -8
  258. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
  259. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
  260. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
  261. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
  262. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
  263. diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
  264. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
  265. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
  266. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
  267. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
  268. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
  269. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
  270. diffusers/pipelines/dit/pipeline_dit.py +4 -2
  271. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
  272. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
  273. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
  274. diffusers/pipelines/flux/__init__.py +4 -0
  275. diffusers/pipelines/flux/modeling_flux.py +1 -1
  276. diffusers/pipelines/flux/pipeline_flux.py +37 -36
  277. diffusers/pipelines/flux/pipeline_flux_control.py +9 -9
  278. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +7 -7
  279. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +7 -7
  280. diffusers/pipelines/flux/pipeline_flux_controlnet.py +7 -7
  281. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +31 -23
  282. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +3 -2
  283. diffusers/pipelines/flux/pipeline_flux_fill.py +7 -7
  284. diffusers/pipelines/flux/pipeline_flux_img2img.py +40 -7
  285. diffusers/pipelines/flux/pipeline_flux_inpaint.py +12 -7
  286. diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
  287. diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
  288. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +2 -2
  289. diffusers/pipelines/flux/pipeline_output.py +6 -4
  290. diffusers/pipelines/free_init_utils.py +2 -2
  291. diffusers/pipelines/free_noise_utils.py +3 -3
  292. diffusers/pipelines/hidream_image/__init__.py +47 -0
  293. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
  294. diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
  295. diffusers/pipelines/hunyuan_video/__init__.py +2 -0
  296. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
  297. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +26 -25
  298. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
  299. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
  300. diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
  301. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
  302. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
  303. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
  304. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
  305. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
  306. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
  307. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
  308. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  309. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
  310. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
  311. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
  312. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
  313. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
  314. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
  315. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
  316. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
  317. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
  318. diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
  319. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
  320. diffusers/pipelines/kolors/text_encoder.py +3 -3
  321. diffusers/pipelines/kolors/tokenizer.py +1 -1
  322. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
  323. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
  324. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  325. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
  326. diffusers/pipelines/latte/pipeline_latte.py +12 -12
  327. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
  328. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
  329. diffusers/pipelines/ltx/__init__.py +4 -0
  330. diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
  331. diffusers/pipelines/ltx/pipeline_ltx.py +64 -18
  332. diffusers/pipelines/ltx/pipeline_ltx_condition.py +117 -38
  333. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +63 -18
  334. diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
  335. diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
  336. diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
  337. diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
  338. diffusers/pipelines/mochi/pipeline_mochi.py +15 -14
  339. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
  340. diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
  341. diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
  342. diffusers/pipelines/onnx_utils.py +15 -2
  343. diffusers/pipelines/pag/pag_utils.py +2 -2
  344. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
  345. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
  346. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
  347. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
  348. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
  349. diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
  350. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
  351. diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
  352. diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
  353. diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
  354. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
  355. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
  356. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
  357. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
  358. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
  359. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
  360. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
  361. diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
  362. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
  363. diffusers/pipelines/pia/pipeline_pia.py +8 -6
  364. diffusers/pipelines/pipeline_flax_utils.py +5 -6
  365. diffusers/pipelines/pipeline_loading_utils.py +113 -15
  366. diffusers/pipelines/pipeline_utils.py +127 -48
  367. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +14 -12
  368. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +31 -11
  369. diffusers/pipelines/qwenimage/__init__.py +55 -0
  370. diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
  371. diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
  372. diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
  373. diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
  374. diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
  375. diffusers/pipelines/sana/__init__.py +4 -0
  376. diffusers/pipelines/sana/pipeline_sana.py +23 -21
  377. diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
  378. diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
  379. diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
  380. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
  381. diffusers/pipelines/shap_e/camera.py +1 -1
  382. diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
  383. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
  384. diffusers/pipelines/shap_e/renderer.py +3 -3
  385. diffusers/pipelines/skyreels_v2/__init__.py +59 -0
  386. diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
  387. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
  388. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
  389. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
  390. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
  391. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
  392. diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
  393. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
  394. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
  395. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
  396. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
  397. diffusers/pipelines/stable_diffusion/__init__.py +0 -7
  398. diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
  399. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
  400. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  401. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
  402. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
  403. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +12 -11
  404. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
  405. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +11 -11
  406. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +10 -10
  407. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -9
  408. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
  409. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
  410. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
  411. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
  412. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
  413. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
  414. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
  415. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
  416. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
  417. diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
  418. diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
  419. diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
  420. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +13 -12
  421. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
  422. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
  423. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
  424. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
  425. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
  426. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
  427. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
  428. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
  429. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
  430. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
  431. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
  432. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  433. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
  434. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  435. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
  436. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
  437. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
  438. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
  439. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
  440. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
  441. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
  442. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
  443. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
  444. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
  445. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
  446. diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
  447. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
  448. diffusers/pipelines/unclip/text_proj.py +2 -2
  449. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
  450. diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
  451. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
  452. diffusers/pipelines/visualcloze/__init__.py +52 -0
  453. diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
  454. diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
  455. diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
  456. diffusers/pipelines/wan/__init__.py +2 -0
  457. diffusers/pipelines/wan/pipeline_wan.py +91 -30
  458. diffusers/pipelines/wan/pipeline_wan_i2v.py +145 -45
  459. diffusers/pipelines/wan/pipeline_wan_vace.py +975 -0
  460. diffusers/pipelines/wan/pipeline_wan_video2video.py +14 -16
  461. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  462. diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
  463. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  464. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
  465. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
  466. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
  467. diffusers/quantizers/__init__.py +3 -1
  468. diffusers/quantizers/base.py +17 -1
  469. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
  470. diffusers/quantizers/bitsandbytes/utils.py +10 -7
  471. diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
  472. diffusers/quantizers/gguf/utils.py +108 -16
  473. diffusers/quantizers/pipe_quant_config.py +202 -0
  474. diffusers/quantizers/quantization_config.py +18 -16
  475. diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
  476. diffusers/quantizers/torchao/torchao_quantizer.py +31 -1
  477. diffusers/schedulers/__init__.py +3 -1
  478. diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
  479. diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
  480. diffusers/schedulers/scheduling_consistency_models.py +1 -1
  481. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
  482. diffusers/schedulers/scheduling_ddim.py +8 -8
  483. diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
  484. diffusers/schedulers/scheduling_ddim_flax.py +6 -6
  485. diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
  486. diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
  487. diffusers/schedulers/scheduling_ddpm.py +9 -9
  488. diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
  489. diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
  490. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
  491. diffusers/schedulers/scheduling_deis_multistep.py +16 -9
  492. diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
  493. diffusers/schedulers/scheduling_dpmsolver_multistep.py +18 -12
  494. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
  495. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
  496. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
  497. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +19 -13
  498. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
  499. diffusers/schedulers/scheduling_edm_euler.py +20 -11
  500. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
  501. diffusers/schedulers/scheduling_euler_discrete.py +3 -3
  502. diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
  503. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
  504. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
  505. diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
  506. diffusers/schedulers/scheduling_heun_discrete.py +2 -2
  507. diffusers/schedulers/scheduling_ipndm.py +2 -2
  508. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
  509. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
  510. diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
  511. diffusers/schedulers/scheduling_lcm.py +3 -3
  512. diffusers/schedulers/scheduling_lms_discrete.py +2 -2
  513. diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
  514. diffusers/schedulers/scheduling_pndm.py +4 -4
  515. diffusers/schedulers/scheduling_pndm_flax.py +4 -4
  516. diffusers/schedulers/scheduling_repaint.py +9 -9
  517. diffusers/schedulers/scheduling_sasolver.py +15 -15
  518. diffusers/schedulers/scheduling_scm.py +1 -2
  519. diffusers/schedulers/scheduling_sde_ve.py +1 -1
  520. diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
  521. diffusers/schedulers/scheduling_tcd.py +3 -3
  522. diffusers/schedulers/scheduling_unclip.py +5 -5
  523. diffusers/schedulers/scheduling_unipc_multistep.py +21 -12
  524. diffusers/schedulers/scheduling_utils.py +3 -3
  525. diffusers/schedulers/scheduling_utils_flax.py +2 -2
  526. diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
  527. diffusers/training_utils.py +91 -5
  528. diffusers/utils/__init__.py +15 -0
  529. diffusers/utils/accelerate_utils.py +1 -1
  530. diffusers/utils/constants.py +4 -0
  531. diffusers/utils/doc_utils.py +1 -1
  532. diffusers/utils/dummy_pt_objects.py +432 -0
  533. diffusers/utils/dummy_torch_and_transformers_objects.py +480 -0
  534. diffusers/utils/dynamic_modules_utils.py +85 -8
  535. diffusers/utils/export_utils.py +1 -1
  536. diffusers/utils/hub_utils.py +33 -17
  537. diffusers/utils/import_utils.py +151 -18
  538. diffusers/utils/logging.py +1 -1
  539. diffusers/utils/outputs.py +2 -1
  540. diffusers/utils/peft_utils.py +96 -10
  541. diffusers/utils/state_dict_utils.py +20 -3
  542. diffusers/utils/testing_utils.py +195 -17
  543. diffusers/utils/torch_utils.py +43 -5
  544. diffusers/video_processor.py +2 -2
  545. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/METADATA +72 -57
  546. diffusers-0.35.0.dist-info/RECORD +703 -0
  547. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/WHEEL +1 -1
  548. diffusers-0.33.1.dist-info/RECORD +0 -608
  549. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
  550. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
  551. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -14,6 +14,7 @@
14
14
 
15
15
  import copy
16
16
  import inspect
17
+ import json
17
18
  import os
18
19
  from pathlib import Path
19
20
  from typing import Callable, Dict, List, Optional, Union
@@ -33,7 +34,6 @@ from ..utils import (
33
34
  delete_adapter_layers,
34
35
  deprecate,
35
36
  get_adapter_name,
36
- get_peft_kwargs,
37
37
  is_accelerate_available,
38
38
  is_peft_available,
39
39
  is_peft_version,
@@ -45,13 +45,13 @@ from ..utils import (
45
45
  set_adapter_layers,
46
46
  set_weights_and_activate_adapters,
47
47
  )
48
+ from ..utils.peft_utils import _create_lora_config
49
+ from ..utils.state_dict_utils import _load_sft_state_dict_metadata
48
50
 
49
51
 
50
52
  if is_transformers_available():
51
53
  from transformers import PreTrainedModel
52
54
 
53
- from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
54
-
55
55
  if is_peft_available():
56
56
  from peft.tuners.tuners_utils import BaseTunerLayer
57
57
 
@@ -62,6 +62,7 @@ logger = logging.get_logger(__name__)
62
62
 
63
63
  LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
64
64
  LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
65
+ LORA_ADAPTER_METADATA_KEY = "lora_adapter_metadata"
65
66
 
66
67
 
67
68
  def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
@@ -206,6 +207,7 @@ def _fetch_state_dict(
206
207
  subfolder,
207
208
  user_agent,
208
209
  allow_pickle,
210
+ metadata=None,
209
211
  ):
210
212
  model_file = None
211
213
  if not isinstance(pretrained_model_name_or_path_or_dict, dict):
@@ -236,11 +238,14 @@ def _fetch_state_dict(
236
238
  user_agent=user_agent,
237
239
  )
238
240
  state_dict = safetensors.torch.load_file(model_file, device="cpu")
241
+ metadata = _load_sft_state_dict_metadata(model_file)
242
+
239
243
  except (IOError, safetensors.SafetensorError) as e:
240
244
  if not allow_pickle:
241
245
  raise e
242
246
  # try loading non-safetensors weights
243
247
  model_file = None
248
+ metadata = None
244
249
  pass
245
250
 
246
251
  if model_file is None:
@@ -261,10 +266,11 @@ def _fetch_state_dict(
261
266
  user_agent=user_agent,
262
267
  )
263
268
  state_dict = load_state_dict(model_file)
269
+ metadata = None
264
270
  else:
265
271
  state_dict = pretrained_model_name_or_path_or_dict
266
272
 
267
- return state_dict
273
+ return state_dict, metadata
268
274
 
269
275
 
270
276
  def _best_guess_weight_name(
@@ -299,13 +305,18 @@ def _best_guess_weight_name(
299
305
  targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files))
300
306
 
301
307
  if len(targeted_files) > 1:
302
- raise ValueError(
303
- f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}."
308
+ logger.warning(
309
+ f"Provided path contains more than one weights file in the {file_extension} format. `{targeted_files[0]}` is going to be loaded, for precise control, specify a `weight_name` in `load_lora_weights`."
304
310
  )
305
311
  weight_name = targeted_files[0]
306
312
  return weight_name
307
313
 
308
314
 
315
+ def _pack_dict_with_prefix(state_dict, prefix):
316
+ sd_with_prefix = {f"{prefix}.{key}": value for key, value in state_dict.items()}
317
+ return sd_with_prefix
318
+
319
+
309
320
  def _load_lora_into_text_encoder(
310
321
  state_dict,
311
322
  network_alphas,
@@ -317,10 +328,16 @@ def _load_lora_into_text_encoder(
317
328
  _pipeline=None,
318
329
  low_cpu_mem_usage=False,
319
330
  hotswap: bool = False,
331
+ metadata=None,
320
332
  ):
333
+ from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
334
+
321
335
  if not USE_PEFT_BACKEND:
322
336
  raise ValueError("PEFT backend is required for this method.")
323
337
 
338
+ if network_alphas and metadata:
339
+ raise ValueError("`network_alphas` and `metadata` cannot be specified both at the same time.")
340
+
324
341
  peft_kwargs = {}
325
342
  if low_cpu_mem_usage:
326
343
  if not is_peft_version(">=", "0.13.1"):
@@ -335,8 +352,6 @@ def _load_lora_into_text_encoder(
335
352
  )
336
353
  peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
337
354
 
338
- from peft import LoraConfig
339
-
340
355
  # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
341
356
  # then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as
342
357
  # their prefixes.
@@ -348,7 +363,9 @@ def _load_lora_into_text_encoder(
348
363
 
349
364
  # Load the layers corresponding to text encoder and make necessary adjustments.
350
365
  if prefix is not None:
351
- state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
366
+ state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
367
+ if metadata is not None:
368
+ metadata = {k.removeprefix(f"{prefix}."): v for k, v in metadata.items() if k.startswith(f"{prefix}.")}
352
369
 
353
370
  if len(state_dict) > 0:
354
371
  logger.info(f"Loading {prefix}.")
@@ -358,54 +375,27 @@ def _load_lora_into_text_encoder(
358
375
  # convert state dict
359
376
  state_dict = convert_state_dict_to_peft(state_dict)
360
377
 
361
- for name, _ in text_encoder_attn_modules(text_encoder):
362
- for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
363
- rank_key = f"{name}.{module}.lora_B.weight"
364
- if rank_key not in state_dict:
365
- continue
366
- rank[rank_key] = state_dict[rank_key].shape[1]
367
-
368
- for name, _ in text_encoder_mlp_modules(text_encoder):
369
- for module in ("fc1", "fc2"):
370
- rank_key = f"{name}.{module}.lora_B.weight"
371
- if rank_key not in state_dict:
372
- continue
373
- rank[rank_key] = state_dict[rank_key].shape[1]
378
+ for name, _ in text_encoder.named_modules():
379
+ if name.endswith((".q_proj", ".k_proj", ".v_proj", ".out_proj", ".fc1", ".fc2")):
380
+ rank_key = f"{name}.lora_B.weight"
381
+ if rank_key in state_dict:
382
+ rank[rank_key] = state_dict[rank_key].shape[1]
374
383
 
375
384
  if network_alphas is not None:
376
385
  alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
377
- network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
378
-
379
- lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)
380
-
381
- if "use_dora" in lora_config_kwargs:
382
- if lora_config_kwargs["use_dora"]:
383
- if is_peft_version("<", "0.9.0"):
384
- raise ValueError(
385
- "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
386
- )
387
- else:
388
- if is_peft_version("<", "0.9.0"):
389
- lora_config_kwargs.pop("use_dora")
390
-
391
- if "lora_bias" in lora_config_kwargs:
392
- if lora_config_kwargs["lora_bias"]:
393
- if is_peft_version("<=", "0.13.2"):
394
- raise ValueError(
395
- "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
396
- )
397
- else:
398
- if is_peft_version("<=", "0.13.2"):
399
- lora_config_kwargs.pop("lora_bias")
386
+ network_alphas = {k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys}
400
387
 
401
- lora_config = LoraConfig(**lora_config_kwargs)
388
+ # create `LoraConfig`
389
+ lora_config = _create_lora_config(state_dict, network_alphas, metadata, rank, is_unet=False)
402
390
 
403
391
  # adapter_name
404
392
  if adapter_name is None:
405
393
  adapter_name = get_adapter_name(text_encoder)
406
394
 
407
- is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
408
-
395
+ # <Unsafe code
396
+ is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = _func_optionally_disable_offloading(
397
+ _pipeline
398
+ )
409
399
  # inject LoRA layers and load the state dict
410
400
  # in transformers we automatically check whether the adapter name is already in use or not
411
401
  text_encoder.load_adapter(
@@ -417,7 +407,6 @@ def _load_lora_into_text_encoder(
417
407
 
418
408
  # scale LoRA layers with `lora_scale`
419
409
  scale_lora_layers(text_encoder, weight=lora_scale)
420
-
421
410
  text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
422
411
 
423
412
  # Offload back.
@@ -425,47 +414,90 @@ def _load_lora_into_text_encoder(
425
414
  _pipeline.enable_model_cpu_offload()
426
415
  elif is_sequential_cpu_offload:
427
416
  _pipeline.enable_sequential_cpu_offload()
417
+ elif is_group_offload:
418
+ for component in _pipeline.components.values():
419
+ if isinstance(component, torch.nn.Module):
420
+ _maybe_remove_and_reapply_group_offloading(component)
428
421
  # Unsafe code />
429
422
 
430
423
  if prefix is not None and not state_dict:
424
+ model_class_name = text_encoder.__class__.__name__
431
425
  logger.warning(
432
- f"No LoRA keys associated to {text_encoder.__class__.__name__} found with the {prefix=}. "
426
+ f"No LoRA keys associated to {model_class_name} found with the {prefix=}. "
433
427
  "This is safe to ignore if LoRA state dict didn't originally have any "
434
- f"{text_encoder.__class__.__name__} related params. You can also try specifying `prefix=None` "
428
+ f"{model_class_name} related params. You can also try specifying `prefix=None` "
435
429
  "to resolve the warning. Otherwise, open an issue if you think it's unexpected: "
436
430
  "https://github.com/huggingface/diffusers/issues/new"
437
431
  )
438
432
 
439
433
 
440
434
  def _func_optionally_disable_offloading(_pipeline):
435
+ """
436
+ Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
437
+
438
+ Args:
439
+ _pipeline (`DiffusionPipeline`):
440
+ The pipeline to disable offloading for.
441
+
442
+ Returns:
443
+ tuple:
444
+ A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` or `is_group_offload` is True.
445
+ """
446
+ from ..hooks.group_offloading import _is_group_offload_enabled
447
+
441
448
  is_model_cpu_offload = False
442
449
  is_sequential_cpu_offload = False
450
+ is_group_offload = False
443
451
 
444
452
  if _pipeline is not None and _pipeline.hf_device_map is None:
445
453
  for _, component in _pipeline.components.items():
446
- if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
447
- if not is_model_cpu_offload:
448
- is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
449
- if not is_sequential_cpu_offload:
450
- is_sequential_cpu_offload = (
451
- isinstance(component._hf_hook, AlignDevicesHook)
452
- or hasattr(component._hf_hook, "hooks")
453
- and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
454
- )
454
+ if not isinstance(component, nn.Module):
455
+ continue
456
+ is_group_offload = is_group_offload or _is_group_offload_enabled(component)
457
+ if not hasattr(component, "_hf_hook"):
458
+ continue
459
+ is_model_cpu_offload = is_model_cpu_offload or isinstance(component._hf_hook, CpuOffload)
460
+ is_sequential_cpu_offload = is_sequential_cpu_offload or (
461
+ isinstance(component._hf_hook, AlignDevicesHook)
462
+ or hasattr(component._hf_hook, "hooks")
463
+ and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
464
+ )
455
465
 
456
- logger.info(
457
- "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
458
- )
466
+ if is_sequential_cpu_offload or is_model_cpu_offload:
467
+ logger.info(
468
+ "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
469
+ )
470
+ for _, component in _pipeline.components.items():
471
+ if not isinstance(component, nn.Module) or not hasattr(component, "_hf_hook"):
472
+ continue
459
473
  remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
460
474
 
461
- return (is_model_cpu_offload, is_sequential_cpu_offload)
475
+ return (is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload)
462
476
 
463
477
 
464
478
  class LoraBaseMixin:
465
479
  """Utility class for handling LoRAs."""
466
480
 
467
481
  _lora_loadable_modules = []
468
- num_fused_loras = 0
482
+ _merged_adapters = set()
483
+
484
+ @property
485
+ def lora_scale(self) -> float:
486
+ """
487
+ Returns the lora scale which can be set at run time by the pipeline. # if `_lora_scale` has not been set,
488
+ return 1.
489
+ """
490
+ return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
491
+
492
+ @property
493
+ def num_fused_loras(self):
494
+ """Returns the number of LoRAs that have been fused."""
495
+ return len(self._merged_adapters)
496
+
497
+ @property
498
+ def fused_loras(self):
499
+ """Returns names of the LoRAs that have been fused."""
500
+ return self._merged_adapters
469
501
 
470
502
  def load_lora_weights(self, **kwargs):
471
503
  raise NotImplementedError("`load_lora_weights()` is not implemented.")
@@ -478,33 +510,6 @@ class LoraBaseMixin:
478
510
  def lora_state_dict(cls, **kwargs):
479
511
  raise NotImplementedError("`lora_state_dict()` is not implemented.")
480
512
 
481
- @classmethod
482
- def _optionally_disable_offloading(cls, _pipeline):
483
- """
484
- Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
485
-
486
- Args:
487
- _pipeline (`DiffusionPipeline`):
488
- The pipeline to disable offloading for.
489
-
490
- Returns:
491
- tuple:
492
- A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
493
- """
494
- return _func_optionally_disable_offloading(_pipeline=_pipeline)
495
-
496
- @classmethod
497
- def _fetch_state_dict(cls, *args, **kwargs):
498
- deprecation_message = f"Using the `_fetch_state_dict()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`."
499
- deprecate("_fetch_state_dict", "0.35.0", deprecation_message)
500
- return _fetch_state_dict(*args, **kwargs)
501
-
502
- @classmethod
503
- def _best_guess_weight_name(cls, *args, **kwargs):
504
- deprecation_message = f"Using the `_best_guess_weight_name()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`."
505
- deprecate("_best_guess_weight_name", "0.35.0", deprecation_message)
506
- return _best_guess_weight_name(*args, **kwargs)
507
-
508
513
  def unload_lora_weights(self):
509
514
  """
510
515
  Unloads the LoRA parameters.
@@ -592,6 +597,9 @@ class LoraBaseMixin:
592
597
  if len(components) == 0:
593
598
  raise ValueError("`components` cannot be an empty list.")
594
599
 
600
+ # Need to retrieve the names as `adapter_names` can be None. So we cannot directly use it
601
+ # in `self._merged_adapters = self._merged_adapters | merged_adapter_names`.
602
+ merged_adapter_names = set()
595
603
  for fuse_component in components:
596
604
  if fuse_component not in self._lora_loadable_modules:
597
605
  raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.")
@@ -601,13 +609,19 @@ class LoraBaseMixin:
601
609
  # check if diffusers model
602
610
  if issubclass(model.__class__, ModelMixin):
603
611
  model.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
612
+ for module in model.modules():
613
+ if isinstance(module, BaseTunerLayer):
614
+ merged_adapter_names.update(set(module.merged_adapters))
604
615
  # handle transformers models.
605
616
  if issubclass(model.__class__, PreTrainedModel):
606
617
  fuse_text_encoder_lora(
607
618
  model, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
608
619
  )
620
+ for module in model.modules():
621
+ if isinstance(module, BaseTunerLayer):
622
+ merged_adapter_names.update(set(module.merged_adapters))
609
623
 
610
- self.num_fused_loras += 1
624
+ self._merged_adapters = self._merged_adapters | merged_adapter_names
611
625
 
612
626
  def unfuse_lora(self, components: List[str] = [], **kwargs):
613
627
  r"""
@@ -661,15 +675,42 @@ class LoraBaseMixin:
661
675
  if issubclass(model.__class__, (ModelMixin, PreTrainedModel)):
662
676
  for module in model.modules():
663
677
  if isinstance(module, BaseTunerLayer):
678
+ for adapter in set(module.merged_adapters):
679
+ if adapter and adapter in self._merged_adapters:
680
+ self._merged_adapters = self._merged_adapters - {adapter}
664
681
  module.unmerge()
665
682
 
666
- self.num_fused_loras -= 1
667
-
668
683
  def set_adapters(
669
684
  self,
670
685
  adapter_names: Union[List[str], str],
671
686
  adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None,
672
687
  ):
688
+ """
689
+ Set the currently active adapters for use in the pipeline.
690
+
691
+ Args:
692
+ adapter_names (`List[str]` or `str`):
693
+ The names of the adapters to use.
694
+ adapter_weights (`Union[List[float], float]`, *optional*):
695
+ The adapter(s) weights to use with the UNet. If `None`, the weights are set to `1.0` for all the
696
+ adapters.
697
+
698
+ Example:
699
+
700
+ ```py
701
+ from diffusers import AutoPipelineForText2Image
702
+ import torch
703
+
704
+ pipeline = AutoPipelineForText2Image.from_pretrained(
705
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
706
+ ).to("cuda")
707
+ pipeline.load_lora_weights(
708
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
709
+ )
710
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
711
+ pipeline.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5])
712
+ ```
713
+ """
673
714
  if isinstance(adapter_weights, dict):
674
715
  components_passed = set(adapter_weights.keys())
675
716
  lora_components = set(self._lora_loadable_modules)
@@ -713,7 +754,11 @@ class LoraBaseMixin:
713
754
  # Decompose weights into weights for denoiser and text encoders.
714
755
  _component_adapter_weights = {}
715
756
  for component in self._lora_loadable_modules:
716
- model = getattr(self, component)
757
+ model = getattr(self, component, None)
758
+ # To guard for cases like Wan. In Wan2.1 and WanVace, we have a single denoiser.
759
+ # Whereas in Wan 2.2, we have two denoisers.
760
+ if model is None:
761
+ continue
717
762
 
718
763
  for adapter_name, weights in zip(adapter_names, adapter_weights):
719
764
  if isinstance(weights, dict):
@@ -739,6 +784,24 @@ class LoraBaseMixin:
739
784
  set_adapters_for_text_encoder(adapter_names, model, _component_adapter_weights[component])
740
785
 
741
786
  def disable_lora(self):
787
+ """
788
+ Disables the active LoRA layers of the pipeline.
789
+
790
+ Example:
791
+
792
+ ```py
793
+ from diffusers import AutoPipelineForText2Image
794
+ import torch
795
+
796
+ pipeline = AutoPipelineForText2Image.from_pretrained(
797
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
798
+ ).to("cuda")
799
+ pipeline.load_lora_weights(
800
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
801
+ )
802
+ pipeline.disable_lora()
803
+ ```
804
+ """
742
805
  if not USE_PEFT_BACKEND:
743
806
  raise ValueError("PEFT backend is required for this method.")
744
807
 
@@ -751,6 +814,24 @@ class LoraBaseMixin:
751
814
  disable_lora_for_text_encoder(model)
752
815
 
753
816
  def enable_lora(self):
817
+ """
818
+ Enables the active LoRA layers of the pipeline.
819
+
820
+ Example:
821
+
822
+ ```py
823
+ from diffusers import AutoPipelineForText2Image
824
+ import torch
825
+
826
+ pipeline = AutoPipelineForText2Image.from_pretrained(
827
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
828
+ ).to("cuda")
829
+ pipeline.load_lora_weights(
830
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
831
+ )
832
+ pipeline.enable_lora()
833
+ ```
834
+ """
754
835
  if not USE_PEFT_BACKEND:
755
836
  raise ValueError("PEFT backend is required for this method.")
756
837
 
@@ -764,10 +845,26 @@ class LoraBaseMixin:
764
845
 
765
846
  def delete_adapters(self, adapter_names: Union[List[str], str]):
766
847
  """
848
+ Delete an adapter's LoRA layers from the pipeline.
849
+
767
850
  Args:
768
- Deletes the LoRA layers of `adapter_name` for the unet and text-encoder(s).
769
851
  adapter_names (`Union[List[str], str]`):
770
- The names of the adapter to delete. Can be a single string or a list of strings
852
+ The names of the adapters to delete.
853
+
854
+ Example:
855
+
856
+ ```py
857
+ from diffusers import AutoPipelineForText2Image
858
+ import torch
859
+
860
+ pipeline = AutoPipelineForText2Image.from_pretrained(
861
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
862
+ ).to("cuda")
863
+ pipeline.load_lora_weights(
864
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_names="cinematic"
865
+ )
866
+ pipeline.delete_adapters("cinematic")
867
+ ```
771
868
  """
772
869
  if not USE_PEFT_BACKEND:
773
870
  raise ValueError("PEFT backend is required for this method.")
@@ -844,6 +941,27 @@ class LoraBaseMixin:
844
941
  Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case
845
942
  you want to load multiple adapters and free some GPU memory.
846
943
 
944
+ After offloading the LoRA adapters to CPU, as long as the rest of the model is still on GPU, the LoRA adapters
945
+ can no longer be used for inference, as that would cause a device mismatch. Remember to set the device back to
946
+ GPU before using those LoRA adapters for inference.
947
+
948
+ ```python
949
+ >>> pipe.load_lora_weights(path_1, adapter_name="adapter-1")
950
+ >>> pipe.load_lora_weights(path_2, adapter_name="adapter-2")
951
+ >>> pipe.set_adapters("adapter-1")
952
+ >>> image_1 = pipe(**kwargs)
953
+ >>> # switch to adapter-2, offload adapter-1
954
+ >>> pipeline.set_lora_device(adapter_names=["adapter-1"], device="cpu")
955
+ >>> pipeline.set_lora_device(adapter_names=["adapter-2"], device="cuda:0")
956
+ >>> pipe.set_adapters("adapter-2")
957
+ >>> image_2 = pipe(**kwargs)
958
+ >>> # switch back to adapter-1, offload adapter-2
959
+ >>> pipeline.set_lora_device(adapter_names=["adapter-2"], device="cpu")
960
+ >>> pipeline.set_lora_device(adapter_names=["adapter-1"], device="cuda:0")
961
+ >>> pipe.set_adapters("adapter-1")
962
+ >>> ...
963
+ ```
964
+
847
965
  Args:
848
966
  adapter_names (`List[str]`):
849
967
  List of adapters to send device to.
@@ -859,6 +977,10 @@ class LoraBaseMixin:
859
977
  for module in model.modules():
860
978
  if isinstance(module, BaseTunerLayer):
861
979
  for adapter_name in adapter_names:
980
+ if adapter_name not in module.lora_A:
981
+ # it is sufficient to check lora_A
982
+ continue
983
+
862
984
  module.lora_A[adapter_name].to(device)
863
985
  module.lora_B[adapter_name].to(device)
864
986
  # this is a param, not a module, so device placement is not in-place -> re-assign
@@ -868,11 +990,28 @@ class LoraBaseMixin:
868
990
  adapter_name
869
991
  ].to(device)
870
992
 
993
+ def enable_lora_hotswap(self, **kwargs) -> None:
994
+ """
995
+ Hotswap adapters without triggering recompilation of a model or if the ranks of the loaded adapters are
996
+ different.
997
+
998
+ Args:
999
+ target_rank (`int`):
1000
+ The highest rank among all the adapters that will be loaded.
1001
+ check_compiled (`str`, *optional*, defaults to `"error"`):
1002
+ How to handle a model that is already compiled. The check can return the following messages:
1003
+ - "error" (default): raise an error
1004
+ - "warn": issue a warning
1005
+ - "ignore": do nothing
1006
+ """
1007
+ for key, component in self.components.items():
1008
+ if hasattr(component, "enable_lora_hotswap") and (key in self._lora_loadable_modules):
1009
+ component.enable_lora_hotswap(**kwargs)
1010
+
871
1011
  @staticmethod
872
1012
  def pack_weights(layers, prefix):
873
1013
  layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
874
- layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
875
- return layers_state_dict
1014
+ return _pack_dict_with_prefix(layers_weights, prefix)
876
1015
 
877
1016
  @staticmethod
878
1017
  def write_lora_layers(
@@ -882,16 +1021,33 @@ class LoraBaseMixin:
882
1021
  weight_name: str,
883
1022
  save_function: Callable,
884
1023
  safe_serialization: bool,
1024
+ lora_adapter_metadata: Optional[dict] = None,
885
1025
  ):
1026
+ """Writes the state dict of the LoRA layers (optionally with metadata) to disk."""
886
1027
  if os.path.isfile(save_directory):
887
1028
  logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
888
1029
  return
889
1030
 
1031
+ if lora_adapter_metadata and not safe_serialization:
1032
+ raise ValueError("`lora_adapter_metadata` cannot be specified when not using `safe_serialization`.")
1033
+ if lora_adapter_metadata and not isinstance(lora_adapter_metadata, dict):
1034
+ raise TypeError("`lora_adapter_metadata` must be of type `dict`.")
1035
+
890
1036
  if save_function is None:
891
1037
  if safe_serialization:
892
1038
 
893
1039
  def save_function(weights, filename):
894
- return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
1040
+ # Inject framework format.
1041
+ metadata = {"format": "pt"}
1042
+ if lora_adapter_metadata:
1043
+ for key, value in lora_adapter_metadata.items():
1044
+ if isinstance(value, set):
1045
+ lora_adapter_metadata[key] = list(value)
1046
+ metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(
1047
+ lora_adapter_metadata, indent=2, sort_keys=True
1048
+ )
1049
+
1050
+ return safetensors.torch.save_file(weights, filename, metadata=metadata)
895
1051
 
896
1052
  else:
897
1053
  save_function = torch.save
@@ -908,28 +1064,6 @@ class LoraBaseMixin:
908
1064
  save_function(state_dict, save_path)
909
1065
  logger.info(f"Model weights saved in {save_path}")
910
1066
 
911
- @property
912
- def lora_scale(self) -> float:
913
- # property function that returns the lora scale which can be set at run time by the pipeline.
914
- # if _lora_scale has not been set, return 1
915
- return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
916
-
917
- def enable_lora_hotswap(self, **kwargs) -> None:
918
- """Enables the possibility to hotswap LoRA adapters.
919
-
920
- Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of
921
- the loaded adapters differ.
922
-
923
- Args:
924
- target_rank (`int`):
925
- The highest rank among all the adapters that will be loaded.
926
- check_compiled (`str`, *optional*, defaults to `"error"`):
927
- How to handle the case when the model is already compiled, which should generally be avoided. The
928
- options are:
929
- - "error" (default): raise an error
930
- - "warn": issue a warning
931
- - "ignore": do nothing
932
- """
933
- for key, component in self.components.items():
934
- if hasattr(component, "enable_lora_hotswap") and (key in self._lora_loadable_modules):
935
- component.enable_lora_hotswap(**kwargs)
1067
+ @classmethod
1068
+ def _optionally_disable_offloading(cls, _pipeline):
1069
+ return _func_optionally_disable_offloading(_pipeline=_pipeline)