diffusers 0.32.2__py3-none-any.whl → 0.33.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (389) hide show
  1. diffusers/__init__.py +186 -3
  2. diffusers/configuration_utils.py +40 -12
  3. diffusers/dependency_versions_table.py +9 -2
  4. diffusers/hooks/__init__.py +9 -0
  5. diffusers/hooks/faster_cache.py +653 -0
  6. diffusers/hooks/group_offloading.py +793 -0
  7. diffusers/hooks/hooks.py +236 -0
  8. diffusers/hooks/layerwise_casting.py +245 -0
  9. diffusers/hooks/pyramid_attention_broadcast.py +311 -0
  10. diffusers/loaders/__init__.py +6 -0
  11. diffusers/loaders/ip_adapter.py +38 -30
  12. diffusers/loaders/lora_base.py +121 -86
  13. diffusers/loaders/lora_conversion_utils.py +504 -44
  14. diffusers/loaders/lora_pipeline.py +1769 -181
  15. diffusers/loaders/peft.py +167 -57
  16. diffusers/loaders/single_file.py +17 -2
  17. diffusers/loaders/single_file_model.py +53 -5
  18. diffusers/loaders/single_file_utils.py +646 -72
  19. diffusers/loaders/textual_inversion.py +9 -9
  20. diffusers/loaders/transformer_flux.py +8 -9
  21. diffusers/loaders/transformer_sd3.py +120 -39
  22. diffusers/loaders/unet.py +20 -7
  23. diffusers/models/__init__.py +22 -0
  24. diffusers/models/activations.py +9 -9
  25. diffusers/models/attention.py +0 -1
  26. diffusers/models/attention_processor.py +163 -25
  27. diffusers/models/auto_model.py +169 -0
  28. diffusers/models/autoencoders/__init__.py +2 -0
  29. diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
  30. diffusers/models/autoencoders/autoencoder_dc.py +106 -4
  31. diffusers/models/autoencoders/autoencoder_kl.py +0 -4
  32. diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
  33. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
  34. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
  35. diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
  36. diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
  37. diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
  38. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
  39. diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
  40. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
  41. diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
  42. diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
  43. diffusers/models/autoencoders/vae.py +31 -141
  44. diffusers/models/autoencoders/vq_model.py +3 -0
  45. diffusers/models/cache_utils.py +108 -0
  46. diffusers/models/controlnets/__init__.py +1 -0
  47. diffusers/models/controlnets/controlnet.py +3 -8
  48. diffusers/models/controlnets/controlnet_flux.py +14 -42
  49. diffusers/models/controlnets/controlnet_sd3.py +58 -34
  50. diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
  51. diffusers/models/controlnets/controlnet_union.py +27 -18
  52. diffusers/models/controlnets/controlnet_xs.py +7 -46
  53. diffusers/models/controlnets/multicontrolnet_union.py +196 -0
  54. diffusers/models/embeddings.py +18 -7
  55. diffusers/models/model_loading_utils.py +122 -80
  56. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  57. diffusers/models/modeling_flax_utils.py +1 -1
  58. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  59. diffusers/models/modeling_utils.py +617 -272
  60. diffusers/models/normalization.py +67 -14
  61. diffusers/models/resnet.py +1 -1
  62. diffusers/models/transformers/__init__.py +6 -0
  63. diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
  64. diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
  65. diffusers/models/transformers/consisid_transformer_3d.py +789 -0
  66. diffusers/models/transformers/dit_transformer_2d.py +5 -19
  67. diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
  68. diffusers/models/transformers/latte_transformer_3d.py +20 -15
  69. diffusers/models/transformers/lumina_nextdit2d.py +3 -1
  70. diffusers/models/transformers/pixart_transformer_2d.py +4 -19
  71. diffusers/models/transformers/prior_transformer.py +5 -1
  72. diffusers/models/transformers/sana_transformer.py +144 -40
  73. diffusers/models/transformers/stable_audio_transformer.py +5 -20
  74. diffusers/models/transformers/transformer_2d.py +7 -22
  75. diffusers/models/transformers/transformer_allegro.py +9 -17
  76. diffusers/models/transformers/transformer_cogview3plus.py +6 -17
  77. diffusers/models/transformers/transformer_cogview4.py +462 -0
  78. diffusers/models/transformers/transformer_easyanimate.py +527 -0
  79. diffusers/models/transformers/transformer_flux.py +68 -110
  80. diffusers/models/transformers/transformer_hunyuan_video.py +404 -46
  81. diffusers/models/transformers/transformer_ltx.py +53 -35
  82. diffusers/models/transformers/transformer_lumina2.py +548 -0
  83. diffusers/models/transformers/transformer_mochi.py +6 -17
  84. diffusers/models/transformers/transformer_omnigen.py +469 -0
  85. diffusers/models/transformers/transformer_sd3.py +56 -86
  86. diffusers/models/transformers/transformer_temporal.py +5 -11
  87. diffusers/models/transformers/transformer_wan.py +469 -0
  88. diffusers/models/unets/unet_1d.py +3 -1
  89. diffusers/models/unets/unet_2d.py +21 -20
  90. diffusers/models/unets/unet_2d_blocks.py +19 -243
  91. diffusers/models/unets/unet_2d_condition.py +4 -6
  92. diffusers/models/unets/unet_3d_blocks.py +14 -127
  93. diffusers/models/unets/unet_3d_condition.py +8 -12
  94. diffusers/models/unets/unet_i2vgen_xl.py +5 -13
  95. diffusers/models/unets/unet_kandinsky3.py +0 -4
  96. diffusers/models/unets/unet_motion_model.py +20 -114
  97. diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
  98. diffusers/models/unets/unet_stable_cascade.py +8 -35
  99. diffusers/models/unets/uvit_2d.py +1 -4
  100. diffusers/optimization.py +2 -2
  101. diffusers/pipelines/__init__.py +57 -8
  102. diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
  103. diffusers/pipelines/amused/pipeline_amused.py +15 -2
  104. diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
  105. diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
  106. diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
  107. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
  108. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
  109. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
  110. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
  111. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
  112. diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
  113. diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
  114. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
  115. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
  116. diffusers/pipelines/auto_pipeline.py +35 -14
  117. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  118. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
  119. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
  120. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
  121. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
  122. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
  123. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
  124. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
  125. diffusers/pipelines/cogview4/__init__.py +49 -0
  126. diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
  127. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
  128. diffusers/pipelines/cogview4/pipeline_output.py +21 -0
  129. diffusers/pipelines/consisid/__init__.py +49 -0
  130. diffusers/pipelines/consisid/consisid_utils.py +357 -0
  131. diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
  132. diffusers/pipelines/consisid/pipeline_output.py +20 -0
  133. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
  134. diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
  135. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
  136. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
  137. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
  138. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
  139. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
  140. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
  141. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
  142. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
  143. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
  144. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
  145. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
  146. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
  147. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
  148. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
  149. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
  150. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
  151. diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
  152. diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
  153. diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
  154. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
  155. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
  156. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
  157. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
  158. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
  159. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
  160. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
  161. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  162. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
  163. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
  164. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
  165. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
  166. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
  167. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
  168. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
  169. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
  170. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  171. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  172. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
  173. diffusers/pipelines/dit/pipeline_dit.py +15 -2
  174. diffusers/pipelines/easyanimate/__init__.py +52 -0
  175. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
  176. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
  177. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
  178. diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
  179. diffusers/pipelines/flux/pipeline_flux.py +53 -21
  180. diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
  181. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
  182. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
  183. diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
  184. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
  185. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
  186. diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
  187. diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
  188. diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
  189. diffusers/pipelines/free_noise_utils.py +3 -3
  190. diffusers/pipelines/hunyuan_video/__init__.py +4 -0
  191. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
  192. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
  193. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
  194. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
  195. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
  196. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
  197. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  198. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
  199. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
  200. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
  201. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
  202. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
  203. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
  204. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
  205. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
  206. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
  207. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
  208. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
  209. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
  210. diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
  211. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
  212. diffusers/pipelines/kolors/text_encoder.py +7 -34
  213. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
  214. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
  215. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
  216. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
  217. diffusers/pipelines/latte/pipeline_latte.py +36 -7
  218. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
  219. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
  220. diffusers/pipelines/ltx/__init__.py +2 -0
  221. diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
  222. diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
  223. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
  224. diffusers/pipelines/lumina/__init__.py +2 -2
  225. diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
  226. diffusers/pipelines/lumina2/__init__.py +48 -0
  227. diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
  228. diffusers/pipelines/marigold/__init__.py +2 -0
  229. diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
  230. diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
  231. diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
  232. diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
  233. diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
  234. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
  235. diffusers/pipelines/omnigen/__init__.py +50 -0
  236. diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
  237. diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
  238. diffusers/pipelines/onnx_utils.py +5 -3
  239. diffusers/pipelines/pag/pag_utils.py +1 -1
  240. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
  241. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
  242. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
  243. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
  244. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
  245. diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
  246. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
  247. diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
  248. diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
  249. diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
  250. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
  251. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
  252. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
  253. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
  254. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
  255. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
  256. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
  257. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
  258. diffusers/pipelines/pia/pipeline_pia.py +13 -1
  259. diffusers/pipelines/pipeline_flax_utils.py +7 -7
  260. diffusers/pipelines/pipeline_loading_utils.py +193 -83
  261. diffusers/pipelines/pipeline_utils.py +221 -106
  262. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
  263. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
  264. diffusers/pipelines/sana/__init__.py +2 -0
  265. diffusers/pipelines/sana/pipeline_sana.py +183 -58
  266. diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
  267. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
  268. diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
  269. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
  270. diffusers/pipelines/shap_e/renderer.py +6 -6
  271. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
  272. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
  273. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
  274. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
  275. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
  276. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
  277. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
  278. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
  279. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  280. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
  281. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
  282. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
  283. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
  284. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
  285. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
  286. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
  287. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
  288. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
  289. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
  290. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
  291. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
  292. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
  293. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
  294. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
  295. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
  296. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
  297. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
  298. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
  299. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
  300. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
  301. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
  302. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
  303. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
  304. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
  305. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
  306. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  307. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
  308. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
  309. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
  310. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
  311. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
  312. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
  313. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
  314. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
  315. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
  316. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
  317. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
  318. diffusers/pipelines/transformers_loading_utils.py +121 -0
  319. diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
  320. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
  321. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
  322. diffusers/pipelines/wan/__init__.py +51 -0
  323. diffusers/pipelines/wan/pipeline_output.py +20 -0
  324. diffusers/pipelines/wan/pipeline_wan.py +593 -0
  325. diffusers/pipelines/wan/pipeline_wan_i2v.py +722 -0
  326. diffusers/pipelines/wan/pipeline_wan_video2video.py +725 -0
  327. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
  328. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
  329. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
  330. diffusers/quantizers/auto.py +5 -1
  331. diffusers/quantizers/base.py +5 -9
  332. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
  333. diffusers/quantizers/bitsandbytes/utils.py +30 -20
  334. diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
  335. diffusers/quantizers/gguf/utils.py +4 -2
  336. diffusers/quantizers/quantization_config.py +59 -4
  337. diffusers/quantizers/quanto/__init__.py +1 -0
  338. diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
  339. diffusers/quantizers/quanto/utils.py +60 -0
  340. diffusers/quantizers/torchao/__init__.py +1 -1
  341. diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
  342. diffusers/schedulers/__init__.py +2 -1
  343. diffusers/schedulers/scheduling_consistency_models.py +1 -2
  344. diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
  345. diffusers/schedulers/scheduling_ddpm.py +2 -3
  346. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
  347. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
  348. diffusers/schedulers/scheduling_edm_euler.py +45 -10
  349. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
  350. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
  351. diffusers/schedulers/scheduling_heun_discrete.py +1 -1
  352. diffusers/schedulers/scheduling_lcm.py +1 -2
  353. diffusers/schedulers/scheduling_lms_discrete.py +1 -1
  354. diffusers/schedulers/scheduling_repaint.py +5 -1
  355. diffusers/schedulers/scheduling_scm.py +265 -0
  356. diffusers/schedulers/scheduling_tcd.py +1 -2
  357. diffusers/schedulers/scheduling_utils.py +2 -1
  358. diffusers/training_utils.py +14 -7
  359. diffusers/utils/__init__.py +9 -1
  360. diffusers/utils/constants.py +13 -1
  361. diffusers/utils/deprecation_utils.py +1 -1
  362. diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
  363. diffusers/utils/dummy_gguf_objects.py +17 -0
  364. diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
  365. diffusers/utils/dummy_pt_objects.py +233 -0
  366. diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
  367. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  368. diffusers/utils/dummy_torchao_objects.py +17 -0
  369. diffusers/utils/dynamic_modules_utils.py +1 -1
  370. diffusers/utils/export_utils.py +28 -3
  371. diffusers/utils/hub_utils.py +52 -102
  372. diffusers/utils/import_utils.py +121 -221
  373. diffusers/utils/loading_utils.py +2 -1
  374. diffusers/utils/logging.py +1 -2
  375. diffusers/utils/peft_utils.py +6 -14
  376. diffusers/utils/remote_utils.py +425 -0
  377. diffusers/utils/source_code_parsing_utils.py +52 -0
  378. diffusers/utils/state_dict_utils.py +15 -1
  379. diffusers/utils/testing_utils.py +243 -13
  380. diffusers/utils/torch_utils.py +10 -0
  381. diffusers/utils/typing_utils.py +91 -0
  382. diffusers/video_processor.py +1 -1
  383. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/METADATA +76 -44
  384. diffusers-0.33.0.dist-info/RECORD +608 -0
  385. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/WHEEL +1 -1
  386. diffusers-0.32.2.dist-info/RECORD +0 -550
  387. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/LICENSE +0 -0
  388. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/entry_points.txt +0 -0
  389. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/top_level.txt +0 -0
