diffusers 0.32.1__py3-none-any.whl → 0.33.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (389) hide show
  1. diffusers/__init__.py +186 -3
  2. diffusers/configuration_utils.py +40 -12
  3. diffusers/dependency_versions_table.py +9 -2
  4. diffusers/hooks/__init__.py +9 -0
  5. diffusers/hooks/faster_cache.py +653 -0
  6. diffusers/hooks/group_offloading.py +793 -0
  7. diffusers/hooks/hooks.py +236 -0
  8. diffusers/hooks/layerwise_casting.py +245 -0
  9. diffusers/hooks/pyramid_attention_broadcast.py +311 -0
  10. diffusers/loaders/__init__.py +6 -0
  11. diffusers/loaders/ip_adapter.py +38 -30
  12. diffusers/loaders/lora_base.py +198 -28
  13. diffusers/loaders/lora_conversion_utils.py +679 -44
  14. diffusers/loaders/lora_pipeline.py +1963 -801
  15. diffusers/loaders/peft.py +169 -84
  16. diffusers/loaders/single_file.py +17 -2
  17. diffusers/loaders/single_file_model.py +53 -5
  18. diffusers/loaders/single_file_utils.py +653 -75
  19. diffusers/loaders/textual_inversion.py +9 -9
  20. diffusers/loaders/transformer_flux.py +8 -9
  21. diffusers/loaders/transformer_sd3.py +120 -39
  22. diffusers/loaders/unet.py +22 -32
  23. diffusers/models/__init__.py +22 -0
  24. diffusers/models/activations.py +9 -9
  25. diffusers/models/attention.py +0 -1
  26. diffusers/models/attention_processor.py +163 -25
  27. diffusers/models/auto_model.py +169 -0
  28. diffusers/models/autoencoders/__init__.py +2 -0
  29. diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
  30. diffusers/models/autoencoders/autoencoder_dc.py +106 -4
  31. diffusers/models/autoencoders/autoencoder_kl.py +0 -4
  32. diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
  33. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
  34. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
  35. diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
  36. diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
  37. diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
  38. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
  39. diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
  40. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
  41. diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
  42. diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
  43. diffusers/models/autoencoders/vae.py +31 -141
  44. diffusers/models/autoencoders/vq_model.py +3 -0
  45. diffusers/models/cache_utils.py +108 -0
  46. diffusers/models/controlnets/__init__.py +1 -0
  47. diffusers/models/controlnets/controlnet.py +3 -8
  48. diffusers/models/controlnets/controlnet_flux.py +14 -42
  49. diffusers/models/controlnets/controlnet_sd3.py +58 -34
  50. diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
  51. diffusers/models/controlnets/controlnet_union.py +27 -18
  52. diffusers/models/controlnets/controlnet_xs.py +7 -46
  53. diffusers/models/controlnets/multicontrolnet_union.py +196 -0
  54. diffusers/models/embeddings.py +18 -7
  55. diffusers/models/model_loading_utils.py +122 -80
  56. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  57. diffusers/models/modeling_flax_utils.py +1 -1
  58. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  59. diffusers/models/modeling_utils.py +617 -272
  60. diffusers/models/normalization.py +67 -14
  61. diffusers/models/resnet.py +1 -1
  62. diffusers/models/transformers/__init__.py +6 -0
  63. diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
  64. diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
  65. diffusers/models/transformers/consisid_transformer_3d.py +789 -0
  66. diffusers/models/transformers/dit_transformer_2d.py +5 -19
  67. diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
  68. diffusers/models/transformers/latte_transformer_3d.py +20 -15
  69. diffusers/models/transformers/lumina_nextdit2d.py +3 -1
  70. diffusers/models/transformers/pixart_transformer_2d.py +4 -19
  71. diffusers/models/transformers/prior_transformer.py +5 -1
  72. diffusers/models/transformers/sana_transformer.py +144 -40
  73. diffusers/models/transformers/stable_audio_transformer.py +5 -20
  74. diffusers/models/transformers/transformer_2d.py +7 -22
  75. diffusers/models/transformers/transformer_allegro.py +9 -17
  76. diffusers/models/transformers/transformer_cogview3plus.py +6 -17
  77. diffusers/models/transformers/transformer_cogview4.py +462 -0
  78. diffusers/models/transformers/transformer_easyanimate.py +527 -0
  79. diffusers/models/transformers/transformer_flux.py +68 -110
  80. diffusers/models/transformers/transformer_hunyuan_video.py +409 -49
  81. diffusers/models/transformers/transformer_ltx.py +53 -35
  82. diffusers/models/transformers/transformer_lumina2.py +548 -0
  83. diffusers/models/transformers/transformer_mochi.py +6 -17
  84. diffusers/models/transformers/transformer_omnigen.py +469 -0
  85. diffusers/models/transformers/transformer_sd3.py +56 -86
  86. diffusers/models/transformers/transformer_temporal.py +5 -11
  87. diffusers/models/transformers/transformer_wan.py +469 -0
  88. diffusers/models/unets/unet_1d.py +3 -1
  89. diffusers/models/unets/unet_2d.py +21 -20
  90. diffusers/models/unets/unet_2d_blocks.py +19 -243
  91. diffusers/models/unets/unet_2d_condition.py +4 -6
  92. diffusers/models/unets/unet_3d_blocks.py +14 -127
  93. diffusers/models/unets/unet_3d_condition.py +8 -12
  94. diffusers/models/unets/unet_i2vgen_xl.py +5 -13
  95. diffusers/models/unets/unet_kandinsky3.py +0 -4
  96. diffusers/models/unets/unet_motion_model.py +20 -114
  97. diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
  98. diffusers/models/unets/unet_stable_cascade.py +8 -35
  99. diffusers/models/unets/uvit_2d.py +1 -4
  100. diffusers/optimization.py +2 -2
  101. diffusers/pipelines/__init__.py +57 -8
  102. diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
  103. diffusers/pipelines/amused/pipeline_amused.py +15 -2
  104. diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
  105. diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
  106. diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
  107. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
  108. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
  109. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
  110. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
  111. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
  112. diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
  113. diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
  114. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
  115. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
  116. diffusers/pipelines/auto_pipeline.py +35 -14
  117. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  118. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
  119. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
  120. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
  121. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
  122. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
  123. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
  124. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
  125. diffusers/pipelines/cogview4/__init__.py +49 -0
  126. diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
  127. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
  128. diffusers/pipelines/cogview4/pipeline_output.py +21 -0
  129. diffusers/pipelines/consisid/__init__.py +49 -0
  130. diffusers/pipelines/consisid/consisid_utils.py +357 -0
  131. diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
  132. diffusers/pipelines/consisid/pipeline_output.py +20 -0
  133. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
  134. diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
  135. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
  136. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
  137. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
  138. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
  139. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
  140. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
  141. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
  142. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
  143. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
  144. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
  145. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
  146. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
  147. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
  148. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
  149. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
  150. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
  151. diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
  152. diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
  153. diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
  154. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
  155. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
  156. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
  157. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
  158. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
  159. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
  160. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
  161. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  162. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
  163. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
  164. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
  165. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
  166. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
  167. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
  168. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
  169. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
  170. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  171. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  172. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
  173. diffusers/pipelines/dit/pipeline_dit.py +15 -2
  174. diffusers/pipelines/easyanimate/__init__.py +52 -0
  175. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
  176. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
  177. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
  178. diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
  179. diffusers/pipelines/flux/pipeline_flux.py +53 -21
  180. diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
  181. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
  182. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
  183. diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
  184. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
  185. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
  186. diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
  187. diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
  188. diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
  189. diffusers/pipelines/free_noise_utils.py +3 -3
  190. diffusers/pipelines/hunyuan_video/__init__.py +4 -0
  191. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
  192. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
  193. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
  194. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
  195. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
  196. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
  197. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  198. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
  199. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
  200. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
  201. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
  202. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
  203. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
  204. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
  205. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
  206. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
  207. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
  208. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
  209. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
  210. diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
  211. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
  212. diffusers/pipelines/kolors/text_encoder.py +7 -34
  213. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
  214. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
  215. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
  216. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
  217. diffusers/pipelines/latte/pipeline_latte.py +36 -7
  218. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
  219. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
  220. diffusers/pipelines/ltx/__init__.py +2 -0
  221. diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
  222. diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
  223. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
  224. diffusers/pipelines/lumina/__init__.py +2 -2
  225. diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
  226. diffusers/pipelines/lumina2/__init__.py +48 -0
  227. diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
  228. diffusers/pipelines/marigold/__init__.py +2 -0
  229. diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
  230. diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
  231. diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
  232. diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
  233. diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
  234. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
  235. diffusers/pipelines/omnigen/__init__.py +50 -0
  236. diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
  237. diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
  238. diffusers/pipelines/onnx_utils.py +5 -3
  239. diffusers/pipelines/pag/pag_utils.py +1 -1
  240. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
  241. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
  242. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
  243. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
  244. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
  245. diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
  246. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
  247. diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
  248. diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
  249. diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
  250. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
  251. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
  252. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
  253. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
  254. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
  255. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
  256. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
  257. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
  258. diffusers/pipelines/pia/pipeline_pia.py +13 -1
  259. diffusers/pipelines/pipeline_flax_utils.py +7 -7
  260. diffusers/pipelines/pipeline_loading_utils.py +193 -83
  261. diffusers/pipelines/pipeline_utils.py +221 -106
  262. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
  263. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
  264. diffusers/pipelines/sana/__init__.py +2 -0
  265. diffusers/pipelines/sana/pipeline_sana.py +183 -58
  266. diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
  267. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
  268. diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
  269. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
  270. diffusers/pipelines/shap_e/renderer.py +6 -6
  271. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
  272. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
  273. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
  274. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
  275. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
  276. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
  277. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
  278. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
  279. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  280. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
  281. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
  282. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
  283. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
  284. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
  285. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
  286. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
  287. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
  288. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
  289. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
  290. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
  291. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
  292. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
  293. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
  294. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
  295. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
  296. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
  297. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
  298. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
  299. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
  300. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
  301. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
  302. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
  303. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
  304. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
  305. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
  306. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  307. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
  308. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
  309. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
  310. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
  311. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
  312. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
  313. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
  314. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
  315. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
  316. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
  317. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
  318. diffusers/pipelines/transformers_loading_utils.py +121 -0
  319. diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
  320. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
  321. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
  322. diffusers/pipelines/wan/__init__.py +51 -0
  323. diffusers/pipelines/wan/pipeline_output.py +20 -0
  324. diffusers/pipelines/wan/pipeline_wan.py +593 -0
  325. diffusers/pipelines/wan/pipeline_wan_i2v.py +722 -0
  326. diffusers/pipelines/wan/pipeline_wan_video2video.py +725 -0
  327. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
  328. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
  329. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
  330. diffusers/quantizers/auto.py +5 -1
  331. diffusers/quantizers/base.py +5 -9
  332. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
  333. diffusers/quantizers/bitsandbytes/utils.py +30 -20
  334. diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
  335. diffusers/quantizers/gguf/utils.py +4 -2
  336. diffusers/quantizers/quantization_config.py +59 -4
  337. diffusers/quantizers/quanto/__init__.py +1 -0
  338. diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
  339. diffusers/quantizers/quanto/utils.py +60 -0
  340. diffusers/quantizers/torchao/__init__.py +1 -1
  341. diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
  342. diffusers/schedulers/__init__.py +2 -1
  343. diffusers/schedulers/scheduling_consistency_models.py +1 -2
  344. diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
  345. diffusers/schedulers/scheduling_ddpm.py +2 -3
  346. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
  347. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
  348. diffusers/schedulers/scheduling_edm_euler.py +45 -10
  349. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
  350. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
  351. diffusers/schedulers/scheduling_heun_discrete.py +1 -1
  352. diffusers/schedulers/scheduling_lcm.py +1 -2
  353. diffusers/schedulers/scheduling_lms_discrete.py +1 -1
  354. diffusers/schedulers/scheduling_repaint.py +5 -1
  355. diffusers/schedulers/scheduling_scm.py +265 -0
  356. diffusers/schedulers/scheduling_tcd.py +1 -2
  357. diffusers/schedulers/scheduling_utils.py +2 -1
  358. diffusers/training_utils.py +14 -7
  359. diffusers/utils/__init__.py +10 -2
  360. diffusers/utils/constants.py +13 -1
  361. diffusers/utils/deprecation_utils.py +1 -1
  362. diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
  363. diffusers/utils/dummy_gguf_objects.py +17 -0
  364. diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
  365. diffusers/utils/dummy_pt_objects.py +233 -0
  366. diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
  367. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  368. diffusers/utils/dummy_torchao_objects.py +17 -0
  369. diffusers/utils/dynamic_modules_utils.py +1 -1
  370. diffusers/utils/export_utils.py +28 -3
  371. diffusers/utils/hub_utils.py +52 -102
  372. diffusers/utils/import_utils.py +121 -221
  373. diffusers/utils/loading_utils.py +14 -1
  374. diffusers/utils/logging.py +1 -2
  375. diffusers/utils/peft_utils.py +6 -14
  376. diffusers/utils/remote_utils.py +425 -0
  377. diffusers/utils/source_code_parsing_utils.py +52 -0
  378. diffusers/utils/state_dict_utils.py +15 -1
  379. diffusers/utils/testing_utils.py +243 -13
  380. diffusers/utils/torch_utils.py +10 -0
  381. diffusers/utils/typing_utils.py +91 -0
  382. diffusers/video_processor.py +1 -1
  383. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/METADATA +76 -44
  384. diffusers-0.33.0.dist-info/RECORD +608 -0
  385. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/WHEEL +1 -1
  386. diffusers-0.32.1.dist-info/RECORD +0 -550
  387. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/LICENSE +0 -0
  388. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/entry_points.txt +0 -0
  389. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/top_level.txt +0 -0
@@ -20,24 +20,31 @@ from huggingface_hub.utils import validate_hf_hub_args
20
20
 
21
21
  from ..utils import (
22
22
  USE_PEFT_BACKEND,
23
- convert_state_dict_to_diffusers,
24
- convert_state_dict_to_peft,
25
23
  deprecate,
26
- get_adapter_name,
27
- get_peft_kwargs,
24
+ get_submodule_by_name,
25
+ is_bitsandbytes_available,
26
+ is_gguf_available,
28
27
  is_peft_available,
29
28
  is_peft_version,
30
29
  is_torch_version,
31
30
  is_transformers_available,
32
31
  is_transformers_version,
33
32
  logging,
34
- scale_lora_layers,
35
33
  )
36
- from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, LoraBaseMixin, _fetch_state_dict # noqa
34
+ from .lora_base import ( # noqa
35
+ LORA_WEIGHT_NAME,
36
+ LORA_WEIGHT_NAME_SAFE,
37
+ LoraBaseMixin,
38
+ _fetch_state_dict,
39
+ _load_lora_into_text_encoder,
40
+ )
37
41
  from .lora_conversion_utils import (
38
42
  _convert_bfl_flux_control_lora_to_diffusers,
43
+ _convert_hunyuan_video_lora_to_diffusers,
39
44
  _convert_kohya_flux_lora_to_diffusers,
40
45
  _convert_non_diffusers_lora_to_diffusers,
46
+ _convert_non_diffusers_lumina2_lora_to_diffusers,
47
+ _convert_non_diffusers_wan_lora_to_diffusers,
41
48
  _convert_xlabs_flux_lora_to_diffusers,
42
49
  _maybe_map_sgm_blocks_to_diffusers,
43
50
  )
@@ -54,9 +61,6 @@ if is_torch_version(">=", "1.9.0"):
54
61
  _LOW_CPU_MEM_USAGE_DEFAULT_LORA = True
55
62
 
56
63
 
57
- if is_transformers_available():
58
- from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
59
-
60
64
  logger = logging.get_logger(__name__)
61
65
 
62
66
  TEXT_ENCODER_NAME = "text_encoder"
@@ -66,6 +70,49 @@ TRANSFORMER_NAME = "transformer"
66
70
  _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX = {"x_embedder": "in_channels"}
67
71
 
68
72
 
73
+ def _maybe_dequantize_weight_for_expanded_lora(model, module):
74
+ if is_bitsandbytes_available():
75
+ from ..quantizers.bitsandbytes import dequantize_bnb_weight
76
+
77
+ if is_gguf_available():
78
+ from ..quantizers.gguf.utils import dequantize_gguf_tensor
79
+
80
+ is_bnb_4bit_quantized = module.weight.__class__.__name__ == "Params4bit"
81
+ is_gguf_quantized = module.weight.__class__.__name__ == "GGUFParameter"
82
+
83
+ if is_bnb_4bit_quantized and not is_bitsandbytes_available():
84
+ raise ValueError(
85
+ "The checkpoint seems to have been quantized with `bitsandbytes` (4bits). Install `bitsandbytes` to load quantized checkpoints."
86
+ )
87
+ if is_gguf_quantized and not is_gguf_available():
88
+ raise ValueError(
89
+ "The checkpoint seems to have been quantized with `gguf`. Install `gguf` to load quantized checkpoints."
90
+ )
91
+
92
+ weight_on_cpu = False
93
+ if not module.weight.is_cuda:
94
+ weight_on_cpu = True
95
+
96
+ if is_bnb_4bit_quantized:
97
+ module_weight = dequantize_bnb_weight(
98
+ module.weight.cuda() if weight_on_cpu else module.weight,
99
+ state=module.weight.quant_state,
100
+ dtype=model.dtype,
101
+ ).data
102
+ elif is_gguf_quantized:
103
+ module_weight = dequantize_gguf_tensor(
104
+ module.weight.cuda() if weight_on_cpu else module.weight,
105
+ )
106
+ module_weight = module_weight.to(model.dtype)
107
+ else:
108
+ module_weight = module.weight.data
109
+
110
+ if weight_on_cpu:
111
+ module_weight = module_weight.cpu()
112
+
113
+ return module_weight
114
+
115
+
69
116
  class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