@@ -22,6 +22,8 @@ from ..utils import (
22
22
  USE_PEFT_BACKEND,
23
23
  deprecate,
24
24
  get_submodule_by_name,
25
+ is_bitsandbytes_available,
26
+ is_gguf_available,
25
27
  is_peft_available,
26
28
  is_peft_version,
27
29
  is_torch_version,
@@ -41,6 +43,8 @@ from .lora_conversion_utils import (
41
43
  _convert_hunyuan_video_lora_to_diffusers,
42
44
  _convert_kohya_flux_lora_to_diffusers,
43
45
  _convert_non_diffusers_lora_to_diffusers,
46
+ _convert_non_diffusers_lumina2_lora_to_diffusers,
47
+ _convert_non_diffusers_wan_lora_to_diffusers,
44
48
  _convert_xlabs_flux_lora_to_diffusers,
45
49
  _maybe_map_sgm_blocks_to_diffusers,
46
50
  )
@@ -66,6 +70,49 @@ TRANSFORMER_NAME = "transformer"
66
70
  _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX = {"x_embedder": "in_channels"}
67
71
 
68
72
 
73
+ def _maybe_dequantize_weight_for_expanded_lora(model, module):
74
+ if is_bitsandbytes_available():
75
+ from ..quantizers.bitsandbytes import dequantize_bnb_weight
76
+
77
+ if is_gguf_available():
78
+ from ..quantizers.gguf.utils import dequantize_gguf_tensor
79
+
80
+ is_bnb_4bit_quantized = module.weight.__class__.__name__ == "Params4bit"
81
+ is_gguf_quantized = module.weight.__class__.__name__ == "GGUFParameter"
82
+
83
+ if is_bnb_4bit_quantized and not is_bitsandbytes_available():
84
+ raise ValueError(
85
+ "The checkpoint seems to have been quantized with `bitsandbytes` (4bits). Install `bitsandbytes` to load quantized checkpoints."
86
+ )
87
+ if is_gguf_quantized and not is_gguf_available():
88
+ raise ValueError(
89
+ "The checkpoint seems to have been quantized with `gguf`. Install `gguf` to load quantized checkpoints."
90
+ )
91
+
92
+ weight_on_cpu = False
93
+ if not module.weight.is_cuda:
94
+ weight_on_cpu = True
95
+
96
+ if is_bnb_4bit_quantized:
97
+ module_weight = dequantize_bnb_weight(
98
+ module.weight.cuda() if weight_on_cpu else module.weight,
99
+ state=module.weight.quant_state,
100
+ dtype=model.dtype,
101
+ ).data
102
+ elif is_gguf_quantized:
103
+ module_weight = dequantize_gguf_tensor(
104
+ module.weight.cuda() if weight_on_cpu else module.weight,
105
+ )
106
+ module_weight = module_weight.to(model.dtype)
107
+ else:
108
+ module_weight = module.weight.data
109
+
110
+ if weight_on_cpu:
111
+ module_weight = module_weight.cpu()
112
+
113
+ return module_weight
114
+
115
+
69
116
  class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
70
117
  r"""
71
118
  Load LoRA layers into Stable Diffusion [`UNet2DConditionModel`] and
@@ -77,10 +124,13 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
77
124
  text_encoder_name = TEXT_ENCODER_NAME
78
125
 
79
126
  def load_lora_weights(
80
- self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
127
+ self,
128
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
129
+ adapter_name=None,
130
+ hotswap: bool = False,
131
+ **kwargs,
81
132
  ):
82
- """
83
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
133
+ """Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
84
134
  `self.text_encoder`.
85
135
 
86
136
  All kwargs are forwarded to `self.lora_state_dict`.
@@ -103,6 +153,29 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
103
153
  low_cpu_mem_usage (`bool`, *optional*):
104
154
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
105
155
  weights.
156
+ hotswap : (`bool`, *optional*)
157
+ Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
158
+ in-place. This means that, instead of loading an additional adapter, this will take the existing
159
+ adapter weights and replace them with the weights of the new adapter. This can be faster and more
160
+ memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
161
+ torch.compile, loading the new adapter does not require recompilation of the model. When using
162
+ hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
163
+
164
+ If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
165
+ to call an additional method before loading the adapter:
166
+
167
+ ```py
168
+ pipeline = ... # load diffusers pipeline
169
+ max_rank = ... # the highest rank among all LoRAs that you want to load
170
+ # call *before* compiling and loading the LoRA adapter
171
+ pipeline.enable_lora_hotswap(target_rank=max_rank)
172
+ pipeline.load_lora_weights(file_name)
173
+ # optionally compile the model now
174
+ ```
175
+
176
+ Note that hotswapping adapters of the text encoder is not yet supported. There are some further
177
+ limitations to this technique, which are documented here:
178
+ https://huggingface.co/docs/peft/main/en/package_reference/hotswap
106
179
  kwargs (`dict`, *optional*):
107
180
  See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
108
181
  """
@@ -133,6 +206,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
133
206
  adapter_name=adapter_name,
134
207
  _pipeline=self,
135
208
  low_cpu_mem_usage=low_cpu_mem_usage,
209
+ hotswap=hotswap,
136
210
  )
137
211
  self.load_lora_into_text_encoder(
138
212
  state_dict,
@@ -144,6 +218,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
144
218
  adapter_name=adapter_name,
145
219
  _pipeline=self,
146
220
  low_cpu_mem_usage=low_cpu_mem_usage,
221
+ hotswap=hotswap,
147
222
  )
148
223
 
149
224
  @classmethod
@@ -263,7 +338,14 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
263
338
 
264
339
  @classmethod
265
340
  def load_lora_into_unet(
266
- cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
341
+ cls,
342
+ state_dict,
343
+ network_alphas,
344
+ unet,
345
+ adapter_name=None,
346
+ _pipeline=None,
347
+ low_cpu_mem_usage=False,
348
+ hotswap: bool = False,
267
349
  ):
268
350
  """
269
351
  This will load the LoRA layers specified in `state_dict` into `unet`.
@@ -285,6 +367,29 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
285
367
  low_cpu_mem_usage (`bool`, *optional*):
286
368
  Speed up model loading only loading the pretrained LoRA weights and not initializing the random
287
369
  weights.
370
+ hotswap : (`bool`, *optional*)
371
+ Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
372
+ in-place. This means that, instead of loading an additional adapter, this will take the existing
373
+ adapter weights and replace them with the weights of the new adapter. This can be faster and more
374
+ memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
375
+ torch.compile, loading the new adapter does not require recompilation of the model. When using
376
+ hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
377
+
378
+ If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
379
+ to call an additional method before loading the adapter:
380
+
381
+ ```py
382
+ pipeline = ... # load diffusers pipeline
383
+ max_rank = ... # the highest rank among all LoRAs that you want to load
384
+ # call *before* compiling and loading the LoRA adapter
385
+ pipeline.enable_lora_hotswap(target_rank=max_rank)
386
+ pipeline.load_lora_weights(file_name)
387
+ # optionally compile the model now
388
+ ```
389
+
390
+ Note that hotswapping adapters of the text encoder is not yet supported. There are some further
391
+ limitations to this technique, which are documented here:
392
+ https://huggingface.co/docs/peft/main/en/package_reference/hotswap
288
393
  """
289
394
  if not USE_PEFT_BACKEND:
290
395
  raise ValueError("PEFT backend is required for this method.")
@@ -297,19 +402,16 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
297
402
  # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
298
403
  # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
299
404
  # their prefixes.
300
- keys = list(state_dict.keys())
301
- only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys)
302
- if not only_text_encoder:
303
- # Load the layers corresponding to UNet.
304
- logger.info(f"Loading {cls.unet_name}.")
305
- unet.load_lora_adapter(
306
- state_dict,
307
- prefix=cls.unet_name,
308
- network_alphas=network_alphas,
309
- adapter_name=adapter_name,
310
- _pipeline=_pipeline,
311
- low_cpu_mem_usage=low_cpu_mem_usage,
312
- )
405
+ logger.info(f"Loading {cls.unet_name}.")
406
+ unet.load_lora_adapter(
407
+ state_dict,
408
+ prefix=cls.unet_name,
409
+ network_alphas=network_alphas,
410
+ adapter_name=adapter_name,
411
+ _pipeline=_pipeline,
412
+ low_cpu_mem_usage=low_cpu_mem_usage,
413
+ hotswap=hotswap,
414
+ )
313
415
 
314
416
  @classmethod
315
417
  def load_lora_into_text_encoder(
@@ -322,6 +424,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
322
424
  adapter_name=None,
323
425
  _pipeline=None,
324
426
  low_cpu_mem_usage=False,
427
+ hotswap: bool = False,
325
428
  ):
326
429
  """
327
430
  This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -347,6 +450,29 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
347
450
  low_cpu_mem_usage (`bool`, *optional*):
348
451
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
349
452
  weights.
453
+ hotswap : (`bool`, *optional*)
454
+ Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
455
+ in-place. This means that, instead of loading an additional adapter, this will take the existing
456
+ adapter weights and replace them with the weights of the new adapter. This can be faster and more
457
+ memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
458
+ torch.compile, loading the new adapter does not require recompilation of the model. When using
459
+ hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
460
+
461
+ If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
462
+ to call an additional method before loading the adapter:
463
+
464
+ ```py
465
+ pipeline = ... # load diffusers pipeline
466
+ max_rank = ... # the highest rank among all LoRAs that you want to load
467
+ # call *before* compiling and loading the LoRA adapter
468
+ pipeline.enable_lora_hotswap(target_rank=max_rank)
469
+ pipeline.load_lora_weights(file_name)
470
+ # optionally compile the model now
471
+ ```
472
+
473
+ Note that hotswapping adapters of the text encoder is not yet supported. There are some further
474
+ limitations to this technique, which are documented here:
475
+ https://huggingface.co/docs/peft/main/en/package_reference/hotswap
350
476
  """
351
477
  _load_lora_into_text_encoder(
352
478
  state_dict=state_dict,
@@ -358,6 +484,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
358
484
  adapter_name=adapter_name,
359
485
  _pipeline=_pipeline,
360
486
  low_cpu_mem_usage=low_cpu_mem_usage,
487
+ hotswap=hotswap,
361
488
  )
362
489
 
363
490
  @classmethod