70
117
  r"""
71
118
  Load LoRA layers into Stable Diffusion [`UNet2DConditionModel`] and
@@ -77,10 +124,13 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
77
124
  text_encoder_name = TEXT_ENCODER_NAME
78
125
 
79
126
  def load_lora_weights(
80
- self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
127
+ self,
128
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
129
+ adapter_name=None,
130
+ hotswap: bool = False,
131
+ **kwargs,
81
132
  ):
82
- """
83
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
133
+ """Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
84
134
  `self.text_encoder`.
85
135
 
86
136
  All kwargs are forwarded to `self.lora_state_dict`.
@@ -103,6 +153,29 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
103
153
  low_cpu_mem_usage (`bool`, *optional*):
104
154
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
105
155
  weights.
156
+ hotswap : (`bool`, *optional*)
157
+ Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
158
+ in-place. This means that, instead of loading an additional adapter, this will take the existing
159
+ adapter weights and replace them with the weights of the new adapter. This can be faster and more
160
+ memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
161
+ torch.compile, loading the new adapter does not require recompilation of the model. When using
162
+ hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
163
+
164
+ If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
165
+ to call an additional method before loading the adapter:
166
+
167
+ ```py
168
+ pipeline = ... # load diffusers pipeline
169
+ max_rank = ... # the highest rank among all LoRAs that you want to load
170
+ # call *before* compiling and loading the LoRA adapter
171
+ pipeline.enable_lora_hotswap(target_rank=max_rank)
172
+ pipeline.load_lora_weights(file_name)
173
+ # optionally compile the model now
174
+ ```
175
+
176
+ Note that hotswapping adapters of the text encoder is not yet supported. There are some further
177
+ limitations to this technique, which are documented here:
178
+ https://huggingface.co/docs/peft/main/en/package_reference/hotswap
106
179
  kwargs (`dict`, *optional*):
107
180
  See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
108
181
  """
@@ -133,6 +206,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
133
206
  adapter_name=adapter_name,
134
207
  _pipeline=self,
135
208
  low_cpu_mem_usage=low_cpu_mem_usage,
209
+ hotswap=hotswap,
136
210
  )
137
211
  self.load_lora_into_text_encoder(
138
212
  state_dict,
@@ -144,6 +218,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
144
218
  adapter_name=adapter_name,
145
219
  _pipeline=self,
146
220
  low_cpu_mem_usage=low_cpu_mem_usage,
221
+ hotswap=hotswap,
147
222
  )
148
223
 
149
224
  @classmethod
@@ -263,7 +338,14 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
263
338
 
264
339
  @classmethod
265
340
  def load_lora_into_unet(
266
- cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
341
+ cls,
342
+ state_dict,
343
+ network_alphas,
344
+ unet,
345
+ adapter_name=None,
346
+ _pipeline=None,
347
+ low_cpu_mem_usage=False,
348
+ hotswap: bool = False,
267
349
  ):
268
350
  """
269
351
  This will load the LoRA layers specified in `state_dict` into `unet`.
@@ -285,6 +367,29 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
285
367
  low_cpu_mem_usage (`bool`, *optional*):
286
368
  Speed up model loading only loading the pretrained LoRA weights and not initializing the random
287
369
  weights.
370
+ hotswap : (`bool`, *optional*)
371
+ Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
372
+ in-place. This means that, instead of loading an additional adapter, this will take the existing
373
+ adapter weights and replace them with the weights of the new adapter. This can be faster and more
374
+ memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
375
+ torch.compile, loading the new adapter does not require recompilation of the model. When using
376
+ hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
377
+
378
+ If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
379
+ to call an additional method before loading the adapter:
380
+
381
+ ```py
382
+ pipeline = ... # load diffusers pipeline
383
+ max_rank = ... # the highest rank among all LoRAs that you want to load
384
+ # call *before* compiling and loading the LoRA adapter
385
+ pipeline.enable_lora_hotswap(target_rank=max_rank)
386
+ pipeline.load_lora_weights(file_name)
387
+ # optionally compile the model now
388
+ ```
389
+
390
+ Note that hotswapping adapters of the text encoder is not yet supported. There are some further
391
+ limitations to this technique, which are documented here:
392
+ https://huggingface.co/docs/peft/main/en/package_reference/hotswap
288
393
  """
289
394
  if not USE_PEFT_BACKEND:
290
395
  raise ValueError("PEFT backend is required for this method.")
@@ -297,19 +402,16 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
297
402
  # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
298
403
  # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
299
404
  # their prefixes.
300
- keys = list(state_dict.keys())
301
- only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys)
302
- if not only_text_encoder:
303
- # Load the layers corresponding to UNet.
304
- logger.info(f"Loading {cls.unet_name}.")
305
- unet.load_lora_adapter(
306
- state_dict,
307
- prefix=cls.unet_name,
308
- network_alphas=network_alphas,
309
- adapter_name=adapter_name,
310
- _pipeline=_pipeline,
311
- low_cpu_mem_usage=low_cpu_mem_usage,
312
- )
405
+ logger.info(f"Loading {cls.unet_name}.")
406
+ unet.load_lora_adapter(
407
+ state_dict,
408
+ prefix=cls.unet_name,
409
+ network_alphas=network_alphas,
410
+ adapter_name=adapter_name,
411
+ _pipeline=_pipeline,
412
+ low_cpu_mem_usage=low_cpu_mem_usage,
413
+ hotswap=hotswap,
414
+ )
313
415
 
314
416
  @classmethod
315
417
  def load_lora_into_text_encoder(
@@ -322,6 +424,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
322
424
  adapter_name=None,
323
425
  _pipeline=None,
324
426
  low_cpu_mem_usage=False,
427
+ hotswap: bool = False,
325
428
  ):
326
429
  """
327
430
  This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -347,120 +450,42 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
347
450
  low_cpu_mem_usage (`bool`, *optional*):
348
451
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
349
452
  weights.
453
+ hotswap : (`bool`, *optional*)
454
+ Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
455
+ in-place. This means that, instead of loading an additional adapter, this will take the existing
456
+ adapter weights and replace them with the weights of the new adapter. This can be faster and more
457
+ memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
458
+ torch.compile, loading the new adapter does not require recompilation of the model. When using
459
+ hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
460
+
461
+ If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
462
+ to call an additional method before loading the adapter:
463
+
464
+ ```py
465
+ pipeline = ... # load diffusers pipeline
466
+ max_rank = ... # the highest rank among all LoRAs that you want to load
467
+ # call *before* compiling and loading the LoRA adapter
468
+ pipeline.enable_lora_hotswap(target_rank=max_rank)
469
+ pipeline.load_lora_weights(file_name)
470
+ # optionally compile the model now
471
+ ```
472
+
473
+ Note that hotswapping adapters of the text encoder is not yet supported. There are some further
474
+ limitations to this technique, which are documented here:
475
+ https://huggingface.co/docs/peft/main/en/package_reference/hotswap
350
476
  """
351
- if not USE_PEFT_BACKEND:
352
- raise ValueError("PEFT backend is required for this method.")
353
-
354
- peft_kwargs = {}
355
- if low_cpu_mem_usage:
356
- if not is_peft_version(">=", "0.13.1"):
357
- raise ValueError(
358
- "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
359
- )
360
- if not is_transformers_version(">", "4.45.2"):
361
- # Note from sayakpaul: It's not in `transformers` stable yet.
362
- # https://github.com/huggingface/transformers/pull/33725/
363
- raise ValueError(
364
- "`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
365
- )
366
- peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
367
-
368
- from peft import LoraConfig
369
-
370
- # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
371
- # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
372
- # their prefixes.
373
- keys = list(state_dict.keys())
374
- prefix = cls.text_encoder_name if prefix is None else prefix
375
-
376
- # Safe prefix to check with.
377
- if any(cls.text_encoder_name in key for key in keys):
378
- # Load the layers corresponding to text encoder and make necessary adjustments.
379
- text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
380
- text_encoder_lora_state_dict = {
381
- k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
382
- }
383
-
384
- if len(text_encoder_lora_state_dict) > 0:
385
- logger.info(f"Loading {prefix}.")
386
- rank = {}
387
- text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
388
-
389
- # convert state dict
390
- text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
391
-
392
- for name, _ in text_encoder_attn_modules(text_encoder):
393
- for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
394
- rank_key = f"{name}.{module}.lora_B.weight"
395
- if rank_key not in text_encoder_lora_state_dict:
396
- continue
397
- rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
398
-
399
- for name, _ in text_encoder_mlp_modules(text_encoder):
400
- for module in ("fc1", "fc2"):
401
- rank_key = f"{name}.{module}.lora_B.weight"
402
- if rank_key not in text_encoder_lora_state_dict:
403
- continue
404
- rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
405
-
406
- if network_alphas is not None:
407
- alpha_keys = [
408
- k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
409
- ]
410
- network_alphas = {
411
- k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
412
- }
413
-
414
- lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
415
-
416
- if "use_dora" in lora_config_kwargs:
417
- if lora_config_kwargs["use_dora"]:
418
- if is_peft_version("<", "0.9.0"):
419
- raise ValueError(
420
- "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
421
- )
422
- else:
423
- if is_peft_version("<", "0.9.0"):
424
- lora_config_kwargs.pop("use_dora")
425
-
426
- if "lora_bias" in lora_config_kwargs:
427
- if lora_config_kwargs["lora_bias"]:
428
- if is_peft_version("<=", "0.13.2"):
429
- raise ValueError(
430
- "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
431
- )
432
- else:
433
- if is_peft_version("<=", "0.13.2"):
434
- lora_config_kwargs.pop("lora_bias")
435
-
436
- lora_config = LoraConfig(**lora_config_kwargs)
437
-
438
- # adapter_name
439
- if adapter_name is None:
440
- adapter_name = get_adapter_name(text_encoder)
441
-
442
- is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
443
-
444
- # inject LoRA layers and load the state dict
445
- # in transformers we automatically check whether the adapter name is already in use or not
446
- text_encoder.load_adapter(
447
- adapter_name=adapter_name,
448
- adapter_state_dict=text_encoder_lora_state_dict,
449
- peft_config=lora_config,
450
- **peft_kwargs,
451
- )
452
-
453
- # scale LoRA layers with `lora_scale`
454
- scale_lora_layers(text_encoder, weight=lora_scale)
455
-
456
- text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
457
-
458
- # Offload back.
459
- if is_model_cpu_offload:
460
- _pipeline.enable_model_cpu_offload()
461
- elif is_sequential_cpu_offload:
462
- _pipeline.enable_sequential_cpu_offload()
463
- # Unsafe code />
477
+ _load_lora_into_text_encoder(
478
+ state_dict=state_dict,
479
+ network_alphas=network_alphas,
480
+ lora_scale=lora_scale,
481
+ text_encoder=text_encoder,
482
+ prefix=prefix,
483
+ text_encoder_name=cls.text_encoder_name,
484
+ adapter_name=adapter_name,
485
+ _pipeline=_pipeline,
486
+ low_cpu_mem_usage=low_cpu_mem_usage,
487
+ hotswap=hotswap,
488
+ )
464
489
 
465
490
  @classmethod
466
491
  def save_lora_weights(
@@ -556,7 +581,11 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
556
581
  ```
557
582
  """
558
583
  super().fuse_lora(
559
- components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
584
+ components=components,
585
+ lora_scale=lora_scale,
586
+ safe_fusing=safe_fusing,
587
+ adapter_names=adapter_names,
588
+ **kwargs,
560
589
  )
561
590
 
562
591
  def unfuse_lora(self, components: List[str] = ["unet", "text_encoder"], **kwargs):
@@ -577,7 +606,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
577
606
  Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
578
607
  LoRA parameters then it won't have any effect.
579
608
  """
580
- super().unfuse_lora(components=components)
609
+ super().unfuse_lora(components=components, **kwargs)
581
610
 
582
611
 
583
612
  class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
@@ -660,31 +689,26 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
660
689
  _pipeline=self,
661
690
  low_cpu_mem_usage=low_cpu_mem_usage,
662
691
  )
663
- text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
664
- if len(text_encoder_state_dict) > 0:
665
- self.load_lora_into_text_encoder(
666
- text_encoder_state_dict,
667
- network_alphas=network_alphas,
668
- text_encoder=self.text_encoder,
669
- prefix="text_encoder",
670
- lora_scale=self.lora_scale,
671
- adapter_name=adapter_name,
672
- _pipeline=self,
673
- low_cpu_mem_usage=low_cpu_mem_usage,
674
- )
675
-
676
- text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
677
- if len(text_encoder_2_state_dict) > 0:
678
- self.load_lora_into_text_encoder(
679
- text_encoder_2_state_dict,
680
- network_alphas=network_alphas,
681
- text_encoder=self.text_encoder_2,
682
- prefix="text_encoder_2",
683
- lora_scale=self.lora_scale,
684
- adapter_name=adapter_name,
685
- _pipeline=self,
686
- low_cpu_mem_usage=low_cpu_mem_usage,
687
- )
692
+ self.load_lora_into_text_encoder(
693
+ state_dict,
694
+ network_alphas=network_alphas,
695
+ text_encoder=self.text_encoder,
696
+ prefix=self.text_encoder_name,
697
+ lora_scale=self.lora_scale,
698
+ adapter_name=adapter_name,
699
+ _pipeline=self,
700
+ low_cpu_mem_usage=low_cpu_mem_usage,
701
+ )
702
+ self.load_lora_into_text_encoder(
703
+ state_dict,
704
+ network_alphas=network_alphas,
705
+ text_encoder=self.text_encoder_2,
706
+ prefix=f"{self.text_encoder_name}_2",
707
+ lora_scale=self.lora_scale,
708
+ adapter_name=adapter_name,
709
+ _pipeline=self,
710
+ low_cpu_mem_usage=low_cpu_mem_usage,
711
+ )
688
712
 
689
713
  @classmethod
690
714
  @validate_hf_hub_args
@@ -805,7 +829,14 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
805
829
  @classmethod
806
830
  # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_unet
807
831
  def load_lora_into_unet(
808
- cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
832
+ cls,
833
+ state_dict,
834
+ network_alphas,
835
+ unet,
836
+ adapter_name=None,
837
+ _pipeline=None,
838
+ low_cpu_mem_usage=False,
839
+ hotswap: bool = False,
809
840
  ):
810
841
  """
811
842
  This will load the LoRA layers specified in `state_dict` into `unet`.
@@ -827,6 +858,29 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
827
858
  low_cpu_mem_usage (`bool`, *optional*):
828
859
  Speed up model loading only loading the pretrained LoRA weights and not initializing the random
829
860
  weights.
861
+ hotswap : (`bool`, *optional*)
862
+ Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
863
+ in-place. This means that, instead of loading an additional adapter, this will take the existing
864
+ adapter weights and replace them with the weights of the new adapter. This can be faster and more
865
+ memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
866
+ torch.compile, loading the new adapter does not require recompilation of the model. When using
867
+ hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
868
+
869
+ If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
870
+ to call an additional method before loading the adapter:
871
+
872
+ ```py
873
+ pipeline = ... # load diffusers pipeline
874
+ max_rank = ... # the highest rank among all LoRAs that you want to load
875
+ # call *before* compiling and loading the LoRA adapter
876
+ pipeline.enable_lora_hotswap(target_rank=max_rank)
877
+ pipeline.load_lora_weights(file_name)
878
+ # optionally compile the model now
879
+ ```
880
+
881
+ Note that hotswapping adapters of the text encoder is not yet supported. There are some further
882
+ limitations to this technique, which are documented here:
883
+ https://huggingface.co/docs/peft/main/en/package_reference/hotswap
830
884
  """
831
885
  if not USE_PEFT_BACKEND:
832
886
  raise ValueError("PEFT backend is required for this method.")
@@ -839,19 +893,16 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
839
893
  # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
840
894
  # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
841
895
  # their prefixes.
842
- keys = list(state_dict.keys())
843
- only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys)
844
- if not only_text_encoder:
845
- # Load the layers corresponding to UNet.
846
- logger.info(f"Loading {cls.unet_name}.")
847
- unet.load_lora_adapter(
848
- state_dict,
849
- prefix=cls.unet_name,
850
- network_alphas=network_alphas,
851
- adapter_name=adapter_name,
852
- _pipeline=_pipeline,
853
- low_cpu_mem_usage=low_cpu_mem_usage,
854
- )
896
+ logger.info(f"Loading {cls.unet_name}.")
897
+ unet.load_lora_adapter(
898
+ state_dict,
899
+ prefix=cls.unet_name,
900
+ network_alphas=network_alphas,
901
+ adapter_name=adapter_name,
902
+ _pipeline=_pipeline,
903
+ low_cpu_mem_usage=low_cpu_mem_usage,
904
+ hotswap=hotswap,
905
+ )
855
906
 
856
907
  @classmethod
857
908
  # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
@@ -865,6 +916,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
865
916
  adapter_name=None,
866
917
  _pipeline=None,
867
918
  low_cpu_mem_usage=False,
919
+ hotswap: bool = False,
868
920
  ):
869
921
  """
870
922
  This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -890,120 +942,42 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
890
942
  low_cpu_mem_usage (`bool`, *optional*):
891
943
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
892
944
  weights.
945
+ hotswap : (`bool`, *optional*)
946
+ Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
947
+ in-place. This means that, instead of loading an additional adapter, this will take the existing
948
+ adapter weights and replace them with the weights of the new adapter. This can be faster and more
949
+ memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
950
+ torch.compile, loading the new adapter does not require recompilation of the model. When using
951
+ hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
952
+
953
+ If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
954
+ to call an additional method before loading the adapter:
955
+
956
+ ```py
957
+ pipeline = ... # load diffusers pipeline
958
+ max_rank = ... # the highest rank among all LoRAs that you want to load
959
+ # call *before* compiling and loading the LoRA adapter
960
+ pipeline.enable_lora_hotswap(target_rank=max_rank)
961
+ pipeline.load_lora_weights(file_name)
962
+ # optionally compile the model now
963
+ ```
964
+
965
+ Note that hotswapping adapters of the text encoder is not yet supported. There are some further
966
+ limitations to this technique, which are documented here:
967
+ https://huggingface.co/docs/peft/main/en/package_reference/hotswap
893
968
  """
894
- if not USE_PEFT_BACKEND:
895
- raise ValueError("PEFT backend is required for this method.")
896
-
897
- peft_kwargs = {}
898
- if low_cpu_mem_usage:
899
- if not is_peft_version(">=", "0.13.1"):
900
- raise ValueError(
901
- "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
902
- )
903
- if not is_transformers_version(">", "4.45.2"):
904
- # Note from sayakpaul: It's not in `transformers` stable yet.
905
- # https://github.com/huggingface/transformers/pull/33725/
906
- raise ValueError(
907
- "`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
908
- )
909
- peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
910
-
911
- from peft import LoraConfig
912
-
913
- # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
914
- # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
915
- # their prefixes.
916
- keys = list(state_dict.keys())
917
- prefix = cls.text_encoder_name if prefix is None else prefix
918
-
919
- # Safe prefix to check with.
920
- if any(cls.text_encoder_name in key for key in keys):
921
- # Load the layers corresponding to text encoder and make necessary adjustments.
922
- text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
923
- text_encoder_lora_state_dict = {
924
- k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
925
- }
926
-
927
- if len(text_encoder_lora_state_dict) > 0:
928
- logger.info(f"Loading {prefix}.")
929
- rank = {}
930
- text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
931
-
932
- # convert state dict
933
- text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
934
-
935
- for name, _ in text_encoder_attn_modules(text_encoder):
936
- for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
937
- rank_key = f"{name}.{module}.lora_B.weight"
938
- if rank_key not in text_encoder_lora_state_dict:
939
- continue
940
- rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
941
-
942
- for name, _ in text_encoder_mlp_modules(text_encoder):
943
- for module in ("fc1", "fc2"):
944
- rank_key = f"{name}.{module}.lora_B.weight"
945
- if rank_key not in text_encoder_lora_state_dict:
946
- continue
947
- rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
948
-
949
- if network_alphas is not None:
950
- alpha_keys = [
951
- k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
952
- ]
953
- network_alphas = {
954
- k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
955
- }
956
-
957
- lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
958
-
959
- if "use_dora" in lora_config_kwargs:
960
- if lora_config_kwargs["use_dora"]:
961
- if is_peft_version("<", "0.9.0"):
962
- raise ValueError(
963
- "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
964
- )
965
- else:
966
- if is_peft_version("<", "0.9.0"):
967
- lora_config_kwargs.pop("use_dora")
968
-
969
- if "lora_bias" in lora_config_kwargs:
970
- if lora_config_kwargs["lora_bias"]:
971
- if is_peft_version("<=", "0.13.2"):
972
- raise ValueError(
973
- "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
974
- )
975
- else:
976
- if is_peft_version("<=", "0.13.2"):
977
- lora_config_kwargs.pop("lora_bias")
978
-
979
- lora_config = LoraConfig(**lora_config_kwargs)
980
-
981
- # adapter_name
982
- if adapter_name is None:
983
- adapter_name = get_adapter_name(text_encoder)
984
-
985
- is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
986
-
987
- # inject LoRA layers and load the state dict
988
- # in transformers we automatically check whether the adapter name is already in use or not
989
- text_encoder.load_adapter(
990
- adapter_name=adapter_name,
991
- adapter_state_dict=text_encoder_lora_state_dict,
992
- peft_config=lora_config,
993
- **peft_kwargs,
994
- )
995
-
996
- # scale LoRA layers with `lora_scale`
997
- scale_lora_layers(text_encoder, weight=lora_scale)
998
-
999
- text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
1000
-
1001
- # Offload back.
1002
- if is_model_cpu_offload:
1003
- _pipeline.enable_model_cpu_offload()
1004
- elif is_sequential_cpu_offload:
1005
- _pipeline.enable_sequential_cpu_offload()
1006
- # Unsafe code />
969
+ _load_lora_into_text_encoder(
970
+ state_dict=state_dict,
971
+ network_alphas=network_alphas,
972
+ lora_scale=lora_scale,
973
+ text_encoder=text_encoder,
974
+ prefix=prefix,
975
+ text_encoder_name=cls.text_encoder_name,
976
+ adapter_name=adapter_name,
977
+ _pipeline=_pipeline,
978
+ low_cpu_mem_usage=low_cpu_mem_usage,
979
+ hotswap=hotswap,
980
+ )
1007
981
 