@@ -454,7 +581,11 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
454
581
  ```
455
582
  """
456
583
  super().fuse_lora(
457
- components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
584
+ components=components,
585
+ lora_scale=lora_scale,
586
+ safe_fusing=safe_fusing,
587
+ adapter_names=adapter_names,
588
+ **kwargs,
458
589
  )
459
590
 
460
591
  def unfuse_lora(self, components: List[str] = ["unet", "text_encoder"], **kwargs):
@@ -475,7 +606,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
475
606
  Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
476
607
  LoRA parameters then it won't have any effect.
477
608
  """
478
- super().unfuse_lora(components=components)
609
+ super().unfuse_lora(components=components, **kwargs)
479
610
 
480
611
 
481
612
  class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
@@ -558,31 +689,26 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
558
689
  _pipeline=self,
559
690
  low_cpu_mem_usage=low_cpu_mem_usage,
560
691
  )
561
- text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
562
- if len(text_encoder_state_dict) > 0:
563
- self.load_lora_into_text_encoder(
564
- text_encoder_state_dict,
565
- network_alphas=network_alphas,
566
- text_encoder=self.text_encoder,
567
- prefix="text_encoder",
568
- lora_scale=self.lora_scale,
569
- adapter_name=adapter_name,
570
- _pipeline=self,
571
- low_cpu_mem_usage=low_cpu_mem_usage,
572
- )
573
-
574
- text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
575
- if len(text_encoder_2_state_dict) > 0:
576
- self.load_lora_into_text_encoder(
577
- text_encoder_2_state_dict,
578
- network_alphas=network_alphas,
579
- text_encoder=self.text_encoder_2,
580
- prefix="text_encoder_2",
581
- lora_scale=self.lora_scale,
582
- adapter_name=adapter_name,
583
- _pipeline=self,
584
- low_cpu_mem_usage=low_cpu_mem_usage,
585
- )
692
+ self.load_lora_into_text_encoder(
693
+ state_dict,
694
+ network_alphas=network_alphas,
695
+ text_encoder=self.text_encoder,
696
+ prefix=self.text_encoder_name,
697
+ lora_scale=self.lora_scale,
698
+ adapter_name=adapter_name,
699
+ _pipeline=self,
700
+ low_cpu_mem_usage=low_cpu_mem_usage,
701
+ )
702
+ self.load_lora_into_text_encoder(
703
+ state_dict,
704
+ network_alphas=network_alphas,
705
+ text_encoder=self.text_encoder_2,
706
+ prefix=f"{self.text_encoder_name}_2",
707
+ lora_scale=self.lora_scale,
708
+ adapter_name=adapter_name,
709
+ _pipeline=self,
710
+ low_cpu_mem_usage=low_cpu_mem_usage,
711
+ )
586
712
 
587
713
  @classmethod
588
714
  @validate_hf_hub_args
@@ -703,7 +829,14 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
703
829
  @classmethod
704
830
  # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_unet
705
831
  def load_lora_into_unet(
706
- cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
832
+ cls,
833
+ state_dict,
834
+ network_alphas,
835
+ unet,
836
+ adapter_name=None,
837
+ _pipeline=None,
838
+ low_cpu_mem_usage=False,
839
+ hotswap: bool = False,
707
840
  ):
708
841
  """
709
842
  This will load the LoRA layers specified in `state_dict` into `unet`.
@@ -725,6 +858,29 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
725
858
  low_cpu_mem_usage (`bool`, *optional*):
726
859
  Speed up model loading only loading the pretrained LoRA weights and not initializing the random
727
860
  weights.
861
+ hotswap : (`bool`, *optional*)
862
+ Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
863
+ in-place. This means that, instead of loading an additional adapter, this will take the existing
864
+ adapter weights and replace them with the weights of the new adapter. This can be faster and more
865
+ memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
866
+ torch.compile, loading the new adapter does not require recompilation of the model. When using
867
+ hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
868
+
869
+ If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
870
+ to call an additional method before loading the adapter:
871
+
872
+ ```py
873
+ pipeline = ... # load diffusers pipeline
874
+ max_rank = ... # the highest rank among all LoRAs that you want to load
875
+ # call *before* compiling and loading the LoRA adapter
876
+ pipeline.enable_lora_hotswap(target_rank=max_rank)
877
+ pipeline.load_lora_weights(file_name)
878
+ # optionally compile the model now
879
+ ```
880
+
881
+ Note that hotswapping adapters of the text encoder is not yet supported. There are some further
882
+ limitations to this technique, which are documented here:
883
+ https://huggingface.co/docs/peft/main/en/package_reference/hotswap
728
884
  """
729
885
  if not USE_PEFT_BACKEND:
730
886
  raise ValueError("PEFT backend is required for this method.")
@@ -737,19 +893,16 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
737
893
  # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
738
894
  # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
739
895
  # their prefixes.
740
- keys = list(state_dict.keys())
741
- only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys)
742
- if not only_text_encoder:
743
- # Load the layers corresponding to UNet.
744
- logger.info(f"Loading {cls.unet_name}.")
745
- unet.load_lora_adapter(
746
- state_dict,
747
- prefix=cls.unet_name,
748
- network_alphas=network_alphas,
749
- adapter_name=adapter_name,
750
- _pipeline=_pipeline,
751
- low_cpu_mem_usage=low_cpu_mem_usage,
752
- )
896
+ logger.info(f"Loading {cls.unet_name}.")
897
+ unet.load_lora_adapter(
898
+ state_dict,
899
+ prefix=cls.unet_name,
900
+ network_alphas=network_alphas,
901
+ adapter_name=adapter_name,
902
+ _pipeline=_pipeline,
903
+ low_cpu_mem_usage=low_cpu_mem_usage,
904
+ hotswap=hotswap,
905
+ )
753
906
 
754
907
  @classmethod
755
908
  # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
@@ -763,6 +916,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
763
916
  adapter_name=None,
764
917
  _pipeline=None,
765
918
  low_cpu_mem_usage=False,
919
+ hotswap: bool = False,
766
920
  ):
767
921
  """
768
922
  This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -788,6 +942,29 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
788
942
  low_cpu_mem_usage (`bool`, *optional*):
789
943
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
790
944
  weights.
945
+ hotswap : (`bool`, *optional*)
946
+ Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
947
+ in-place. This means that, instead of loading an additional adapter, this will take the existing
948
+ adapter weights and replace them with the weights of the new adapter. This can be faster and more
949
+ memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
950
+ torch.compile, loading the new adapter does not require recompilation of the model. When using
951
+ hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
952
+
953
+ If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
954
+ to call an additional method before loading the adapter:
955
+
956
+ ```py
957
+ pipeline = ... # load diffusers pipeline
958
+ max_rank = ... # the highest rank among all LoRAs that you want to load
959
+ # call *before* compiling and loading the LoRA adapter
960
+ pipeline.enable_lora_hotswap(target_rank=max_rank)
961
+ pipeline.load_lora_weights(file_name)
962
+ # optionally compile the model now
963
+ ```
964
+
965
+ Note that hotswapping adapters of the text encoder is not yet supported. There are some further
966
+ limitations to this technique, which are documented here:
967
+ https://huggingface.co/docs/peft/main/en/package_reference/hotswap
791
968
  """
792
969
  _load_lora_into_text_encoder(
793
970
  state_dict=state_dict,
@@ -799,6 +976,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
799
976
  adapter_name=adapter_name,
800
977
  _pipeline=_pipeline,
801
978
  low_cpu_mem_usage=low_cpu_mem_usage,
979
+ hotswap=hotswap,
802
980
  )
803
981
 
804
982
  @classmethod
@@ -842,11 +1020,11 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
842
1020
 
843
1021
  if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
844
1022
  raise ValueError(
845
- "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`."
1023
+ "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`."
846
1024
  )
847
1025
 
848
1026
  if unet_lora_layers:
849
- state_dict.update(cls.pack_weights(unet_lora_layers, "unet"))
1027
+ state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name))
850
1028
 
851
1029
  if text_encoder_lora_layers:
852
1030
  state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))
@@ -903,7 +1081,11 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
903
1081
  ```
904
1082
  """
905
1083
  super().fuse_lora(
906
- components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
1084
+ components=components,
1085
+ lora_scale=lora_scale,
1086
+ safe_fusing=safe_fusing,
1087
+ adapter_names=adapter_names,
1088
+ **kwargs,
907
1089
  )
908
1090
 
909
1091
  def unfuse_lora(self, components: List[str] = ["unet", "text_encoder", "text_encoder_2"], **kwargs):
@@ -924,7 +1106,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
924
1106
  Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
925
1107
  LoRA parameters then it won't have any effect.
926
1108
  """
927
- super().unfuse_lora(components=components)
1109
+ super().unfuse_lora(components=components, **kwargs)
928
1110
 
929
1111
 
930
1112
  class SD3LoraLoaderMixin(LoraBaseMixin):
@@ -1038,7 +1220,11 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1038
1220
  return state_dict
1039
1221
 
1040
1222
  def load_lora_weights(
1041
- self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
1223
+ self,
1224
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
1225
+ adapter_name=None,
1226
+ hotswap: bool = False,
1227
+ **kwargs,
1042
1228
  ):
1043
1229
  """
1044
1230
  Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
@@ -1061,6 +1247,29 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1061
1247
  low_cpu_mem_usage (`bool`, *optional*):
1062
1248
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
1063
1249
  weights.
1250
+ hotswap : (`bool`, *optional*)
1251
+ Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
1252
+ in-place. This means that, instead of loading an additional adapter, this will take the existing
1253
+ adapter weights and replace them with the weights of the new adapter. This can be faster and more
1254
+ memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
1255
+ torch.compile, loading the new adapter does not require recompilation of the model. When using
1256
+ hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
1257
+
1258
+ If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
1259
+ to call an additional method before loading the adapter:
1260
+
1261
+ ```py
1262
+ pipeline = ... # load diffusers pipeline
1263
+ max_rank = ... # the highest rank among all LoRAs that you want to load
1264
+ # call *before* compiling and loading the LoRA adapter
1265
+ pipeline.enable_lora_hotswap(target_rank=max_rank)
1266
+ pipeline.load_lora_weights(file_name)
1267
+ # optionally compile the model now
1268
+ ```
1269
+
1270
+ Note that hotswapping adapters of the text encoder is not yet supported. There are some further
1271
+ limitations to this technique, which are documented here:
1272
+ https://huggingface.co/docs/peft/main/en/package_reference/hotswap
1064
1273
  kwargs (`dict`, *optional*):
1065
1274
  See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
1066
1275
  """
@@ -1084,47 +1293,40 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1084
1293
  if not is_correct_format:
1085
1294
  raise ValueError("Invalid LoRA checkpoint.")
1086
1295
 
1087
- transformer_state_dict = {k: v for k, v in state_dict.items() if "transformer." in k}
1088
- if len(transformer_state_dict) > 0:
1089
- self.load_lora_into_transformer(
1090
- state_dict,
1091
- transformer=getattr(self, self.transformer_name)
1092
- if not hasattr(self, "transformer")
1093
- else self.transformer,
1094
- adapter_name=adapter_name,
1095
- _pipeline=self,
1096
- low_cpu_mem_usage=low_cpu_mem_usage,
1097
- )
1098
-
1099
- text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
1100
- if len(text_encoder_state_dict) > 0:
1101
- self.load_lora_into_text_encoder(
1102
- text_encoder_state_dict,
1103
- network_alphas=None,
1104
- text_encoder=self.text_encoder,
1105
- prefix="text_encoder",
1106
- lora_scale=self.lora_scale,
1107
- adapter_name=adapter_name,
1108
- _pipeline=self,
1109
- low_cpu_mem_usage=low_cpu_mem_usage,
1110
- )
1111
-
1112
- text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
1113
- if len(text_encoder_2_state_dict) > 0:
1114
- self.load_lora_into_text_encoder(
1115
- text_encoder_2_state_dict,
1116
- network_alphas=None,
1117
- text_encoder=self.text_encoder_2,
1118
- prefix="text_encoder_2",
1119
- lora_scale=self.lora_scale,
1120
- adapter_name=adapter_name,
1121
- _pipeline=self,
1122
- low_cpu_mem_usage=low_cpu_mem_usage,
1123
- )
1296
+ self.load_lora_into_transformer(
1297
+ state_dict,
1298
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
1299
+ adapter_name=adapter_name,
1300
+ _pipeline=self,
1301
+ low_cpu_mem_usage=low_cpu_mem_usage,
1302
+ hotswap=hotswap,
1303
+ )
1304
+ self.load_lora_into_text_encoder(
1305
+ state_dict,
1306
+ network_alphas=None,
1307
+ text_encoder=self.text_encoder,
1308
+ prefix=self.text_encoder_name,
1309
+ lora_scale=self.lora_scale,
1310
+ adapter_name=adapter_name,
1311
+ _pipeline=self,
1312
+ low_cpu_mem_usage=low_cpu_mem_usage,
1313
+ hotswap=hotswap,
1314
+ )
1315
+ self.load_lora_into_text_encoder(
1316
+ state_dict,
1317
+ network_alphas=None,
1318
+ text_encoder=self.text_encoder_2,
1319
+ prefix=f"{self.text_encoder_name}_2",
1320
+ lora_scale=self.lora_scale,
1321
+ adapter_name=adapter_name,
1322
+ _pipeline=self,
1323
+ low_cpu_mem_usage=low_cpu_mem_usage,
1324
+ hotswap=hotswap,
1325
+ )
1124
1326
 
1125
1327
  @classmethod
1126
1328
  def load_lora_into_transformer(
1127
- cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
1329
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
1128
1330
  ):
1129
1331
  """
1130
1332
  This will load the LoRA layers specified in `state_dict` into `transformer`.
@@ -1142,6 +1344,29 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1142
1344
  low_cpu_mem_usage (`bool`, *optional*):
1143
1345
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
1144
1346
  weights.
1347
+ hotswap : (`bool`, *optional*)
1348
+ Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
1349
+ in-place. This means that, instead of loading an additional adapter, this will take the existing
1350
+ adapter weights and replace them with the weights of the new adapter. This can be faster and more
1351
+ memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
1352
+ torch.compile, loading the new adapter does not require recompilation of the model. When using
1353
+ hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
1354
+
1355
+ If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
1356
+ to call an additional method before loading the adapter:
1357
+
1358
+ ```py
1359
+ pipeline = ... # load diffusers pipeline
1360
+ max_rank = ... # the highest rank among all LoRAs that you want to load
1361
+ # call *before* compiling and loading the LoRA adapter
1362
+ pipeline.enable_lora_hotswap(target_rank=max_rank)
1363
+ pipeline.load_lora_weights(file_name)
1364
+ # optionally compile the model now
1365
+ ```
1366
+
1367
+ Note that hotswapping adapters of the text encoder is not yet supported. There are some further
1368
+ limitations to this technique, which are documented here:
1369
+ https://huggingface.co/docs/peft/main/en/package_reference/hotswap
1145
1370
  """
1146
1371
  if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
1147
1372
  raise ValueError(
@@ -1156,6 +1381,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1156
1381
  adapter_name=adapter_name,
1157
1382
  _pipeline=_pipeline,
1158
1383
  low_cpu_mem_usage=low_cpu_mem_usage,
1384
+ hotswap=hotswap,
1159
1385
  )
1160
1386
 
1161
1387
  @classmethod
@@ -1170,6 +1396,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1170
1396
  adapter_name=None,
1171
1397
  _pipeline=None,
1172
1398
  low_cpu_mem_usage=False,
1399
+ hotswap: bool = False,
1173
1400
  ):
1174
1401
  """
1175
1402
  This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -1195,6 +1422,29 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1195
1422
  low_cpu_mem_usage (`bool`, *optional*):
1196
1423
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
1197
1424
  weights.
1425
+ hotswap : (`bool`, *optional*)
1426
+ Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
1427
+ in-place. This means that, instead of loading an additional adapter, this will take the existing
1428
+ adapter weights and replace them with the weights of the new adapter. This can be faster and more
1429
+ memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
1430
+ torch.compile, loading the new adapter does not require recompilation of the model. When using
1431
+ hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
1432
+
1433
+ If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
1434
+ to call an additional method before loading the adapter:
1435
+
1436
+ ```py
1437
+ pipeline = ... # load diffusers pipeline
1438
+ max_rank = ... # the highest rank among all LoRAs that you want to load
1439
+ # call *before* compiling and loading the LoRA adapter
1440
+ pipeline.enable_lora_hotswap(target_rank=max_rank)
1441
+ pipeline.load_lora_weights(file_name)
1442
+ # optionally compile the model now
1443
+ ```
1444
+
1445
+ Note that hotswapping adapters of the text encoder is not yet supported. There are some further
1446
+ limitations to this technique, which are documented here:
1447
+ https://huggingface.co/docs/peft/main/en/package_reference/hotswap
1198
1448
  """
1199
1449
  _load_lora_into_text_encoder(
1200
1450
  state_dict=state_dict,
@@ -1206,13 +1456,15 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1206
1456
  adapter_name=adapter_name,
1207
1457
  _pipeline=_pipeline,
1208
1458
  low_cpu_mem_usage=low_cpu_mem_usage,
1459
+ hotswap=hotswap,
1209
1460
  )
1210
1461
 
1211
1462
  @classmethod
1463
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.save_lora_weights with unet->transformer
1212
1464
  def save_lora_weights(
1213
1465
  cls,
1214
1466
  save_directory: Union[str, os.PathLike],
1215
- transformer_lora_layers: Dict[str, torch.nn.Module] = None,
1467
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
1216
1468
  text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
1217
1469
  text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
1218
1470
  is_main_process: bool = True,
@@ -1261,7 +1513,6 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1261
1513
  if text_encoder_2_lora_layers:
1262
1514
  state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
1263
1515
 
1264
- # Save the model
1265
1516
  cls.write_lora_layers(
1266
1517
  state_dict=state_dict,
1267
1518
  save_directory=save_directory,
@@ -1271,6 +1522,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1271
1522
  safe_serialization=safe_serialization,
1272
1523
  )
1273
1524
 
1525
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.fuse_lora with unet->transformer
1274
1526
  def fuse_lora(
1275
1527
  self,
1276
1528
  components: List[str] = ["transformer", "text_encoder", "text_encoder_2"],
@@ -1311,9 +1563,14 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1311
1563
  ```
1312
1564
  """
1313
1565
  super().fuse_lora(
1314
- components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
1566
+ components=components,
1567
+ lora_scale=lora_scale,
1568
+ safe_fusing=safe_fusing,
1569
+ adapter_names=adapter_names,
1570
+ **kwargs,
1315
1571
  )
1316
1572
 
1573
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.unfuse_lora with unet->transformer
1317
1574
  def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], **kwargs):
1318
1575
  r"""
1319
1576
  Reverses the effect of
@@ -1327,12 +1584,12 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1327
1584
 
1328
1585
  Args:
1329
1586
  components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
1330
- unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
1587
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
1331
1588
  unfuse_text_encoder (`bool`, defaults to `True`):
1332
1589
  Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
1333
1590
  LoRA parameters then it won't have any effect.
1334
1591
  """
1335
- super().unfuse_lora(components=components)
1592
+ super().unfuse_lora(components=components, **kwargs)
1336
1593
 
1337
1594
 
1338
1595
  class FluxLoraLoaderMixin(LoraBaseMixin):
@@ -1483,7 +1740,11 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1483
1740
  return state_dict
1484
1741
 
1485
1742
  def load_lora_weights(
1486
- self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
1743
+ self,
1744
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
1745
+ adapter_name=None,
1746
+ hotswap: bool = False,
1747
+ **kwargs,
1487
1748
  ):
1488
1749
  """
1489
1750
  Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
@@ -1508,6 +1769,26 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1508
1769
  low_cpu_mem_usage (`bool`, *optional*):
1509
1770
  `Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
1510
1771
  weights.
1772
+ hotswap : (`bool`, *optional*)
1773
+ Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
1774
+ in-place. This means that, instead of loading an additional adapter, this will take the existing
1775
+ adapter weights and replace them with the weights of the new adapter. This can be faster and more
1776
+ memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
1777
+ torch.compile, loading the new adapter does not require recompilation of the model. When using
1778
+ hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. If the new
1779
+ adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need to call an
1780
+ additional method before loading the adapter:
1781
+ ```py
1782
+ pipeline = ... # load diffusers pipeline
1783
+ max_rank = ... # the highest rank among all LoRAs that you want to load
1784
+ # call *before* compiling and loading the LoRA adapter
1785
+ pipeline.enable_lora_hotswap(target_rank=max_rank)
1786
+ pipeline.load_lora_weights(file_name)
1787
+ # optionally compile the model now
1788
+ ```
1789
+ Note that hotswapping adapters of the text encoder is not yet supported. There are some further
1790
+ limitations to this technique, which are documented here:
1791
+ https://huggingface.co/docs/peft/main/en/package_reference/hotswap
1511
1792
  """
1512
1793
  if not USE_PEFT_BACKEND:
1513
1794
  raise ValueError("PEFT backend is required for this method.")
@@ -1538,18 +1819,23 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1538
1819
  raise ValueError("Invalid LoRA checkpoint.")
1539
1820
 
1540
1821
  transformer_lora_state_dict = {
1541
- k: state_dict.pop(k) for k in list(state_dict.keys()) if "transformer." in k and "lora" in k
1822
+ k: state_dict.get(k)
1823
+ for k in list(state_dict.keys())
1824
+ if k.startswith(f"{self.transformer_name}.") and "lora" in k
1542
1825
  }
1543
1826
  transformer_norm_state_dict = {
1544
1827
  k: state_dict.pop(k)
1545
1828
  for k in list(state_dict.keys())
1546
- if "transformer." in k and any(norm_key in k for norm_key in self._control_lora_supported_norm_keys)
1829
+ if k.startswith(f"{self.transformer_name}.")
1830
+ and any(norm_key in k for norm_key in self._control_lora_supported_norm_keys)
1547
1831
  }
1548
1832
 
1549
1833
  transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
1550
- has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_(
1551
- transformer, transformer_lora_state_dict, transformer_norm_state_dict
1552
- )
1834
+ has_param_with_expanded_shape = False
1835
+ if len(transformer_lora_state_dict) > 0:
1836
+ has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_(
1837
+ transformer, transformer_lora_state_dict, transformer_norm_state_dict
1838
+ )
1553
1839
 
1554
1840
  if has_param_with_expanded_shape:
1555
1841
  logger.info(
@@ -1557,19 +1843,22 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1557
1843
  "As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. "
1558
1844
  "To get a comprehensive list of parameter names that were modified, enable debug logging."
1559
1845
  )
1560
- transformer_lora_state_dict = self._maybe_expand_lora_state_dict(
1561
- transformer=transformer, lora_state_dict=transformer_lora_state_dict
1562
- )
1563
-
1564
1846
  if len(transformer_lora_state_dict) > 0:
1565
- self.load_lora_into_transformer(
1566
- transformer_lora_state_dict,
1567
- network_alphas=network_alphas,
1568
- transformer=transformer,
1569
- adapter_name=adapter_name,
1570
- _pipeline=self,
1571
- low_cpu_mem_usage=low_cpu_mem_usage,
1847
+ transformer_lora_state_dict = self._maybe_expand_lora_state_dict(
1848
+ transformer=transformer, lora_state_dict=transformer_lora_state_dict
1572
1849
  )
1850
+ for k in transformer_lora_state_dict:
1851
+ state_dict.update({k: transformer_lora_state_dict[k]})
1852
+
1853
+ self.load_lora_into_transformer(
1854
+ state_dict,
1855
+ network_alphas=network_alphas,
1856
+ transformer=transformer,
1857
+ adapter_name=adapter_name,
1858
+ _pipeline=self,
1859
+ low_cpu_mem_usage=low_cpu_mem_usage,
1860
+ hotswap=hotswap,
1861
+ )
1573
1862
 
1574
1863
  if len(transformer_norm_state_dict) > 0:
1575
1864
  transformer._transformer_norm_layers = self._load_norm_into_transformer(
@@ -1578,22 +1867,28 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1578
1867
  discard_original_layers=False,
1579
1868
  )
1580
1869
 
1581
- text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
1582
- if len(text_encoder_state_dict) > 0:
1583
- self.load_lora_into_text_encoder(
1584
- text_encoder_state_dict,
1585
- network_alphas=network_alphas,
1586
- text_encoder=self.text_encoder,
1587
- prefix="text_encoder",
1588
- lora_scale=self.lora_scale,
1589
- adapter_name=adapter_name,
1590
- _pipeline=self,
1591
- low_cpu_mem_usage=low_cpu_mem_usage,
1592
- )
1870
+ self.load_lora_into_text_encoder(
1871
+ state_dict,
1872
+ network_alphas=network_alphas,
1873
+ text_encoder=self.text_encoder,
1874
+ prefix=self.text_encoder_name,
1875
+ lora_scale=self.lora_scale,
1876
+ adapter_name=adapter_name,
1877
+ _pipeline=self,
1878
+ low_cpu_mem_usage=low_cpu_mem_usage,
1879
+ hotswap=hotswap,
1880
+ )
1593
1881
 
1594
1882
  @classmethod
1595
1883
  def load_lora_into_transformer(
1596
- cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
1884
+ cls,
1885
+ state_dict,
1886
+ network_alphas,
1887
+ transformer,
1888
+ adapter_name=None,
1889
+ _pipeline=None,
1890
+ low_cpu_mem_usage=False,
1891
+ hotswap: bool = False,
1597
1892
  ):
1598
1893
  """
1599
1894
  This will load the LoRA layers specified in `state_dict` into `transformer`.
@@ -1615,6 +1910,29 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1615
1910
  low_cpu_mem_usage (`bool`, *optional*):
1616
1911
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
1617
1912
  weights.
1913
+ hotswap : (`bool`, *optional*)
1914
+ Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
1915
+ in-place. This means that, instead of loading an additional adapter, this will take the existing
1916
+ adapter weights and replace them with the weights of the new adapter. This can be faster and more
1917
+ memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
1918
+ torch.compile, loading the new adapter does not require recompilation of the model. When using
1919
+ hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
1920
+
1921
+ If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
1922
+ to call an additional method before loading the adapter:
1923
+
1924
+ ```py
1925
+ pipeline = ... # load diffusers pipeline
1926
+ max_rank = ... # the highest rank among all LoRAs that you want to load
1927
+ # call *before* compiling and loading the LoRA adapter
1928
+ pipeline.enable_lora_hotswap(target_rank=max_rank)
1929
+ pipeline.load_lora_weights(file_name)
1930
+ # optionally compile the model now
1931
+ ```
1932
+
1933
+ Note that hotswapping adapters of the text encoder is not yet supported. There are some further
1934
+ limitations to this technique, which are documented here:
1935
+ https://huggingface.co/docs/peft/main/en/package_reference/hotswap
1618
1936
  """
1619
1937
  if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
1620
1938
  raise ValueError(
@@ -1622,17 +1940,15 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1622
1940
  )
1623
1941
 
1624
1942
  # Load the layers corresponding to transformer.
1625
- keys = list(state_dict.keys())
1626
- transformer_present = any(key.startswith(cls.transformer_name) for key in keys)
1627
- if transformer_present:
1628
- logger.info(f"Loading {cls.transformer_name}.")
1629
- transformer.load_lora_adapter(
1630
- state_dict,
1631
- network_alphas=network_alphas,
1632
- adapter_name=adapter_name,
1633
- _pipeline=_pipeline,
1634
- low_cpu_mem_usage=low_cpu_mem_usage,
1635
- )
1943
+ logger.info(f"Loading {cls.transformer_name}.")
1944
+ transformer.load_lora_adapter(
1945
+ state_dict,
1946
+ network_alphas=network_alphas,
1947
+ adapter_name=adapter_name,
1948
+ _pipeline=_pipeline,
1949
+ low_cpu_mem_usage=low_cpu_mem_usage,
1950
+ hotswap=hotswap,
1951
+ )
1636
1952
 
1637
1953
  @classmethod
1638
1954
  def _load_norm_into_transformer(
@@ -1700,6 +2016,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1700
2016
  adapter_name=None,
1701
2017
  _pipeline=None,
1702
2018
  low_cpu_mem_usage=False,
2019
+ hotswap: bool = False,
1703
2020
  ):
1704
2021
  """
1705
2022
  This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -1725,6 +2042,29 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1725
2042
  low_cpu_mem_usage (`bool`, *optional*):
1726
2043
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
1727
2044
  weights.
2045
+ hotswap : (`bool`, *optional*)
2046
+ Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
2047
+ in-place. This means that, instead of loading an additional adapter, this will take the existing
2048
+ adapter weights and replace them with the weights of the new adapter. This can be faster and more
2049
+ memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
2050
+ torch.compile, loading the new adapter does not require recompilation of the model. When using
2051
+ hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
2052
+
2053
+ If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
2054
+ to call an additional method before loading the adapter:
2055
+
2056
+ ```py
2057
+ pipeline = ... # load diffusers pipeline
2058
+ max_rank = ... # the highest rank among all LoRAs that you want to load
2059
+ # call *before* compiling and loading the LoRA adapter
2060
+ pipeline.enable_lora_hotswap(target_rank=max_rank)
2061
+ pipeline.load_lora_weights(file_name)
2062
+ # optionally compile the model now
2063
+ ```
2064
+
2065
+ Note that hotswapping adapters of the text encoder is not yet supported. There are some further
2066
+ limitations to this technique, which are documented here:
2067
+ https://huggingface.co/docs/peft/main/en/package_reference/hotswap
1728
2068
  """
1729
2069
  _load_lora_into_text_encoder(
1730
2070
  state_dict=state_dict,
@@ -1736,6 +2076,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1736
2076
  adapter_name=adapter_name,
1737
2077
  _pipeline=_pipeline,
1738
2078
  low_cpu_mem_usage=low_cpu_mem_usage,
2079
+ hotswap=hotswap,
1739
2080
  )
1740
2081
 
1741
2082
  @classmethod
@@ -1846,7 +2187,11 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1846
2187
  )
1847
2188
 
1848
2189
  super().fuse_lora(
1849
- components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
2190
+ components=components,
2191
+ lora_scale=lora_scale,
2192
+ safe_fusing=safe_fusing,
2193
+ adapter_names=adapter_names,
2194
+ **kwargs,
1850
2195
  )
1851
2196
 
1852
2197
  def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
@@ -1867,7 +2212,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1867
2212
  if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers:
1868
2213
  transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
1869
2214
 
1870
- super().unfuse_lora(components=components)
2215
+ super().unfuse_lora(components=components, **kwargs)
1871
2216
 
1872
2217
  # We override this here account for `_transformer_norm_layers` and `_overwritten_params`.
1873
2218
  def unload_lora_weights(self, reset_to_overwritten_params=False):
@@ -1967,6 +2312,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1967
2312
  overwritten_params = {}
1968
2313
 
1969
2314
  is_peft_loaded = getattr(transformer, "peft_config", None) is not None
2315
+ is_quantized = hasattr(transformer, "hf_quantizer")
1970
2316
  for name, module in transformer.named_modules():
1971
2317
  if isinstance(module, torch.nn.Linear):
1972
2318
  module_weight = module.weight.data
@@ -1991,9 +2337,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1991
2337
  if tuple(module_weight_shape) == (out_features, in_features):
1992
2338
  continue
1993
2339
 
1994
- # TODO (sayakpaul): We still need to consider if the module we're expanding is
1995
- # quantized and handle it accordingly if that is the case.
1996
- module_out_features, module_in_features = module_weight.shape
2340
+ module_out_features, module_in_features = module_weight_shape
1997
2341
  debug_message = ""
1998
2342
  if in_features > module_in_features:
1999
2343
  debug_message += (
@@ -2016,6 +2360,10 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
2016
2360
  parent_module_name, _, current_module_name = name.rpartition(".")
2017
2361
  parent_module = transformer.get_submodule(parent_module_name)
2018
2362
 
2363
+ if is_quantized:
2364
+ module_weight = _maybe_dequantize_weight_for_expanded_lora(transformer, module)
2365
+
2366
+ # TODO: consider if this layer needs to be a quantized layer as well if `is_quantized` is True.
2019
2367
  with torch.device("meta"):
2020
2368
  expanded_module = torch.nn.Linear(
2021
2369
  in_features, out_features, bias=bias, dtype=module_weight.dtype
@@ -2027,7 +2375,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
2027
2375
  new_weight = torch.zeros_like(
2028
2376
  expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype
2029
2377
  )
2030
- slices = tuple(slice(0, dim) for dim in module_weight.shape)
2378
+ slices = tuple(slice(0, dim) for dim in module_weight_shape)
2031
2379
  new_weight[slices] = module_weight
2032
2380
  tmp_state_dict = {"weight": new_weight}
2033
2381
  if module_bias is not None:
@@ -2116,7 +2464,12 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
2116
2464
  base_weight_param_name: str = None,
2117
2465
  ) -> "torch.Size":
2118
2466
  def _get_weight_shape(weight: torch.Tensor):
2119
- return weight.quant_state.shape if weight.__class__.__name__ == "Params4bit" else weight.shape
2467
+ if weight.__class__.__name__ == "Params4bit":
2468
+ return weight.quant_state.shape
2469
+ elif weight.__class__.__name__ == "GGUFParameter":
2470
+ return weight.quant_shape
2471
+ else:
2472
+ return weight.shape
2120
2473
 
2121
2474
  if base_module is not None:
2122
2475
  return _get_weight_shape(base_module.weight)
@@ -2142,7 +2495,14 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
2142
2495
  @classmethod
2143
2496
  # Copied from diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin.load_lora_into_transformer with FluxTransformer2DModel->UVit2DModel
2144
2497
  def load_lora_into_transformer(
2145
- cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
2498
+ cls,
2499
+ state_dict,
2500
+ network_alphas,
2501
+ transformer,
2502
+ adapter_name=None,
2503
+ _pipeline=None,
2504
+ low_cpu_mem_usage=False,
2505
+ hotswap: bool = False,
2146
2506
  ):
2147
2507
  """
2148
2508
  This will load the LoRA layers specified in `state_dict` into `transformer`.
@@ -2164,6 +2524,29 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
2164
2524
  low_cpu_mem_usage (`bool`, *optional*):
2165
2525
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
2166
2526
  weights.
2527
+ hotswap : (`bool`, *optional*)
2528
+ Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
2529
+ in-place. This means that, instead of loading an additional adapter, this will take the existing
2530
+ adapter weights and replace them with the weights of the new adapter. This can be faster and more
2531
+ memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
2532
+ torch.compile, loading the new adapter does not require recompilation of the model. When using
2533
+ hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
2534
+
2535
+ If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
2536
+ to call an additional method before loading the adapter:
2537
+
2538
+ ```py
2539
+ pipeline = ... # load diffusers pipeline
2540
+ max_rank = ... # the highest rank among all LoRAs that you want to load
2541
+ # call *before* compiling and loading the LoRA adapter
2542
+ pipeline.enable_lora_hotswap(target_rank=max_rank)
2543
+ pipeline.load_lora_weights(file_name)
2544
+ # optionally compile the model now
2545
+ ```
2546
+
2547
+ Note that hotswapping adapters of the text encoder is not yet supported. There are some further
2548
+ limitations to this technique, which are documented here:
2549
+ https://huggingface.co/docs/peft/main/en/package_reference/hotswap
2167
2550
  """
2168
2551
  if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
2169
2552
  raise ValueError(
@@ -2171,17 +2554,15 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
2171
2554
  )
2172
2555
 
2173
2556
  # Load the layers corresponding to transformer.
2174
- keys = list(state_dict.keys())
2175
- transformer_present = any(key.startswith(cls.transformer_name) for key in keys)
2176
- if transformer_present:
2177
- logger.info(f"Loading {cls.transformer_name}.")
2178
- transformer.load_lora_adapter(
2179
- state_dict,
2180
- network_alphas=network_alphas,
2181
- adapter_name=adapter_name,
2182
- _pipeline=_pipeline,
2183
- low_cpu_mem_usage=low_cpu_mem_usage,
2184
- )
2557
+ logger.info(f"Loading {cls.transformer_name}.")
2558
+ transformer.load_lora_adapter(
2559
+ state_dict,
2560
+ network_alphas=network_alphas,
2561
+ adapter_name=adapter_name,
2562
+ _pipeline=_pipeline,
2563
+ low_cpu_mem_usage=low_cpu_mem_usage,
2564
+ hotswap=hotswap,
2565
+ )
2185
2566
 
2186
2567
  @classmethod
2187
2568
  # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
@@ -2195,6 +2576,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
2195
2576
  adapter_name=None,
2196
2577
  _pipeline=None,
2197
2578
  low_cpu_mem_usage=False,
2579
+ hotswap: bool = False,
2198
2580
  ):
2199
2581
  """
2200
2582
  This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -2220,6 +2602,29 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
2220
2602
  low_cpu_mem_usage (`bool`, *optional*):
2221
2603
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
2222
2604
  weights.
2605
+ hotswap : (`bool`, *optional*)
2606
+ Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
2607
+ in-place. This means that, instead of loading an additional adapter, this will take the existing
2608
+ adapter weights and replace them with the weights of the new adapter. This can be faster and more
2609
+ memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
2610
+ torch.compile, loading the new adapter does not require recompilation of the model. When using
2611
+ hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
2612
+
2613
+ If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
2614
+ to call an additional method before loading the adapter:
2615
+
2616
+ ```py
2617
+ pipeline = ... # load diffusers pipeline
2618
+ max_rank = ... # the highest rank among all LoRAs that you want to load
2619
+ # call *before* compiling and loading the LoRA adapter
2620
+ pipeline.enable_lora_hotswap(target_rank=max_rank)
2621
+ pipeline.load_lora_weights(file_name)
2622
+ # optionally compile the model now
2623
+ ```
2624
+
2625
+ Note that hotswapping adapters of the text encoder is not yet supported. There are some further
2626
+ limitations to this technique, which are documented here:
2627
+ https://huggingface.co/docs/peft/main/en/package_reference/hotswap
2223
2628
  """
2224
2629
  _load_lora_into_text_encoder(
2225
2630
  state_dict=state_dict,
@@ -2231,6 +2636,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
2231
2636
  adapter_name=adapter_name,
2232
2637
  _pipeline=_pipeline,
2233
2638
  low_cpu_mem_usage=low_cpu_mem_usage,
2639
+ hotswap=hotswap,
2234
2640
  )
2235
2641
 
2236
2642
  @classmethod
@@ -2447,7 +2853,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
2447
2853
  @classmethod
2448
2854
  # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogVideoXTransformer3DModel
2449
2855
  def load_lora_into_transformer(
2450
- cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
2856
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
2451
2857
  ):
2452
2858
  """
2453
2859
  This will load the LoRA layers specified in `state_dict` into `transformer`.
@@ -2465,6 +2871,29 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
2465
2871
  low_cpu_mem_usage (`bool`, *optional*):
2466
2872
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
2467
2873
  weights.
2874
+ hotswap : (`bool`, *optional*)
2875
+ Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
2876
+ in-place. This means that, instead of loading an additional adapter, this will take the existing
2877
+ adapter weights and replace them with the weights of the new adapter. This can be faster and more
2878
+ memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
2879
+ torch.compile, loading the new adapter does not require recompilation of the model. When using
2880
+ hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
2881
+
2882
+ If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
2883
+ to call an additional method before loading the adapter:
2884
+
2885
+ ```py
2886
+ pipeline = ... # load diffusers pipeline
2887
+ max_rank = ... # the highest rank among all LoRAs that you want to load
2888
+ # call *before* compiling and loading the LoRA adapter
2889
+ pipeline.enable_lora_hotswap(target_rank=max_rank)
2890
+ pipeline.load_lora_weights(file_name)
2891
+ # optionally compile the model now
2892
+ ```
2893
+
2894
+ Note that hotswapping adapters of the text encoder is not yet supported. There are some further
2895
+ limitations to this technique, which are documented here:
2896
+ https://huggingface.co/docs/peft/main/en/package_reference/hotswap
2468
2897
  """
2469
2898
  if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
2470
2899
  raise ValueError(
@@ -2479,6 +2908,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
2479
2908
  adapter_name=adapter_name,
2480
2909
  _pipeline=_pipeline,
2481
2910
  low_cpu_mem_usage=low_cpu_mem_usage,
2911
+ hotswap=hotswap,
2482
2912
  )
2483
2913
 
2484
2914
  @classmethod
@@ -2569,7 +2999,11 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
2569
2999
  ```