1008
982
  @classmethod
1009
983
  def save_lora_weights(
@@ -1046,11 +1020,11 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
1046
1020
 
1047
1021
  if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
1048
1022
  raise ValueError(
1049
- "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`."
1023
+ "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`."
1050
1024
  )
1051
1025
 
1052
1026
  if unet_lora_layers:
1053
- state_dict.update(cls.pack_weights(unet_lora_layers, "unet"))
1027
+ state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name))
1054
1028
 
1055
1029
  if text_encoder_lora_layers:
1056
1030
  state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))
@@ -1107,7 +1081,11 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
1107
1081
  ```
1108
1082
  """
1109
1083
  super().fuse_lora(
1110
- components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
1084
+ components=components,
1085
+ lora_scale=lora_scale,
1086
+ safe_fusing=safe_fusing,
1087
+ adapter_names=adapter_names,
1088
+ **kwargs,
1111
1089
  )
1112
1090
 
1113
1091
  def unfuse_lora(self, components: List[str] = ["unet", "text_encoder", "text_encoder_2"], **kwargs):
@@ -1128,7 +1106,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
1128
1106
  Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
1129
1107
  LoRA parameters then it won't have any effect.
1130
1108
  """
1131
- super().unfuse_lora(components=components)
1109
+ super().unfuse_lora(components=components, **kwargs)
1132
1110
 
1133
1111
 
1134
1112
  class SD3LoraLoaderMixin(LoraBaseMixin):
@@ -1242,7 +1220,11 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1242
1220
  return state_dict
1243
1221
 
1244
1222
  def load_lora_weights(
1245
- self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
1223
+ self,
1224
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
1225
+ adapter_name=None,
1226
+ hotswap: bool = False,
1227
+ **kwargs,
1246
1228
  ):
1247
1229
  """
1248
1230
  Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
@@ -1265,6 +1247,29 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1265
1247
  low_cpu_mem_usage (`bool`, *optional*):
1266
1248
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
1267
1249
  weights.
1250
+ hotswap : (`bool`, *optional*)
1251
+ Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
1252
+ in-place. This means that, instead of loading an additional adapter, this will take the existing
1253
+ adapter weights and replace them with the weights of the new adapter. This can be faster and more
1254
+ memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
1255
+ torch.compile, loading the new adapter does not require recompilation of the model. When using
1256
+ hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
1257
+
1258
+ If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
1259
+ to call an additional method before loading the adapter:
1260
+
1261
+ ```py
1262
+ pipeline = ... # load diffusers pipeline
1263
+ max_rank = ... # the highest rank among all LoRAs that you want to load
1264
+ # call *before* compiling and loading the LoRA adapter
1265
+ pipeline.enable_lora_hotswap(target_rank=max_rank)
1266
+ pipeline.load_lora_weights(file_name)
1267
+ # optionally compile the model now
1268
+ ```
1269
+
1270
+ Note that hotswapping adapters of the text encoder is not yet supported. There are some further
1271
+ limitations to this technique, which are documented here:
1272
+ https://huggingface.co/docs/peft/main/en/package_reference/hotswap
1268
1273
  kwargs (`dict`, *optional*):
1269
1274
  See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
1270
1275
  """
@@ -1288,47 +1293,40 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1288
1293
  if not is_correct_format:
1289
1294
  raise ValueError("Invalid LoRA checkpoint.")
1290
1295
 
1291
- transformer_state_dict = {k: v for k, v in state_dict.items() if "transformer." in k}
1292
- if len(transformer_state_dict) > 0:
1293
- self.load_lora_into_transformer(
1294
- state_dict,
1295
- transformer=getattr(self, self.transformer_name)
1296
- if not hasattr(self, "transformer")
1297
- else self.transformer,
1298
- adapter_name=adapter_name,
1299
- _pipeline=self,
1300
- low_cpu_mem_usage=low_cpu_mem_usage,
1301
- )
1302
-
1303
- text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
1304
- if len(text_encoder_state_dict) > 0:
1305
- self.load_lora_into_text_encoder(
1306
- text_encoder_state_dict,
1307
- network_alphas=None,
1308
- text_encoder=self.text_encoder,
1309
- prefix="text_encoder",
1310
- lora_scale=self.lora_scale,
1311
- adapter_name=adapter_name,
1312
- _pipeline=self,
1313
- low_cpu_mem_usage=low_cpu_mem_usage,
1314
- )
1315
-
1316
- text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
1317
- if len(text_encoder_2_state_dict) > 0:
1318
- self.load_lora_into_text_encoder(
1319
- text_encoder_2_state_dict,
1320
- network_alphas=None,
1321
- text_encoder=self.text_encoder_2,
1322
- prefix="text_encoder_2",
1323
- lora_scale=self.lora_scale,
1324
- adapter_name=adapter_name,
1325
- _pipeline=self,
1326
- low_cpu_mem_usage=low_cpu_mem_usage,
1327
- )
1296
+ self.load_lora_into_transformer(
1297
+ state_dict,
1298
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
1299
+ adapter_name=adapter_name,
1300
+ _pipeline=self,
1301
+ low_cpu_mem_usage=low_cpu_mem_usage,
1302
+ hotswap=hotswap,
1303
+ )
1304
+ self.load_lora_into_text_encoder(
1305
+ state_dict,
1306
+ network_alphas=None,
1307
+ text_encoder=self.text_encoder,
1308
+ prefix=self.text_encoder_name,
1309
+ lora_scale=self.lora_scale,
1310
+ adapter_name=adapter_name,
1311
+ _pipeline=self,
1312
+ low_cpu_mem_usage=low_cpu_mem_usage,
1313
+ hotswap=hotswap,
1314
+ )
1315
+ self.load_lora_into_text_encoder(
1316
+ state_dict,
1317
+ network_alphas=None,
1318
+ text_encoder=self.text_encoder_2,
1319
+ prefix=f"{self.text_encoder_name}_2",
1320
+ lora_scale=self.lora_scale,
1321
+ adapter_name=adapter_name,
1322
+ _pipeline=self,
1323
+ low_cpu_mem_usage=low_cpu_mem_usage,
1324
+ hotswap=hotswap,
1325
+ )
1328
1326
 
1329
1327
  @classmethod
1330
1328
  def load_lora_into_transformer(
1331
- cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
1329
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
1332
1330
  ):
1333
1331
  """
1334
1332
  This will load the LoRA layers specified in `state_dict` into `transformer`.
@@ -1346,6 +1344,29 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1346
1344
  low_cpu_mem_usage (`bool`, *optional*):
1347
1345
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
1348
1346
  weights.
1347
+ hotswap : (`bool`, *optional*)
1348
+ Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
1349
+ in-place. This means that, instead of loading an additional adapter, this will take the existing
1350
+ adapter weights and replace them with the weights of the new adapter. This can be faster and more
1351
+ memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
1352
+ torch.compile, loading the new adapter does not require recompilation of the model. When using
1353
+ hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
1354
+
1355
+ If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
1356
+ to call an additional method before loading the adapter:
1357
+
1358
+ ```py
1359
+ pipeline = ... # load diffusers pipeline
1360
+ max_rank = ... # the highest rank among all LoRAs that you want to load
1361
+ # call *before* compiling and loading the LoRA adapter
1362
+ pipeline.enable_lora_hotswap(target_rank=max_rank)
1363
+ pipeline.load_lora_weights(file_name)
1364
+ # optionally compile the model now
1365
+ ```
1366
+
1367
+ Note that hotswapping adapters of the text encoder is not yet supported. There are some further
1368
+ limitations to this technique, which are documented here:
1369
+ https://huggingface.co/docs/peft/main/en/package_reference/hotswap
1349
1370
  """
1350
1371
  if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
1351
1372
  raise ValueError(
@@ -1360,6 +1381,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1360
1381
  adapter_name=adapter_name,
1361
1382
  _pipeline=_pipeline,
1362
1383
  low_cpu_mem_usage=low_cpu_mem_usage,
1384
+ hotswap=hotswap,
1363
1385
  )
1364
1386
 
1365
1387
  @classmethod
@@ -1374,6 +1396,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1374
1396
  adapter_name=None,
1375
1397
  _pipeline=None,
1376
1398
  low_cpu_mem_usage=False,
1399
+ hotswap: bool = False,
1377
1400
  ):
1378
1401
  """
1379
1402
  This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -1399,126 +1422,49 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1399
1422
  low_cpu_mem_usage (`bool`, *optional*):
1400
1423
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
1401
1424
  weights.
1425
+ hotswap : (`bool`, *optional*)
1426
+ Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
1427
+ in-place. This means that, instead of loading an additional adapter, this will take the existing
1428
+ adapter weights and replace them with the weights of the new adapter. This can be faster and more
1429
+ memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
1430
+ torch.compile, loading the new adapter does not require recompilation of the model. When using
1431
+ hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
1432
+
1433
+ If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
1434
+ to call an additional method before loading the adapter:
1435
+
1436
+ ```py
1437
+ pipeline = ... # load diffusers pipeline
1438
+ max_rank = ... # the highest rank among all LoRAs that you want to load
1439
+ # call *before* compiling and loading the LoRA adapter
1440
+ pipeline.enable_lora_hotswap(target_rank=max_rank)
1441
+ pipeline.load_lora_weights(file_name)
1442
+ # optionally compile the model now
1443
+ ```
1444
+
1445
+ Note that hotswapping adapters of the text encoder is not yet supported. There are some further
1446
+ limitations to this technique, which are documented here:
1447
+ https://huggingface.co/docs/peft/main/en/package_reference/hotswap
1402
1448
  """
1403
- if not USE_PEFT_BACKEND:
1404
- raise ValueError("PEFT backend is required for this method.")
1405
-
1406
- peft_kwargs = {}
1407
- if low_cpu_mem_usage:
1408
- if not is_peft_version(">=", "0.13.1"):
1409
- raise ValueError(
1410
- "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
1411
- )
1412
- if not is_transformers_version(">", "4.45.2"):
1413
- # Note from sayakpaul: It's not in `transformers` stable yet.
1414
- # https://github.com/huggingface/transformers/pull/33725/
1415
- raise ValueError(
1416
- "`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
1417
- )
1418
- peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
1419
-
1420
- from peft import LoraConfig
1421
-
1422
- # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
1423
- # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
1424
- # their prefixes.
1425
- keys = list(state_dict.keys())
1426
- prefix = cls.text_encoder_name if prefix is None else prefix
1427
-
1428
- # Safe prefix to check with.
1429
- if any(cls.text_encoder_name in key for key in keys):
1430
- # Load the layers corresponding to text encoder and make necessary adjustments.
1431
- text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
1432
- text_encoder_lora_state_dict = {
1433
- k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
1434
- }
1435
-
1436
- if len(text_encoder_lora_state_dict) > 0:
1437
- logger.info(f"Loading {prefix}.")
1438
- rank = {}
1439
- text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
1440
-
1441
- # convert state dict
1442
- text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
1443
-
1444
- for name, _ in text_encoder_attn_modules(text_encoder):
1445
- for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
1446
- rank_key = f"{name}.{module}.lora_B.weight"
1447
- if rank_key not in text_encoder_lora_state_dict:
1448
- continue
1449
- rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
1450
-
1451
- for name, _ in text_encoder_mlp_modules(text_encoder):
1452
- for module in ("fc1", "fc2"):
1453
- rank_key = f"{name}.{module}.lora_B.weight"
1454
- if rank_key not in text_encoder_lora_state_dict:
1455
- continue
1456
- rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
1457
-
1458
- if network_alphas is not None:
1459
- alpha_keys = [
1460
- k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
1461
- ]
1462
- network_alphas = {
1463
- k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
1464
- }
1465
-
1466
- lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
1467
-
1468
- if "use_dora" in lora_config_kwargs:
1469
- if lora_config_kwargs["use_dora"]:
1470
- if is_peft_version("<", "0.9.0"):
1471
- raise ValueError(
1472
- "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
1473
- )
1474
- else:
1475
- if is_peft_version("<", "0.9.0"):
1476
- lora_config_kwargs.pop("use_dora")
1477
-
1478
- if "lora_bias" in lora_config_kwargs:
1479
- if lora_config_kwargs["lora_bias"]:
1480
- if is_peft_version("<=", "0.13.2"):
1481
- raise ValueError(
1482
- "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
1483
- )
1484
- else:
1485
- if is_peft_version("<=", "0.13.2"):
1486
- lora_config_kwargs.pop("lora_bias")
1487
-
1488
- lora_config = LoraConfig(**lora_config_kwargs)
1489
-
1490
- # adapter_name
1491
- if adapter_name is None:
1492
- adapter_name = get_adapter_name(text_encoder)
1493
-
1494
- is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
1495
-
1496
- # inject LoRA layers and load the state dict
1497
- # in transformers we automatically check whether the adapter name is already in use or not
1498
- text_encoder.load_adapter(
1499
- adapter_name=adapter_name,
1500
- adapter_state_dict=text_encoder_lora_state_dict,
1501
- peft_config=lora_config,
1502
- **peft_kwargs,
1503
- )
1504
-
1505
- # scale LoRA layers with `lora_scale`
1506
- scale_lora_layers(text_encoder, weight=lora_scale)
1507
-
1508
- text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
1509
-
1510
- # Offload back.
1511
- if is_model_cpu_offload:
1512
- _pipeline.enable_model_cpu_offload()
1513
- elif is_sequential_cpu_offload:
1514
- _pipeline.enable_sequential_cpu_offload()
1515
- # Unsafe code />
1449
+ _load_lora_into_text_encoder(
1450
+ state_dict=state_dict,
1451
+ network_alphas=network_alphas,
1452
+ lora_scale=lora_scale,
1453
+ text_encoder=text_encoder,
1454
+ prefix=prefix,
1455
+ text_encoder_name=cls.text_encoder_name,
1456
+ adapter_name=adapter_name,
1457
+ _pipeline=_pipeline,
1458
+ low_cpu_mem_usage=low_cpu_mem_usage,
1459
+ hotswap=hotswap,
1460
+ )
1516
1461
 