2570
3000
  """
2571
3001
  super().fuse_lora(
2572
- components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
3002
+ components=components,
3003
+ lora_scale=lora_scale,
3004
+ safe_fusing=safe_fusing,
3005
+ adapter_names=adapter_names,
3006
+ **kwargs,
2573
3007
  )
2574
3008
 
2575
3009
  def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
@@ -2587,7 +3021,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
2587
3021
  components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
2588
3022
  unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
2589
3023
  """
2590
- super().unfuse_lora(components=components)
3024
+ super().unfuse_lora(components=components, **kwargs)
2591
3025
 
2592
3026
 
2593
3027
  class Mochi1LoraLoaderMixin(LoraBaseMixin):
@@ -2750,7 +3184,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
2750
3184
  @classmethod
2751
3185
  # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->MochiTransformer3DModel
2752
3186
  def load_lora_into_transformer(
2753
- cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
3187
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
2754
3188
  ):
2755
3189
  """
2756
3190
  This will load the LoRA layers specified in `state_dict` into `transformer`.
@@ -2768,6 +3202,29 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
2768
3202
  low_cpu_mem_usage (`bool`, *optional*):
2769
3203
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
2770
3204
  weights.
3205
+ hotswap : (`bool`, *optional*)
3206
+ Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
3207
+ in-place. This means that, instead of loading an additional adapter, this will take the existing
3208
+ adapter weights and replace them with the weights of the new adapter. This can be faster and more
3209
+ memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
3210
+ torch.compile, loading the new adapter does not require recompilation of the model. When using
3211
+ hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
3212
+
3213
+ If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
3214
+ to call an additional method before loading the adapter:
3215
+
3216
+ ```py
3217
+ pipeline = ... # load diffusers pipeline
3218
+ max_rank = ... # the highest rank among all LoRAs that you want to load
3219
+ # call *before* compiling and loading the LoRA adapter
3220
+ pipeline.enable_lora_hotswap(target_rank=max_rank)
3221
+ pipeline.load_lora_weights(file_name)
3222
+ # optionally compile the model now
3223
+ ```
3224
+
3225
+ Note that hotswapping adapters of the text encoder is not yet supported. There are some further
3226
+ limitations to this technique, which are documented here:
3227
+ https://huggingface.co/docs/peft/main/en/package_reference/hotswap
2771
3228
  """
2772
3229
  if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
2773
3230
  raise ValueError(
@@ -2782,6 +3239,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
2782
3239
  adapter_name=adapter_name,
2783
3240
  _pipeline=_pipeline,
2784
3241
  low_cpu_mem_usage=low_cpu_mem_usage,
3242
+ hotswap=hotswap,
2785
3243
  )
2786
3244
 
2787
3245
  @classmethod
@@ -2832,6 +3290,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
2832
3290
  safe_serialization=safe_serialization,
2833
3291
  )
2834
3292
 
3293
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
2835
3294
  def fuse_lora(
2836
3295
  self,
2837
3296
  components: List[str] = ["transformer"],
@@ -2872,9 +3331,14 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
2872
3331
  ```
2873
3332
  """
2874
3333
  super().fuse_lora(
2875
- components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
3334
+ components=components,
3335
+ lora_scale=lora_scale,
3336
+ safe_fusing=safe_fusing,
3337
+ adapter_names=adapter_names,
3338
+ **kwargs,
2876
3339
  )
2877
3340
 
3341
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
2878
3342
  def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
2879
3343
  r"""
2880
3344
  Reverses the effect of
@@ -2890,7 +3354,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
2890
3354
  components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
2891
3355
  unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
2892
3356
  """
2893
- super().unfuse_lora(components=components)
3357
+ super().unfuse_lora(components=components, **kwargs)
2894
3358
 
2895
3359
 
2896
3360
  class LTXVideoLoraLoaderMixin(LoraBaseMixin):
@@ -3053,7 +3517,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
3053
3517
  @classmethod
3054
3518
  # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->LTXVideoTransformer3DModel
3055
3519
  def load_lora_into_transformer(
3056
- cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
3520
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
3057
3521
  ):
3058
3522
  """
3059
3523
  This will load the LoRA layers specified in `state_dict` into `transformer`.
@@ -3071,6 +3535,29 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
3071
3535
  low_cpu_mem_usage (`bool`, *optional*):
3072
3536
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
3073
3537
  weights.
3538
+ hotswap : (`bool`, *optional*)
3539
+ Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
3540
+ in-place. This means that, instead of loading an additional adapter, this will take the existing
3541
+ adapter weights and replace them with the weights of the new adapter. This can be faster and more
3542
+ memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
3543
+ torch.compile, loading the new adapter does not require recompilation of the model. When using
3544
+ hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
3545
+
3546
+ If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
3547
+ to call an additional method before loading the adapter:
3548
+
3549
+ ```py
3550
+ pipeline = ... # load diffusers pipeline
3551
+ max_rank = ... # the highest rank among all LoRAs that you want to load
3552
+ # call *before* compiling and loading the LoRA adapter
3553
+ pipeline.enable_lora_hotswap(target_rank=max_rank)
3554
+ pipeline.load_lora_weights(file_name)
3555
+ # optionally compile the model now
3556
+ ```
3557
+
3558
+ Note that hotswapping adapters of the text encoder is not yet supported. There are some further
3559
+ limitations to this technique, which are documented here:
3560
+ https://huggingface.co/docs/peft/main/en/package_reference/hotswap
3074
3561
  """
3075
3562
  if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
3076
3563
  raise ValueError(
@@ -3085,6 +3572,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
3085
3572
  adapter_name=adapter_name,
3086
3573
  _pipeline=_pipeline,
3087
3574
  low_cpu_mem_usage=low_cpu_mem_usage,
3575
+ hotswap=hotswap,
3088
3576
  )
3089
3577
 
3090
3578
  @classmethod
@@ -3135,6 +3623,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
3135
3623
  safe_serialization=safe_serialization,
3136
3624
  )
3137
3625
 
3626
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
3138
3627
  def fuse_lora(
3139
3628
  self,
3140
3629
  components: List[str] = ["transformer"],
@@ -3175,9 +3664,14 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
3175
3664
  ```
3176
3665
  """
3177
3666
  super().fuse_lora(
3178
- components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
3667
+ components=components,
3668
+ lora_scale=lora_scale,
3669
+ safe_fusing=safe_fusing,
3670
+ adapter_names=adapter_names,
3671
+ **kwargs,
3179
3672
  )
3180
3673
 
3674
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
3181
3675
  def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
3182
3676
  r"""
3183
3677
  Reverses the effect of
@@ -3193,7 +3687,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
3193
3687
  components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
3194
3688
  unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
3195
3689
  """
3196
- super().unfuse_lora(components=components)
3690
+ super().unfuse_lora(components=components, **kwargs)
3197
3691
 
3198
3692
 
3199
3693
  class SanaLoraLoaderMixin(LoraBaseMixin):
@@ -3356,7 +3850,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
3356
3850
  @classmethod
3357
3851
  # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SanaTransformer2DModel
3358
3852
  def load_lora_into_transformer(
3359
- cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
3853
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
3360
3854
  ):
3361
3855
  """
3362
3856
  This will load the LoRA layers specified in `state_dict` into `transformer`.
@@ -3374,6 +3868,29 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
3374
3868
  low_cpu_mem_usage (`bool`, *optional*):
3375
3869
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
3376
3870
  weights.
3871
+ hotswap : (`bool`, *optional*)
3872
+ Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
3873
+ in-place. This means that, instead of loading an additional adapter, this will take the existing
3874
+ adapter weights and replace them with the weights of the new adapter. This can be faster and more
3875
+ memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
3876
+ torch.compile, loading the new adapter does not require recompilation of the model. When using
3877
+ hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
3878
+
3879
+ If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
3880
+ to call an additional method before loading the adapter:
3881
+
3882
+ ```py
3883
+ pipeline = ... # load diffusers pipeline
3884
+ max_rank = ... # the highest rank among all LoRAs that you want to load
3885
+ # call *before* compiling and loading the LoRA adapter
3886
+ pipeline.enable_lora_hotswap(target_rank=max_rank)
3887
+ pipeline.load_lora_weights(file_name)
3888
+ # optionally compile the model now
3889
+ ```
3890
+
3891
+ Note that hotswapping adapters of the text encoder is not yet supported. There are some further
3892
+ limitations to this technique, which are documented here:
3893
+ https://huggingface.co/docs/peft/main/en/package_reference/hotswap
3377
3894
  """
3378
3895
  if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
3379
3896
  raise ValueError(
@@ -3388,6 +3905,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
3388
3905
  adapter_name=adapter_name,
3389
3906
  _pipeline=_pipeline,
3390
3907
  low_cpu_mem_usage=low_cpu_mem_usage,
3908
+ hotswap=hotswap,
3391
3909
  )
3392
3910
 
3393
3911
  @classmethod
@@ -3438,6 +3956,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
3438
3956
  safe_serialization=safe_serialization,
3439
3957
  )
3440
3958
 
3959
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
3441
3960
  def fuse_lora(
3442
3961
  self,
3443
3962
  components: List[str] = ["transformer"],
@@ -3478,9 +3997,14 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
3478
3997
  ```
3479
3998
  """
3480
3999
  super().fuse_lora(
3481
- components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
4000
+ components=components,
4001
+ lora_scale=lora_scale,
4002
+ safe_fusing=safe_fusing,
4003
+ adapter_names=adapter_names,
4004
+ **kwargs,
3482
4005
  )
3483
4006
 
4007
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
3484
4008
  def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
3485
4009
  r"""
3486
4010
  Reverses the effect of
@@ -3496,7 +4020,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
3496
4020
  components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
3497
4021
  unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
3498
4022
  """
3499
- super().unfuse_lora(components=components)
4023
+ super().unfuse_lora(components=components, **kwargs)
3500
4024
 
3501
4025
 
3502
4026
  class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
@@ -3662,7 +4186,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
3662
4186
  @classmethod
3663
4187
  # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HunyuanVideoTransformer3DModel
3664
4188
  def load_lora_into_transformer(
3665
- cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
4189
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
3666
4190
  ):
3667
4191
  """
3668
4192
  This will load the LoRA layers specified in `state_dict` into `transformer`.
@@ -3680,6 +4204,29 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
3680
4204
  low_cpu_mem_usage (`bool`, *optional*):
3681
4205
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
3682
4206
  weights.
4207
+ hotswap : (`bool`, *optional*)
4208
+ Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
4209
+ in-place. This means that, instead of loading an additional adapter, this will take the existing
4210
+ adapter weights and replace them with the weights of the new adapter. This can be faster and more
4211
+ memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
4212
+ torch.compile, loading the new adapter does not require recompilation of the model. When using
4213
+ hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
4214
+
4215
+ If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
4216
+ to call an additional method before loading the adapter:
4217
+
4218
+ ```py
4219
+ pipeline = ... # load diffusers pipeline
4220
+ max_rank = ... # the highest rank among all LoRAs that you want to load
4221
+ # call *before* compiling and loading the LoRA adapter
4222
+ pipeline.enable_lora_hotswap(target_rank=max_rank)
4223
+ pipeline.load_lora_weights(file_name)
4224
+ # optionally compile the model now
4225
+ ```
4226
+
4227
+ Note that hotswapping adapters of the text encoder is not yet supported. There are some further
4228
+ limitations to this technique, which are documented here:
4229
+ https://huggingface.co/docs/peft/main/en/package_reference/hotswap
3683
4230
  """
3684
4231
  if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
3685
4232
  raise ValueError(
@@ -3694,6 +4241,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
3694
4241
  adapter_name=adapter_name,
3695
4242
  _pipeline=_pipeline,
3696
4243
  low_cpu_mem_usage=low_cpu_mem_usage,
4244
+ hotswap=hotswap,
3697
4245
  )
3698
4246
 
3699
4247
  @classmethod
@@ -3744,6 +4292,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
3744
4292
  safe_serialization=safe_serialization,
3745
4293
  )
3746
4294
 