1517
1462
  @classmethod
1463
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.save_lora_weights with unet->transformer
1518
1464
  def save_lora_weights(
1519
1465
  cls,
1520
1466
  save_directory: Union[str, os.PathLike],
1521
- transformer_lora_layers: Dict[str, torch.nn.Module] = None,
1467
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
1522
1468
  text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
1523
1469
  text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
1524
1470
  is_main_process: bool = True,
@@ -1567,7 +1513,6 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1567
1513
  if text_encoder_2_lora_layers:
1568
1514
  state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
1569
1515
 
1570
- # Save the model
1571
1516
  cls.write_lora_layers(
1572
1517
  state_dict=state_dict,
1573
1518
  save_directory=save_directory,
@@ -1577,6 +1522,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1577
1522
  safe_serialization=safe_serialization,
1578
1523
  )
1579
1524
 
1525
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.fuse_lora with unet->transformer
1580
1526
  def fuse_lora(
1581
1527
  self,
1582
1528
  components: List[str] = ["transformer", "text_encoder", "text_encoder_2"],
@@ -1617,9 +1563,14 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1617
1563
  ```
1618
1564
  """
1619
1565
  super().fuse_lora(
1620
- components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
1566
+ components=components,
1567
+ lora_scale=lora_scale,
1568
+ safe_fusing=safe_fusing,
1569
+ adapter_names=adapter_names,
1570
+ **kwargs,
1621
1571
  )
1622
1572
 
1573
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.unfuse_lora with unet->transformer
1623
1574
  def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], **kwargs):
1624
1575
  r"""
1625
1576
  Reverses the effect of
@@ -1633,12 +1584,12 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1633
1584
 
1634
1585
  Args:
1635
1586
  components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
1636
- unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
1587
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
1637
1588
  unfuse_text_encoder (`bool`, defaults to `True`):
1638
1589
  Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
1639
1590
  LoRA parameters then it won't have any effect.
1640
1591
  """
1641
- super().unfuse_lora(components=components)
1592
+ super().unfuse_lora(components=components, **kwargs)
1642
1593
 
1643
1594
 
1644
1595
  class FluxLoraLoaderMixin(LoraBaseMixin):
@@ -1789,7 +1740,11 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1789
1740
  return state_dict
1790
1741
 
1791
1742
  def load_lora_weights(
1792
- self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
1743
+ self,
1744
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
1745
+ adapter_name=None,
1746
+ hotswap: bool = False,
1747
+ **kwargs,
1793
1748
  ):
1794
1749
  """
1795
1750
  Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
@@ -1814,6 +1769,26 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1814
1769
  low_cpu_mem_usage (`bool`, *optional*):
1815
1770
  `Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
1816
1771
  weights.
1772
+ hotswap : (`bool`, *optional*)
1773
+ Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
1774
+ in-place. This means that, instead of loading an additional adapter, this will take the existing
1775
+ adapter weights and replace them with the weights of the new adapter. This can be faster and more
1776
+ memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
1777
+ torch.compile, loading the new adapter does not require recompilation of the model. When using
1778
+ hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. If the new
1779
+ adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need to call an
1780
+ additional method before loading the adapter:
1781
+ ```py
1782
+ pipeline = ... # load diffusers pipeline
1783
+ max_rank = ... # the highest rank among all LoRAs that you want to load
1784
+ # call *before* compiling and loading the LoRA adapter
1785
+ pipeline.enable_lora_hotswap(target_rank=max_rank)
1786
+ pipeline.load_lora_weights(file_name)
1787
+ # optionally compile the model now
1788
+ ```
1789
+ Note that hotswapping adapters of the text encoder is not yet supported. There are some further
1790
+ limitations to this technique, which are documented here:
1791
+ https://huggingface.co/docs/peft/main/en/package_reference/hotswap
1817
1792
  """
1818
1793
  if not USE_PEFT_BACKEND:
1819
1794
  raise ValueError("PEFT backend is required for this method.")
@@ -1844,18 +1819,23 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1844
1819
  raise ValueError("Invalid LoRA checkpoint.")
1845
1820
 
1846
1821
  transformer_lora_state_dict = {
1847
- k: state_dict.pop(k) for k in list(state_dict.keys()) if "transformer." in k and "lora" in k
1822
+ k: state_dict.get(k)
1823
+ for k in list(state_dict.keys())
1824
+ if k.startswith(f"{self.transformer_name}.") and "lora" in k
1848
1825
  }
1849
1826
  transformer_norm_state_dict = {
1850
1827
  k: state_dict.pop(k)
1851
1828
  for k in list(state_dict.keys())
1852
- if "transformer." in k and any(norm_key in k for norm_key in self._control_lora_supported_norm_keys)
1829
+ if k.startswith(f"{self.transformer_name}.")
1830
+ and any(norm_key in k for norm_key in self._control_lora_supported_norm_keys)
1853
1831
  }
1854
1832
 
1855
1833
  transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
1856
- has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_(
1857
- transformer, transformer_lora_state_dict, transformer_norm_state_dict
1858
- )
1834
+ has_param_with_expanded_shape = False
1835
+ if len(transformer_lora_state_dict) > 0:
1836
+ has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_(
1837
+ transformer, transformer_lora_state_dict, transformer_norm_state_dict
1838
+ )
1859
1839
 
1860
1840
  if has_param_with_expanded_shape:
1861
1841
  logger.info(
@@ -1863,19 +1843,22 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1863
1843
  "As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. "
1864
1844
  "To get a comprehensive list of parameter names that were modified, enable debug logging."
1865
1845
  )
1866
- transformer_lora_state_dict = self._maybe_expand_lora_state_dict(
1867
- transformer=transformer, lora_state_dict=transformer_lora_state_dict
1868
- )
1869
-
1870
1846
  if len(transformer_lora_state_dict) > 0:
1871
- self.load_lora_into_transformer(
1872
- transformer_lora_state_dict,
1873
- network_alphas=network_alphas,
1874
- transformer=transformer,
1875
- adapter_name=adapter_name,
1876
- _pipeline=self,
1877
- low_cpu_mem_usage=low_cpu_mem_usage,
1847
+ transformer_lora_state_dict = self._maybe_expand_lora_state_dict(
1848
+ transformer=transformer, lora_state_dict=transformer_lora_state_dict
1878
1849
  )
1850
+ for k in transformer_lora_state_dict:
1851
+ state_dict.update({k: transformer_lora_state_dict[k]})
1852
+
1853
+ self.load_lora_into_transformer(
1854
+ state_dict,
1855
+ network_alphas=network_alphas,
1856
+ transformer=transformer,
1857
+ adapter_name=adapter_name,
1858
+ _pipeline=self,
1859
+ low_cpu_mem_usage=low_cpu_mem_usage,
1860
+ hotswap=hotswap,
1861
+ )
1879
1862
 
1880
1863
  if len(transformer_norm_state_dict) > 0:
1881
1864
  transformer._transformer_norm_layers = self._load_norm_into_transformer(
@@ -1884,22 +1867,28 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1884
1867
  discard_original_layers=False,
1885
1868
  )
1886
1869
 
1887
- text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
1888
- if len(text_encoder_state_dict) > 0:
1889
- self.load_lora_into_text_encoder(
1890
- text_encoder_state_dict,
1891
- network_alphas=network_alphas,
1892
- text_encoder=self.text_encoder,
1893
- prefix="text_encoder",
1894
- lora_scale=self.lora_scale,
1895
- adapter_name=adapter_name,
1896
- _pipeline=self,
1897
- low_cpu_mem_usage=low_cpu_mem_usage,
1898
- )
1870
+ self.load_lora_into_text_encoder(
1871
+ state_dict,
1872
+ network_alphas=network_alphas,
1873
+ text_encoder=self.text_encoder,
1874
+ prefix=self.text_encoder_name,
1875
+ lora_scale=self.lora_scale,
1876
+ adapter_name=adapter_name,
1877
+ _pipeline=self,
1878
+ low_cpu_mem_usage=low_cpu_mem_usage,
1879
+ hotswap=hotswap,
1880
+ )
1899
1881
 
1900
1882
  @classmethod
1901
1883
  def load_lora_into_transformer(
1902
- cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
1884
+ cls,
1885
+ state_dict,
1886
+ network_alphas,
1887
+ transformer,
1888
+ adapter_name=None,
1889
+ _pipeline=None,
1890
+ low_cpu_mem_usage=False,
1891
+ hotswap: bool = False,
1903
1892
  ):
1904
1893
  """
1905
1894
  This will load the LoRA layers specified in `state_dict` into `transformer`.
@@ -1921,6 +1910,29 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1921
1910
  low_cpu_mem_usage (`bool`, *optional*):
1922
1911
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
1923
1912
  weights.
1913
+ hotswap : (`bool`, *optional*)
1914
+ Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
1915
+ in-place. This means that, instead of loading an additional adapter, this will take the existing
1916
+ adapter weights and replace them with the weights of the new adapter. This can be faster and more
1917
+ memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
1918
+ torch.compile, loading the new adapter does not require recompilation of the model. When using
1919
+ hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
1920
+
1921
+ If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
1922
+ to call an additional method before loading the adapter:
1923
+
1924
+ ```py
1925
+ pipeline = ... # load diffusers pipeline
1926
+ max_rank = ... # the highest rank among all LoRAs that you want to load
1927
+ # call *before* compiling and loading the LoRA adapter
1928
+ pipeline.enable_lora_hotswap(target_rank=max_rank)
1929
+ pipeline.load_lora_weights(file_name)
1930
+ # optionally compile the model now
1931
+ ```
1932
+
1933
+ Note that hotswapping adapters of the text encoder is not yet supported. There are some further
1934
+ limitations to this technique, which are documented here:
1935
+ https://huggingface.co/docs/peft/main/en/package_reference/hotswap
1924
1936
  """
1925
1937
  if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
1926
1938
  raise ValueError(
@@ -1928,17 +1940,15 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1928
1940
  )
1929
1941
 
1930
1942
  # Load the layers corresponding to transformer.
1931
- keys = list(state_dict.keys())
1932
- transformer_present = any(key.startswith(cls.transformer_name) for key in keys)
1933
- if transformer_present:
1934
- logger.info(f"Loading {cls.transformer_name}.")
1935
- transformer.load_lora_adapter(
1936
- state_dict,
1937
- network_alphas=network_alphas,
1938
- adapter_name=adapter_name,
1939
- _pipeline=_pipeline,
1940
- low_cpu_mem_usage=low_cpu_mem_usage,
1941
- )
1943
+ logger.info(f"Loading {cls.transformer_name}.")
1944
+ transformer.load_lora_adapter(
1945
+ state_dict,
1946
+ network_alphas=network_alphas,
1947
+ adapter_name=adapter_name,
1948
+ _pipeline=_pipeline,
1949
+ low_cpu_mem_usage=low_cpu_mem_usage,
1950
+ hotswap=hotswap,
1951
+ )
1942
1952
 
1943
1953
  @classmethod
1944
1954
  def _load_norm_into_transformer(
@@ -2006,6 +2016,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
2006
2016
  adapter_name=None,
2007
2017
  _pipeline=None,
2008
2018
  low_cpu_mem_usage=False,
2019
+ hotswap: bool = False,
2009
2020
  ):
2010
2021
  """
2011
2022
  This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -2031,120 +2042,42 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
2031
2042
  low_cpu_mem_usage (`bool`, *optional*):
2032
2043
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
2033
2044
  weights.
2045
+ hotswap : (`bool`, *optional*)
2046
+ Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
2047
+ in-place. This means that, instead of loading an additional adapter, this will take the existing
2048
+ adapter weights and replace them with the weights of the new adapter. This can be faster and more
2049
+ memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
2050
+ torch.compile, loading the new adapter does not require recompilation of the model. When using
2051
+ hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
2052
+
2053
+ If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
2054
+ to call an additional method before loading the adapter:
2055
+
2056
+ ```py
2057
+ pipeline = ... # load diffusers pipeline
2058
+ max_rank = ... # the highest rank among all LoRAs that you want to load
2059
+ # call *before* compiling and loading the LoRA adapter
2060
+ pipeline.enable_lora_hotswap(target_rank=max_rank)
2061
+ pipeline.load_lora_weights(file_name)
2062
+ # optionally compile the model now
2063
+ ```
2064
+
2065
+ Note that hotswapping adapters of the text encoder is not yet supported. There are some further
2066
+ limitations to this technique, which are documented here:
2067
+ https://huggingface.co/docs/peft/main/en/package_reference/hotswap
2034
2068
  """
2035
- if not USE_PEFT_BACKEND:
2036
- raise ValueError("PEFT backend is required for this method.")
2037
-
2038
- peft_kwargs = {}
2039
- if low_cpu_mem_usage:
2040
- if not is_peft_version(">=", "0.13.1"):
2041
- raise ValueError(
2042
- "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
2043
- )
2044
- if not is_transformers_version(">", "4.45.2"):
2045
- # Note from sayakpaul: It's not in `transformers` stable yet.
2046
- # https://github.com/huggingface/transformers/pull/33725/
2047
- raise ValueError(
2048
- "`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
2049
- )
2050
- peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
2051
-
2052
- from peft import LoraConfig
2053
-
2054
- # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
2055
- # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
2056
- # their prefixes.
2057
- keys = list(state_dict.keys())
2058
- prefix = cls.text_encoder_name if prefix is None else prefix
2059
-
2060
- # Safe prefix to check with.
2061
- if any(cls.text_encoder_name in key for key in keys):
2062
- # Load the layers corresponding to text encoder and make necessary adjustments.
2063
- text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
2064
- text_encoder_lora_state_dict = {
2065
- k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
2066
- }
2067
-
2068
- if len(text_encoder_lora_state_dict) > 0:
2069
- logger.info(f"Loading {prefix}.")
2070
- rank = {}
2071
- text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
2072
-
2073
- # convert state dict
2074
- text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
2075
-
2076
- for name, _ in text_encoder_attn_modules(text_encoder):
2077
- for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
2078
- rank_key = f"{name}.{module}.lora_B.weight"
2079
- if rank_key not in text_encoder_lora_state_dict:
2080
- continue
2081
- rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
2082
-
2083
- for name, _ in text_encoder_mlp_modules(text_encoder):
2084
- for module in ("fc1", "fc2"):
2085
- rank_key = f"{name}.{module}.lora_B.weight"
2086
- if rank_key not in text_encoder_lora_state_dict:
2087
- continue
2088
- rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
2089
-
2090
- if network_alphas is not None:
2091
- alpha_keys = [
2092
- k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
2093
- ]
2094
- network_alphas = {
2095
- k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
2096
- }
2097
-
2098
- lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
2099
-
2100
- if "use_dora" in lora_config_kwargs:
2101
- if lora_config_kwargs["use_dora"]:
2102
- if is_peft_version("<", "0.9.0"):
2103
- raise ValueError(
2104
- "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
2105
- )
2106
- else:
2107
- if is_peft_version("<", "0.9.0"):
2108
- lora_config_kwargs.pop("use_dora")
2109
-
2110
- if "lora_bias" in lora_config_kwargs:
2111
- if lora_config_kwargs["lora_bias"]:
2112
- if is_peft_version("<=", "0.13.2"):
2113
- raise ValueError(
2114
- "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
2115
- )
2116
- else:
2117
- if is_peft_version("<=", "0.13.2"):
2118
- lora_config_kwargs.pop("lora_bias")
2119
-
2120
- lora_config = LoraConfig(**lora_config_kwargs)
2121
-
2122
- # adapter_name
2123
- if adapter_name is None:
2124
- adapter_name = get_adapter_name(text_encoder)
2125
-
2126
- is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
2127
-
2128
- # inject LoRA layers and load the state dict
2129
- # in transformers we automatically check whether the adapter name is already in use or not
2130
- text_encoder.load_adapter(
2131
- adapter_name=adapter_name,
2132
- adapter_state_dict=text_encoder_lora_state_dict,
2133
- peft_config=lora_config,
2134
- **peft_kwargs,
2135
- )
2136
-
2137
- # scale LoRA layers with `lora_scale`
2138
- scale_lora_layers(text_encoder, weight=lora_scale)
2139
-
2140
- text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
2141
-
2142
- # Offload back.
2143
- if is_model_cpu_offload:
2144
- _pipeline.enable_model_cpu_offload()
2145
- elif is_sequential_cpu_offload:
2146
- _pipeline.enable_sequential_cpu_offload()
2147
- # Unsafe code />
2069
+ _load_lora_into_text_encoder(
2070
+ state_dict=state_dict,
2071
+ network_alphas=network_alphas,
2072
+ lora_scale=lora_scale,
2073
+ text_encoder=text_encoder,
2074
+ prefix=prefix,
2075
+ text_encoder_name=cls.text_encoder_name,
2076
+ adapter_name=adapter_name,
2077
+ _pipeline=_pipeline,
2078
+ low_cpu_mem_usage=low_cpu_mem_usage,
2079
+ hotswap=hotswap,
2080
+ )
2148
2081
 
2149
2082
  @classmethod
2150
2083
  # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights with unet->transformer
@@ -2203,7 +2136,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
2203
2136
 
2204
2137
  def fuse_lora(
2205
2138
  self,
2206
- components: List[str] = ["transformer", "text_encoder"],
2139
+ components: List[str] = ["transformer"],
2207
2140
  lora_scale: float = 1.0,
2208
2141
  safe_fusing: bool = False,
2209
2142
  adapter_names: Optional[List[str]] = None,
@@ -2254,7 +2187,11 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
2254
2187
  )
2255
2188
 
2256
2189
  super().fuse_lora(
2257
- components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
2190
+ components=components,
2191
+ lora_scale=lora_scale,
2192
+ safe_fusing=safe_fusing,
2193
+ adapter_names=adapter_names,
2194
+ **kwargs,
2258
2195
  )
2259
2196
 
2260
2197
  def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
@@ -2275,10 +2212,26 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
2275
2212
  if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers:
2276
2213
  transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
2277
2214
 
2278
- super().unfuse_lora(components=components)
2215
+ super().unfuse_lora(components=components, **kwargs)
2216
+
2217
+ # We override this here account for `_transformer_norm_layers` and `_overwritten_params`.
2218
+ def unload_lora_weights(self, reset_to_overwritten_params=False):
2219
+ """
2220
+ Unloads the LoRA parameters.
2221
+
2222
+ Args:
2223
+ reset_to_overwritten_params (`bool`, defaults to `False`): Whether to reset the LoRA-loaded modules
2224
+ to their original params. Refer to the [Flux
2225
+ documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux) to learn more.
2226
+
2227
+ Examples:
2279
2228
 
2280
- # We override this here account for `_transformer_norm_layers`.
2281
- def unload_lora_weights(self):
2229
+ ```python
2230
+ >>> # Assuming `pipeline` is already loaded with the LoRA parameters.
2231
+ >>> pipeline.unload_lora_weights()
2232
+ >>> ...
2233
+ ```
2234
+ """
2282
2235
  super().unload_lora_weights()
2283
2236
 
2284
2237
  transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
@@ -2286,11 +2239,55 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
2286
2239
  transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
2287
2240
  transformer._transformer_norm_layers = None
2288
2241
 
2289
- @classmethod
2290
- def _maybe_expand_transformer_param_shape_or_error_(
2291
- cls,
2292
- transformer: torch.nn.Module,
2293
- lora_state_dict=None,
2242
+ if reset_to_overwritten_params and getattr(transformer, "_overwritten_params", None) is not None:
2243
+ overwritten_params = transformer._overwritten_params
2244
+ module_names = set()
2245
+
2246
+ for param_name in overwritten_params:
2247
+ if param_name.endswith(".weight"):
2248
+ module_names.add(param_name.replace(".weight", ""))
2249
+
2250
+ for name, module in transformer.named_modules():
2251
+ if isinstance(module, torch.nn.Linear) and name in module_names:
2252
+ module_weight = module.weight.data
2253
+ module_bias = module.bias.data if module.bias is not None else None
2254
+ bias = module_bias is not None
2255
+
2256
+ parent_module_name, _, current_module_name = name.rpartition(".")
2257
+ parent_module = transformer.get_submodule(parent_module_name)
2258
+
2259
+ current_param_weight = overwritten_params[f"{name}.weight"]
2260
+ in_features, out_features = current_param_weight.shape[1], current_param_weight.shape[0]
2261
+ with torch.device("meta"):
2262
+ original_module = torch.nn.Linear(
2263
+ in_features,
2264
+ out_features,
2265
+ bias=bias,
2266
+ dtype=module_weight.dtype,
2267
+ )
2268
+
2269
+ tmp_state_dict = {"weight": current_param_weight}
2270
+ if module_bias is not None:
2271
+ tmp_state_dict.update({"bias": overwritten_params[f"{name}.bias"]})
2272
+ original_module.load_state_dict(tmp_state_dict, assign=True, strict=True)
2273
+ setattr(parent_module, current_module_name, original_module)
2274
+
2275
+ del tmp_state_dict
2276
+
2277
+ if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX:
2278
+ attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name]
2279
+ new_value = int(current_param_weight.shape[1])
2280
+ old_value = getattr(transformer.config, attribute_name)
2281
+ setattr(transformer.config, attribute_name, new_value)
2282
+ logger.info(
2283
+ f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}."
2284
+ )
2285
+
2286
+ @classmethod
2287
+ def _maybe_expand_transformer_param_shape_or_error_(
2288
+ cls,
2289
+ transformer: torch.nn.Module,
2290
+ lora_state_dict=None,
2294
2291
  norm_state_dict=None,
2295
2292
  prefix=None,
2296
2293
  ) -> bool:
@@ -2312,7 +2309,10 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
2312
2309
 
2313
2310
  # Expand transformer parameter shapes if they don't match lora
2314
2311
  has_param_with_shape_update = False
2312
+ overwritten_params = {}
2313
+
2315
2314
  is_peft_loaded = getattr(transformer, "peft_config", None) is not None
2315
+ is_quantized = hasattr(transformer, "hf_quantizer")
2316
2316
  for name, module in transformer.named_modules():
2317
2317
  if isinstance(module, torch.nn.Linear):
2318
2318
  module_weight = module.weight.data
@@ -2328,11 +2328,16 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
2328
2328
  in_features = state_dict[lora_A_weight_name].shape[1]
2329
2329
  out_features = state_dict[lora_B_weight_name].shape[0]
2330
2330
 
2331
+ # Model maybe loaded with different quantization schemes which may flatten the params.
2332
+ # `bitsandbytes`, for example, flatten the weights when using 4bit. 8bit bnb models
2333
+ # preserve weight shape.
2334
+ module_weight_shape = cls._calculate_module_shape(model=transformer, base_module=module)
2335
+
2331
2336
  # This means there's no need for an expansion in the params, so we simply skip.
2332
- if tuple(module_weight.shape) == (out_features, in_features):
2337
+ if tuple(module_weight_shape) == (out_features, in_features):
2333
2338
  continue
2334
2339
 
2335
- module_out_features, module_in_features = module_weight.shape
2340
+ module_out_features, module_in_features = module_weight_shape
2336
2341
  debug_message = ""
2337
2342
  if in_features > module_in_features:
2338
2343
  debug_message += (
@@ -2355,6 +2360,10 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
2355
2360
  parent_module_name, _, current_module_name = name.rpartition(".")
2356
2361
  parent_module = transformer.get_submodule(parent_module_name)
2357
2362
 
2363
+ if is_quantized:
2364
+ module_weight = _maybe_dequantize_weight_for_expanded_lora(transformer, module)
2365
+
2366
+ # TODO: consider if this layer needs to be a quantized layer as well if `is_quantized` is True.
2358
2367
  with torch.device("meta"):
2359
2368
  expanded_module = torch.nn.Linear(
2360
2369
  in_features, out_features, bias=bias, dtype=module_weight.dtype
@@ -2366,7 +2375,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
2366
2375
  new_weight = torch.zeros_like(
2367
2376
  expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype
2368
2377
  )
2369
- slices = tuple(slice(0, dim) for dim in module_weight.shape)
2378
+ slices = tuple(slice(0, dim) for dim in module_weight_shape)
2370
2379
  new_weight[slices] = module_weight
2371
2380
  tmp_state_dict = {"weight": new_weight}
2372
2381
  if module_bias is not None:
@@ -2386,6 +2395,16 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
2386
2395
  f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}."
2387
2396
  )
2388
2397
 
2398
+ # For `unload_lora_weights()`.
2399
+ # TODO: this could lead to more memory overhead if the number of overwritten params
2400
+ # are large. Should be revisited later and tackled through a `discard_original_layers` arg.
2401
+ overwritten_params[f"{current_module_name}.weight"] = module_weight
2402
+ if module_bias is not None:
2403
+ overwritten_params[f"{current_module_name}.bias"] = module_bias
2404
+
2405
+ if len(overwritten_params) > 0:
2406
+ transformer._overwritten_params = overwritten_params
2407
+
2389
2408
  return has_param_with_shape_update
2390
2409
 
2391
2410
  @classmethod
@@ -2410,18 +2429,23 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
2410
2429
  continue
2411
2430
 
2412
2431
  base_param_name = (
2413
- f"{k.replace(prefix, '')}.base_layer.weight" if is_peft_loaded else f"{k.replace(prefix, '')}.weight"
2432
+ f"{k.replace(prefix, '')}.base_layer.weight"
2433
+ if is_peft_loaded and f"{k.replace(prefix, '')}.base_layer.weight" in transformer_state_dict
2434
+ else f"{k.replace(prefix, '')}.weight"
2414
2435
  )
2415
2436
  base_weight_param = transformer_state_dict[base_param_name]
2416
2437
  lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"]
2417
2438
 
2418
- if base_weight_param.shape[1] > lora_A_param.shape[1]:
2439
+ # TODO (sayakpaul): Handle the cases when we actually need to expand when using quantization.
2440
+ base_module_shape = cls._calculate_module_shape(model=transformer, base_weight_param_name=base_param_name)
2441
+
2442
+ if base_module_shape[1] > lora_A_param.shape[1]:
2419
2443
  shape = (lora_A_param.shape[0], base_weight_param.shape[1])
2420
2444
  expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device)
2421
2445
  expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param)
2422
2446
  lora_state_dict[f"{prefix}{k}.lora_A.weight"] = expanded_state_dict_weight
2423
2447
  expanded_module_names.add(k)
2424
- elif base_weight_param.shape[1] < lora_A_param.shape[1]:
2448
+ elif base_module_shape[1] < lora_A_param.shape[1]:
2425
2449
  raise NotImplementedError(
2426
2450
  f"This LoRA param ({k}.lora_A.weight) has an incompatible shape {lora_A_param.shape}. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new."
2427
2451
  )
@@ -2433,6 +2457,33 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
2433
2457
 
2434
2458
  return lora_state_dict
2435
2459
 
2460
+ @staticmethod
2461
+ def _calculate_module_shape(
2462
+ model: "torch.nn.Module",
2463
+ base_module: "torch.nn.Linear" = None,
2464
+ base_weight_param_name: str = None,
2465
+ ) -> "torch.Size":
2466
+ def _get_weight_shape(weight: torch.Tensor):
2467
+ if weight.__class__.__name__ == "Params4bit":
2468
+ return weight.quant_state.shape
2469
+ elif weight.__class__.__name__ == "GGUFParameter":
2470
+ return weight.quant_shape
2471
+ else:
2472
+ return weight.shape
2473
+
2474
+ if base_module is not None:
2475
+ return _get_weight_shape(base_module.weight)
2476
+ elif base_weight_param_name is not None:
2477
+ if not base_weight_param_name.endswith(".weight"):
2478
+ raise ValueError(
2479
+ f"Invalid `base_weight_param_name` passed as it does not end with '.weight' {base_weight_param_name=}."
2480
+ )
2481
+ module_path = base_weight_param_name.rsplit(".weight", 1)[0]
2482
+ submodule = get_submodule_by_name(model, module_path)
2483
+ return _get_weight_shape(submodule.weight)
2484
+
2485
+ raise ValueError("Either `base_module` or `base_weight_param_name` must be provided.")
2486
+
2436
2487
 
2437
2488
  # The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
2438
2489
  # relied on `StableDiffusionLoraLoaderMixin` for its LoRA support.
@@ -2444,7 +2495,14 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
2444
2495
  @classmethod
2445
2496
  # Copied from diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin.load_lora_into_transformer with FluxTransformer2DModel->UVit2DModel
2446
2497
  def load_lora_into_transformer(
2447
- cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
2498
+ cls,
2499
+ state_dict,
2500
+ network_alphas,
2501
+ transformer,
2502
+ adapter_name=None,
2503
+ _pipeline=None,
2504
+ low_cpu_mem_usage=False,
2505
+ hotswap: bool = False,
2448
2506
  ):
2449
2507
  """
2450
2508
  This will load the LoRA layers specified in `state_dict` into `transformer`.
@@ -2466,6 +2524,29 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
2466
2524
  low_cpu_mem_usage (`bool`, *optional*):
2467
2525
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
2468
2526
  weights.
2527
+ hotswap : (`bool`, *optional*)
2528
+ Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
2529
+ in-place. This means that, instead of loading an additional adapter, this will take the existing
2530
+ adapter weights and replace them with the weights of the new adapter. This can be faster and more
2531
+ memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
2532
+ torch.compile, loading the new adapter does not require recompilation of the model. When using
2533
+ hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
2534
+
2535
+ If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
2536
+ to call an additional method before loading the adapter:
2537
+
2538
+ ```py
2539
+ pipeline = ... # load diffusers pipeline
2540
+ max_rank = ... # the highest rank among all LoRAs that you want to load
2541
+ # call *before* compiling and loading the LoRA adapter
2542
+ pipeline.enable_lora_hotswap(target_rank=max_rank)
2543
+ pipeline.load_lora_weights(file_name)
2544
+ # optionally compile the model now
2545
+ ```
2546
+
2547
+ Note that hotswapping adapters of the text encoder is not yet supported. There are some further
2548
+ limitations to this technique, which are documented here:
2549
+ https://huggingface.co/docs/peft/main/en/package_reference/hotswap
2469
2550
  """
2470
2551
  if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
2471
2552
  raise ValueError(
@@ -2473,17 +2554,15 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
2473
2554
  )
2474
2555
 
2475
2556
  # Load the layers corresponding to transformer.
2476
- keys = list(state_dict.keys())
2477
- transformer_present = any(key.startswith(cls.transformer_name) for key in keys)
2478
- if transformer_present:
2479
- logger.info(f"Loading {cls.transformer_name}.")
2480
- transformer.load_lora_adapter(
2481
- state_dict,
2482
- network_alphas=network_alphas,
2483
- adapter_name=adapter_name,
2484
- _pipeline=_pipeline,
2485
- low_cpu_mem_usage=low_cpu_mem_usage,
2486
- )
2557
+ logger.info(f"Loading {cls.transformer_name}.")
2558
+ transformer.load_lora_adapter(
2559
+ state_dict,
2560
+ network_alphas=network_alphas,
2561
+ adapter_name=adapter_name,
2562
+ _pipeline=_pipeline,
2563
+ low_cpu_mem_usage=low_cpu_mem_usage,
2564
+ hotswap=hotswap,
2565
+ )
2487
2566
 
2488
2567
  @classmethod
2489
2568
  # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
@@ -2497,6 +2576,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
2497
2576
  adapter_name=None,
2498
2577
  _pipeline=None,
2499
2578
  low_cpu_mem_usage=False,
2579
+ hotswap: bool = False,
2500
2580
  ):
2501
2581
  """
2502
2582
  This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -2522,120 +2602,42 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
2522
2602
  low_cpu_mem_usage (`bool`, *optional*):
2523
2603
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
2524
2604
  weights.
2605
+ hotswap : (`bool`, *optional*)
2606
+ Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
2607
+ in-place. This means that, instead of loading an additional adapter, this will take the existing
2608
+ adapter weights and replace them with the weights of the new adapter. This can be faster and more
2609
+ memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
2610
+ torch.compile, loading the new adapter does not require recompilation of the model. When using
2611
+ hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
2612
+
2613
+ If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
2614
+ to call an additional method before loading the adapter:
2615
+
2616
+ ```py
2617
+ pipeline = ... # load diffusers pipeline
2618
+ max_rank = ... # the highest rank among all LoRAs that you want to load
2619
+ # call *before* compiling and loading the LoRA adapter
2620
+ pipeline.enable_lora_hotswap(target_rank=max_rank)
2621
+ pipeline.load_lora_weights(file_name)
2622
+ # optionally compile the model now
2623
+ ```
2624
+
2625
+ Note that hotswapping adapters of the text encoder is not yet supported. There are some further
2626
+ limitations to this technique, which are documented here:
2627
+ https://huggingface.co/docs/peft/main/en/package_reference/hotswap
2525
2628
  """
2526
- if not USE_PEFT_BACKEND:
2527
- raise ValueError("PEFT backend is required for this method.")
2528
-
2529
- peft_kwargs = {}
2530
- if low_cpu_mem_usage:
2531
- if not is_peft_version(">=", "0.13.1"):
2532
- raise ValueError(
2533
- "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
2534
- )
2535
- if not is_transformers_version(">", "4.45.2"):
2536
- # Note from sayakpaul: It's not in `transformers` stable yet.
2537
- # https://github.com/huggingface/transformers/pull/33725/
2538
- raise ValueError(
2539
- "`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
2540
- )
2541
- peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
2542
-
2543
- from peft import LoraConfig
2544
-
2545
- # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
2546
- # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
2547
- # their prefixes.
2548
- keys = list(state_dict.keys())
2549
- prefix = cls.text_encoder_name if prefix is None else prefix
2550
-
2551
- # Safe prefix to check with.
2552
- if any(cls.text_encoder_name in key for key in keys):
2553
- # Load the layers corresponding to text encoder and make necessary adjustments.
2554
- text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
2555
- text_encoder_lora_state_dict = {
2556
- k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
2557
- }
2558
-
2559
- if len(text_encoder_lora_state_dict) > 0:
2560
- logger.info(f"Loading {prefix}.")
2561
- rank = {}
2562
- text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
2563
-
2564
- # convert state dict
2565
- text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
2566
-
2567
- for name, _ in text_encoder_attn_modules(text_encoder):
2568
- for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
2569
- rank_key = f"{name}.{module}.lora_B.weight"
2570
- if rank_key not in text_encoder_lora_state_dict:
2571
- continue
2572
- rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
2573
-
2574
- for name, _ in text_encoder_mlp_modules(text_encoder):
2575
- for module in ("fc1", "fc2"):
2576
- rank_key = f"{name}.{module}.lora_B.weight"
2577
- if rank_key not in text_encoder_lora_state_dict:
2578
- continue
2579
- rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
2580
-
2581
- if network_alphas is not None:
2582
- alpha_keys = [
2583
- k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
2584
- ]
2585
- network_alphas = {
2586
- k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
2587
- }
2588
-
2589
- lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
2590
-
2591
- if "use_dora" in lora_config_kwargs:
2592
- if lora_config_kwargs["use_dora"]:
2593
- if is_peft_version("<", "0.9.0"):
2594
- raise ValueError(
2595
- "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
2596
- )
2597
- else:
2598
- if is_peft_version("<", "0.9.0"):
2599
- lora_config_kwargs.pop("use_dora")
2600
-
2601
- if "lora_bias" in lora_config_kwargs:
2602
- if lora_config_kwargs["lora_bias"]:
2603
- if is_peft_version("<=", "0.13.2"):
2604
- raise ValueError(
2605
- "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
2606
- )
2607
- else:
2608
- if is_peft_version("<=", "0.13.2"):
2609
- lora_config_kwargs.pop("lora_bias")
2610
-
2611
- lora_config = LoraConfig(**lora_config_kwargs)
2612
-
2613
- # adapter_name
2614
- if adapter_name is None:
2615
- adapter_name = get_adapter_name(text_encoder)
2616
-
2617
- is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
2618
-
2619
- # inject LoRA layers and load the state dict
2620
- # in transformers we automatically check whether the adapter name is already in use or not
2621
- text_encoder.load_adapter(
2622
- adapter_name=adapter_name,
2623
- adapter_state_dict=text_encoder_lora_state_dict,
2624
- peft_config=lora_config,
2625
- **peft_kwargs,
2626
- )
2627
-
2628
- # scale LoRA layers with `lora_scale`
2629
- scale_lora_layers(text_encoder, weight=lora_scale)
2630
-
2631
- text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
2632
-
2633
- # Offload back.
2634
- if is_model_cpu_offload:
2635
- _pipeline.enable_model_cpu_offload()
2636
- elif is_sequential_cpu_offload:
2637
- _pipeline.enable_sequential_cpu_offload()
2638
- # Unsafe code />
2629
+ _load_lora_into_text_encoder(
2630
+ state_dict=state_dict,
2631
+ network_alphas=network_alphas,
2632
+ lora_scale=lora_scale,
2633
+ text_encoder=text_encoder,
2634
+ prefix=prefix,
2635
+ text_encoder_name=cls.text_encoder_name,
2636
+ adapter_name=adapter_name,
2637
+ _pipeline=_pipeline,
2638
+ low_cpu_mem_usage=low_cpu_mem_usage,
2639
+ hotswap=hotswap,
2640
+ )
2639
2641
 
2640
2642
  @classmethod
2641
2643
  def save_lora_weights(
@@ -2851,7 +2853,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
2851
2853
  @classmethod
2852
2854
  # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogVideoXTransformer3DModel
2853
2855
  def load_lora_into_transformer(
2854
- cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
2856
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
2855
2857
  ):
2856
2858
  """
2857
2859
  This will load the LoRA layers specified in `state_dict` into `transformer`.
@@ -2869,6 +2871,29 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
2869
2871
  low_cpu_mem_usage (`bool`, *optional*):
2870
2872
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
2871
2873
  weights.
2874
+ hotswap : (`bool`, *optional*)
2875
+ Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
2876
+ in-place. This means that, instead of loading an additional adapter, this will take the existing
2877
+ adapter weights and replace them with the weights of the new adapter. This can be faster and more
2878
+ memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
2879
+ torch.compile, loading the new adapter does not require recompilation of the model. When using
2880
+ hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
2881
+
2882
+ If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
2883
+ to call an additional method before loading the adapter:
2884
+
2885
+ ```py
2886
+ pipeline = ... # load diffusers pipeline
2887
+ max_rank = ... # the highest rank among all LoRAs that you want to load
2888
+ # call *before* compiling and loading the LoRA adapter
2889
+ pipeline.enable_lora_hotswap(target_rank=max_rank)
2890
+ pipeline.load_lora_weights(file_name)
2891
+ # optionally compile the model now
2892
+ ```
2893
+
2894
+ Note that hotswapping adapters of the text encoder is not yet supported. There are some further
2895
+ limitations to this technique, which are documented here:
2896
+ https://huggingface.co/docs/peft/main/en/package_reference/hotswap
2872
2897
  """
2873
2898
  if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
2874
2899
  raise ValueError(
@@ -2883,6 +2908,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
2883
2908
  adapter_name=adapter_name,
2884
2909
  _pipeline=_pipeline,
2885
2910
  low_cpu_mem_usage=low_cpu_mem_usage,
2911
+ hotswap=hotswap,
2886
2912
  )
2887
2913
 
2888
2914
  @classmethod
@@ -2933,10 +2959,9 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
2933
2959
  safe_serialization=safe_serialization,
2934
2960
  )
2935
2961
 
2936
- # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
2937
2962
  def fuse_lora(
2938
2963
  self,
2939
- components: List[str] = ["transformer", "text_encoder"],
2964
+ components: List[str] = ["transformer"],
2940
2965
  lora_scale: float = 1.0,
2941
2966
  safe_fusing: bool = False,
2942
2967
  adapter_names: Optional[List[str]] = None,
@@ -2974,11 +2999,14 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
2974
2999
  ```
2975
3000
  """
2976
3001
  super().fuse_lora(
2977
- components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
3002
+ components=components,
3003
+ lora_scale=lora_scale,
3004
+ safe_fusing=safe_fusing,
3005
+ adapter_names=adapter_names,
3006
+ **kwargs,
2978
3007
  )
2979
3008
 
2980
- # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
2981
- def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
3009
+ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
2982
3010
  r"""
2983
3011
  Reverses the effect of
2984
3012
  [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
@@ -2992,11 +3020,8 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
2992
3020
  Args:
2993
3021
  components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
2994
3022
  unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
2995
- unfuse_text_encoder (`bool`, defaults to `True`):
2996
- Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
2997
- LoRA parameters then it won't have any effect.
2998
3023
  """
2999
- super().unfuse_lora(components=components)
3024
+ super().unfuse_lora(components=components, **kwargs)
3000
3025
 
3001
3026
 
3002
3027
  class Mochi1LoraLoaderMixin(LoraBaseMixin):
@@ -3159,7 +3184,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
3159
3184
  @classmethod
3160
3185
  # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->MochiTransformer3DModel
3161
3186
  def load_lora_into_transformer(
3162
- cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
3187
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
3163
3188
  ):
3164
3189
  """
3165
3190
  This will load the LoRA layers specified in `state_dict` into `transformer`.
@@ -3177,6 +3202,29 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
3177
3202
  low_cpu_mem_usage (`bool`, *optional*):
3178
3203
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
3179
3204
  weights.
3205
+ hotswap : (`bool`, *optional*)
3206
+ Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
3207
+ in-place. This means that, instead of loading an additional adapter, this will take the existing
3208
+ adapter weights and replace them with the weights of the new adapter. This can be faster and more
3209
+ memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
3210
+ torch.compile, loading the new adapter does not require recompilation of the model. When using
3211
+ hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
3212
+
3213
+ If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
3214
+ to call an additional method before loading the adapter:
3215
+
3216
+ ```py
3217
+ pipeline = ... # load diffusers pipeline
3218
+ max_rank = ... # the highest rank among all LoRAs that you want to load
3219
+ # call *before* compiling and loading the LoRA adapter
3220
+ pipeline.enable_lora_hotswap(target_rank=max_rank)
3221
+ pipeline.load_lora_weights(file_name)
3222
+ # optionally compile the model now
3223
+ ```
3224
+
3225
+ Note that hotswapping adapters of the text encoder is not yet supported. There are some further
3226
+ limitations to this technique, which are documented here:
3227
+ https://huggingface.co/docs/peft/main/en/package_reference/hotswap
3180
3228
  """
3181
3229
  if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
3182
3230
  raise ValueError(
@@ -3191,6 +3239,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
3191
3239
  adapter_name=adapter_name,
3192
3240
  _pipeline=_pipeline,
3193
3241
  low_cpu_mem_usage=low_cpu_mem_usage,
3242
+ hotswap=hotswap,
3194
3243
  )
3195
3244
 
3196
3245
  @classmethod
@@ -3241,10 +3290,10 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
3241
3290
  safe_serialization=safe_serialization,
3242
3291
  )
3243
3292
 
3244
- # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
3293
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
3245
3294
  def fuse_lora(
3246
3295
  self,
3247
- components: List[str] = ["transformer", "text_encoder"],
3296
+ components: List[str] = ["transformer"],
3248
3297
  lora_scale: float = 1.0,
3249
3298
  safe_fusing: bool = False,
3250
3299
  adapter_names: Optional[List[str]] = None,
@@ -3282,11 +3331,15 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
3282
3331
  ```
3283
3332
  """
3284
3333
  super().fuse_lora(
3285
- components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
3334
+ components=components,
3335
+ lora_scale=lora_scale,
3336
+ safe_fusing=safe_fusing,
3337
+ adapter_names=adapter_names,
3338
+ **kwargs,
3286
3339
  )
3287
3340
 
3288
- # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
3289
- def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
3341
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
3342
+ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
3290
3343
  r"""
3291
3344
  Reverses the effect of
3292
3345
  [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
@@ -3300,11 +3353,8 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
3300
3353
  Args:
3301
3354
  components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
3302
3355
  unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
3303
- unfuse_text_encoder (`bool`, defaults to `True`):
3304
- Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
3305
- LoRA parameters then it won't have any effect.
3306
3356
  """
3307
- super().unfuse_lora(components=components)
3357
+ super().unfuse_lora(components=components, **kwargs)
3308
3358
 
3309
3359
 
3310
3360
  class LTXVideoLoraLoaderMixin(LoraBaseMixin):
@@ -3467,7 +3517,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
3467
3517
  @classmethod
3468
3518
  # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->LTXVideoTransformer3DModel
3469
3519
  def load_lora_into_transformer(
3470
- cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
3520
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
3471
3521
  ):
3472
3522
  """
3473
3523
  This will load the LoRA layers specified in `state_dict` into `transformer`.
@@ -3485,6 +3535,29 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
3485
3535
  low_cpu_mem_usage (`bool`, *optional*):
3486
3536
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
3487
3537
  weights.
3538
+ hotswap : (`bool`, *optional*)
3539
+ Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
3540
+ in-place. This means that, instead of loading an additional adapter, this will take the existing
3541
+ adapter weights and replace them with the weights of the new adapter. This can be faster and more
3542
+ memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
3543
+ torch.compile, loading the new adapter does not require recompilation of the model. When using
3544
+ hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
3545
+
3546
+ If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
3547
+ to call an additional method before loading the adapter:
3548
+
3549
+ ```py
3550
+ pipeline = ... # load diffusers pipeline
3551
+ max_rank = ... # the highest rank among all LoRAs that you want to load
3552
+ # call *before* compiling and loading the LoRA adapter
3553
+ pipeline.enable_lora_hotswap(target_rank=max_rank)
3554
+ pipeline.load_lora_weights(file_name)
3555
+ # optionally compile the model now
3556
+ ```
3557
+
3558
+ Note that hotswapping adapters of the text encoder is not yet supported. There are some further
3559
+ limitations to this technique, which are documented here:
3560
+ https://huggingface.co/docs/peft/main/en/package_reference/hotswap
3488
3561
  """
3489
3562
  if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
3490
3563
  raise ValueError(
@@ -3499,6 +3572,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
3499
3572
  adapter_name=adapter_name,
3500
3573
  _pipeline=_pipeline,
3501
3574
  low_cpu_mem_usage=low_cpu_mem_usage,
3575
+ hotswap=hotswap,
3502
3576
  )
3503
3577
 
3504
3578
  @classmethod
@@ -3549,10 +3623,10 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
3549
3623
  safe_serialization=safe_serialization,
3550
3624
  )
3551
3625
 
3552
- # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
3626
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
3553
3627
  def fuse_lora(
3554
3628
  self,
3555
- components: List[str] = ["transformer", "text_encoder"],
3629
+ components: List[str] = ["transformer"],
3556
3630
  lora_scale: float = 1.0,
3557
3631
  safe_fusing: bool = False,
3558
3632
  adapter_names: Optional[List[str]] = None,
@@ -3590,11 +3664,15 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
3590
3664
  ```
3591
3665
  """
3592
3666
  super().fuse_lora(
3593
- components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
3667
+ components=components,
3668
+ lora_scale=lora_scale,
3669
+ safe_fusing=safe_fusing,
3670
+ adapter_names=adapter_names,
3671
+ **kwargs,
3594
3672
  )
3595
3673
 
3596
- # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
3597
- def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
3674
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
3675
+ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
3598
3676
  r"""
3599
3677
  Reverses the effect of
3600
3678
  [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
@@ -3608,11 +3686,8 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
3608
3686
  Args:
3609
3687
  components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
3610
3688
  unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
3611
- unfuse_text_encoder (`bool`, defaults to `True`):
3612
- Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
3613
- LoRA parameters then it won't have any effect.
3614
3689
  """
3615
- super().unfuse_lora(components=components)
3690
+ super().unfuse_lora(components=components, **kwargs)
3616
3691
 
3617
3692
 
3618
3693
  class SanaLoraLoaderMixin(LoraBaseMixin):
@@ -3775,7 +3850,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
3775
3850
  @classmethod
3776
3851
  # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SanaTransformer2DModel
3777
3852
  def load_lora_into_transformer(
3778
- cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
3853
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
3779
3854
  ):
3780
3855
  """
3781
3856
  This will load the LoRA layers specified in `state_dict` into `transformer`.
@@ -3793,6 +3868,29 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
3793
3868
  low_cpu_mem_usage (`bool`, *optional*):
3794
3869
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
3795
3870
  weights.
3871
+ hotswap : (`bool`, *optional*)
3872
+ Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
3873
+ in-place. This means that, instead of loading an additional adapter, this will take the existing
3874
+ adapter weights and replace them with the weights of the new adapter. This can be faster and more
3875
+ memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
3876
+ torch.compile, loading the new adapter does not require recompilation of the model. When using
3877
+ hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
3878
+
3879
+ If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
3880
+ to call an additional method before loading the adapter:
3881
+
3882
+ ```py
3883
+ pipeline = ... # load diffusers pipeline
3884
+ max_rank = ... # the highest rank among all LoRAs that you want to load
3885
+ # call *before* compiling and loading the LoRA adapter
3886
+ pipeline.enable_lora_hotswap(target_rank=max_rank)
3887
+ pipeline.load_lora_weights(file_name)
3888
+ # optionally compile the model now
3889
+ ```
3890
+
3891
+ Note that hotswapping adapters of the text encoder is not yet supported. There are some further
3892
+ limitations to this technique, which are documented here:
3893
+ https://huggingface.co/docs/peft/main/en/package_reference/hotswap
3796
3894
  """
3797
3895
  if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
3798
3896
  raise ValueError(
@@ -3807,6 +3905,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
3807
3905
  adapter_name=adapter_name,
3808
3906
  _pipeline=_pipeline,
3809
3907
  low_cpu_mem_usage=low_cpu_mem_usage,
3908
+ hotswap=hotswap,
3810
3909
  )
3811
3910
 
3812
3911
  @classmethod
@@ -3857,10 +3956,10 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
3857
3956
  safe_serialization=safe_serialization,
3858
3957
  )
3859
3958
 
3860
- # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
3959
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
3861
3960
  def fuse_lora(
3862
3961
  self,
3863
- components: List[str] = ["transformer", "text_encoder"],
3962
+ components: List[str] = ["transformer"],
3864
3963
  lora_scale: float = 1.0,
3865
3964
  safe_fusing: bool = False,
3866
3965
  adapter_names: Optional[List[str]] = None,
@@ -3898,11 +3997,15 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
3898
3997
  ```
3899
3998
  """
3900
3999
  super().fuse_lora(
3901
- components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
4000
+ components=components,
4001
+ lora_scale=lora_scale,
4002
+ safe_fusing=safe_fusing,
4003
+ adapter_names=adapter_names,
4004
+ **kwargs,
3902
4005
  )
3903
4006
 
3904
- # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
3905
- def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
4007
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
4008
+ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
3906
4009
  r"""
3907
4010
  Reverses the effect of
3908
4011
  [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
@@ -3916,11 +4019,8 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
3916
4019
  Args:
3917
4020
  components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
3918
4021
  unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
3919
- unfuse_text_encoder (`bool`, defaults to `True`):
3920
- Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
3921
- LoRA parameters then it won't have any effect.
3922
4022
  """
3923
- super().unfuse_lora(components=components)
4023
+ super().unfuse_lora(components=components, **kwargs)
3924
4024
 
3925
4025
 
3926
4026
  class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
@@ -3933,7 +4033,6 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
3933
4033
 
3934
4034
  @classmethod
3935
4035
  @validate_hf_hub_args
3936
- # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
3937
4036
  def lora_state_dict(
3938
4037
  cls,
3939
4038
  pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
@@ -3944,7 +4043,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
3944
4043
 
3945
4044
  <Tip warning={true}>
3946
4045
 
3947
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
4046
+ We support loading original format HunyuanVideo LoRA checkpoints.
3948
4047
 
3949
4048
  This function is experimental and might change in the future.
3950
4049
 
@@ -4027,6 +4126,10 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
4027
4126
  logger.warning(warn_msg)
4028
4127
  state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
4029
4128
 
4129
+ is_original_hunyuan_video = any("img_attn_qkv" in k for k in state_dict)
4130
+ if is_original_hunyuan_video:
4131
+ state_dict = _convert_hunyuan_video_lora_to_diffusers(state_dict)
4132
+
4030
4133
  return state_dict
4031
4134
 
4032
4135
  # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
@@ -4083,7 +4186,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
4083
4186
  @classmethod
4084
4187
  # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HunyuanVideoTransformer3DModel
4085
4188
  def load_lora_into_transformer(
4086
- cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
4189
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
4087
4190
  ):
4088
4191
  """
4089
4192
  This will load the LoRA layers specified in `state_dict` into `transformer`.
@@ -4101,6 +4204,29 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
4101
4204
  low_cpu_mem_usage (`bool`, *optional*):
4102
4205
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
4103
4206
  weights.
4207
+ hotswap : (`bool`, *optional*)
4208
+ Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
4209
+ in-place. This means that, instead of loading an additional adapter, this will take the existing
4210
+ adapter weights and replace them with the weights of the new adapter. This can be faster and more
4211
+ memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
4212
+ torch.compile, loading the new adapter does not require recompilation of the model. When using
4213
+ hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
4214
+
4215
+ If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
4216
+ to call an additional method before loading the adapter:
4217
+
4218
+ ```py
4219
+ pipeline = ... # load diffusers pipeline
4220
+ max_rank = ... # the highest rank among all LoRAs that you want to load
4221
+ # call *before* compiling and loading the LoRA adapter
4222
+ pipeline.enable_lora_hotswap(target_rank=max_rank)
4223
+ pipeline.load_lora_weights(file_name)
4224
+ # optionally compile the model now
4225
+ ```
4226
+
4227
+ Note that hotswapping adapters of the text encoder is not yet supported. There are some further
4228
+ limitations to this technique, which are documented here:
4229
+ https://huggingface.co/docs/peft/main/en/package_reference/hotswap
4104
4230
  """
4105
4231
  if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
4106
4232
  raise ValueError(
@@ -4115,6 +4241,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
4115
4241
  adapter_name=adapter_name,
4116
4242
  _pipeline=_pipeline,
4117
4243
  low_cpu_mem_usage=low_cpu_mem_usage,
4244
+ hotswap=hotswap,
4118
4245
  )
4119
4246
 
4120
4247
  @classmethod
@@ -4165,10 +4292,10 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
4165
4292
  safe_serialization=safe_serialization,
4166
4293
  )
4167
4294
 
4168
- # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
4295
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
4169
4296
  def fuse_lora(
4170
4297
  self,
4171
- components: List[str] = ["transformer", "text_encoder"],
4298
+ components: List[str] = ["transformer"],
4172
4299
  lora_scale: float = 1.0,
4173
4300
  safe_fusing: bool = False,
4174
4301
  adapter_names: Optional[List[str]] = None,
@@ -4206,11 +4333,1049 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
4206
4333
  ```
4207
4334
  """
4208
4335
  super().fuse_lora(
4209
- components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
4336
+ components=components,
4337
+ lora_scale=lora_scale,
4338
+ safe_fusing=safe_fusing,
4339
+ adapter_names=adapter_names,
4340
+ **kwargs,
4210
4341
  )
4211
4342
 
4212
- # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
4213
- def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
4343
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
4344
+ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
4345
+ r"""
4346
+ Reverses the effect of
4347
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
4348
+
4349
+ <Tip warning={true}>
4350
+
4351
+ This is an experimental API.
4352
+
4353
+ </Tip>
4354
+
4355
+ Args:
4356
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
4357
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
4358
+ """
4359
+ super().unfuse_lora(components=components, **kwargs)
4360
+
4361
+
4362
+ class Lumina2LoraLoaderMixin(LoraBaseMixin):
4363
+ r"""
4364
+ Load LoRA layers into [`Lumina2Transformer2DModel`]. Specific to [`Lumina2Text2ImgPipeline`].
4365
+ """
4366
+
4367
+ _lora_loadable_modules = ["transformer"]
4368
+ transformer_name = TRANSFORMER_NAME
4369
+
4370
+ @classmethod
4371
+ @validate_hf_hub_args
4372
+ def lora_state_dict(
4373
+ cls,
4374
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
4375
+ **kwargs,
4376
+ ):
4377
+ r"""
4378
+ Return state dict for lora weights and the network alphas.
4379
+
4380
+ <Tip warning={true}>
4381
+
4382
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
4383
+
4384
+ This function is experimental and might change in the future.
4385
+
4386
+ </Tip>
4387
+
4388
+ Parameters:
4389
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
4390
+ Can be either:
4391
+
4392
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
4393
+ the Hub.
4394
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
4395
+ with [`ModelMixin.save_pretrained`].
4396
+ - A [torch state
4397
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
4398
+
4399
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
4400
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
4401
+ is not used.
4402
+ force_download (`bool`, *optional*, defaults to `False`):
4403
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
4404
+ cached versions if they exist.
4405
+
4406
+ proxies (`Dict[str, str]`, *optional*):
4407
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
4408
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
4409
+ local_files_only (`bool`, *optional*, defaults to `False`):
4410
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
4411
+ won't be downloaded from the Hub.
4412
+ token (`str` or *bool*, *optional*):
4413
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
4414
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
4415
+ revision (`str`, *optional*, defaults to `"main"`):
4416
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
4417
+ allowed by Git.
4418
+ subfolder (`str`, *optional*, defaults to `""`):
4419
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
4420
+
4421
+ """
4422
+ # Load the main state dict first which has the LoRA layers for either of
4423
+ # transformer and text encoder or both.
4424
+ cache_dir = kwargs.pop("cache_dir", None)
4425
+ force_download = kwargs.pop("force_download", False)
4426
+ proxies = kwargs.pop("proxies", None)
4427
+ local_files_only = kwargs.pop("local_files_only", None)
4428
+ token = kwargs.pop("token", None)
4429
+ revision = kwargs.pop("revision", None)
4430
+ subfolder = kwargs.pop("subfolder", None)
4431
+ weight_name = kwargs.pop("weight_name", None)
4432
+ use_safetensors = kwargs.pop("use_safetensors", None)
4433
+
4434
+ allow_pickle = False
4435
+ if use_safetensors is None:
4436
+ use_safetensors = True
4437
+ allow_pickle = True
4438
+
4439
+ user_agent = {
4440
+ "file_type": "attn_procs_weights",
4441
+ "framework": "pytorch",
4442
+ }
4443
+
4444
+ state_dict = _fetch_state_dict(
4445
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
4446
+ weight_name=weight_name,
4447
+ use_safetensors=use_safetensors,
4448
+ local_files_only=local_files_only,
4449
+ cache_dir=cache_dir,
4450
+ force_download=force_download,
4451
+ proxies=proxies,
4452
+ token=token,
4453
+ revision=revision,
4454
+ subfolder=subfolder,
4455
+ user_agent=user_agent,
4456
+ allow_pickle=allow_pickle,
4457
+ )
4458
+
4459
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
4460
+ if is_dora_scale_present:
4461
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
4462
+ logger.warning(warn_msg)
4463
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
4464
+
4465
+ # conversion.
4466
+ non_diffusers = any(k.startswith("diffusion_model.") for k in state_dict)
4467
+ if non_diffusers:
4468
+ state_dict = _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict)
4469
+
4470
+ return state_dict
4471
+
4472
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
4473
+ def load_lora_weights(
4474
+ self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
4475
+ ):
4476
+ """
4477
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
4478
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
4479
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
4480
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
4481
+ dict is loaded into `self.transformer`.
4482
+
4483
+ Parameters:
4484
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
4485
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
4486
+ adapter_name (`str`, *optional*):
4487
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
4488
+ `default_{i}` where i is the total number of adapters being loaded.
4489
+ low_cpu_mem_usage (`bool`, *optional*):
4490
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
4491
+ weights.
4492
+ kwargs (`dict`, *optional*):
4493
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
4494
+ """
4495
+ if not USE_PEFT_BACKEND:
4496
+ raise ValueError("PEFT backend is required for this method.")
4497
+
4498
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
4499
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
4500
+ raise ValueError(
4501
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
4502
+ )
4503
+
4504
+ # if a dict is passed, copy it instead of modifying it inplace
4505
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
4506
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
4507
+
4508
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
4509
+ state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
4510
+
4511
+ is_correct_format = all("lora" in key for key in state_dict.keys())
4512
+ if not is_correct_format:
4513
+ raise ValueError("Invalid LoRA checkpoint.")
4514
+
4515
+ self.load_lora_into_transformer(
4516
+ state_dict,
4517
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
4518
+ adapter_name=adapter_name,
4519
+ _pipeline=self,
4520
+ low_cpu_mem_usage=low_cpu_mem_usage,
4521
+ )
4522
+
4523
+ @classmethod
4524
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->Lumina2Transformer2DModel
4525
+ def load_lora_into_transformer(
4526
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
4527
+ ):
4528
+ """
4529
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
4530
+
4531
+ Parameters:
4532
+ state_dict (`dict`):
4533
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
4534
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
4535
+ encoder lora layers.
4536
+ transformer (`Lumina2Transformer2DModel`):
4537
+ The Transformer model to load the LoRA layers into.
4538
+ adapter_name (`str`, *optional*):
4539
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
4540
+ `default_{i}` where i is the total number of adapters being loaded.
4541
+ low_cpu_mem_usage (`bool`, *optional*):
4542
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
4543
+ weights.
4544
+ hotswap : (`bool`, *optional*)
4545
+ Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
4546
+ in-place. This means that, instead of loading an additional adapter, this will take the existing
4547
+ adapter weights and replace them with the weights of the new adapter. This can be faster and more
4548
+ memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
4549
+ torch.compile, loading the new adapter does not require recompilation of the model. When using
4550
+ hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
4551
+
4552
+ If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
4553
+ to call an additional method before loading the adapter:
4554
+
4555
+ ```py
4556
+ pipeline = ... # load diffusers pipeline
4557
+ max_rank = ... # the highest rank among all LoRAs that you want to load
4558
+ # call *before* compiling and loading the LoRA adapter
4559
+ pipeline.enable_lora_hotswap(target_rank=max_rank)
4560
+ pipeline.load_lora_weights(file_name)
4561
+ # optionally compile the model now
4562
+ ```
4563
+
4564
+ Note that hotswapping adapters of the text encoder is not yet supported. There are some further
4565
+ limitations to this technique, which are documented here:
4566
+ https://huggingface.co/docs/peft/main/en/package_reference/hotswap
4567
+ """
4568
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
4569
+ raise ValueError(
4570
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
4571
+ )
4572
+
4573
+ # Load the layers corresponding to transformer.
4574
+ logger.info(f"Loading {cls.transformer_name}.")
4575
+ transformer.load_lora_adapter(
4576
+ state_dict,
4577
+ network_alphas=None,
4578
+ adapter_name=adapter_name,
4579
+ _pipeline=_pipeline,
4580
+ low_cpu_mem_usage=low_cpu_mem_usage,
4581
+ hotswap=hotswap,
4582
+ )
4583
+
4584
+ @classmethod
4585
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
4586
+ def save_lora_weights(
4587
+ cls,
4588
+ save_directory: Union[str, os.PathLike],
4589
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
4590
+ is_main_process: bool = True,
4591
+ weight_name: str = None,
4592
+ save_function: Callable = None,
4593
+ safe_serialization: bool = True,
4594
+ ):
4595
+ r"""
4596
+ Save the LoRA parameters corresponding to the UNet and text encoder.
4597
+
4598
+ Arguments:
4599
+ save_directory (`str` or `os.PathLike`):
4600
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
4601
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
4602
+ State dict of the LoRA layers corresponding to the `transformer`.
4603
+ is_main_process (`bool`, *optional*, defaults to `True`):
4604
+ Whether the process calling this is the main process or not. Useful during distributed training and you
4605
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
4606
+ process to avoid race conditions.
4607
+ save_function (`Callable`):
4608
+ The function to use to save the state dictionary. Useful during distributed training when you need to
4609
+ replace `torch.save` with another method. Can be configured with the environment variable
4610
+ `DIFFUSERS_SAVE_MODE`.
4611
+ safe_serialization (`bool`, *optional*, defaults to `True`):
4612
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
4613
+ """
4614
+ state_dict = {}
4615
+
4616
+ if not transformer_lora_layers:
4617
+ raise ValueError("You must pass `transformer_lora_layers`.")
4618
+
4619
+ if transformer_lora_layers:
4620
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
4621
+
4622
+ # Save the model
4623
+ cls.write_lora_layers(
4624
+ state_dict=state_dict,
4625
+ save_directory=save_directory,
4626
+ is_main_process=is_main_process,
4627
+ weight_name=weight_name,
4628
+ save_function=save_function,
4629
+ safe_serialization=safe_serialization,
4630
+ )
4631
+
4632
+ # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
4633
+ def fuse_lora(
4634
+ self,
4635
+ components: List[str] = ["transformer"],
4636
+ lora_scale: float = 1.0,
4637
+ safe_fusing: bool = False,
4638
+ adapter_names: Optional[List[str]] = None,
4639
+ **kwargs,
4640
+ ):
4641
+ r"""
4642
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
4643
+
4644
+ <Tip warning={true}>
4645
+
4646
+ This is an experimental API.
4647
+
4648
+ </Tip>
4649
+
4650
+ Args:
4651
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
4652
+ lora_scale (`float`, defaults to 1.0):
4653
+ Controls how much to influence the outputs with the LoRA parameters.
4654
+ safe_fusing (`bool`, defaults to `False`):
4655
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
4656
+ adapter_names (`List[str]`, *optional*):
4657
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
4658
+
4659
+ Example:
4660
+
4661
+ ```py
4662
+ from diffusers import DiffusionPipeline
4663
+ import torch
4664
+
4665
+ pipeline = DiffusionPipeline.from_pretrained(
4666
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
4667
+ ).to("cuda")
4668
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
4669
+ pipeline.fuse_lora(lora_scale=0.7)
4670
+ ```
4671
+ """
4672
+ super().fuse_lora(
4673
+ components=components,
4674
+ lora_scale=lora_scale,
4675
+ safe_fusing=safe_fusing,
4676
+ adapter_names=adapter_names,
4677
+ **kwargs,
4678
+ )
4679
+
4680
+ # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
4681
+ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
4682
+ r"""
4683
+ Reverses the effect of
4684
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
4685
+
4686
+ <Tip warning={true}>
4687
+
4688
+ This is an experimental API.
4689
+
4690
+ </Tip>
4691
+
4692
+ Args:
4693
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
4694
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
4695
+ """
4696
+ super().unfuse_lora(components=components, **kwargs)
4697
+
4698
+
4699
+ class WanLoraLoaderMixin(LoraBaseMixin):
4700
+ r"""
4701
+ Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`WanPipeline`] and `[WanImageToVideoPipeline`].
4702
+ """
4703
+
4704
+ _lora_loadable_modules = ["transformer"]
4705
+ transformer_name = TRANSFORMER_NAME
4706
+
4707
+ @classmethod
4708
+ @validate_hf_hub_args
4709
+ def lora_state_dict(
4710
+ cls,
4711
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
4712
+ **kwargs,
4713
+ ):
4714
+ r"""
4715
+ Return state dict for lora weights and the network alphas.
4716
+
4717
+ <Tip warning={true}>
4718
+
4719
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
4720
+
4721
+ This function is experimental and might change in the future.
4722
+
4723
+ </Tip>
4724
+
4725
+ Parameters:
4726
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
4727
+ Can be either:
4728
+
4729
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
4730
+ the Hub.
4731
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
4732
+ with [`ModelMixin.save_pretrained`].
4733
+ - A [torch state
4734
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
4735
+
4736
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
4737
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
4738
+ is not used.
4739
+ force_download (`bool`, *optional*, defaults to `False`):
4740
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
4741
+ cached versions if they exist.
4742
+
4743
+ proxies (`Dict[str, str]`, *optional*):
4744
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
4745
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
4746
+ local_files_only (`bool`, *optional*, defaults to `False`):
4747
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
4748
+ won't be downloaded from the Hub.
4749
+ token (`str` or *bool*, *optional*):
4750
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
4751
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
4752
+ revision (`str`, *optional*, defaults to `"main"`):
4753
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
4754
+ allowed by Git.
4755
+ subfolder (`str`, *optional*, defaults to `""`):
4756
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
4757
+
4758
+ """
4759
+ # Load the main state dict first which has the LoRA layers for either of
4760
+ # transformer and text encoder or both.
4761
+ cache_dir = kwargs.pop("cache_dir", None)
4762
+ force_download = kwargs.pop("force_download", False)
4763
+ proxies = kwargs.pop("proxies", None)
4764
+ local_files_only = kwargs.pop("local_files_only", None)
4765
+ token = kwargs.pop("token", None)
4766
+ revision = kwargs.pop("revision", None)
4767
+ subfolder = kwargs.pop("subfolder", None)
4768
+ weight_name = kwargs.pop("weight_name", None)
4769
+ use_safetensors = kwargs.pop("use_safetensors", None)
4770
+
4771
+ allow_pickle = False
4772
+ if use_safetensors is None:
4773
+ use_safetensors = True
4774
+ allow_pickle = True
4775
+
4776
+ user_agent = {
4777
+ "file_type": "attn_procs_weights",
4778
+ "framework": "pytorch",
4779
+ }
4780
+
4781
+ state_dict = _fetch_state_dict(
4782
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
4783
+ weight_name=weight_name,
4784
+ use_safetensors=use_safetensors,
4785
+ local_files_only=local_files_only,
4786
+ cache_dir=cache_dir,
4787
+ force_download=force_download,
4788
+ proxies=proxies,
4789
+ token=token,
4790
+ revision=revision,
4791
+ subfolder=subfolder,
4792
+ user_agent=user_agent,
4793
+ allow_pickle=allow_pickle,
4794
+ )
4795
+ if any(k.startswith("diffusion_model.") for k in state_dict):
4796
+ state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict)
4797
+
4798
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
4799
+ if is_dora_scale_present:
4800
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
4801
+ logger.warning(warn_msg)
4802
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
4803
+
4804
+ return state_dict
4805
+
4806
+ @classmethod
4807
+ def _maybe_expand_t2v_lora_for_i2v(
4808
+ cls,
4809
+ transformer: torch.nn.Module,
4810
+ state_dict,
4811
+ ):
4812
+ if transformer.config.image_dim is None:
4813
+ return state_dict
4814
+
4815
+ if any(k.startswith("transformer.blocks.") for k in state_dict):
4816
+ num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict})
4817
+ is_i2v_lora = any("add_k_proj" in k for k in state_dict) and any("add_v_proj" in k for k in state_dict)
4818
+
4819
+ if is_i2v_lora:
4820
+ return state_dict
4821
+
4822
+ for i in range(num_blocks):
4823
+ for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
4824
+ state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like(
4825
+ state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"]
4826
+ )
4827
+ state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like(
4828
+ state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"]
4829
+ )
4830
+
4831
+ return state_dict
4832
+
4833
+ def load_lora_weights(
4834
+ self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
4835
+ ):
4836
+ """
4837
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
4838
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
4839
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
4840
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
4841
+ dict is loaded into `self.transformer`.
4842
+
4843
+ Parameters:
4844
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
4845
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
4846
+ adapter_name (`str`, *optional*):
4847
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
4848
+ `default_{i}` where i is the total number of adapters being loaded.
4849
+ low_cpu_mem_usage (`bool`, *optional*):
4850
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
4851
+ weights.
4852
+ kwargs (`dict`, *optional*):
4853
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
4854
+ """
4855
+ if not USE_PEFT_BACKEND:
4856
+ raise ValueError("PEFT backend is required for this method.")
4857
+
4858
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
4859
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
4860
+ raise ValueError(
4861
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
4862
+ )
4863
+
4864
+ # if a dict is passed, copy it instead of modifying it inplace
4865
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
4866
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
4867
+
4868
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
4869
+ state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
4870
+ # convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers
4871
+ state_dict = self._maybe_expand_t2v_lora_for_i2v(
4872
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
4873
+ state_dict=state_dict,
4874
+ )
4875
+ is_correct_format = all("lora" in key for key in state_dict.keys())
4876
+ if not is_correct_format:
4877
+ raise ValueError("Invalid LoRA checkpoint.")
4878
+
4879
+ self.load_lora_into_transformer(
4880
+ state_dict,
4881
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
4882
+ adapter_name=adapter_name,
4883
+ _pipeline=self,
4884
+ low_cpu_mem_usage=low_cpu_mem_usage,
4885
+ )
4886
+
4887
+ @classmethod
4888
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel
4889
+ def load_lora_into_transformer(
4890
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
4891
+ ):
4892
+ """
4893
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
4894
+
4895
+ Parameters:
4896
+ state_dict (`dict`):
4897
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
4898
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
4899
+ encoder lora layers.
4900
+ transformer (`WanTransformer3DModel`):
4901
+ The Transformer model to load the LoRA layers into.
4902
+ adapter_name (`str`, *optional*):
4903
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
4904
+ `default_{i}` where i is the total number of adapters being loaded.
4905
+ low_cpu_mem_usage (`bool`, *optional*):
4906
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
4907
+ weights.
4908
+ hotswap : (`bool`, *optional*)
4909
+ Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
4910
+ in-place. This means that, instead of loading an additional adapter, this will take the existing
4911
+ adapter weights and replace them with the weights of the new adapter. This can be faster and more
4912
+ memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
4913
+ torch.compile, loading the new adapter does not require recompilation of the model. When using
4914
+ hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
4915
+
4916
+ If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
4917
+ to call an additional method before loading the adapter:
4918
+
4919
+ ```py
4920
+ pipeline = ... # load diffusers pipeline
4921
+ max_rank = ... # the highest rank among all LoRAs that you want to load
4922
+ # call *before* compiling and loading the LoRA adapter
4923
+ pipeline.enable_lora_hotswap(target_rank=max_rank)
4924
+ pipeline.load_lora_weights(file_name)
4925
+ # optionally compile the model now
4926
+ ```
4927
+
4928
+ Note that hotswapping adapters of the text encoder is not yet supported. There are some further
4929
+ limitations to this technique, which are documented here:
4930
+ https://huggingface.co/docs/peft/main/en/package_reference/hotswap
4931
+ """
4932
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
4933
+ raise ValueError(
4934
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
4935
+ )
4936
+
4937
+ # Load the layers corresponding to transformer.
4938
+ logger.info(f"Loading {cls.transformer_name}.")
4939
+ transformer.load_lora_adapter(
4940
+ state_dict,
4941
+ network_alphas=None,
4942
+ adapter_name=adapter_name,
4943
+ _pipeline=_pipeline,
4944
+ low_cpu_mem_usage=low_cpu_mem_usage,
4945
+ hotswap=hotswap,
4946
+ )
4947
+
4948
+ @classmethod
4949
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
4950
+ def save_lora_weights(
4951
+ cls,
4952
+ save_directory: Union[str, os.PathLike],
4953
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
4954
+ is_main_process: bool = True,
4955
+ weight_name: str = None,
4956
+ save_function: Callable = None,
4957
+ safe_serialization: bool = True,
4958
+ ):
4959
+ r"""
4960
+ Save the LoRA parameters corresponding to the UNet and text encoder.
4961
+
4962
+ Arguments:
4963
+ save_directory (`str` or `os.PathLike`):
4964
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
4965
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
4966
+ State dict of the LoRA layers corresponding to the `transformer`.
4967
+ is_main_process (`bool`, *optional*, defaults to `True`):
4968
+ Whether the process calling this is the main process or not. Useful during distributed training and you
4969
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
4970
+ process to avoid race conditions.
4971
+ save_function (`Callable`):
4972
+ The function to use to save the state dictionary. Useful during distributed training when you need to
4973
+ replace `torch.save` with another method. Can be configured with the environment variable
4974
+ `DIFFUSERS_SAVE_MODE`.
4975
+ safe_serialization (`bool`, *optional*, defaults to `True`):
4976
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
4977
+ """
4978
+ state_dict = {}
4979
+
4980
+ if not transformer_lora_layers:
4981
+ raise ValueError("You must pass `transformer_lora_layers`.")
4982
+
4983
+ if transformer_lora_layers:
4984
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
4985
+
4986
+ # Save the model
4987
+ cls.write_lora_layers(
4988
+ state_dict=state_dict,
4989
+ save_directory=save_directory,
4990
+ is_main_process=is_main_process,
4991
+ weight_name=weight_name,
4992
+ save_function=save_function,
4993
+ safe_serialization=safe_serialization,
4994
+ )
4995
+
4996
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
4997
+ def fuse_lora(
4998
+ self,
4999
+ components: List[str] = ["transformer"],
5000
+ lora_scale: float = 1.0,
5001
+ safe_fusing: bool = False,
5002
+ adapter_names: Optional[List[str]] = None,
5003
+ **kwargs,
5004
+ ):
5005
+ r"""
5006
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
5007
+
5008
+ <Tip warning={true}>
5009
+
5010
+ This is an experimental API.
5011
+
5012
+ </Tip>
5013
+
5014
+ Args:
5015
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
5016
+ lora_scale (`float`, defaults to 1.0):
5017
+ Controls how much to influence the outputs with the LoRA parameters.
5018
+ safe_fusing (`bool`, defaults to `False`):
5019
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
5020
+ adapter_names (`List[str]`, *optional*):
5021
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
5022
+
5023
+ Example:
5024
+
5025
+ ```py
5026
+ from diffusers import DiffusionPipeline
5027
+ import torch
5028
+
5029
+ pipeline = DiffusionPipeline.from_pretrained(
5030
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
5031
+ ).to("cuda")
5032
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
5033
+ pipeline.fuse_lora(lora_scale=0.7)
5034
+ ```
5035
+ """
5036
+ super().fuse_lora(
5037
+ components=components,
5038
+ lora_scale=lora_scale,
5039
+ safe_fusing=safe_fusing,
5040
+ adapter_names=adapter_names,
5041
+ **kwargs,
5042
+ )
5043
+
5044
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
5045
+ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
5046
+ r"""
5047
+ Reverses the effect of
5048
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
5049
+
5050
+ <Tip warning={true}>
5051
+
5052
+ This is an experimental API.
5053
+
5054
+ </Tip>
5055
+
5056
+ Args:
5057
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
5058
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
5059
+ """
5060
+ super().unfuse_lora(components=components, **kwargs)
5061
+
5062
+
5063
+ class CogView4LoraLoaderMixin(LoraBaseMixin):
5064
+ r"""
5065
+ Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`CogView4Pipeline`].
5066
+ """
5067
+
5068
+ _lora_loadable_modules = ["transformer"]
5069
+ transformer_name = TRANSFORMER_NAME
5070
+
5071
+ @classmethod
5072
+ @validate_hf_hub_args
5073
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
5074
+ def lora_state_dict(
5075
+ cls,
5076
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
5077
+ **kwargs,
5078
+ ):
5079
+ r"""
5080
+ Return state dict for lora weights and the network alphas.
5081
+
5082
+ <Tip warning={true}>
5083
+
5084
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
5085
+
5086
+ This function is experimental and might change in the future.
5087
+
5088
+ </Tip>
5089
+
5090
+ Parameters:
5091
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
5092
+ Can be either:
5093
+
5094
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
5095
+ the Hub.
5096
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
5097
+ with [`ModelMixin.save_pretrained`].
5098
+ - A [torch state
5099
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
5100
+
5101
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
5102
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
5103
+ is not used.
5104
+ force_download (`bool`, *optional*, defaults to `False`):
5105
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
5106
+ cached versions if they exist.
5107
+
5108
+ proxies (`Dict[str, str]`, *optional*):
5109
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
5110
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
5111
+ local_files_only (`bool`, *optional*, defaults to `False`):
5112
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
5113
+ won't be downloaded from the Hub.
5114
+ token (`str` or *bool*, *optional*):
5115
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
5116
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
5117
+ revision (`str`, *optional*, defaults to `"main"`):
5118
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
5119
+ allowed by Git.
5120
+ subfolder (`str`, *optional*, defaults to `""`):
5121
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
5122
+
5123
+ """
5124
+ # Load the main state dict first which has the LoRA layers for either of
5125
+ # transformer and text encoder or both.
5126
+ cache_dir = kwargs.pop("cache_dir", None)
5127
+ force_download = kwargs.pop("force_download", False)
5128
+ proxies = kwargs.pop("proxies", None)
5129
+ local_files_only = kwargs.pop("local_files_only", None)
5130
+ token = kwargs.pop("token", None)
5131
+ revision = kwargs.pop("revision", None)
5132
+ subfolder = kwargs.pop("subfolder", None)
5133
+ weight_name = kwargs.pop("weight_name", None)
5134
+ use_safetensors = kwargs.pop("use_safetensors", None)
5135
+
5136
+ allow_pickle = False
5137
+ if use_safetensors is None:
5138
+ use_safetensors = True
5139
+ allow_pickle = True
5140
+
5141
+ user_agent = {
5142
+ "file_type": "attn_procs_weights",
5143
+ "framework": "pytorch",
5144
+ }
5145
+
5146
+ state_dict = _fetch_state_dict(
5147
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
5148
+ weight_name=weight_name,
5149
+ use_safetensors=use_safetensors,
5150
+ local_files_only=local_files_only,
5151
+ cache_dir=cache_dir,
5152
+ force_download=force_download,
5153
+ proxies=proxies,
5154
+ token=token,
5155
+ revision=revision,
5156
+ subfolder=subfolder,
5157
+ user_agent=user_agent,
5158
+ allow_pickle=allow_pickle,
5159
+ )
5160
+
5161
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
5162
+ if is_dora_scale_present:
5163
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
5164
+ logger.warning(warn_msg)
5165
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
5166
+
5167
+ return state_dict
5168
+
5169
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
5170
+ def load_lora_weights(
5171
+ self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
5172
+ ):
5173
+ """
5174
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
5175
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
5176
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
5177
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
5178
+ dict is loaded into `self.transformer`.
5179
+
5180
+ Parameters:
5181
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
5182
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
5183
+ adapter_name (`str`, *optional*):
5184
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
5185
+ `default_{i}` where i is the total number of adapters being loaded.
5186
+ low_cpu_mem_usage (`bool`, *optional*):
5187
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
5188
+ weights.
5189
+ kwargs (`dict`, *optional*):
5190
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
5191
+ """
5192
+ if not USE_PEFT_BACKEND:
5193
+ raise ValueError("PEFT backend is required for this method.")
5194
+
5195
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
5196
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
5197
+ raise ValueError(
5198
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
5199
+ )
5200
+
5201
+ # if a dict is passed, copy it instead of modifying it inplace
5202
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
5203
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
5204
+
5205
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
5206
+ state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
5207
+
5208
+ is_correct_format = all("lora" in key for key in state_dict.keys())
5209
+ if not is_correct_format:
5210
+ raise ValueError("Invalid LoRA checkpoint.")
5211
+
5212
+ self.load_lora_into_transformer(
5213
+ state_dict,
5214
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
5215
+ adapter_name=adapter_name,
5216
+ _pipeline=self,
5217
+ low_cpu_mem_usage=low_cpu_mem_usage,
5218
+ )
5219
+
5220
+ @classmethod
5221
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogView4Transformer2DModel
5222
+ def load_lora_into_transformer(
5223
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
5224
+ ):
5225
+ """
5226
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
5227
+
5228
+ Parameters:
5229
+ state_dict (`dict`):
5230
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
5231
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
5232
+ encoder lora layers.
5233
+ transformer (`CogView4Transformer2DModel`):
5234
+ The Transformer model to load the LoRA layers into.
5235
+ adapter_name (`str`, *optional*):
5236
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
5237
+ `default_{i}` where i is the total number of adapters being loaded.
5238
+ low_cpu_mem_usage (`bool`, *optional*):
5239
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
5240
+ weights.
5241
+ hotswap : (`bool`, *optional*)
5242
+ Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
5243
+ in-place. This means that, instead of loading an additional adapter, this will take the existing
5244
+ adapter weights and replace them with the weights of the new adapter. This can be faster and more
5245
+ memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
5246
+ torch.compile, loading the new adapter does not require recompilation of the model. When using
5247
+ hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
5248
+
5249
+ If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
5250
+ to call an additional method before loading the adapter:
5251
+
5252
+ ```py
5253
+ pipeline = ... # load diffusers pipeline
5254
+ max_rank = ... # the highest rank among all LoRAs that you want to load
5255
+ # call *before* compiling and loading the LoRA adapter
5256
+ pipeline.enable_lora_hotswap(target_rank=max_rank)
5257
+ pipeline.load_lora_weights(file_name)
5258
+ # optionally compile the model now
5259
+ ```
5260
+
5261
+ Note that hotswapping adapters of the text encoder is not yet supported. There are some further
5262
+ limitations to this technique, which are documented here:
5263
+ https://huggingface.co/docs/peft/main/en/package_reference/hotswap
5264
+ """
5265
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
5266
+ raise ValueError(
5267
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
5268
+ )
5269
+
5270
+ # Load the layers corresponding to transformer.
5271
+ logger.info(f"Loading {cls.transformer_name}.")
5272
+ transformer.load_lora_adapter(
5273
+ state_dict,
5274
+ network_alphas=None,
5275
+ adapter_name=adapter_name,
5276
+ _pipeline=_pipeline,
5277
+ low_cpu_mem_usage=low_cpu_mem_usage,
5278
+ hotswap=hotswap,
5279
+ )
5280
+
5281
+ @classmethod
5282
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
5283
+ def save_lora_weights(
5284
+ cls,
5285
+ save_directory: Union[str, os.PathLike],
5286
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
5287
+ is_main_process: bool = True,
5288
+ weight_name: str = None,
5289
+ save_function: Callable = None,
5290
+ safe_serialization: bool = True,
5291
+ ):
5292
+ r"""
5293
+ Save the LoRA parameters corresponding to the UNet and text encoder.
5294
+
5295
+ Arguments:
5296
+ save_directory (`str` or `os.PathLike`):
5297
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
5298
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
5299
+ State dict of the LoRA layers corresponding to the `transformer`.
5300
+ is_main_process (`bool`, *optional*, defaults to `True`):
5301
+ Whether the process calling this is the main process or not. Useful during distributed training and you
5302
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
5303
+ process to avoid race conditions.
5304
+ save_function (`Callable`):
5305
+ The function to use to save the state dictionary. Useful during distributed training when you need to
5306
+ replace `torch.save` with another method. Can be configured with the environment variable
5307
+ `DIFFUSERS_SAVE_MODE`.
5308
+ safe_serialization (`bool`, *optional*, defaults to `True`):
5309
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
5310
+ """
5311
+ state_dict = {}
5312
+
5313
+ if not transformer_lora_layers:
5314
+ raise ValueError("You must pass `transformer_lora_layers`.")
5315
+
5316
+ if transformer_lora_layers:
5317
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
5318
+
5319
+ # Save the model
5320
+ cls.write_lora_layers(
5321
+ state_dict=state_dict,
5322
+ save_directory=save_directory,
5323
+ is_main_process=is_main_process,
5324
+ weight_name=weight_name,
5325
+ save_function=save_function,
5326
+ safe_serialization=safe_serialization,
5327
+ )
5328
+
5329
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
5330
+ def fuse_lora(
5331
+ self,
5332
+ components: List[str] = ["transformer"],
5333
+ lora_scale: float = 1.0,
5334
+ safe_fusing: bool = False,
5335
+ adapter_names: Optional[List[str]] = None,
5336
+ **kwargs,
5337
+ ):
5338
+ r"""
5339
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
5340
+
5341
+ <Tip warning={true}>
5342
+
5343
+ This is an experimental API.
5344
+
5345
+ </Tip>
5346
+
5347
+ Args:
5348
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
5349
+ lora_scale (`float`, defaults to 1.0):
5350
+ Controls how much to influence the outputs with the LoRA parameters.
5351
+ safe_fusing (`bool`, defaults to `False`):
5352
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
5353
+ adapter_names (`List[str]`, *optional*):
5354
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
5355
+
5356
+ Example:
5357
+
5358
+ ```py
5359
+ from diffusers import DiffusionPipeline
5360
+ import torch
5361
+
5362
+ pipeline = DiffusionPipeline.from_pretrained(
5363
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
5364
+ ).to("cuda")
5365
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
5366
+ pipeline.fuse_lora(lora_scale=0.7)
5367
+ ```
5368
+ """
5369
+ super().fuse_lora(
5370
+ components=components,
5371
+ lora_scale=lora_scale,
5372
+ safe_fusing=safe_fusing,
5373
+ adapter_names=adapter_names,
5374
+ **kwargs,
5375
+ )
5376
+
5377
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
5378
+ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
4214
5379
  r"""
4215
5380
  Reverses the effect of
4216
5381
  [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
@@ -4224,11 +5389,8 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
4224
5389
  Args:
4225
5390
  components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
4226
5391
  unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
4227
- unfuse_text_encoder (`bool`, defaults to `True`):
4228
- Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
4229
- LoRA parameters then it won't have any effect.
4230
5392
  """
4231
- super().unfuse_lora(components=components)
5393
+ super().unfuse_lora(components=components, **kwargs)
4232
5394
 
4233
5395
 
4234
5396
  class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):