4295
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
3747
4296
  def fuse_lora(
3748
4297
  self,
3749
4298
  components: List[str] = ["transformer"],
@@ -3784,9 +4333,1048 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
3784
4333
  ```
3785
4334
  """
3786
4335
  super().fuse_lora(
3787
- components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
4336
+ components=components,
4337
+ lora_scale=lora_scale,
4338
+ safe_fusing=safe_fusing,
4339
+ adapter_names=adapter_names,
4340
+ **kwargs,
4341
+ )
4342
+
4343
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
4344
+ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
4345
+ r"""
4346
+ Reverses the effect of
4347
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
4348
+
4349
+ <Tip warning={true}>
4350
+
4351
+ This is an experimental API.
4352
+
4353
+ </Tip>
4354
+
4355
+ Args:
4356
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
4357
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
4358
+ """
4359
+ super().unfuse_lora(components=components, **kwargs)
4360
+
4361
+
4362
+ class Lumina2LoraLoaderMixin(LoraBaseMixin):
4363
+ r"""
4364
+ Load LoRA layers into [`Lumina2Transformer2DModel`]. Specific to [`Lumina2Text2ImgPipeline`].
4365
+ """
4366
+
4367
+ _lora_loadable_modules = ["transformer"]
4368
+ transformer_name = TRANSFORMER_NAME
4369
+
4370
+ @classmethod
4371
+ @validate_hf_hub_args
4372
+ def lora_state_dict(
4373
+ cls,
4374
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
4375
+ **kwargs,
4376
+ ):
4377
+ r"""
4378
+ Return state dict for lora weights and the network alphas.
4379
+
4380
+ <Tip warning={true}>
4381
+
4382
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
4383
+
4384
+ This function is experimental and might change in the future.
4385
+
4386
+ </Tip>
4387
+
4388
+ Parameters:
4389
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
4390
+ Can be either:
4391
+
4392
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
4393
+ the Hub.
4394
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
4395
+ with [`ModelMixin.save_pretrained`].
4396
+ - A [torch state
4397
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
4398
+
4399
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
4400
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
4401
+ is not used.
4402
+ force_download (`bool`, *optional*, defaults to `False`):
4403
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
4404
+ cached versions if they exist.
4405
+
4406
+ proxies (`Dict[str, str]`, *optional*):
4407
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
4408
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
4409
+ local_files_only (`bool`, *optional*, defaults to `False`):
4410
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
4411
+ won't be downloaded from the Hub.
4412
+ token (`str` or *bool*, *optional*):
4413
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
4414
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
4415
+ revision (`str`, *optional*, defaults to `"main"`):
4416
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
4417
+ allowed by Git.
4418
+ subfolder (`str`, *optional*, defaults to `""`):
4419
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
4420
+
4421
+ """
4422
+ # Load the main state dict first which has the LoRA layers for either of
4423
+ # transformer and text encoder or both.
4424
+ cache_dir = kwargs.pop("cache_dir", None)
4425
+ force_download = kwargs.pop("force_download", False)
4426
+ proxies = kwargs.pop("proxies", None)
4427
+ local_files_only = kwargs.pop("local_files_only", None)
4428
+ token = kwargs.pop("token", None)
4429
+ revision = kwargs.pop("revision", None)
4430
+ subfolder = kwargs.pop("subfolder", None)
4431
+ weight_name = kwargs.pop("weight_name", None)
4432
+ use_safetensors = kwargs.pop("use_safetensors", None)
4433
+
4434
+ allow_pickle = False
4435
+ if use_safetensors is None:
4436
+ use_safetensors = True
4437
+ allow_pickle = True
4438
+
4439
+ user_agent = {
4440
+ "file_type": "attn_procs_weights",
4441
+ "framework": "pytorch",
4442
+ }
4443
+
4444
+ state_dict = _fetch_state_dict(
4445
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
4446
+ weight_name=weight_name,
4447
+ use_safetensors=use_safetensors,
4448
+ local_files_only=local_files_only,
4449
+ cache_dir=cache_dir,
4450
+ force_download=force_download,
4451
+ proxies=proxies,
4452
+ token=token,
4453
+ revision=revision,
4454
+ subfolder=subfolder,
4455
+ user_agent=user_agent,
4456
+ allow_pickle=allow_pickle,
4457
+ )
4458
+
4459
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
4460
+ if is_dora_scale_present:
4461
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
4462
+ logger.warning(warn_msg)
4463
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
4464
+
4465
+ # conversion.
4466
+ non_diffusers = any(k.startswith("diffusion_model.") for k in state_dict)
4467
+ if non_diffusers:
4468
+ state_dict = _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict)
4469
+
4470
+ return state_dict
4471
+
4472
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
4473
+ def load_lora_weights(
4474
+ self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
4475
+ ):
4476
+ """
4477
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
4478
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
4479
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
4480
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
4481
+ dict is loaded into `self.transformer`.
4482
+
4483
+ Parameters:
4484
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
4485
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
4486
+ adapter_name (`str`, *optional*):
4487
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
4488
+ `default_{i}` where i is the total number of adapters being loaded.
4489
+ low_cpu_mem_usage (`bool`, *optional*):
4490
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
4491
+ weights.
4492
+ kwargs (`dict`, *optional*):
4493
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
4494
+ """
4495
+ if not USE_PEFT_BACKEND:
4496
+ raise ValueError("PEFT backend is required for this method.")
4497
+
4498
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
4499
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
4500
+ raise ValueError(
4501
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
4502
+ )
4503
+
4504
+ # if a dict is passed, copy it instead of modifying it inplace
4505
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
4506
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
4507
+
4508
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
4509
+ state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
4510
+
4511
+ is_correct_format = all("lora" in key for key in state_dict.keys())
4512
+ if not is_correct_format:
4513
+ raise ValueError("Invalid LoRA checkpoint.")
4514
+
4515
+ self.load_lora_into_transformer(
4516
+ state_dict,
4517
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
4518
+ adapter_name=adapter_name,
4519
+ _pipeline=self,
4520
+ low_cpu_mem_usage=low_cpu_mem_usage,
4521
+ )
4522
+
4523
+ @classmethod
4524
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->Lumina2Transformer2DModel
4525
+ def load_lora_into_transformer(
4526
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
4527
+ ):
4528
+ """
4529
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
4530
+
4531
+ Parameters:
4532
+ state_dict (`dict`):
4533
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
4534
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
4535
+ encoder lora layers.
4536
+ transformer (`Lumina2Transformer2DModel`):
4537
+ The Transformer model to load the LoRA layers into.
4538
+ adapter_name (`str`, *optional*):
4539
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
4540
+ `default_{i}` where i is the total number of adapters being loaded.
4541
+ low_cpu_mem_usage (`bool`, *optional*):
4542
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
4543
+ weights.
4544
+ hotswap : (`bool`, *optional*)
4545
+ Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
4546
+ in-place. This means that, instead of loading an additional adapter, this will take the existing
4547
+ adapter weights and replace them with the weights of the new adapter. This can be faster and more
4548
+ memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
4549
+ torch.compile, loading the new adapter does not require recompilation of the model. When using
4550
+ hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
4551
+
4552
+ If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
4553
+ to call an additional method before loading the adapter:
4554
+
4555
+ ```py
4556
+ pipeline = ... # load diffusers pipeline
4557
+ max_rank = ... # the highest rank among all LoRAs that you want to load
4558
+ # call *before* compiling and loading the LoRA adapter
4559
+ pipeline.enable_lora_hotswap(target_rank=max_rank)
4560
+ pipeline.load_lora_weights(file_name)
4561
+ # optionally compile the model now
4562
+ ```
4563
+
4564
+ Note that hotswapping adapters of the text encoder is not yet supported. There are some further
4565
+ limitations to this technique, which are documented here:
4566
+ https://huggingface.co/docs/peft/main/en/package_reference/hotswap
4567
+ """
4568
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
4569
+ raise ValueError(
4570
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
4571
+ )
4572
+
4573
+ # Load the layers corresponding to transformer.
4574
+ logger.info(f"Loading {cls.transformer_name}.")
4575
+ transformer.load_lora_adapter(
4576
+ state_dict,
4577
+ network_alphas=None,
4578
+ adapter_name=adapter_name,
4579
+ _pipeline=_pipeline,
4580
+ low_cpu_mem_usage=low_cpu_mem_usage,
4581
+ hotswap=hotswap,
4582
+ )
4583
+
4584
+ @classmethod
4585
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
4586
+ def save_lora_weights(
4587
+ cls,
4588
+ save_directory: Union[str, os.PathLike],
4589
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
4590
+ is_main_process: bool = True,
4591
+ weight_name: str = None,
4592
+ save_function: Callable = None,
4593
+ safe_serialization: bool = True,
4594
+ ):
4595
+ r"""
4596
+ Save the LoRA parameters corresponding to the UNet and text encoder.
4597
+
4598
+ Arguments:
4599
+ save_directory (`str` or `os.PathLike`):
4600
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
4601
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
4602
+ State dict of the LoRA layers corresponding to the `transformer`.
4603
+ is_main_process (`bool`, *optional*, defaults to `True`):
4604
+ Whether the process calling this is the main process or not. Useful during distributed training and you
4605
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
4606
+ process to avoid race conditions.
4607
+ save_function (`Callable`):
4608
+ The function to use to save the state dictionary. Useful during distributed training when you need to
4609
+ replace `torch.save` with another method. Can be configured with the environment variable
4610
+ `DIFFUSERS_SAVE_MODE`.
4611
+ safe_serialization (`bool`, *optional*, defaults to `True`):
4612
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
4613
+ """
4614
+ state_dict = {}
4615
+
4616
+ if not transformer_lora_layers:
4617
+ raise ValueError("You must pass `transformer_lora_layers`.")
4618
+
4619
+ if transformer_lora_layers:
4620
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
4621
+
4622
+ # Save the model
4623
+ cls.write_lora_layers(
4624
+ state_dict=state_dict,
4625
+ save_directory=save_directory,
4626
+ is_main_process=is_main_process,
4627
+ weight_name=weight_name,
4628
+ save_function=save_function,
4629
+ safe_serialization=safe_serialization,
4630
+ )
4631
+
4632
+ # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
4633
+ def fuse_lora(
4634
+ self,
4635
+ components: List[str] = ["transformer"],
4636
+ lora_scale: float = 1.0,
4637
+ safe_fusing: bool = False,
4638
+ adapter_names: Optional[List[str]] = None,
4639
+ **kwargs,
4640
+ ):
4641
+ r"""
4642
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
4643
+
4644
+ <Tip warning={true}>
4645
+
4646
+ This is an experimental API.
4647
+
4648
+ </Tip>
4649
+
4650
+ Args:
4651
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
4652
+ lora_scale (`float`, defaults to 1.0):
4653
+ Controls how much to influence the outputs with the LoRA parameters.
4654
+ safe_fusing (`bool`, defaults to `False`):
4655
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
4656
+ adapter_names (`List[str]`, *optional*):
4657
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
4658
+
4659
+ Example:
4660
+
4661
+ ```py
4662
+ from diffusers import DiffusionPipeline
4663
+ import torch
4664
+
4665
+ pipeline = DiffusionPipeline.from_pretrained(
4666
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
4667
+ ).to("cuda")
4668
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
4669
+ pipeline.fuse_lora(lora_scale=0.7)
4670
+ ```
4671
+ """
4672
+ super().fuse_lora(
4673
+ components=components,
4674
+ lora_scale=lora_scale,
4675
+ safe_fusing=safe_fusing,
4676
+ adapter_names=adapter_names,
4677
+ **kwargs,
4678
+ )
4679
+
4680
+ # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
4681
+ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
4682
+ r"""
4683
+ Reverses the effect of
4684
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
4685
+
4686
+ <Tip warning={true}>
4687
+
4688
+ This is an experimental API.
4689
+
4690
+ </Tip>
4691
+
4692
+ Args:
4693
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
4694
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
4695
+ """
4696
+ super().unfuse_lora(components=components, **kwargs)
4697
+
4698
+
4699
+ class WanLoraLoaderMixin(LoraBaseMixin):
4700
+ r"""
4701
+ Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`WanPipeline`] and `[WanImageToVideoPipeline`].
4702
+ """
4703
+
4704
+ _lora_loadable_modules = ["transformer"]
4705
+ transformer_name = TRANSFORMER_NAME
4706
+
4707
+ @classmethod
4708
+ @validate_hf_hub_args
4709
+ def lora_state_dict(
4710
+ cls,
4711
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
4712
+ **kwargs,
4713
+ ):
4714
+ r"""
4715
+ Return state dict for lora weights and the network alphas.
4716
+
4717
+ <Tip warning={true}>
4718
+
4719
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
4720
+
4721
+ This function is experimental and might change in the future.
4722
+
4723
+ </Tip>
4724
+
4725
+ Parameters:
4726
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
4727
+ Can be either:
4728
+
4729
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
4730
+ the Hub.
4731
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
4732
+ with [`ModelMixin.save_pretrained`].
4733
+ - A [torch state
4734
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
4735
+
4736
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
4737
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
4738
+ is not used.
4739
+ force_download (`bool`, *optional*, defaults to `False`):
4740
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
4741
+ cached versions if they exist.
4742
+
4743
+ proxies (`Dict[str, str]`, *optional*):
4744
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
4745
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
4746
+ local_files_only (`bool`, *optional*, defaults to `False`):
4747
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
4748
+ won't be downloaded from the Hub.
4749
+ token (`str` or *bool*, *optional*):
4750
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
4751
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
4752
+ revision (`str`, *optional*, defaults to `"main"`):
4753
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
4754
+ allowed by Git.
4755
+ subfolder (`str`, *optional*, defaults to `""`):
4756
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
4757
+
4758
+ """
4759
+ # Load the main state dict first which has the LoRA layers for either of
4760
+ # transformer and text encoder or both.
4761
+ cache_dir = kwargs.pop("cache_dir", None)
4762
+ force_download = kwargs.pop("force_download", False)
4763
+ proxies = kwargs.pop("proxies", None)
4764
+ local_files_only = kwargs.pop("local_files_only", None)
4765
+ token = kwargs.pop("token", None)
4766
+ revision = kwargs.pop("revision", None)
4767
+ subfolder = kwargs.pop("subfolder", None)
4768
+ weight_name = kwargs.pop("weight_name", None)
4769
+ use_safetensors = kwargs.pop("use_safetensors", None)
4770
+
4771
+ allow_pickle = False
4772
+ if use_safetensors is None:
4773
+ use_safetensors = True
4774
+ allow_pickle = True
4775
+
4776
+ user_agent = {
4777
+ "file_type": "attn_procs_weights",
4778
+ "framework": "pytorch",
4779
+ }
4780
+
4781
+ state_dict = _fetch_state_dict(
4782
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
4783
+ weight_name=weight_name,
4784
+ use_safetensors=use_safetensors,
4785
+ local_files_only=local_files_only,
4786
+ cache_dir=cache_dir,
4787
+ force_download=force_download,
4788
+ proxies=proxies,
4789
+ token=token,
4790
+ revision=revision,
4791
+ subfolder=subfolder,
4792
+ user_agent=user_agent,
4793
+ allow_pickle=allow_pickle,
4794
+ )
4795
+ if any(k.startswith("diffusion_model.") for k in state_dict):
4796
+ state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict)
4797
+
4798
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
4799
+ if is_dora_scale_present:
4800
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
4801
+ logger.warning(warn_msg)
4802
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
4803
+
4804
+ return state_dict
4805
+
4806
+ @classmethod
4807
+ def _maybe_expand_t2v_lora_for_i2v(
4808
+ cls,
4809
+ transformer: torch.nn.Module,
4810
+ state_dict,
4811
+ ):
4812
+ if transformer.config.image_dim is None:
4813
+ return state_dict
4814
+
4815
+ if any(k.startswith("transformer.blocks.") for k in state_dict):
4816
+ num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict})
4817
+ is_i2v_lora = any("add_k_proj" in k for k in state_dict) and any("add_v_proj" in k for k in state_dict)
4818
+
4819
+ if is_i2v_lora:
4820
+ return state_dict
4821
+
4822
+ for i in range(num_blocks):
4823
+ for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
4824
+ state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like(
4825
+ state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"]
4826
+ )
4827
+ state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like(
4828
+ state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"]
4829
+ )
4830
+
4831
+ return state_dict
4832
+
4833
+ def load_lora_weights(
4834
+ self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
4835
+ ):
4836
+ """
4837
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
4838
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
4839
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
4840
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
4841
+ dict is loaded into `self.transformer`.
4842
+
4843
+ Parameters:
4844
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
4845
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
4846
+ adapter_name (`str`, *optional*):
4847
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
4848
+ `default_{i}` where i is the total number of adapters being loaded.
4849
+ low_cpu_mem_usage (`bool`, *optional*):
4850
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
4851
+ weights.
4852
+ kwargs (`dict`, *optional*):
4853
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
4854
+ """
4855
+ if not USE_PEFT_BACKEND:
4856
+ raise ValueError("PEFT backend is required for this method.")
4857
+
4858
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
4859
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
4860
+ raise ValueError(
4861
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
4862
+ )
4863
+
4864
+ # if a dict is passed, copy it instead of modifying it inplace
4865
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
4866
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
4867
+
4868
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
4869
+ state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
4870
+ # convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers
4871
+ state_dict = self._maybe_expand_t2v_lora_for_i2v(
4872
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
4873
+ state_dict=state_dict,
4874
+ )
4875
+ is_correct_format = all("lora" in key for key in state_dict.keys())
4876
+ if not is_correct_format:
4877
+ raise ValueError("Invalid LoRA checkpoint.")
4878
+
4879
+ self.load_lora_into_transformer(
4880
+ state_dict,
4881
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
4882
+ adapter_name=adapter_name,
4883
+ _pipeline=self,
4884
+ low_cpu_mem_usage=low_cpu_mem_usage,
4885
+ )
4886
+
4887
+ @classmethod
4888
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel
4889
+ def load_lora_into_transformer(
4890
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
4891
+ ):
4892
+ """
4893
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
4894
+
4895
+ Parameters:
4896
+ state_dict (`dict`):
4897
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
4898
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
4899
+ encoder lora layers.
4900
+ transformer (`WanTransformer3DModel`):
4901
+ The Transformer model to load the LoRA layers into.
4902
+ adapter_name (`str`, *optional*):
4903
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
4904
+ `default_{i}` where i is the total number of adapters being loaded.
4905
+ low_cpu_mem_usage (`bool`, *optional*):
4906
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
4907
+ weights.
4908
+ hotswap : (`bool`, *optional*)
4909
+ Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
4910
+ in-place. This means that, instead of loading an additional adapter, this will take the existing
4911
+ adapter weights and replace them with the weights of the new adapter. This can be faster and more
4912
+ memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
4913
+ torch.compile, loading the new adapter does not require recompilation of the model. When using
4914
+ hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
4915
+
4916
+ If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
4917
+ to call an additional method before loading the adapter:
4918
+
4919
+ ```py
4920
+ pipeline = ... # load diffusers pipeline
4921
+ max_rank = ... # the highest rank among all LoRAs that you want to load
4922
+ # call *before* compiling and loading the LoRA adapter
4923
+ pipeline.enable_lora_hotswap(target_rank=max_rank)
4924
+ pipeline.load_lora_weights(file_name)
4925
+ # optionally compile the model now
4926
+ ```
4927
+
4928
+ Note that hotswapping adapters of the text encoder is not yet supported. There are some further
4929
+ limitations to this technique, which are documented here:
4930
+ https://huggingface.co/docs/peft/main/en/package_reference/hotswap
4931
+ """
4932
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
4933
+ raise ValueError(
4934
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
4935
+ )
4936
+
4937
+ # Load the layers corresponding to transformer.
4938
+ logger.info(f"Loading {cls.transformer_name}.")
4939
+ transformer.load_lora_adapter(
4940
+ state_dict,
4941
+ network_alphas=None,
4942
+ adapter_name=adapter_name,
4943
+ _pipeline=_pipeline,
4944
+ low_cpu_mem_usage=low_cpu_mem_usage,
4945
+ hotswap=hotswap,
4946
+ )
4947
+
4948
+ @classmethod
4949
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
4950
+ def save_lora_weights(
4951
+ cls,
4952
+ save_directory: Union[str, os.PathLike],
4953
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
4954
+ is_main_process: bool = True,
4955
+ weight_name: str = None,
4956
+ save_function: Callable = None,
4957
+ safe_serialization: bool = True,
4958
+ ):
4959
+ r"""
4960
+ Save the LoRA parameters corresponding to the UNet and text encoder.
4961
+
4962
+ Arguments:
4963
+ save_directory (`str` or `os.PathLike`):
4964
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
4965
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
4966
+ State dict of the LoRA layers corresponding to the `transformer`.
4967
+ is_main_process (`bool`, *optional*, defaults to `True`):
4968
+ Whether the process calling this is the main process or not. Useful during distributed training and you
4969
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
4970
+ process to avoid race conditions.
4971
+ save_function (`Callable`):
4972
+ The function to use to save the state dictionary. Useful during distributed training when you need to
4973
+ replace `torch.save` with another method. Can be configured with the environment variable
4974
+ `DIFFUSERS_SAVE_MODE`.
4975
+ safe_serialization (`bool`, *optional*, defaults to `True`):
4976
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
4977
+ """
4978
+ state_dict = {}
4979
+
4980
+ if not transformer_lora_layers:
4981
+ raise ValueError("You must pass `transformer_lora_layers`.")
4982
+
4983
+ if transformer_lora_layers:
4984
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
4985
+
4986
+ # Save the model
4987
+ cls.write_lora_layers(
4988
+ state_dict=state_dict,
4989
+ save_directory=save_directory,
4990
+ is_main_process=is_main_process,
4991
+ weight_name=weight_name,
4992
+ save_function=save_function,
4993
+ safe_serialization=safe_serialization,
4994
+ )
4995
+
4996
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
4997
+ def fuse_lora(
4998
+ self,
4999
+ components: List[str] = ["transformer"],
5000
+ lora_scale: float = 1.0,
5001
+ safe_fusing: bool = False,
5002
+ adapter_names: Optional[List[str]] = None,
5003
+ **kwargs,
5004
+ ):
5005
+ r"""
5006
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
5007
+
5008
+ <Tip warning={true}>
5009
+
5010
+ This is an experimental API.
5011
+
5012
+ </Tip>
5013
+
5014
+ Args:
5015
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
5016
+ lora_scale (`float`, defaults to 1.0):
5017
+ Controls how much to influence the outputs with the LoRA parameters.
5018
+ safe_fusing (`bool`, defaults to `False`):
5019
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
5020
+ adapter_names (`List[str]`, *optional*):
5021
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
5022
+
5023
+ Example:
5024
+
5025
+ ```py
5026
+ from diffusers import DiffusionPipeline
5027
+ import torch
5028
+
5029
+ pipeline = DiffusionPipeline.from_pretrained(
5030
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
5031
+ ).to("cuda")
5032
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
5033
+ pipeline.fuse_lora(lora_scale=0.7)
5034
+ ```
5035
+ """
5036
+ super().fuse_lora(
5037
+ components=components,
5038
+ lora_scale=lora_scale,
5039
+ safe_fusing=safe_fusing,
5040
+ adapter_names=adapter_names,
5041
+ **kwargs,
5042
+ )
5043
+
5044
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
5045
+ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
5046
+ r"""
5047
+ Reverses the effect of
5048
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
5049
+
5050
+ <Tip warning={true}>
5051
+
5052
+ This is an experimental API.
5053
+
5054
+ </Tip>
5055
+
5056
+ Args:
5057
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
5058
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
5059
+ """
5060
+ super().unfuse_lora(components=components, **kwargs)
5061
+
5062
+
5063
+ class CogView4LoraLoaderMixin(LoraBaseMixin):
5064
+ r"""
5065
+ Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`CogView4Pipeline`].
5066
+ """
5067
+
5068
+ _lora_loadable_modules = ["transformer"]
5069
+ transformer_name = TRANSFORMER_NAME
5070
+
5071
+ @classmethod
5072
+ @validate_hf_hub_args
5073
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
5074
+ def lora_state_dict(
5075
+ cls,
5076
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
5077
+ **kwargs,
5078
+ ):
5079
+ r"""
5080
+ Return state dict for lora weights and the network alphas.
5081
+
5082
+ <Tip warning={true}>
5083
+
5084
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
5085
+
5086
+ This function is experimental and might change in the future.
5087
+
5088
+ </Tip>
5089
+
5090
+ Parameters:
5091
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
5092
+ Can be either:
5093
+
5094
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
5095
+ the Hub.
5096
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
5097
+ with [`ModelMixin.save_pretrained`].
5098
+ - A [torch state
5099
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
5100
+
5101
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
5102
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
5103
+ is not used.
5104
+ force_download (`bool`, *optional*, defaults to `False`):
5105
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
5106
+ cached versions if they exist.
5107
+
5108
+ proxies (`Dict[str, str]`, *optional*):
5109
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
5110
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
5111
+ local_files_only (`bool`, *optional*, defaults to `False`):
5112
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
5113
+ won't be downloaded from the Hub.
5114
+ token (`str` or *bool*, *optional*):
5115
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
5116
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
5117
+ revision (`str`, *optional*, defaults to `"main"`):
5118
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
5119
+ allowed by Git.
5120
+ subfolder (`str`, *optional*, defaults to `""`):
5121
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
5122
+
5123
+ """
5124
+ # Load the main state dict first which has the LoRA layers for either of
5125
+ # transformer and text encoder or both.
5126
+ cache_dir = kwargs.pop("cache_dir", None)
5127
+ force_download = kwargs.pop("force_download", False)
5128
+ proxies = kwargs.pop("proxies", None)
5129
+ local_files_only = kwargs.pop("local_files_only", None)
5130
+ token = kwargs.pop("token", None)
5131
+ revision = kwargs.pop("revision", None)
5132
+ subfolder = kwargs.pop("subfolder", None)
5133
+ weight_name = kwargs.pop("weight_name", None)
5134
+ use_safetensors = kwargs.pop("use_safetensors", None)
5135
+
5136
+ allow_pickle = False
5137
+ if use_safetensors is None:
5138
+ use_safetensors = True
5139
+ allow_pickle = True
5140
+
5141
+ user_agent = {
5142
+ "file_type": "attn_procs_weights",
5143
+ "framework": "pytorch",
5144
+ }
5145
+
5146
+ state_dict = _fetch_state_dict(
5147
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
5148
+ weight_name=weight_name,
5149
+ use_safetensors=use_safetensors,
5150
+ local_files_only=local_files_only,
5151
+ cache_dir=cache_dir,
5152
+ force_download=force_download,
5153
+ proxies=proxies,
5154
+ token=token,
5155
+ revision=revision,
5156
+ subfolder=subfolder,
5157
+ user_agent=user_agent,
5158
+ allow_pickle=allow_pickle,
5159
+ )
5160
+
5161
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
5162
+ if is_dora_scale_present:
5163
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
5164
+ logger.warning(warn_msg)
5165
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
5166
+
5167
+ return state_dict
5168
+
5169
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
5170
+ def load_lora_weights(
5171
+ self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
5172
+ ):
5173
+ """
5174
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
5175
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
5176
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
5177
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
5178
+ dict is loaded into `self.transformer`.
5179
+
5180
+ Parameters:
5181
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
5182
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
5183
+ adapter_name (`str`, *optional*):
5184
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
5185
+ `default_{i}` where i is the total number of adapters being loaded.
5186
+ low_cpu_mem_usage (`bool`, *optional*):
5187
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
5188
+ weights.
5189
+ kwargs (`dict`, *optional*):
5190
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
5191
+ """
5192
+ if not USE_PEFT_BACKEND:
5193
+ raise ValueError("PEFT backend is required for this method.")
5194
+
5195
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
5196
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
5197
+ raise ValueError(
5198
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
5199
+ )
5200
+
5201
+ # if a dict is passed, copy it instead of modifying it inplace
5202
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
5203
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
5204
+
5205
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
5206
+ state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
5207
+
5208
+ is_correct_format = all("lora" in key for key in state_dict.keys())
5209
+ if not is_correct_format:
5210
+ raise ValueError("Invalid LoRA checkpoint.")
5211
+
5212
+ self.load_lora_into_transformer(
5213
+ state_dict,
5214
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
5215
+ adapter_name=adapter_name,
5216
+ _pipeline=self,
5217
+ low_cpu_mem_usage=low_cpu_mem_usage,
5218
+ )
5219
+
5220
+ @classmethod
5221
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogView4Transformer2DModel
5222
+ def load_lora_into_transformer(
5223
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
5224
+ ):
5225
+ """
5226
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
5227
+
5228
+ Parameters:
5229
+ state_dict (`dict`):
5230
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
5231
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
5232
+ encoder lora layers.
5233
+ transformer (`CogView4Transformer2DModel`):
5234
+ The Transformer model to load the LoRA layers into.
5235
+ adapter_name (`str`, *optional*):
5236
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
5237
+ `default_{i}` where i is the total number of adapters being loaded.
5238
+ low_cpu_mem_usage (`bool`, *optional*):
5239
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
5240
+ weights.
5241
+ hotswap : (`bool`, *optional*)
5242
+ Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
5243
+ in-place. This means that, instead of loading an additional adapter, this will take the existing
5244
+ adapter weights and replace them with the weights of the new adapter. This can be faster and more
5245
+ memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
5246
+ torch.compile, loading the new adapter does not require recompilation of the model. When using
5247
+ hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
5248
+
5249
+ If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
5250
+ to call an additional method before loading the adapter:
5251
+
5252
+ ```py
5253
+ pipeline = ... # load diffusers pipeline
5254
+ max_rank = ... # the highest rank among all LoRAs that you want to load
5255
+ # call *before* compiling and loading the LoRA adapter
5256
+ pipeline.enable_lora_hotswap(target_rank=max_rank)
5257
+ pipeline.load_lora_weights(file_name)
5258
+ # optionally compile the model now
5259
+ ```
5260
+
5261
+ Note that hotswapping adapters of the text encoder is not yet supported. There are some further
5262
+ limitations to this technique, which are documented here:
5263
+ https://huggingface.co/docs/peft/main/en/package_reference/hotswap
5264
+ """
5265
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
5266
+ raise ValueError(
5267
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
5268
+ )
5269
+
5270
+ # Load the layers corresponding to transformer.
5271
+ logger.info(f"Loading {cls.transformer_name}.")
5272
+ transformer.load_lora_adapter(
5273
+ state_dict,
5274
+ network_alphas=None,
5275
+ adapter_name=adapter_name,
5276
+ _pipeline=_pipeline,
5277
+ low_cpu_mem_usage=low_cpu_mem_usage,
5278
+ hotswap=hotswap,
5279
+ )
5280
+
5281
+ @classmethod
5282
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
5283
+ def save_lora_weights(
5284
+ cls,
5285
+ save_directory: Union[str, os.PathLike],
5286
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
5287
+ is_main_process: bool = True,
5288
+ weight_name: str = None,
5289
+ save_function: Callable = None,
5290
+ safe_serialization: bool = True,
5291
+ ):
5292
+ r"""
5293
+ Save the LoRA parameters corresponding to the UNet and text encoder.
5294
+
5295
+ Arguments:
5296
+ save_directory (`str` or `os.PathLike`):
5297
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
5298
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
5299
+ State dict of the LoRA layers corresponding to the `transformer`.
5300
+ is_main_process (`bool`, *optional*, defaults to `True`):
5301
+ Whether the process calling this is the main process or not. Useful during distributed training and you
5302
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
5303
+ process to avoid race conditions.
5304
+ save_function (`Callable`):
5305
+ The function to use to save the state dictionary. Useful during distributed training when you need to
5306
+ replace `torch.save` with another method. Can be configured with the environment variable
5307
+ `DIFFUSERS_SAVE_MODE`.
5308
+ safe_serialization (`bool`, *optional*, defaults to `True`):
5309
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
5310
+ """
5311
+ state_dict = {}
5312
+
5313
+ if not transformer_lora_layers:
5314
+ raise ValueError("You must pass `transformer_lora_layers`.")
5315
+
5316
+ if transformer_lora_layers:
5317
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
5318
+
5319
+ # Save the model
5320
+ cls.write_lora_layers(
5321
+ state_dict=state_dict,
5322
+ save_directory=save_directory,
5323
+ is_main_process=is_main_process,
5324
+ weight_name=weight_name,
5325
+ save_function=save_function,
5326
+ safe_serialization=safe_serialization,
5327
+ )
5328
+
5329
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
5330
+ def fuse_lora(
5331
+ self,
5332
+ components: List[str] = ["transformer"],
5333
+ lora_scale: float = 1.0,
5334
+ safe_fusing: bool = False,
5335
+ adapter_names: Optional[List[str]] = None,
5336
+ **kwargs,
5337
+ ):
5338
+ r"""
5339
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
5340
+
5341
+ <Tip warning={true}>
5342
+
5343
+ This is an experimental API.
5344
+
5345
+ </Tip>
5346
+
5347
+ Args:
5348
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
5349
+ lora_scale (`float`, defaults to 1.0):
5350
+ Controls how much to influence the outputs with the LoRA parameters.
5351
+ safe_fusing (`bool`, defaults to `False`):
5352
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
5353
+ adapter_names (`List[str]`, *optional*):
5354
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
5355
+
5356
+ Example:
5357
+
5358
+ ```py
5359
+ from diffusers import DiffusionPipeline
5360
+ import torch
5361
+
5362
+ pipeline = DiffusionPipeline.from_pretrained(
5363
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
5364
+ ).to("cuda")
5365
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
5366
+ pipeline.fuse_lora(lora_scale=0.7)
5367
+ ```
5368
+ """
5369
+ super().fuse_lora(
5370
+ components=components,
5371
+ lora_scale=lora_scale,
5372
+ safe_fusing=safe_fusing,
5373
+ adapter_names=adapter_names,
5374
+ **kwargs,
3788
5375
  )
3789
5376
 
5377
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
3790
5378
  def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
3791
5379
  r"""
3792
5380
  Reverses the effect of
@@ -3802,7 +5390,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
3802
5390
  components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
3803
5391
  unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
3804
5392
  """
3805
- super().unfuse_lora(components=components)
5393
+ super().unfuse_lora(components=components, **kwargs)
3806
5394
 
3807
5395
 
3808
5396
  class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):