diffusers 0.32.1__py3-none-any.whl → 0.33.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +186 -3
- diffusers/configuration_utils.py +40 -12
- diffusers/dependency_versions_table.py +9 -2
- diffusers/hooks/__init__.py +9 -0
- diffusers/hooks/faster_cache.py +653 -0
- diffusers/hooks/group_offloading.py +793 -0
- diffusers/hooks/hooks.py +236 -0
- diffusers/hooks/layerwise_casting.py +245 -0
- diffusers/hooks/pyramid_attention_broadcast.py +311 -0
- diffusers/loaders/__init__.py +6 -0
- diffusers/loaders/ip_adapter.py +38 -30
- diffusers/loaders/lora_base.py +198 -28
- diffusers/loaders/lora_conversion_utils.py +679 -44
- diffusers/loaders/lora_pipeline.py +1963 -801
- diffusers/loaders/peft.py +169 -84
- diffusers/loaders/single_file.py +17 -2
- diffusers/loaders/single_file_model.py +53 -5
- diffusers/loaders/single_file_utils.py +653 -75
- diffusers/loaders/textual_inversion.py +9 -9
- diffusers/loaders/transformer_flux.py +8 -9
- diffusers/loaders/transformer_sd3.py +120 -39
- diffusers/loaders/unet.py +22 -32
- diffusers/models/__init__.py +22 -0
- diffusers/models/activations.py +9 -9
- diffusers/models/attention.py +0 -1
- diffusers/models/attention_processor.py +163 -25
- diffusers/models/auto_model.py +169 -0
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
- diffusers/models/autoencoders/autoencoder_dc.py +106 -4
- diffusers/models/autoencoders/autoencoder_kl.py +0 -4
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
- diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
- diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
- diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
- diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
- diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
- diffusers/models/autoencoders/vae.py +31 -141
- diffusers/models/autoencoders/vq_model.py +3 -0
- diffusers/models/cache_utils.py +108 -0
- diffusers/models/controlnets/__init__.py +1 -0
- diffusers/models/controlnets/controlnet.py +3 -8
- diffusers/models/controlnets/controlnet_flux.py +14 -42
- diffusers/models/controlnets/controlnet_sd3.py +58 -34
- diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
- diffusers/models/controlnets/controlnet_union.py +27 -18
- diffusers/models/controlnets/controlnet_xs.py +7 -46
- diffusers/models/controlnets/multicontrolnet_union.py +196 -0
- diffusers/models/embeddings.py +18 -7
- diffusers/models/model_loading_utils.py +122 -80
- diffusers/models/modeling_flax_pytorch_utils.py +1 -1
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +617 -272
- diffusers/models/normalization.py +67 -14
- diffusers/models/resnet.py +1 -1
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
- diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
- diffusers/models/transformers/consisid_transformer_3d.py +789 -0
- diffusers/models/transformers/dit_transformer_2d.py +5 -19
- diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
- diffusers/models/transformers/latte_transformer_3d.py +20 -15
- diffusers/models/transformers/lumina_nextdit2d.py +3 -1
- diffusers/models/transformers/pixart_transformer_2d.py +4 -19
- diffusers/models/transformers/prior_transformer.py +5 -1
- diffusers/models/transformers/sana_transformer.py +144 -40
- diffusers/models/transformers/stable_audio_transformer.py +5 -20
- diffusers/models/transformers/transformer_2d.py +7 -22
- diffusers/models/transformers/transformer_allegro.py +9 -17
- diffusers/models/transformers/transformer_cogview3plus.py +6 -17
- diffusers/models/transformers/transformer_cogview4.py +462 -0
- diffusers/models/transformers/transformer_easyanimate.py +527 -0
- diffusers/models/transformers/transformer_flux.py +68 -110
- diffusers/models/transformers/transformer_hunyuan_video.py +409 -49
- diffusers/models/transformers/transformer_ltx.py +53 -35
- diffusers/models/transformers/transformer_lumina2.py +548 -0
- diffusers/models/transformers/transformer_mochi.py +6 -17
- diffusers/models/transformers/transformer_omnigen.py +469 -0
- diffusers/models/transformers/transformer_sd3.py +56 -86
- diffusers/models/transformers/transformer_temporal.py +5 -11
- diffusers/models/transformers/transformer_wan.py +469 -0
- diffusers/models/unets/unet_1d.py +3 -1
- diffusers/models/unets/unet_2d.py +21 -20
- diffusers/models/unets/unet_2d_blocks.py +19 -243
- diffusers/models/unets/unet_2d_condition.py +4 -6
- diffusers/models/unets/unet_3d_blocks.py +14 -127
- diffusers/models/unets/unet_3d_condition.py +8 -12
- diffusers/models/unets/unet_i2vgen_xl.py +5 -13
- diffusers/models/unets/unet_kandinsky3.py +0 -4
- diffusers/models/unets/unet_motion_model.py +20 -114
- diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
- diffusers/models/unets/unet_stable_cascade.py +8 -35
- diffusers/models/unets/uvit_2d.py +1 -4
- diffusers/optimization.py +2 -2
- diffusers/pipelines/__init__.py +57 -8
- diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
- diffusers/pipelines/amused/pipeline_amused.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
- diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
- diffusers/pipelines/auto_pipeline.py +35 -14
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
- diffusers/pipelines/cogview4/__init__.py +49 -0
- diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
- diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
- diffusers/pipelines/cogview4/pipeline_output.py +21 -0
- diffusers/pipelines/consisid/__init__.py +49 -0
- diffusers/pipelines/consisid/consisid_utils.py +357 -0
- diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
- diffusers/pipelines/consisid/pipeline_output.py +20 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
- diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
- diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +15 -2
- diffusers/pipelines/easyanimate/__init__.py +52 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
- diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
- diffusers/pipelines/flux/pipeline_flux.py +53 -21
- diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
- diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
- diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
- diffusers/pipelines/free_noise_utils.py +3 -3
- diffusers/pipelines/hunyuan_video/__init__.py +4 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
- diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
- diffusers/pipelines/kolors/text_encoder.py +7 -34
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
- diffusers/pipelines/latte/pipeline_latte.py +36 -7
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
- diffusers/pipelines/ltx/__init__.py +2 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
- diffusers/pipelines/lumina/__init__.py +2 -2
- diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
- diffusers/pipelines/lumina2/__init__.py +48 -0
- diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
- diffusers/pipelines/marigold/__init__.py +2 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
- diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
- diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
- diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
- diffusers/pipelines/omnigen/__init__.py +50 -0
- diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
- diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
- diffusers/pipelines/onnx_utils.py +5 -3
- diffusers/pipelines/pag/pag_utils.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
- diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
- diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
- diffusers/pipelines/pia/pipeline_pia.py +13 -1
- diffusers/pipelines/pipeline_flax_utils.py +7 -7
- diffusers/pipelines/pipeline_loading_utils.py +193 -83
- diffusers/pipelines/pipeline_utils.py +221 -106
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
- diffusers/pipelines/sana/__init__.py +2 -0
- diffusers/pipelines/sana/pipeline_sana.py +183 -58
- diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
- diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
- diffusers/pipelines/shap_e/renderer.py +6 -6
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
- diffusers/pipelines/transformers_loading_utils.py +121 -0
- diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
- diffusers/pipelines/wan/__init__.py +51 -0
- diffusers/pipelines/wan/pipeline_output.py +20 -0
- diffusers/pipelines/wan/pipeline_wan.py +593 -0
- diffusers/pipelines/wan/pipeline_wan_i2v.py +722 -0
- diffusers/pipelines/wan/pipeline_wan_video2video.py +725 -0
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
- diffusers/quantizers/auto.py +5 -1
- diffusers/quantizers/base.py +5 -9
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
- diffusers/quantizers/bitsandbytes/utils.py +30 -20
- diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
- diffusers/quantizers/gguf/utils.py +4 -2
- diffusers/quantizers/quantization_config.py +59 -4
- diffusers/quantizers/quanto/__init__.py +1 -0
- diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
- diffusers/quantizers/quanto/utils.py +60 -0
- diffusers/quantizers/torchao/__init__.py +1 -1
- diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
- diffusers/schedulers/__init__.py +2 -1
- diffusers/schedulers/scheduling_consistency_models.py +1 -2
- diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
- diffusers/schedulers/scheduling_ddpm.py +2 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
- diffusers/schedulers/scheduling_edm_euler.py +45 -10
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
- diffusers/schedulers/scheduling_heun_discrete.py +1 -1
- diffusers/schedulers/scheduling_lcm.py +1 -2
- diffusers/schedulers/scheduling_lms_discrete.py +1 -1
- diffusers/schedulers/scheduling_repaint.py +5 -1
- diffusers/schedulers/scheduling_scm.py +265 -0
- diffusers/schedulers/scheduling_tcd.py +1 -2
- diffusers/schedulers/scheduling_utils.py +2 -1
- diffusers/training_utils.py +14 -7
- diffusers/utils/__init__.py +10 -2
- diffusers/utils/constants.py +13 -1
- diffusers/utils/deprecation_utils.py +1 -1
- diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
- diffusers/utils/dummy_gguf_objects.py +17 -0
- diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
- diffusers/utils/dummy_pt_objects.py +233 -0
- diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
- diffusers/utils/dummy_torchao_objects.py +17 -0
- diffusers/utils/dynamic_modules_utils.py +1 -1
- diffusers/utils/export_utils.py +28 -3
- diffusers/utils/hub_utils.py +52 -102
- diffusers/utils/import_utils.py +121 -221
- diffusers/utils/loading_utils.py +14 -1
- diffusers/utils/logging.py +1 -2
- diffusers/utils/peft_utils.py +6 -14
- diffusers/utils/remote_utils.py +425 -0
- diffusers/utils/source_code_parsing_utils.py +52 -0
- diffusers/utils/state_dict_utils.py +15 -1
- diffusers/utils/testing_utils.py +243 -13
- diffusers/utils/torch_utils.py +10 -0
- diffusers/utils/typing_utils.py +91 -0
- diffusers/video_processor.py +1 -1
- {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/METADATA +76 -44
- diffusers-0.33.0.dist-info/RECORD +608 -0
- {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/WHEEL +1 -1
- diffusers-0.32.1.dist-info/RECORD +0 -550
- {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/LICENSE +0 -0
- {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/top_level.txt +0 -0
@@ -20,24 +20,31 @@ from huggingface_hub.utils import validate_hf_hub_args
|
|
20
20
|
|
21
21
|
from ..utils import (
|
22
22
|
USE_PEFT_BACKEND,
|
23
|
-
convert_state_dict_to_diffusers,
|
24
|
-
convert_state_dict_to_peft,
|
25
23
|
deprecate,
|
26
|
-
|
27
|
-
|
24
|
+
get_submodule_by_name,
|
25
|
+
is_bitsandbytes_available,
|
26
|
+
is_gguf_available,
|
28
27
|
is_peft_available,
|
29
28
|
is_peft_version,
|
30
29
|
is_torch_version,
|
31
30
|
is_transformers_available,
|
32
31
|
is_transformers_version,
|
33
32
|
logging,
|
34
|
-
scale_lora_layers,
|
35
33
|
)
|
36
|
-
from .lora_base import
|
34
|
+
from .lora_base import ( # noqa
|
35
|
+
LORA_WEIGHT_NAME,
|
36
|
+
LORA_WEIGHT_NAME_SAFE,
|
37
|
+
LoraBaseMixin,
|
38
|
+
_fetch_state_dict,
|
39
|
+
_load_lora_into_text_encoder,
|
40
|
+
)
|
37
41
|
from .lora_conversion_utils import (
|
38
42
|
_convert_bfl_flux_control_lora_to_diffusers,
|
43
|
+
_convert_hunyuan_video_lora_to_diffusers,
|
39
44
|
_convert_kohya_flux_lora_to_diffusers,
|
40
45
|
_convert_non_diffusers_lora_to_diffusers,
|
46
|
+
_convert_non_diffusers_lumina2_lora_to_diffusers,
|
47
|
+
_convert_non_diffusers_wan_lora_to_diffusers,
|
41
48
|
_convert_xlabs_flux_lora_to_diffusers,
|
42
49
|
_maybe_map_sgm_blocks_to_diffusers,
|
43
50
|
)
|
@@ -54,9 +61,6 @@ if is_torch_version(">=", "1.9.0"):
|
|
54
61
|
_LOW_CPU_MEM_USAGE_DEFAULT_LORA = True
|
55
62
|
|
56
63
|
|
57
|
-
if is_transformers_available():
|
58
|
-
from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
|
59
|
-
|
60
64
|
logger = logging.get_logger(__name__)
|
61
65
|
|
62
66
|
TEXT_ENCODER_NAME = "text_encoder"
|
@@ -66,6 +70,49 @@ TRANSFORMER_NAME = "transformer"
|
|
66
70
|
_MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX = {"x_embedder": "in_channels"}
|
67
71
|
|
68
72
|
|
73
|
+
def _maybe_dequantize_weight_for_expanded_lora(model, module):
|
74
|
+
if is_bitsandbytes_available():
|
75
|
+
from ..quantizers.bitsandbytes import dequantize_bnb_weight
|
76
|
+
|
77
|
+
if is_gguf_available():
|
78
|
+
from ..quantizers.gguf.utils import dequantize_gguf_tensor
|
79
|
+
|
80
|
+
is_bnb_4bit_quantized = module.weight.__class__.__name__ == "Params4bit"
|
81
|
+
is_gguf_quantized = module.weight.__class__.__name__ == "GGUFParameter"
|
82
|
+
|
83
|
+
if is_bnb_4bit_quantized and not is_bitsandbytes_available():
|
84
|
+
raise ValueError(
|
85
|
+
"The checkpoint seems to have been quantized with `bitsandbytes` (4bits). Install `bitsandbytes` to load quantized checkpoints."
|
86
|
+
)
|
87
|
+
if is_gguf_quantized and not is_gguf_available():
|
88
|
+
raise ValueError(
|
89
|
+
"The checkpoint seems to have been quantized with `gguf`. Install `gguf` to load quantized checkpoints."
|
90
|
+
)
|
91
|
+
|
92
|
+
weight_on_cpu = False
|
93
|
+
if not module.weight.is_cuda:
|
94
|
+
weight_on_cpu = True
|
95
|
+
|
96
|
+
if is_bnb_4bit_quantized:
|
97
|
+
module_weight = dequantize_bnb_weight(
|
98
|
+
module.weight.cuda() if weight_on_cpu else module.weight,
|
99
|
+
state=module.weight.quant_state,
|
100
|
+
dtype=model.dtype,
|
101
|
+
).data
|
102
|
+
elif is_gguf_quantized:
|
103
|
+
module_weight = dequantize_gguf_tensor(
|
104
|
+
module.weight.cuda() if weight_on_cpu else module.weight,
|
105
|
+
)
|
106
|
+
module_weight = module_weight.to(model.dtype)
|
107
|
+
else:
|
108
|
+
module_weight = module.weight.data
|
109
|
+
|
110
|
+
if weight_on_cpu:
|
111
|
+
module_weight = module_weight.cpu()
|
112
|
+
|
113
|
+
return module_weight
|
114
|
+
|
115
|
+
|
69
116
|
class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
70
117
|
r"""
|
71
118
|
Load LoRA layers into Stable Diffusion [`UNet2DConditionModel`] and
|
@@ -77,10 +124,13 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
77
124
|
text_encoder_name = TEXT_ENCODER_NAME
|
78
125
|
|
79
126
|
def load_lora_weights(
|
80
|
-
self,
|
127
|
+
self,
|
128
|
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
129
|
+
adapter_name=None,
|
130
|
+
hotswap: bool = False,
|
131
|
+
**kwargs,
|
81
132
|
):
|
82
|
-
"""
|
83
|
-
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
|
133
|
+
"""Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
|
84
134
|
`self.text_encoder`.
|
85
135
|
|
86
136
|
All kwargs are forwarded to `self.lora_state_dict`.
|
@@ -103,6 +153,29 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
103
153
|
low_cpu_mem_usage (`bool`, *optional*):
|
104
154
|
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
105
155
|
weights.
|
156
|
+
hotswap : (`bool`, *optional*)
|
157
|
+
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
158
|
+
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
159
|
+
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
160
|
+
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
161
|
+
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
162
|
+
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
163
|
+
|
164
|
+
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
165
|
+
to call an additional method before loading the adapter:
|
166
|
+
|
167
|
+
```py
|
168
|
+
pipeline = ... # load diffusers pipeline
|
169
|
+
max_rank = ... # the highest rank among all LoRAs that you want to load
|
170
|
+
# call *before* compiling and loading the LoRA adapter
|
171
|
+
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
172
|
+
pipeline.load_lora_weights(file_name)
|
173
|
+
# optionally compile the model now
|
174
|
+
```
|
175
|
+
|
176
|
+
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
177
|
+
limitations to this technique, which are documented here:
|
178
|
+
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
106
179
|
kwargs (`dict`, *optional*):
|
107
180
|
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
108
181
|
"""
|
@@ -133,6 +206,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
133
206
|
adapter_name=adapter_name,
|
134
207
|
_pipeline=self,
|
135
208
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
209
|
+
hotswap=hotswap,
|
136
210
|
)
|
137
211
|
self.load_lora_into_text_encoder(
|
138
212
|
state_dict,
|
@@ -144,6 +218,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
144
218
|
adapter_name=adapter_name,
|
145
219
|
_pipeline=self,
|
146
220
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
221
|
+
hotswap=hotswap,
|
147
222
|
)
|
148
223
|
|
149
224
|
@classmethod
|
@@ -263,7 +338,14 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
263
338
|
|
264
339
|
@classmethod
|
265
340
|
def load_lora_into_unet(
|
266
|
-
cls,
|
341
|
+
cls,
|
342
|
+
state_dict,
|
343
|
+
network_alphas,
|
344
|
+
unet,
|
345
|
+
adapter_name=None,
|
346
|
+
_pipeline=None,
|
347
|
+
low_cpu_mem_usage=False,
|
348
|
+
hotswap: bool = False,
|
267
349
|
):
|
268
350
|
"""
|
269
351
|
This will load the LoRA layers specified in `state_dict` into `unet`.
|
@@ -285,6 +367,29 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
285
367
|
low_cpu_mem_usage (`bool`, *optional*):
|
286
368
|
Speed up model loading only loading the pretrained LoRA weights and not initializing the random
|
287
369
|
weights.
|
370
|
+
hotswap : (`bool`, *optional*)
|
371
|
+
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
372
|
+
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
373
|
+
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
374
|
+
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
375
|
+
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
376
|
+
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
377
|
+
|
378
|
+
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
379
|
+
to call an additional method before loading the adapter:
|
380
|
+
|
381
|
+
```py
|
382
|
+
pipeline = ... # load diffusers pipeline
|
383
|
+
max_rank = ... # the highest rank among all LoRAs that you want to load
|
384
|
+
# call *before* compiling and loading the LoRA adapter
|
385
|
+
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
386
|
+
pipeline.load_lora_weights(file_name)
|
387
|
+
# optionally compile the model now
|
388
|
+
```
|
389
|
+
|
390
|
+
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
391
|
+
limitations to this technique, which are documented here:
|
392
|
+
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
288
393
|
"""
|
289
394
|
if not USE_PEFT_BACKEND:
|
290
395
|
raise ValueError("PEFT backend is required for this method.")
|
@@ -297,19 +402,16 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
297
402
|
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
298
403
|
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
|
299
404
|
# their prefixes.
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
_pipeline=_pipeline,
|
311
|
-
low_cpu_mem_usage=low_cpu_mem_usage,
|
312
|
-
)
|
405
|
+
logger.info(f"Loading {cls.unet_name}.")
|
406
|
+
unet.load_lora_adapter(
|
407
|
+
state_dict,
|
408
|
+
prefix=cls.unet_name,
|
409
|
+
network_alphas=network_alphas,
|
410
|
+
adapter_name=adapter_name,
|
411
|
+
_pipeline=_pipeline,
|
412
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
413
|
+
hotswap=hotswap,
|
414
|
+
)
|
313
415
|
|
314
416
|
@classmethod
|
315
417
|
def load_lora_into_text_encoder(
|
@@ -322,6 +424,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
322
424
|
adapter_name=None,
|
323
425
|
_pipeline=None,
|
324
426
|
low_cpu_mem_usage=False,
|
427
|
+
hotswap: bool = False,
|
325
428
|
):
|
326
429
|
"""
|
327
430
|
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
@@ -347,120 +450,42 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
347
450
|
low_cpu_mem_usage (`bool`, *optional*):
|
348
451
|
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
349
452
|
weights.
|
453
|
+
hotswap : (`bool`, *optional*)
|
454
|
+
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
455
|
+
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
456
|
+
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
457
|
+
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
458
|
+
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
459
|
+
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
460
|
+
|
461
|
+
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
462
|
+
to call an additional method before loading the adapter:
|
463
|
+
|
464
|
+
```py
|
465
|
+
pipeline = ... # load diffusers pipeline
|
466
|
+
max_rank = ... # the highest rank among all LoRAs that you want to load
|
467
|
+
# call *before* compiling and loading the LoRA adapter
|
468
|
+
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
469
|
+
pipeline.load_lora_weights(file_name)
|
470
|
+
# optionally compile the model now
|
471
|
+
```
|
472
|
+
|
473
|
+
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
474
|
+
limitations to this technique, which are documented here:
|
475
|
+
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
350
476
|
"""
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
raise ValueError(
|
364
|
-
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
|
365
|
-
)
|
366
|
-
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
367
|
-
|
368
|
-
from peft import LoraConfig
|
369
|
-
|
370
|
-
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
371
|
-
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
|
372
|
-
# their prefixes.
|
373
|
-
keys = list(state_dict.keys())
|
374
|
-
prefix = cls.text_encoder_name if prefix is None else prefix
|
375
|
-
|
376
|
-
# Safe prefix to check with.
|
377
|
-
if any(cls.text_encoder_name in key for key in keys):
|
378
|
-
# Load the layers corresponding to text encoder and make necessary adjustments.
|
379
|
-
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
|
380
|
-
text_encoder_lora_state_dict = {
|
381
|
-
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
|
382
|
-
}
|
383
|
-
|
384
|
-
if len(text_encoder_lora_state_dict) > 0:
|
385
|
-
logger.info(f"Loading {prefix}.")
|
386
|
-
rank = {}
|
387
|
-
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
|
388
|
-
|
389
|
-
# convert state dict
|
390
|
-
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
|
391
|
-
|
392
|
-
for name, _ in text_encoder_attn_modules(text_encoder):
|
393
|
-
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
|
394
|
-
rank_key = f"{name}.{module}.lora_B.weight"
|
395
|
-
if rank_key not in text_encoder_lora_state_dict:
|
396
|
-
continue
|
397
|
-
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
398
|
-
|
399
|
-
for name, _ in text_encoder_mlp_modules(text_encoder):
|
400
|
-
for module in ("fc1", "fc2"):
|
401
|
-
rank_key = f"{name}.{module}.lora_B.weight"
|
402
|
-
if rank_key not in text_encoder_lora_state_dict:
|
403
|
-
continue
|
404
|
-
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
405
|
-
|
406
|
-
if network_alphas is not None:
|
407
|
-
alpha_keys = [
|
408
|
-
k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
|
409
|
-
]
|
410
|
-
network_alphas = {
|
411
|
-
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
412
|
-
}
|
413
|
-
|
414
|
-
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
|
415
|
-
|
416
|
-
if "use_dora" in lora_config_kwargs:
|
417
|
-
if lora_config_kwargs["use_dora"]:
|
418
|
-
if is_peft_version("<", "0.9.0"):
|
419
|
-
raise ValueError(
|
420
|
-
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
421
|
-
)
|
422
|
-
else:
|
423
|
-
if is_peft_version("<", "0.9.0"):
|
424
|
-
lora_config_kwargs.pop("use_dora")
|
425
|
-
|
426
|
-
if "lora_bias" in lora_config_kwargs:
|
427
|
-
if lora_config_kwargs["lora_bias"]:
|
428
|
-
if is_peft_version("<=", "0.13.2"):
|
429
|
-
raise ValueError(
|
430
|
-
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
|
431
|
-
)
|
432
|
-
else:
|
433
|
-
if is_peft_version("<=", "0.13.2"):
|
434
|
-
lora_config_kwargs.pop("lora_bias")
|
435
|
-
|
436
|
-
lora_config = LoraConfig(**lora_config_kwargs)
|
437
|
-
|
438
|
-
# adapter_name
|
439
|
-
if adapter_name is None:
|
440
|
-
adapter_name = get_adapter_name(text_encoder)
|
441
|
-
|
442
|
-
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
443
|
-
|
444
|
-
# inject LoRA layers and load the state dict
|
445
|
-
# in transformers we automatically check whether the adapter name is already in use or not
|
446
|
-
text_encoder.load_adapter(
|
447
|
-
adapter_name=adapter_name,
|
448
|
-
adapter_state_dict=text_encoder_lora_state_dict,
|
449
|
-
peft_config=lora_config,
|
450
|
-
**peft_kwargs,
|
451
|
-
)
|
452
|
-
|
453
|
-
# scale LoRA layers with `lora_scale`
|
454
|
-
scale_lora_layers(text_encoder, weight=lora_scale)
|
455
|
-
|
456
|
-
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
457
|
-
|
458
|
-
# Offload back.
|
459
|
-
if is_model_cpu_offload:
|
460
|
-
_pipeline.enable_model_cpu_offload()
|
461
|
-
elif is_sequential_cpu_offload:
|
462
|
-
_pipeline.enable_sequential_cpu_offload()
|
463
|
-
# Unsafe code />
|
477
|
+
_load_lora_into_text_encoder(
|
478
|
+
state_dict=state_dict,
|
479
|
+
network_alphas=network_alphas,
|
480
|
+
lora_scale=lora_scale,
|
481
|
+
text_encoder=text_encoder,
|
482
|
+
prefix=prefix,
|
483
|
+
text_encoder_name=cls.text_encoder_name,
|
484
|
+
adapter_name=adapter_name,
|
485
|
+
_pipeline=_pipeline,
|
486
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
487
|
+
hotswap=hotswap,
|
488
|
+
)
|
464
489
|
|
465
490
|
@classmethod
|
466
491
|
def save_lora_weights(
|
@@ -556,7 +581,11 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
556
581
|
```
|
557
582
|
"""
|
558
583
|
super().fuse_lora(
|
559
|
-
components=components,
|
584
|
+
components=components,
|
585
|
+
lora_scale=lora_scale,
|
586
|
+
safe_fusing=safe_fusing,
|
587
|
+
adapter_names=adapter_names,
|
588
|
+
**kwargs,
|
560
589
|
)
|
561
590
|
|
562
591
|
def unfuse_lora(self, components: List[str] = ["unet", "text_encoder"], **kwargs):
|
@@ -577,7 +606,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
577
606
|
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
578
607
|
LoRA parameters then it won't have any effect.
|
579
608
|
"""
|
580
|
-
super().unfuse_lora(components=components)
|
609
|
+
super().unfuse_lora(components=components, **kwargs)
|
581
610
|
|
582
611
|
|
583
612
|
class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
@@ -660,31 +689,26 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
660
689
|
_pipeline=self,
|
661
690
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
662
691
|
)
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
|
674
|
-
|
675
|
-
|
676
|
-
|
677
|
-
|
678
|
-
self.
|
679
|
-
|
680
|
-
|
681
|
-
|
682
|
-
|
683
|
-
lora_scale=self.lora_scale,
|
684
|
-
adapter_name=adapter_name,
|
685
|
-
_pipeline=self,
|
686
|
-
low_cpu_mem_usage=low_cpu_mem_usage,
|
687
|
-
)
|
692
|
+
self.load_lora_into_text_encoder(
|
693
|
+
state_dict,
|
694
|
+
network_alphas=network_alphas,
|
695
|
+
text_encoder=self.text_encoder,
|
696
|
+
prefix=self.text_encoder_name,
|
697
|
+
lora_scale=self.lora_scale,
|
698
|
+
adapter_name=adapter_name,
|
699
|
+
_pipeline=self,
|
700
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
701
|
+
)
|
702
|
+
self.load_lora_into_text_encoder(
|
703
|
+
state_dict,
|
704
|
+
network_alphas=network_alphas,
|
705
|
+
text_encoder=self.text_encoder_2,
|
706
|
+
prefix=f"{self.text_encoder_name}_2",
|
707
|
+
lora_scale=self.lora_scale,
|
708
|
+
adapter_name=adapter_name,
|
709
|
+
_pipeline=self,
|
710
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
711
|
+
)
|
688
712
|
|
689
713
|
@classmethod
|
690
714
|
@validate_hf_hub_args
|
@@ -805,7 +829,14 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
805
829
|
@classmethod
|
806
830
|
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_unet
|
807
831
|
def load_lora_into_unet(
|
808
|
-
cls,
|
832
|
+
cls,
|
833
|
+
state_dict,
|
834
|
+
network_alphas,
|
835
|
+
unet,
|
836
|
+
adapter_name=None,
|
837
|
+
_pipeline=None,
|
838
|
+
low_cpu_mem_usage=False,
|
839
|
+
hotswap: bool = False,
|
809
840
|
):
|
810
841
|
"""
|
811
842
|
This will load the LoRA layers specified in `state_dict` into `unet`.
|
@@ -827,6 +858,29 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
827
858
|
low_cpu_mem_usage (`bool`, *optional*):
|
828
859
|
Speed up model loading only loading the pretrained LoRA weights and not initializing the random
|
829
860
|
weights.
|
861
|
+
hotswap : (`bool`, *optional*)
|
862
|
+
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
863
|
+
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
864
|
+
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
865
|
+
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
866
|
+
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
867
|
+
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
868
|
+
|
869
|
+
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
870
|
+
to call an additional method before loading the adapter:
|
871
|
+
|
872
|
+
```py
|
873
|
+
pipeline = ... # load diffusers pipeline
|
874
|
+
max_rank = ... # the highest rank among all LoRAs that you want to load
|
875
|
+
# call *before* compiling and loading the LoRA adapter
|
876
|
+
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
877
|
+
pipeline.load_lora_weights(file_name)
|
878
|
+
# optionally compile the model now
|
879
|
+
```
|
880
|
+
|
881
|
+
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
882
|
+
limitations to this technique, which are documented here:
|
883
|
+
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
830
884
|
"""
|
831
885
|
if not USE_PEFT_BACKEND:
|
832
886
|
raise ValueError("PEFT backend is required for this method.")
|
@@ -839,19 +893,16 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
839
893
|
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
840
894
|
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
|
841
895
|
# their prefixes.
|
842
|
-
|
843
|
-
|
844
|
-
|
845
|
-
|
846
|
-
|
847
|
-
|
848
|
-
|
849
|
-
|
850
|
-
|
851
|
-
|
852
|
-
_pipeline=_pipeline,
|
853
|
-
low_cpu_mem_usage=low_cpu_mem_usage,
|
854
|
-
)
|
896
|
+
logger.info(f"Loading {cls.unet_name}.")
|
897
|
+
unet.load_lora_adapter(
|
898
|
+
state_dict,
|
899
|
+
prefix=cls.unet_name,
|
900
|
+
network_alphas=network_alphas,
|
901
|
+
adapter_name=adapter_name,
|
902
|
+
_pipeline=_pipeline,
|
903
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
904
|
+
hotswap=hotswap,
|
905
|
+
)
|
855
906
|
|
856
907
|
@classmethod
|
857
908
|
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
|
@@ -865,6 +916,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
865
916
|
adapter_name=None,
|
866
917
|
_pipeline=None,
|
867
918
|
low_cpu_mem_usage=False,
|
919
|
+
hotswap: bool = False,
|
868
920
|
):
|
869
921
|
"""
|
870
922
|
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
@@ -890,120 +942,42 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
890
942
|
low_cpu_mem_usage (`bool`, *optional*):
|
891
943
|
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
892
944
|
weights.
|
945
|
+
hotswap : (`bool`, *optional*)
|
946
|
+
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
947
|
+
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
948
|
+
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
949
|
+
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
950
|
+
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
951
|
+
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
952
|
+
|
953
|
+
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
954
|
+
to call an additional method before loading the adapter:
|
955
|
+
|
956
|
+
```py
|
957
|
+
pipeline = ... # load diffusers pipeline
|
958
|
+
max_rank = ... # the highest rank among all LoRAs that you want to load
|
959
|
+
# call *before* compiling and loading the LoRA adapter
|
960
|
+
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
961
|
+
pipeline.load_lora_weights(file_name)
|
962
|
+
# optionally compile the model now
|
963
|
+
```
|
964
|
+
|
965
|
+
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
966
|
+
limitations to this technique, which are documented here:
|
967
|
+
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
893
968
|
"""
|
894
|
-
|
895
|
-
|
896
|
-
|
897
|
-
|
898
|
-
|
899
|
-
|
900
|
-
|
901
|
-
|
902
|
-
|
903
|
-
|
904
|
-
|
905
|
-
|
906
|
-
raise ValueError(
|
907
|
-
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
|
908
|
-
)
|
909
|
-
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
910
|
-
|
911
|
-
from peft import LoraConfig
|
912
|
-
|
913
|
-
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
914
|
-
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
|
915
|
-
# their prefixes.
|
916
|
-
keys = list(state_dict.keys())
|
917
|
-
prefix = cls.text_encoder_name if prefix is None else prefix
|
918
|
-
|
919
|
-
# Safe prefix to check with.
|
920
|
-
if any(cls.text_encoder_name in key for key in keys):
|
921
|
-
# Load the layers corresponding to text encoder and make necessary adjustments.
|
922
|
-
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
|
923
|
-
text_encoder_lora_state_dict = {
|
924
|
-
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
|
925
|
-
}
|
926
|
-
|
927
|
-
if len(text_encoder_lora_state_dict) > 0:
|
928
|
-
logger.info(f"Loading {prefix}.")
|
929
|
-
rank = {}
|
930
|
-
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
|
931
|
-
|
932
|
-
# convert state dict
|
933
|
-
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
|
934
|
-
|
935
|
-
for name, _ in text_encoder_attn_modules(text_encoder):
|
936
|
-
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
|
937
|
-
rank_key = f"{name}.{module}.lora_B.weight"
|
938
|
-
if rank_key not in text_encoder_lora_state_dict:
|
939
|
-
continue
|
940
|
-
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
941
|
-
|
942
|
-
for name, _ in text_encoder_mlp_modules(text_encoder):
|
943
|
-
for module in ("fc1", "fc2"):
|
944
|
-
rank_key = f"{name}.{module}.lora_B.weight"
|
945
|
-
if rank_key not in text_encoder_lora_state_dict:
|
946
|
-
continue
|
947
|
-
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
948
|
-
|
949
|
-
if network_alphas is not None:
|
950
|
-
alpha_keys = [
|
951
|
-
k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
|
952
|
-
]
|
953
|
-
network_alphas = {
|
954
|
-
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
955
|
-
}
|
956
|
-
|
957
|
-
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
|
958
|
-
|
959
|
-
if "use_dora" in lora_config_kwargs:
|
960
|
-
if lora_config_kwargs["use_dora"]:
|
961
|
-
if is_peft_version("<", "0.9.0"):
|
962
|
-
raise ValueError(
|
963
|
-
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
964
|
-
)
|
965
|
-
else:
|
966
|
-
if is_peft_version("<", "0.9.0"):
|
967
|
-
lora_config_kwargs.pop("use_dora")
|
968
|
-
|
969
|
-
if "lora_bias" in lora_config_kwargs:
|
970
|
-
if lora_config_kwargs["lora_bias"]:
|
971
|
-
if is_peft_version("<=", "0.13.2"):
|
972
|
-
raise ValueError(
|
973
|
-
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
|
974
|
-
)
|
975
|
-
else:
|
976
|
-
if is_peft_version("<=", "0.13.2"):
|
977
|
-
lora_config_kwargs.pop("lora_bias")
|
978
|
-
|
979
|
-
lora_config = LoraConfig(**lora_config_kwargs)
|
980
|
-
|
981
|
-
# adapter_name
|
982
|
-
if adapter_name is None:
|
983
|
-
adapter_name = get_adapter_name(text_encoder)
|
984
|
-
|
985
|
-
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
986
|
-
|
987
|
-
# inject LoRA layers and load the state dict
|
988
|
-
# in transformers we automatically check whether the adapter name is already in use or not
|
989
|
-
text_encoder.load_adapter(
|
990
|
-
adapter_name=adapter_name,
|
991
|
-
adapter_state_dict=text_encoder_lora_state_dict,
|
992
|
-
peft_config=lora_config,
|
993
|
-
**peft_kwargs,
|
994
|
-
)
|
995
|
-
|
996
|
-
# scale LoRA layers with `lora_scale`
|
997
|
-
scale_lora_layers(text_encoder, weight=lora_scale)
|
998
|
-
|
999
|
-
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
1000
|
-
|
1001
|
-
# Offload back.
|
1002
|
-
if is_model_cpu_offload:
|
1003
|
-
_pipeline.enable_model_cpu_offload()
|
1004
|
-
elif is_sequential_cpu_offload:
|
1005
|
-
_pipeline.enable_sequential_cpu_offload()
|
1006
|
-
# Unsafe code />
|
969
|
+
_load_lora_into_text_encoder(
|
970
|
+
state_dict=state_dict,
|
971
|
+
network_alphas=network_alphas,
|
972
|
+
lora_scale=lora_scale,
|
973
|
+
text_encoder=text_encoder,
|
974
|
+
prefix=prefix,
|
975
|
+
text_encoder_name=cls.text_encoder_name,
|
976
|
+
adapter_name=adapter_name,
|
977
|
+
_pipeline=_pipeline,
|
978
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
979
|
+
hotswap=hotswap,
|
980
|
+
)
|
1007
981
|
|
1008
982
|
@classmethod
|
1009
983
|
def save_lora_weights(
|
@@ -1046,11 +1020,11 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
1046
1020
|
|
1047
1021
|
if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
|
1048
1022
|
raise ValueError(
|
1049
|
-
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers
|
1023
|
+
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`."
|
1050
1024
|
)
|
1051
1025
|
|
1052
1026
|
if unet_lora_layers:
|
1053
|
-
state_dict.update(cls.pack_weights(unet_lora_layers,
|
1027
|
+
state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name))
|
1054
1028
|
|
1055
1029
|
if text_encoder_lora_layers:
|
1056
1030
|
state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))
|
@@ -1107,7 +1081,11 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
1107
1081
|
```
|
1108
1082
|
"""
|
1109
1083
|
super().fuse_lora(
|
1110
|
-
components=components,
|
1084
|
+
components=components,
|
1085
|
+
lora_scale=lora_scale,
|
1086
|
+
safe_fusing=safe_fusing,
|
1087
|
+
adapter_names=adapter_names,
|
1088
|
+
**kwargs,
|
1111
1089
|
)
|
1112
1090
|
|
1113
1091
|
def unfuse_lora(self, components: List[str] = ["unet", "text_encoder", "text_encoder_2"], **kwargs):
|
@@ -1128,7 +1106,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
1128
1106
|
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
1129
1107
|
LoRA parameters then it won't have any effect.
|
1130
1108
|
"""
|
1131
|
-
super().unfuse_lora(components=components)
|
1109
|
+
super().unfuse_lora(components=components, **kwargs)
|
1132
1110
|
|
1133
1111
|
|
1134
1112
|
class SD3LoraLoaderMixin(LoraBaseMixin):
|
@@ -1242,7 +1220,11 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1242
1220
|
return state_dict
|
1243
1221
|
|
1244
1222
|
def load_lora_weights(
|
1245
|
-
self,
|
1223
|
+
self,
|
1224
|
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
1225
|
+
adapter_name=None,
|
1226
|
+
hotswap: bool = False,
|
1227
|
+
**kwargs,
|
1246
1228
|
):
|
1247
1229
|
"""
|
1248
1230
|
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
|
@@ -1265,6 +1247,29 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1265
1247
|
low_cpu_mem_usage (`bool`, *optional*):
|
1266
1248
|
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
1267
1249
|
weights.
|
1250
|
+
hotswap : (`bool`, *optional*)
|
1251
|
+
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
1252
|
+
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
1253
|
+
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
1254
|
+
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
1255
|
+
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
1256
|
+
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
1257
|
+
|
1258
|
+
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
1259
|
+
to call an additional method before loading the adapter:
|
1260
|
+
|
1261
|
+
```py
|
1262
|
+
pipeline = ... # load diffusers pipeline
|
1263
|
+
max_rank = ... # the highest rank among all LoRAs that you want to load
|
1264
|
+
# call *before* compiling and loading the LoRA adapter
|
1265
|
+
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
1266
|
+
pipeline.load_lora_weights(file_name)
|
1267
|
+
# optionally compile the model now
|
1268
|
+
```
|
1269
|
+
|
1270
|
+
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
1271
|
+
limitations to this technique, which are documented here:
|
1272
|
+
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
1268
1273
|
kwargs (`dict`, *optional*):
|
1269
1274
|
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
1270
1275
|
"""
|
@@ -1288,47 +1293,40 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1288
1293
|
if not is_correct_format:
|
1289
1294
|
raise ValueError("Invalid LoRA checkpoint.")
|
1290
1295
|
|
1291
|
-
|
1292
|
-
|
1293
|
-
self.
|
1294
|
-
|
1295
|
-
|
1296
|
-
|
1297
|
-
|
1298
|
-
|
1299
|
-
|
1300
|
-
|
1301
|
-
|
1302
|
-
|
1303
|
-
|
1304
|
-
|
1305
|
-
|
1306
|
-
|
1307
|
-
|
1308
|
-
|
1309
|
-
|
1310
|
-
|
1311
|
-
|
1312
|
-
|
1313
|
-
|
1314
|
-
|
1315
|
-
|
1316
|
-
|
1317
|
-
|
1318
|
-
|
1319
|
-
|
1320
|
-
|
1321
|
-
text_encoder=self.text_encoder_2,
|
1322
|
-
prefix="text_encoder_2",
|
1323
|
-
lora_scale=self.lora_scale,
|
1324
|
-
adapter_name=adapter_name,
|
1325
|
-
_pipeline=self,
|
1326
|
-
low_cpu_mem_usage=low_cpu_mem_usage,
|
1327
|
-
)
|
1296
|
+
self.load_lora_into_transformer(
|
1297
|
+
state_dict,
|
1298
|
+
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
1299
|
+
adapter_name=adapter_name,
|
1300
|
+
_pipeline=self,
|
1301
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
1302
|
+
hotswap=hotswap,
|
1303
|
+
)
|
1304
|
+
self.load_lora_into_text_encoder(
|
1305
|
+
state_dict,
|
1306
|
+
network_alphas=None,
|
1307
|
+
text_encoder=self.text_encoder,
|
1308
|
+
prefix=self.text_encoder_name,
|
1309
|
+
lora_scale=self.lora_scale,
|
1310
|
+
adapter_name=adapter_name,
|
1311
|
+
_pipeline=self,
|
1312
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
1313
|
+
hotswap=hotswap,
|
1314
|
+
)
|
1315
|
+
self.load_lora_into_text_encoder(
|
1316
|
+
state_dict,
|
1317
|
+
network_alphas=None,
|
1318
|
+
text_encoder=self.text_encoder_2,
|
1319
|
+
prefix=f"{self.text_encoder_name}_2",
|
1320
|
+
lora_scale=self.lora_scale,
|
1321
|
+
adapter_name=adapter_name,
|
1322
|
+
_pipeline=self,
|
1323
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
1324
|
+
hotswap=hotswap,
|
1325
|
+
)
|
1328
1326
|
|
1329
1327
|
@classmethod
|
1330
1328
|
def load_lora_into_transformer(
|
1331
|
-
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
1329
|
+
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
|
1332
1330
|
):
|
1333
1331
|
"""
|
1334
1332
|
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
@@ -1346,6 +1344,29 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1346
1344
|
low_cpu_mem_usage (`bool`, *optional*):
|
1347
1345
|
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
1348
1346
|
weights.
|
1347
|
+
hotswap : (`bool`, *optional*)
|
1348
|
+
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
1349
|
+
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
1350
|
+
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
1351
|
+
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
1352
|
+
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
1353
|
+
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
1354
|
+
|
1355
|
+
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
1356
|
+
to call an additional method before loading the adapter:
|
1357
|
+
|
1358
|
+
```py
|
1359
|
+
pipeline = ... # load diffusers pipeline
|
1360
|
+
max_rank = ... # the highest rank among all LoRAs that you want to load
|
1361
|
+
# call *before* compiling and loading the LoRA adapter
|
1362
|
+
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
1363
|
+
pipeline.load_lora_weights(file_name)
|
1364
|
+
# optionally compile the model now
|
1365
|
+
```
|
1366
|
+
|
1367
|
+
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
1368
|
+
limitations to this technique, which are documented here:
|
1369
|
+
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
1349
1370
|
"""
|
1350
1371
|
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
1351
1372
|
raise ValueError(
|
@@ -1360,6 +1381,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1360
1381
|
adapter_name=adapter_name,
|
1361
1382
|
_pipeline=_pipeline,
|
1362
1383
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
1384
|
+
hotswap=hotswap,
|
1363
1385
|
)
|
1364
1386
|
|
1365
1387
|
@classmethod
|
@@ -1374,6 +1396,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1374
1396
|
adapter_name=None,
|
1375
1397
|
_pipeline=None,
|
1376
1398
|
low_cpu_mem_usage=False,
|
1399
|
+
hotswap: bool = False,
|
1377
1400
|
):
|
1378
1401
|
"""
|
1379
1402
|
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
@@ -1399,126 +1422,49 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1399
1422
|
low_cpu_mem_usage (`bool`, *optional*):
|
1400
1423
|
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
1401
1424
|
weights.
|
1425
|
+
hotswap : (`bool`, *optional*)
|
1426
|
+
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
1427
|
+
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
1428
|
+
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
1429
|
+
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
1430
|
+
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
1431
|
+
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
1432
|
+
|
1433
|
+
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
1434
|
+
to call an additional method before loading the adapter:
|
1435
|
+
|
1436
|
+
```py
|
1437
|
+
pipeline = ... # load diffusers pipeline
|
1438
|
+
max_rank = ... # the highest rank among all LoRAs that you want to load
|
1439
|
+
# call *before* compiling and loading the LoRA adapter
|
1440
|
+
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
1441
|
+
pipeline.load_lora_weights(file_name)
|
1442
|
+
# optionally compile the model now
|
1443
|
+
```
|
1444
|
+
|
1445
|
+
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
1446
|
+
limitations to this technique, which are documented here:
|
1447
|
+
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
1402
1448
|
"""
|
1403
|
-
|
1404
|
-
|
1405
|
-
|
1406
|
-
|
1407
|
-
|
1408
|
-
|
1409
|
-
|
1410
|
-
|
1411
|
-
|
1412
|
-
|
1413
|
-
|
1414
|
-
|
1415
|
-
raise ValueError(
|
1416
|
-
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
|
1417
|
-
)
|
1418
|
-
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
1419
|
-
|
1420
|
-
from peft import LoraConfig
|
1421
|
-
|
1422
|
-
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
1423
|
-
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
|
1424
|
-
# their prefixes.
|
1425
|
-
keys = list(state_dict.keys())
|
1426
|
-
prefix = cls.text_encoder_name if prefix is None else prefix
|
1427
|
-
|
1428
|
-
# Safe prefix to check with.
|
1429
|
-
if any(cls.text_encoder_name in key for key in keys):
|
1430
|
-
# Load the layers corresponding to text encoder and make necessary adjustments.
|
1431
|
-
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
|
1432
|
-
text_encoder_lora_state_dict = {
|
1433
|
-
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
|
1434
|
-
}
|
1435
|
-
|
1436
|
-
if len(text_encoder_lora_state_dict) > 0:
|
1437
|
-
logger.info(f"Loading {prefix}.")
|
1438
|
-
rank = {}
|
1439
|
-
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
|
1440
|
-
|
1441
|
-
# convert state dict
|
1442
|
-
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
|
1443
|
-
|
1444
|
-
for name, _ in text_encoder_attn_modules(text_encoder):
|
1445
|
-
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
|
1446
|
-
rank_key = f"{name}.{module}.lora_B.weight"
|
1447
|
-
if rank_key not in text_encoder_lora_state_dict:
|
1448
|
-
continue
|
1449
|
-
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
1450
|
-
|
1451
|
-
for name, _ in text_encoder_mlp_modules(text_encoder):
|
1452
|
-
for module in ("fc1", "fc2"):
|
1453
|
-
rank_key = f"{name}.{module}.lora_B.weight"
|
1454
|
-
if rank_key not in text_encoder_lora_state_dict:
|
1455
|
-
continue
|
1456
|
-
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
1457
|
-
|
1458
|
-
if network_alphas is not None:
|
1459
|
-
alpha_keys = [
|
1460
|
-
k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
|
1461
|
-
]
|
1462
|
-
network_alphas = {
|
1463
|
-
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
1464
|
-
}
|
1465
|
-
|
1466
|
-
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
|
1467
|
-
|
1468
|
-
if "use_dora" in lora_config_kwargs:
|
1469
|
-
if lora_config_kwargs["use_dora"]:
|
1470
|
-
if is_peft_version("<", "0.9.0"):
|
1471
|
-
raise ValueError(
|
1472
|
-
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
1473
|
-
)
|
1474
|
-
else:
|
1475
|
-
if is_peft_version("<", "0.9.0"):
|
1476
|
-
lora_config_kwargs.pop("use_dora")
|
1477
|
-
|
1478
|
-
if "lora_bias" in lora_config_kwargs:
|
1479
|
-
if lora_config_kwargs["lora_bias"]:
|
1480
|
-
if is_peft_version("<=", "0.13.2"):
|
1481
|
-
raise ValueError(
|
1482
|
-
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
|
1483
|
-
)
|
1484
|
-
else:
|
1485
|
-
if is_peft_version("<=", "0.13.2"):
|
1486
|
-
lora_config_kwargs.pop("lora_bias")
|
1487
|
-
|
1488
|
-
lora_config = LoraConfig(**lora_config_kwargs)
|
1489
|
-
|
1490
|
-
# adapter_name
|
1491
|
-
if adapter_name is None:
|
1492
|
-
adapter_name = get_adapter_name(text_encoder)
|
1493
|
-
|
1494
|
-
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
1495
|
-
|
1496
|
-
# inject LoRA layers and load the state dict
|
1497
|
-
# in transformers we automatically check whether the adapter name is already in use or not
|
1498
|
-
text_encoder.load_adapter(
|
1499
|
-
adapter_name=adapter_name,
|
1500
|
-
adapter_state_dict=text_encoder_lora_state_dict,
|
1501
|
-
peft_config=lora_config,
|
1502
|
-
**peft_kwargs,
|
1503
|
-
)
|
1504
|
-
|
1505
|
-
# scale LoRA layers with `lora_scale`
|
1506
|
-
scale_lora_layers(text_encoder, weight=lora_scale)
|
1507
|
-
|
1508
|
-
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
1509
|
-
|
1510
|
-
# Offload back.
|
1511
|
-
if is_model_cpu_offload:
|
1512
|
-
_pipeline.enable_model_cpu_offload()
|
1513
|
-
elif is_sequential_cpu_offload:
|
1514
|
-
_pipeline.enable_sequential_cpu_offload()
|
1515
|
-
# Unsafe code />
|
1449
|
+
_load_lora_into_text_encoder(
|
1450
|
+
state_dict=state_dict,
|
1451
|
+
network_alphas=network_alphas,
|
1452
|
+
lora_scale=lora_scale,
|
1453
|
+
text_encoder=text_encoder,
|
1454
|
+
prefix=prefix,
|
1455
|
+
text_encoder_name=cls.text_encoder_name,
|
1456
|
+
adapter_name=adapter_name,
|
1457
|
+
_pipeline=_pipeline,
|
1458
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
1459
|
+
hotswap=hotswap,
|
1460
|
+
)
|
1516
1461
|
|
1517
1462
|
@classmethod
|
1463
|
+
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.save_lora_weights with unet->transformer
|
1518
1464
|
def save_lora_weights(
|
1519
1465
|
cls,
|
1520
1466
|
save_directory: Union[str, os.PathLike],
|
1521
|
-
transformer_lora_layers: Dict[str, torch.nn.Module] = None,
|
1467
|
+
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
1522
1468
|
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
1523
1469
|
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
1524
1470
|
is_main_process: bool = True,
|
@@ -1567,7 +1513,6 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1567
1513
|
if text_encoder_2_lora_layers:
|
1568
1514
|
state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
|
1569
1515
|
|
1570
|
-
# Save the model
|
1571
1516
|
cls.write_lora_layers(
|
1572
1517
|
state_dict=state_dict,
|
1573
1518
|
save_directory=save_directory,
|
@@ -1577,6 +1522,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1577
1522
|
safe_serialization=safe_serialization,
|
1578
1523
|
)
|
1579
1524
|
|
1525
|
+
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.fuse_lora with unet->transformer
|
1580
1526
|
def fuse_lora(
|
1581
1527
|
self,
|
1582
1528
|
components: List[str] = ["transformer", "text_encoder", "text_encoder_2"],
|
@@ -1617,9 +1563,14 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1617
1563
|
```
|
1618
1564
|
"""
|
1619
1565
|
super().fuse_lora(
|
1620
|
-
components=components,
|
1566
|
+
components=components,
|
1567
|
+
lora_scale=lora_scale,
|
1568
|
+
safe_fusing=safe_fusing,
|
1569
|
+
adapter_names=adapter_names,
|
1570
|
+
**kwargs,
|
1621
1571
|
)
|
1622
1572
|
|
1573
|
+
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.unfuse_lora with unet->transformer
|
1623
1574
|
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], **kwargs):
|
1624
1575
|
r"""
|
1625
1576
|
Reverses the effect of
|
@@ -1633,12 +1584,12 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1633
1584
|
|
1634
1585
|
Args:
|
1635
1586
|
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
1636
|
-
|
1587
|
+
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
1637
1588
|
unfuse_text_encoder (`bool`, defaults to `True`):
|
1638
1589
|
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
1639
1590
|
LoRA parameters then it won't have any effect.
|
1640
1591
|
"""
|
1641
|
-
super().unfuse_lora(components=components)
|
1592
|
+
super().unfuse_lora(components=components, **kwargs)
|
1642
1593
|
|
1643
1594
|
|
1644
1595
|
class FluxLoraLoaderMixin(LoraBaseMixin):
|
@@ -1789,7 +1740,11 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1789
1740
|
return state_dict
|
1790
1741
|
|
1791
1742
|
def load_lora_weights(
|
1792
|
-
self,
|
1743
|
+
self,
|
1744
|
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
1745
|
+
adapter_name=None,
|
1746
|
+
hotswap: bool = False,
|
1747
|
+
**kwargs,
|
1793
1748
|
):
|
1794
1749
|
"""
|
1795
1750
|
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
@@ -1814,6 +1769,26 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1814
1769
|
low_cpu_mem_usage (`bool`, *optional*):
|
1815
1770
|
`Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
1816
1771
|
weights.
|
1772
|
+
hotswap : (`bool`, *optional*)
|
1773
|
+
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
1774
|
+
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
1775
|
+
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
1776
|
+
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
1777
|
+
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
1778
|
+
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. If the new
|
1779
|
+
adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need to call an
|
1780
|
+
additional method before loading the adapter:
|
1781
|
+
```py
|
1782
|
+
pipeline = ... # load diffusers pipeline
|
1783
|
+
max_rank = ... # the highest rank among all LoRAs that you want to load
|
1784
|
+
# call *before* compiling and loading the LoRA adapter
|
1785
|
+
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
1786
|
+
pipeline.load_lora_weights(file_name)
|
1787
|
+
# optionally compile the model now
|
1788
|
+
```
|
1789
|
+
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
1790
|
+
limitations to this technique, which are documented here:
|
1791
|
+
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
1817
1792
|
"""
|
1818
1793
|
if not USE_PEFT_BACKEND:
|
1819
1794
|
raise ValueError("PEFT backend is required for this method.")
|
@@ -1844,18 +1819,23 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1844
1819
|
raise ValueError("Invalid LoRA checkpoint.")
|
1845
1820
|
|
1846
1821
|
transformer_lora_state_dict = {
|
1847
|
-
k: state_dict.
|
1822
|
+
k: state_dict.get(k)
|
1823
|
+
for k in list(state_dict.keys())
|
1824
|
+
if k.startswith(f"{self.transformer_name}.") and "lora" in k
|
1848
1825
|
}
|
1849
1826
|
transformer_norm_state_dict = {
|
1850
1827
|
k: state_dict.pop(k)
|
1851
1828
|
for k in list(state_dict.keys())
|
1852
|
-
if
|
1829
|
+
if k.startswith(f"{self.transformer_name}.")
|
1830
|
+
and any(norm_key in k for norm_key in self._control_lora_supported_norm_keys)
|
1853
1831
|
}
|
1854
1832
|
|
1855
1833
|
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
|
1856
|
-
has_param_with_expanded_shape =
|
1857
|
-
|
1858
|
-
|
1834
|
+
has_param_with_expanded_shape = False
|
1835
|
+
if len(transformer_lora_state_dict) > 0:
|
1836
|
+
has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_(
|
1837
|
+
transformer, transformer_lora_state_dict, transformer_norm_state_dict
|
1838
|
+
)
|
1859
1839
|
|
1860
1840
|
if has_param_with_expanded_shape:
|
1861
1841
|
logger.info(
|
@@ -1863,19 +1843,22 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1863
1843
|
"As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. "
|
1864
1844
|
"To get a comprehensive list of parameter names that were modified, enable debug logging."
|
1865
1845
|
)
|
1866
|
-
transformer_lora_state_dict = self._maybe_expand_lora_state_dict(
|
1867
|
-
transformer=transformer, lora_state_dict=transformer_lora_state_dict
|
1868
|
-
)
|
1869
|
-
|
1870
1846
|
if len(transformer_lora_state_dict) > 0:
|
1871
|
-
self.
|
1872
|
-
transformer_lora_state_dict
|
1873
|
-
network_alphas=network_alphas,
|
1874
|
-
transformer=transformer,
|
1875
|
-
adapter_name=adapter_name,
|
1876
|
-
_pipeline=self,
|
1877
|
-
low_cpu_mem_usage=low_cpu_mem_usage,
|
1847
|
+
transformer_lora_state_dict = self._maybe_expand_lora_state_dict(
|
1848
|
+
transformer=transformer, lora_state_dict=transformer_lora_state_dict
|
1878
1849
|
)
|
1850
|
+
for k in transformer_lora_state_dict:
|
1851
|
+
state_dict.update({k: transformer_lora_state_dict[k]})
|
1852
|
+
|
1853
|
+
self.load_lora_into_transformer(
|
1854
|
+
state_dict,
|
1855
|
+
network_alphas=network_alphas,
|
1856
|
+
transformer=transformer,
|
1857
|
+
adapter_name=adapter_name,
|
1858
|
+
_pipeline=self,
|
1859
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
1860
|
+
hotswap=hotswap,
|
1861
|
+
)
|
1879
1862
|
|
1880
1863
|
if len(transformer_norm_state_dict) > 0:
|
1881
1864
|
transformer._transformer_norm_layers = self._load_norm_into_transformer(
|
@@ -1884,22 +1867,28 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1884
1867
|
discard_original_layers=False,
|
1885
1868
|
)
|
1886
1869
|
|
1887
|
-
|
1888
|
-
|
1889
|
-
|
1890
|
-
|
1891
|
-
|
1892
|
-
|
1893
|
-
|
1894
|
-
|
1895
|
-
|
1896
|
-
|
1897
|
-
|
1898
|
-
)
|
1870
|
+
self.load_lora_into_text_encoder(
|
1871
|
+
state_dict,
|
1872
|
+
network_alphas=network_alphas,
|
1873
|
+
text_encoder=self.text_encoder,
|
1874
|
+
prefix=self.text_encoder_name,
|
1875
|
+
lora_scale=self.lora_scale,
|
1876
|
+
adapter_name=adapter_name,
|
1877
|
+
_pipeline=self,
|
1878
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
1879
|
+
hotswap=hotswap,
|
1880
|
+
)
|
1899
1881
|
|
1900
1882
|
@classmethod
|
1901
1883
|
def load_lora_into_transformer(
|
1902
|
-
cls,
|
1884
|
+
cls,
|
1885
|
+
state_dict,
|
1886
|
+
network_alphas,
|
1887
|
+
transformer,
|
1888
|
+
adapter_name=None,
|
1889
|
+
_pipeline=None,
|
1890
|
+
low_cpu_mem_usage=False,
|
1891
|
+
hotswap: bool = False,
|
1903
1892
|
):
|
1904
1893
|
"""
|
1905
1894
|
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
@@ -1921,6 +1910,29 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1921
1910
|
low_cpu_mem_usage (`bool`, *optional*):
|
1922
1911
|
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
1923
1912
|
weights.
|
1913
|
+
hotswap : (`bool`, *optional*)
|
1914
|
+
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
1915
|
+
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
1916
|
+
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
1917
|
+
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
1918
|
+
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
1919
|
+
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
1920
|
+
|
1921
|
+
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
1922
|
+
to call an additional method before loading the adapter:
|
1923
|
+
|
1924
|
+
```py
|
1925
|
+
pipeline = ... # load diffusers pipeline
|
1926
|
+
max_rank = ... # the highest rank among all LoRAs that you want to load
|
1927
|
+
# call *before* compiling and loading the LoRA adapter
|
1928
|
+
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
1929
|
+
pipeline.load_lora_weights(file_name)
|
1930
|
+
# optionally compile the model now
|
1931
|
+
```
|
1932
|
+
|
1933
|
+
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
1934
|
+
limitations to this technique, which are documented here:
|
1935
|
+
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
1924
1936
|
"""
|
1925
1937
|
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
1926
1938
|
raise ValueError(
|
@@ -1928,17 +1940,15 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1928
1940
|
)
|
1929
1941
|
|
1930
1942
|
# Load the layers corresponding to transformer.
|
1931
|
-
|
1932
|
-
|
1933
|
-
|
1934
|
-
|
1935
|
-
|
1936
|
-
|
1937
|
-
|
1938
|
-
|
1939
|
-
|
1940
|
-
low_cpu_mem_usage=low_cpu_mem_usage,
|
1941
|
-
)
|
1943
|
+
logger.info(f"Loading {cls.transformer_name}.")
|
1944
|
+
transformer.load_lora_adapter(
|
1945
|
+
state_dict,
|
1946
|
+
network_alphas=network_alphas,
|
1947
|
+
adapter_name=adapter_name,
|
1948
|
+
_pipeline=_pipeline,
|
1949
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
1950
|
+
hotswap=hotswap,
|
1951
|
+
)
|
1942
1952
|
|
1943
1953
|
@classmethod
|
1944
1954
|
def _load_norm_into_transformer(
|
@@ -2006,6 +2016,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
2006
2016
|
adapter_name=None,
|
2007
2017
|
_pipeline=None,
|
2008
2018
|
low_cpu_mem_usage=False,
|
2019
|
+
hotswap: bool = False,
|
2009
2020
|
):
|
2010
2021
|
"""
|
2011
2022
|
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
@@ -2031,120 +2042,42 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
2031
2042
|
low_cpu_mem_usage (`bool`, *optional*):
|
2032
2043
|
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
2033
2044
|
weights.
|
2045
|
+
hotswap : (`bool`, *optional*)
|
2046
|
+
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
2047
|
+
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
2048
|
+
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
2049
|
+
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
2050
|
+
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
2051
|
+
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
2052
|
+
|
2053
|
+
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
2054
|
+
to call an additional method before loading the adapter:
|
2055
|
+
|
2056
|
+
```py
|
2057
|
+
pipeline = ... # load diffusers pipeline
|
2058
|
+
max_rank = ... # the highest rank among all LoRAs that you want to load
|
2059
|
+
# call *before* compiling and loading the LoRA adapter
|
2060
|
+
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
2061
|
+
pipeline.load_lora_weights(file_name)
|
2062
|
+
# optionally compile the model now
|
2063
|
+
```
|
2064
|
+
|
2065
|
+
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
2066
|
+
limitations to this technique, which are documented here:
|
2067
|
+
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
2034
2068
|
"""
|
2035
|
-
|
2036
|
-
|
2037
|
-
|
2038
|
-
|
2039
|
-
|
2040
|
-
|
2041
|
-
|
2042
|
-
|
2043
|
-
|
2044
|
-
|
2045
|
-
|
2046
|
-
|
2047
|
-
raise ValueError(
|
2048
|
-
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
|
2049
|
-
)
|
2050
|
-
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
2051
|
-
|
2052
|
-
from peft import LoraConfig
|
2053
|
-
|
2054
|
-
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
2055
|
-
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
|
2056
|
-
# their prefixes.
|
2057
|
-
keys = list(state_dict.keys())
|
2058
|
-
prefix = cls.text_encoder_name if prefix is None else prefix
|
2059
|
-
|
2060
|
-
# Safe prefix to check with.
|
2061
|
-
if any(cls.text_encoder_name in key for key in keys):
|
2062
|
-
# Load the layers corresponding to text encoder and make necessary adjustments.
|
2063
|
-
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
|
2064
|
-
text_encoder_lora_state_dict = {
|
2065
|
-
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
|
2066
|
-
}
|
2067
|
-
|
2068
|
-
if len(text_encoder_lora_state_dict) > 0:
|
2069
|
-
logger.info(f"Loading {prefix}.")
|
2070
|
-
rank = {}
|
2071
|
-
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
|
2072
|
-
|
2073
|
-
# convert state dict
|
2074
|
-
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
|
2075
|
-
|
2076
|
-
for name, _ in text_encoder_attn_modules(text_encoder):
|
2077
|
-
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
|
2078
|
-
rank_key = f"{name}.{module}.lora_B.weight"
|
2079
|
-
if rank_key not in text_encoder_lora_state_dict:
|
2080
|
-
continue
|
2081
|
-
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
2082
|
-
|
2083
|
-
for name, _ in text_encoder_mlp_modules(text_encoder):
|
2084
|
-
for module in ("fc1", "fc2"):
|
2085
|
-
rank_key = f"{name}.{module}.lora_B.weight"
|
2086
|
-
if rank_key not in text_encoder_lora_state_dict:
|
2087
|
-
continue
|
2088
|
-
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
2089
|
-
|
2090
|
-
if network_alphas is not None:
|
2091
|
-
alpha_keys = [
|
2092
|
-
k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
|
2093
|
-
]
|
2094
|
-
network_alphas = {
|
2095
|
-
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
2096
|
-
}
|
2097
|
-
|
2098
|
-
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
|
2099
|
-
|
2100
|
-
if "use_dora" in lora_config_kwargs:
|
2101
|
-
if lora_config_kwargs["use_dora"]:
|
2102
|
-
if is_peft_version("<", "0.9.0"):
|
2103
|
-
raise ValueError(
|
2104
|
-
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
2105
|
-
)
|
2106
|
-
else:
|
2107
|
-
if is_peft_version("<", "0.9.0"):
|
2108
|
-
lora_config_kwargs.pop("use_dora")
|
2109
|
-
|
2110
|
-
if "lora_bias" in lora_config_kwargs:
|
2111
|
-
if lora_config_kwargs["lora_bias"]:
|
2112
|
-
if is_peft_version("<=", "0.13.2"):
|
2113
|
-
raise ValueError(
|
2114
|
-
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
|
2115
|
-
)
|
2116
|
-
else:
|
2117
|
-
if is_peft_version("<=", "0.13.2"):
|
2118
|
-
lora_config_kwargs.pop("lora_bias")
|
2119
|
-
|
2120
|
-
lora_config = LoraConfig(**lora_config_kwargs)
|
2121
|
-
|
2122
|
-
# adapter_name
|
2123
|
-
if adapter_name is None:
|
2124
|
-
adapter_name = get_adapter_name(text_encoder)
|
2125
|
-
|
2126
|
-
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
2127
|
-
|
2128
|
-
# inject LoRA layers and load the state dict
|
2129
|
-
# in transformers we automatically check whether the adapter name is already in use or not
|
2130
|
-
text_encoder.load_adapter(
|
2131
|
-
adapter_name=adapter_name,
|
2132
|
-
adapter_state_dict=text_encoder_lora_state_dict,
|
2133
|
-
peft_config=lora_config,
|
2134
|
-
**peft_kwargs,
|
2135
|
-
)
|
2136
|
-
|
2137
|
-
# scale LoRA layers with `lora_scale`
|
2138
|
-
scale_lora_layers(text_encoder, weight=lora_scale)
|
2139
|
-
|
2140
|
-
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
2141
|
-
|
2142
|
-
# Offload back.
|
2143
|
-
if is_model_cpu_offload:
|
2144
|
-
_pipeline.enable_model_cpu_offload()
|
2145
|
-
elif is_sequential_cpu_offload:
|
2146
|
-
_pipeline.enable_sequential_cpu_offload()
|
2147
|
-
# Unsafe code />
|
2069
|
+
_load_lora_into_text_encoder(
|
2070
|
+
state_dict=state_dict,
|
2071
|
+
network_alphas=network_alphas,
|
2072
|
+
lora_scale=lora_scale,
|
2073
|
+
text_encoder=text_encoder,
|
2074
|
+
prefix=prefix,
|
2075
|
+
text_encoder_name=cls.text_encoder_name,
|
2076
|
+
adapter_name=adapter_name,
|
2077
|
+
_pipeline=_pipeline,
|
2078
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
2079
|
+
hotswap=hotswap,
|
2080
|
+
)
|
2148
2081
|
|
2149
2082
|
@classmethod
|
2150
2083
|
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights with unet->transformer
|
@@ -2203,7 +2136,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
2203
2136
|
|
2204
2137
|
def fuse_lora(
|
2205
2138
|
self,
|
2206
|
-
components: List[str] = ["transformer"
|
2139
|
+
components: List[str] = ["transformer"],
|
2207
2140
|
lora_scale: float = 1.0,
|
2208
2141
|
safe_fusing: bool = False,
|
2209
2142
|
adapter_names: Optional[List[str]] = None,
|
@@ -2254,7 +2187,11 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
2254
2187
|
)
|
2255
2188
|
|
2256
2189
|
super().fuse_lora(
|
2257
|
-
components=components,
|
2190
|
+
components=components,
|
2191
|
+
lora_scale=lora_scale,
|
2192
|
+
safe_fusing=safe_fusing,
|
2193
|
+
adapter_names=adapter_names,
|
2194
|
+
**kwargs,
|
2258
2195
|
)
|
2259
2196
|
|
2260
2197
|
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
|
@@ -2275,10 +2212,26 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
2275
2212
|
if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers:
|
2276
2213
|
transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
|
2277
2214
|
|
2278
|
-
super().unfuse_lora(components=components)
|
2215
|
+
super().unfuse_lora(components=components, **kwargs)
|
2216
|
+
|
2217
|
+
# We override this here account for `_transformer_norm_layers` and `_overwritten_params`.
|
2218
|
+
def unload_lora_weights(self, reset_to_overwritten_params=False):
|
2219
|
+
"""
|
2220
|
+
Unloads the LoRA parameters.
|
2221
|
+
|
2222
|
+
Args:
|
2223
|
+
reset_to_overwritten_params (`bool`, defaults to `False`): Whether to reset the LoRA-loaded modules
|
2224
|
+
to their original params. Refer to the [Flux
|
2225
|
+
documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux) to learn more.
|
2226
|
+
|
2227
|
+
Examples:
|
2279
2228
|
|
2280
|
-
|
2281
|
-
|
2229
|
+
```python
|
2230
|
+
>>> # Assuming `pipeline` is already loaded with the LoRA parameters.
|
2231
|
+
>>> pipeline.unload_lora_weights()
|
2232
|
+
>>> ...
|
2233
|
+
```
|
2234
|
+
"""
|
2282
2235
|
super().unload_lora_weights()
|
2283
2236
|
|
2284
2237
|
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
|
@@ -2286,11 +2239,55 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
2286
2239
|
transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
|
2287
2240
|
transformer._transformer_norm_layers = None
|
2288
2241
|
|
2289
|
-
|
2290
|
-
|
2291
|
-
|
2292
|
-
|
2293
|
-
|
2242
|
+
if reset_to_overwritten_params and getattr(transformer, "_overwritten_params", None) is not None:
|
2243
|
+
overwritten_params = transformer._overwritten_params
|
2244
|
+
module_names = set()
|
2245
|
+
|
2246
|
+
for param_name in overwritten_params:
|
2247
|
+
if param_name.endswith(".weight"):
|
2248
|
+
module_names.add(param_name.replace(".weight", ""))
|
2249
|
+
|
2250
|
+
for name, module in transformer.named_modules():
|
2251
|
+
if isinstance(module, torch.nn.Linear) and name in module_names:
|
2252
|
+
module_weight = module.weight.data
|
2253
|
+
module_bias = module.bias.data if module.bias is not None else None
|
2254
|
+
bias = module_bias is not None
|
2255
|
+
|
2256
|
+
parent_module_name, _, current_module_name = name.rpartition(".")
|
2257
|
+
parent_module = transformer.get_submodule(parent_module_name)
|
2258
|
+
|
2259
|
+
current_param_weight = overwritten_params[f"{name}.weight"]
|
2260
|
+
in_features, out_features = current_param_weight.shape[1], current_param_weight.shape[0]
|
2261
|
+
with torch.device("meta"):
|
2262
|
+
original_module = torch.nn.Linear(
|
2263
|
+
in_features,
|
2264
|
+
out_features,
|
2265
|
+
bias=bias,
|
2266
|
+
dtype=module_weight.dtype,
|
2267
|
+
)
|
2268
|
+
|
2269
|
+
tmp_state_dict = {"weight": current_param_weight}
|
2270
|
+
if module_bias is not None:
|
2271
|
+
tmp_state_dict.update({"bias": overwritten_params[f"{name}.bias"]})
|
2272
|
+
original_module.load_state_dict(tmp_state_dict, assign=True, strict=True)
|
2273
|
+
setattr(parent_module, current_module_name, original_module)
|
2274
|
+
|
2275
|
+
del tmp_state_dict
|
2276
|
+
|
2277
|
+
if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX:
|
2278
|
+
attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name]
|
2279
|
+
new_value = int(current_param_weight.shape[1])
|
2280
|
+
old_value = getattr(transformer.config, attribute_name)
|
2281
|
+
setattr(transformer.config, attribute_name, new_value)
|
2282
|
+
logger.info(
|
2283
|
+
f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}."
|
2284
|
+
)
|
2285
|
+
|
2286
|
+
@classmethod
|
2287
|
+
def _maybe_expand_transformer_param_shape_or_error_(
|
2288
|
+
cls,
|
2289
|
+
transformer: torch.nn.Module,
|
2290
|
+
lora_state_dict=None,
|
2294
2291
|
norm_state_dict=None,
|
2295
2292
|
prefix=None,
|
2296
2293
|
) -> bool:
|
@@ -2312,7 +2309,10 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
2312
2309
|
|
2313
2310
|
# Expand transformer parameter shapes if they don't match lora
|
2314
2311
|
has_param_with_shape_update = False
|
2312
|
+
overwritten_params = {}
|
2313
|
+
|
2315
2314
|
is_peft_loaded = getattr(transformer, "peft_config", None) is not None
|
2315
|
+
is_quantized = hasattr(transformer, "hf_quantizer")
|
2316
2316
|
for name, module in transformer.named_modules():
|
2317
2317
|
if isinstance(module, torch.nn.Linear):
|
2318
2318
|
module_weight = module.weight.data
|
@@ -2328,11 +2328,16 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
2328
2328
|
in_features = state_dict[lora_A_weight_name].shape[1]
|
2329
2329
|
out_features = state_dict[lora_B_weight_name].shape[0]
|
2330
2330
|
|
2331
|
+
# Model maybe loaded with different quantization schemes which may flatten the params.
|
2332
|
+
# `bitsandbytes`, for example, flatten the weights when using 4bit. 8bit bnb models
|
2333
|
+
# preserve weight shape.
|
2334
|
+
module_weight_shape = cls._calculate_module_shape(model=transformer, base_module=module)
|
2335
|
+
|
2331
2336
|
# This means there's no need for an expansion in the params, so we simply skip.
|
2332
|
-
if tuple(
|
2337
|
+
if tuple(module_weight_shape) == (out_features, in_features):
|
2333
2338
|
continue
|
2334
2339
|
|
2335
|
-
module_out_features, module_in_features =
|
2340
|
+
module_out_features, module_in_features = module_weight_shape
|
2336
2341
|
debug_message = ""
|
2337
2342
|
if in_features > module_in_features:
|
2338
2343
|
debug_message += (
|
@@ -2355,6 +2360,10 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
2355
2360
|
parent_module_name, _, current_module_name = name.rpartition(".")
|
2356
2361
|
parent_module = transformer.get_submodule(parent_module_name)
|
2357
2362
|
|
2363
|
+
if is_quantized:
|
2364
|
+
module_weight = _maybe_dequantize_weight_for_expanded_lora(transformer, module)
|
2365
|
+
|
2366
|
+
# TODO: consider if this layer needs to be a quantized layer as well if `is_quantized` is True.
|
2358
2367
|
with torch.device("meta"):
|
2359
2368
|
expanded_module = torch.nn.Linear(
|
2360
2369
|
in_features, out_features, bias=bias, dtype=module_weight.dtype
|
@@ -2366,7 +2375,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
2366
2375
|
new_weight = torch.zeros_like(
|
2367
2376
|
expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype
|
2368
2377
|
)
|
2369
|
-
slices = tuple(slice(0, dim) for dim in
|
2378
|
+
slices = tuple(slice(0, dim) for dim in module_weight_shape)
|
2370
2379
|
new_weight[slices] = module_weight
|
2371
2380
|
tmp_state_dict = {"weight": new_weight}
|
2372
2381
|
if module_bias is not None:
|
@@ -2386,6 +2395,16 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
2386
2395
|
f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}."
|
2387
2396
|
)
|
2388
2397
|
|
2398
|
+
# For `unload_lora_weights()`.
|
2399
|
+
# TODO: this could lead to more memory overhead if the number of overwritten params
|
2400
|
+
# are large. Should be revisited later and tackled through a `discard_original_layers` arg.
|
2401
|
+
overwritten_params[f"{current_module_name}.weight"] = module_weight
|
2402
|
+
if module_bias is not None:
|
2403
|
+
overwritten_params[f"{current_module_name}.bias"] = module_bias
|
2404
|
+
|
2405
|
+
if len(overwritten_params) > 0:
|
2406
|
+
transformer._overwritten_params = overwritten_params
|
2407
|
+
|
2389
2408
|
return has_param_with_shape_update
|
2390
2409
|
|
2391
2410
|
@classmethod
|
@@ -2410,18 +2429,23 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
2410
2429
|
continue
|
2411
2430
|
|
2412
2431
|
base_param_name = (
|
2413
|
-
f"{k.replace(prefix, '')}.base_layer.weight"
|
2432
|
+
f"{k.replace(prefix, '')}.base_layer.weight"
|
2433
|
+
if is_peft_loaded and f"{k.replace(prefix, '')}.base_layer.weight" in transformer_state_dict
|
2434
|
+
else f"{k.replace(prefix, '')}.weight"
|
2414
2435
|
)
|
2415
2436
|
base_weight_param = transformer_state_dict[base_param_name]
|
2416
2437
|
lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"]
|
2417
2438
|
|
2418
|
-
|
2439
|
+
# TODO (sayakpaul): Handle the cases when we actually need to expand when using quantization.
|
2440
|
+
base_module_shape = cls._calculate_module_shape(model=transformer, base_weight_param_name=base_param_name)
|
2441
|
+
|
2442
|
+
if base_module_shape[1] > lora_A_param.shape[1]:
|
2419
2443
|
shape = (lora_A_param.shape[0], base_weight_param.shape[1])
|
2420
2444
|
expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device)
|
2421
2445
|
expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param)
|
2422
2446
|
lora_state_dict[f"{prefix}{k}.lora_A.weight"] = expanded_state_dict_weight
|
2423
2447
|
expanded_module_names.add(k)
|
2424
|
-
elif
|
2448
|
+
elif base_module_shape[1] < lora_A_param.shape[1]:
|
2425
2449
|
raise NotImplementedError(
|
2426
2450
|
f"This LoRA param ({k}.lora_A.weight) has an incompatible shape {lora_A_param.shape}. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new."
|
2427
2451
|
)
|
@@ -2433,6 +2457,33 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
2433
2457
|
|
2434
2458
|
return lora_state_dict
|
2435
2459
|
|
2460
|
+
@staticmethod
|
2461
|
+
def _calculate_module_shape(
|
2462
|
+
model: "torch.nn.Module",
|
2463
|
+
base_module: "torch.nn.Linear" = None,
|
2464
|
+
base_weight_param_name: str = None,
|
2465
|
+
) -> "torch.Size":
|
2466
|
+
def _get_weight_shape(weight: torch.Tensor):
|
2467
|
+
if weight.__class__.__name__ == "Params4bit":
|
2468
|
+
return weight.quant_state.shape
|
2469
|
+
elif weight.__class__.__name__ == "GGUFParameter":
|
2470
|
+
return weight.quant_shape
|
2471
|
+
else:
|
2472
|
+
return weight.shape
|
2473
|
+
|
2474
|
+
if base_module is not None:
|
2475
|
+
return _get_weight_shape(base_module.weight)
|
2476
|
+
elif base_weight_param_name is not None:
|
2477
|
+
if not base_weight_param_name.endswith(".weight"):
|
2478
|
+
raise ValueError(
|
2479
|
+
f"Invalid `base_weight_param_name` passed as it does not end with '.weight' {base_weight_param_name=}."
|
2480
|
+
)
|
2481
|
+
module_path = base_weight_param_name.rsplit(".weight", 1)[0]
|
2482
|
+
submodule = get_submodule_by_name(model, module_path)
|
2483
|
+
return _get_weight_shape(submodule.weight)
|
2484
|
+
|
2485
|
+
raise ValueError("Either `base_module` or `base_weight_param_name` must be provided.")
|
2486
|
+
|
2436
2487
|
|
2437
2488
|
# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
|
2438
2489
|
# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support.
|
@@ -2444,7 +2495,14 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
|
2444
2495
|
@classmethod
|
2445
2496
|
# Copied from diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin.load_lora_into_transformer with FluxTransformer2DModel->UVit2DModel
|
2446
2497
|
def load_lora_into_transformer(
|
2447
|
-
cls,
|
2498
|
+
cls,
|
2499
|
+
state_dict,
|
2500
|
+
network_alphas,
|
2501
|
+
transformer,
|
2502
|
+
adapter_name=None,
|
2503
|
+
_pipeline=None,
|
2504
|
+
low_cpu_mem_usage=False,
|
2505
|
+
hotswap: bool = False,
|
2448
2506
|
):
|
2449
2507
|
"""
|
2450
2508
|
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
@@ -2466,6 +2524,29 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
|
2466
2524
|
low_cpu_mem_usage (`bool`, *optional*):
|
2467
2525
|
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
2468
2526
|
weights.
|
2527
|
+
hotswap : (`bool`, *optional*)
|
2528
|
+
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
2529
|
+
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
2530
|
+
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
2531
|
+
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
2532
|
+
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
2533
|
+
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
2534
|
+
|
2535
|
+
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
2536
|
+
to call an additional method before loading the adapter:
|
2537
|
+
|
2538
|
+
```py
|
2539
|
+
pipeline = ... # load diffusers pipeline
|
2540
|
+
max_rank = ... # the highest rank among all LoRAs that you want to load
|
2541
|
+
# call *before* compiling and loading the LoRA adapter
|
2542
|
+
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
2543
|
+
pipeline.load_lora_weights(file_name)
|
2544
|
+
# optionally compile the model now
|
2545
|
+
```
|
2546
|
+
|
2547
|
+
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
2548
|
+
limitations to this technique, which are documented here:
|
2549
|
+
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
2469
2550
|
"""
|
2470
2551
|
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
2471
2552
|
raise ValueError(
|
@@ -2473,17 +2554,15 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
|
2473
2554
|
)
|
2474
2555
|
|
2475
2556
|
# Load the layers corresponding to transformer.
|
2476
|
-
|
2477
|
-
|
2478
|
-
|
2479
|
-
|
2480
|
-
|
2481
|
-
|
2482
|
-
|
2483
|
-
|
2484
|
-
|
2485
|
-
low_cpu_mem_usage=low_cpu_mem_usage,
|
2486
|
-
)
|
2557
|
+
logger.info(f"Loading {cls.transformer_name}.")
|
2558
|
+
transformer.load_lora_adapter(
|
2559
|
+
state_dict,
|
2560
|
+
network_alphas=network_alphas,
|
2561
|
+
adapter_name=adapter_name,
|
2562
|
+
_pipeline=_pipeline,
|
2563
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
2564
|
+
hotswap=hotswap,
|
2565
|
+
)
|
2487
2566
|
|
2488
2567
|
@classmethod
|
2489
2568
|
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
|
@@ -2497,6 +2576,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
|
2497
2576
|
adapter_name=None,
|
2498
2577
|
_pipeline=None,
|
2499
2578
|
low_cpu_mem_usage=False,
|
2579
|
+
hotswap: bool = False,
|
2500
2580
|
):
|
2501
2581
|
"""
|
2502
2582
|
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
@@ -2522,120 +2602,42 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
|
2522
2602
|
low_cpu_mem_usage (`bool`, *optional*):
|
2523
2603
|
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
2524
2604
|
weights.
|
2605
|
+
hotswap : (`bool`, *optional*)
|
2606
|
+
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
2607
|
+
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
2608
|
+
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
2609
|
+
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
2610
|
+
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
2611
|
+
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
2612
|
+
|
2613
|
+
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
2614
|
+
to call an additional method before loading the adapter:
|
2615
|
+
|
2616
|
+
```py
|
2617
|
+
pipeline = ... # load diffusers pipeline
|
2618
|
+
max_rank = ... # the highest rank among all LoRAs that you want to load
|
2619
|
+
# call *before* compiling and loading the LoRA adapter
|
2620
|
+
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
2621
|
+
pipeline.load_lora_weights(file_name)
|
2622
|
+
# optionally compile the model now
|
2623
|
+
```
|
2624
|
+
|
2625
|
+
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
2626
|
+
limitations to this technique, which are documented here:
|
2627
|
+
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
2525
2628
|
"""
|
2526
|
-
|
2527
|
-
|
2528
|
-
|
2529
|
-
|
2530
|
-
|
2531
|
-
|
2532
|
-
|
2533
|
-
|
2534
|
-
|
2535
|
-
|
2536
|
-
|
2537
|
-
|
2538
|
-
raise ValueError(
|
2539
|
-
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
|
2540
|
-
)
|
2541
|
-
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
2542
|
-
|
2543
|
-
from peft import LoraConfig
|
2544
|
-
|
2545
|
-
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
2546
|
-
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
|
2547
|
-
# their prefixes.
|
2548
|
-
keys = list(state_dict.keys())
|
2549
|
-
prefix = cls.text_encoder_name if prefix is None else prefix
|
2550
|
-
|
2551
|
-
# Safe prefix to check with.
|
2552
|
-
if any(cls.text_encoder_name in key for key in keys):
|
2553
|
-
# Load the layers corresponding to text encoder and make necessary adjustments.
|
2554
|
-
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
|
2555
|
-
text_encoder_lora_state_dict = {
|
2556
|
-
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
|
2557
|
-
}
|
2558
|
-
|
2559
|
-
if len(text_encoder_lora_state_dict) > 0:
|
2560
|
-
logger.info(f"Loading {prefix}.")
|
2561
|
-
rank = {}
|
2562
|
-
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
|
2563
|
-
|
2564
|
-
# convert state dict
|
2565
|
-
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
|
2566
|
-
|
2567
|
-
for name, _ in text_encoder_attn_modules(text_encoder):
|
2568
|
-
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
|
2569
|
-
rank_key = f"{name}.{module}.lora_B.weight"
|
2570
|
-
if rank_key not in text_encoder_lora_state_dict:
|
2571
|
-
continue
|
2572
|
-
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
2573
|
-
|
2574
|
-
for name, _ in text_encoder_mlp_modules(text_encoder):
|
2575
|
-
for module in ("fc1", "fc2"):
|
2576
|
-
rank_key = f"{name}.{module}.lora_B.weight"
|
2577
|
-
if rank_key not in text_encoder_lora_state_dict:
|
2578
|
-
continue
|
2579
|
-
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
2580
|
-
|
2581
|
-
if network_alphas is not None:
|
2582
|
-
alpha_keys = [
|
2583
|
-
k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
|
2584
|
-
]
|
2585
|
-
network_alphas = {
|
2586
|
-
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
2587
|
-
}
|
2588
|
-
|
2589
|
-
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
|
2590
|
-
|
2591
|
-
if "use_dora" in lora_config_kwargs:
|
2592
|
-
if lora_config_kwargs["use_dora"]:
|
2593
|
-
if is_peft_version("<", "0.9.0"):
|
2594
|
-
raise ValueError(
|
2595
|
-
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
2596
|
-
)
|
2597
|
-
else:
|
2598
|
-
if is_peft_version("<", "0.9.0"):
|
2599
|
-
lora_config_kwargs.pop("use_dora")
|
2600
|
-
|
2601
|
-
if "lora_bias" in lora_config_kwargs:
|
2602
|
-
if lora_config_kwargs["lora_bias"]:
|
2603
|
-
if is_peft_version("<=", "0.13.2"):
|
2604
|
-
raise ValueError(
|
2605
|
-
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
|
2606
|
-
)
|
2607
|
-
else:
|
2608
|
-
if is_peft_version("<=", "0.13.2"):
|
2609
|
-
lora_config_kwargs.pop("lora_bias")
|
2610
|
-
|
2611
|
-
lora_config = LoraConfig(**lora_config_kwargs)
|
2612
|
-
|
2613
|
-
# adapter_name
|
2614
|
-
if adapter_name is None:
|
2615
|
-
adapter_name = get_adapter_name(text_encoder)
|
2616
|
-
|
2617
|
-
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
2618
|
-
|
2619
|
-
# inject LoRA layers and load the state dict
|
2620
|
-
# in transformers we automatically check whether the adapter name is already in use or not
|
2621
|
-
text_encoder.load_adapter(
|
2622
|
-
adapter_name=adapter_name,
|
2623
|
-
adapter_state_dict=text_encoder_lora_state_dict,
|
2624
|
-
peft_config=lora_config,
|
2625
|
-
**peft_kwargs,
|
2626
|
-
)
|
2627
|
-
|
2628
|
-
# scale LoRA layers with `lora_scale`
|
2629
|
-
scale_lora_layers(text_encoder, weight=lora_scale)
|
2630
|
-
|
2631
|
-
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
2632
|
-
|
2633
|
-
# Offload back.
|
2634
|
-
if is_model_cpu_offload:
|
2635
|
-
_pipeline.enable_model_cpu_offload()
|
2636
|
-
elif is_sequential_cpu_offload:
|
2637
|
-
_pipeline.enable_sequential_cpu_offload()
|
2638
|
-
# Unsafe code />
|
2629
|
+
_load_lora_into_text_encoder(
|
2630
|
+
state_dict=state_dict,
|
2631
|
+
network_alphas=network_alphas,
|
2632
|
+
lora_scale=lora_scale,
|
2633
|
+
text_encoder=text_encoder,
|
2634
|
+
prefix=prefix,
|
2635
|
+
text_encoder_name=cls.text_encoder_name,
|
2636
|
+
adapter_name=adapter_name,
|
2637
|
+
_pipeline=_pipeline,
|
2638
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
2639
|
+
hotswap=hotswap,
|
2640
|
+
)
|
2639
2641
|
|
2640
2642
|
@classmethod
|
2641
2643
|
def save_lora_weights(
|
@@ -2851,7 +2853,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
|
2851
2853
|
@classmethod
|
2852
2854
|
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogVideoXTransformer3DModel
|
2853
2855
|
def load_lora_into_transformer(
|
2854
|
-
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
2856
|
+
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
|
2855
2857
|
):
|
2856
2858
|
"""
|
2857
2859
|
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
@@ -2869,6 +2871,29 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
|
2869
2871
|
low_cpu_mem_usage (`bool`, *optional*):
|
2870
2872
|
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
2871
2873
|
weights.
|
2874
|
+
hotswap : (`bool`, *optional*)
|
2875
|
+
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
2876
|
+
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
2877
|
+
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
2878
|
+
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
2879
|
+
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
2880
|
+
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
2881
|
+
|
2882
|
+
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
2883
|
+
to call an additional method before loading the adapter:
|
2884
|
+
|
2885
|
+
```py
|
2886
|
+
pipeline = ... # load diffusers pipeline
|
2887
|
+
max_rank = ... # the highest rank among all LoRAs that you want to load
|
2888
|
+
# call *before* compiling and loading the LoRA adapter
|
2889
|
+
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
2890
|
+
pipeline.load_lora_weights(file_name)
|
2891
|
+
# optionally compile the model now
|
2892
|
+
```
|
2893
|
+
|
2894
|
+
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
2895
|
+
limitations to this technique, which are documented here:
|
2896
|
+
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
2872
2897
|
"""
|
2873
2898
|
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
2874
2899
|
raise ValueError(
|
@@ -2883,6 +2908,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
|
2883
2908
|
adapter_name=adapter_name,
|
2884
2909
|
_pipeline=_pipeline,
|
2885
2910
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
2911
|
+
hotswap=hotswap,
|
2886
2912
|
)
|
2887
2913
|
|
2888
2914
|
@classmethod
|
@@ -2933,10 +2959,9 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
|
2933
2959
|
safe_serialization=safe_serialization,
|
2934
2960
|
)
|
2935
2961
|
|
2936
|
-
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
|
2937
2962
|
def fuse_lora(
|
2938
2963
|
self,
|
2939
|
-
components: List[str] = ["transformer"
|
2964
|
+
components: List[str] = ["transformer"],
|
2940
2965
|
lora_scale: float = 1.0,
|
2941
2966
|
safe_fusing: bool = False,
|
2942
2967
|
adapter_names: Optional[List[str]] = None,
|
@@ -2974,11 +2999,14 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
|
2974
2999
|
```
|
2975
3000
|
"""
|
2976
3001
|
super().fuse_lora(
|
2977
|
-
components=components,
|
3002
|
+
components=components,
|
3003
|
+
lora_scale=lora_scale,
|
3004
|
+
safe_fusing=safe_fusing,
|
3005
|
+
adapter_names=adapter_names,
|
3006
|
+
**kwargs,
|
2978
3007
|
)
|
2979
3008
|
|
2980
|
-
|
2981
|
-
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
|
3009
|
+
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
2982
3010
|
r"""
|
2983
3011
|
Reverses the effect of
|
2984
3012
|
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
@@ -2992,11 +3020,8 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
|
2992
3020
|
Args:
|
2993
3021
|
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
2994
3022
|
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
2995
|
-
unfuse_text_encoder (`bool`, defaults to `True`):
|
2996
|
-
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
2997
|
-
LoRA parameters then it won't have any effect.
|
2998
3023
|
"""
|
2999
|
-
super().unfuse_lora(components=components)
|
3024
|
+
super().unfuse_lora(components=components, **kwargs)
|
3000
3025
|
|
3001
3026
|
|
3002
3027
|
class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
@@ -3159,7 +3184,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
|
3159
3184
|
@classmethod
|
3160
3185
|
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->MochiTransformer3DModel
|
3161
3186
|
def load_lora_into_transformer(
|
3162
|
-
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
3187
|
+
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
|
3163
3188
|
):
|
3164
3189
|
"""
|
3165
3190
|
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
@@ -3177,6 +3202,29 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
|
3177
3202
|
low_cpu_mem_usage (`bool`, *optional*):
|
3178
3203
|
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
3179
3204
|
weights.
|
3205
|
+
hotswap : (`bool`, *optional*)
|
3206
|
+
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
3207
|
+
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
3208
|
+
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
3209
|
+
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
3210
|
+
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
3211
|
+
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
3212
|
+
|
3213
|
+
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
3214
|
+
to call an additional method before loading the adapter:
|
3215
|
+
|
3216
|
+
```py
|
3217
|
+
pipeline = ... # load diffusers pipeline
|
3218
|
+
max_rank = ... # the highest rank among all LoRAs that you want to load
|
3219
|
+
# call *before* compiling and loading the LoRA adapter
|
3220
|
+
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
3221
|
+
pipeline.load_lora_weights(file_name)
|
3222
|
+
# optionally compile the model now
|
3223
|
+
```
|
3224
|
+
|
3225
|
+
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
3226
|
+
limitations to this technique, which are documented here:
|
3227
|
+
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
3180
3228
|
"""
|
3181
3229
|
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
3182
3230
|
raise ValueError(
|
@@ -3191,6 +3239,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
|
3191
3239
|
adapter_name=adapter_name,
|
3192
3240
|
_pipeline=_pipeline,
|
3193
3241
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
3242
|
+
hotswap=hotswap,
|
3194
3243
|
)
|
3195
3244
|
|
3196
3245
|
@classmethod
|
@@ -3241,10 +3290,10 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
|
3241
3290
|
safe_serialization=safe_serialization,
|
3242
3291
|
)
|
3243
3292
|
|
3244
|
-
# Copied from diffusers.loaders.lora_pipeline.
|
3293
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
3245
3294
|
def fuse_lora(
|
3246
3295
|
self,
|
3247
|
-
components: List[str] = ["transformer"
|
3296
|
+
components: List[str] = ["transformer"],
|
3248
3297
|
lora_scale: float = 1.0,
|
3249
3298
|
safe_fusing: bool = False,
|
3250
3299
|
adapter_names: Optional[List[str]] = None,
|
@@ -3282,11 +3331,15 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
|
3282
3331
|
```
|
3283
3332
|
"""
|
3284
3333
|
super().fuse_lora(
|
3285
|
-
components=components,
|
3334
|
+
components=components,
|
3335
|
+
lora_scale=lora_scale,
|
3336
|
+
safe_fusing=safe_fusing,
|
3337
|
+
adapter_names=adapter_names,
|
3338
|
+
**kwargs,
|
3286
3339
|
)
|
3287
3340
|
|
3288
|
-
# Copied from diffusers.loaders.lora_pipeline.
|
3289
|
-
def unfuse_lora(self, components: List[str] = ["transformer"
|
3341
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
3342
|
+
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
3290
3343
|
r"""
|
3291
3344
|
Reverses the effect of
|
3292
3345
|
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
@@ -3300,11 +3353,8 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
|
3300
3353
|
Args:
|
3301
3354
|
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
3302
3355
|
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
3303
|
-
unfuse_text_encoder (`bool`, defaults to `True`):
|
3304
|
-
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
3305
|
-
LoRA parameters then it won't have any effect.
|
3306
3356
|
"""
|
3307
|
-
super().unfuse_lora(components=components)
|
3357
|
+
super().unfuse_lora(components=components, **kwargs)
|
3308
3358
|
|
3309
3359
|
|
3310
3360
|
class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
@@ -3467,7 +3517,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
|
3467
3517
|
@classmethod
|
3468
3518
|
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->LTXVideoTransformer3DModel
|
3469
3519
|
def load_lora_into_transformer(
|
3470
|
-
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
3520
|
+
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
|
3471
3521
|
):
|
3472
3522
|
"""
|
3473
3523
|
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
@@ -3485,6 +3535,29 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
|
3485
3535
|
low_cpu_mem_usage (`bool`, *optional*):
|
3486
3536
|
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
3487
3537
|
weights.
|
3538
|
+
hotswap : (`bool`, *optional*)
|
3539
|
+
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
3540
|
+
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
3541
|
+
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
3542
|
+
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
3543
|
+
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
3544
|
+
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
3545
|
+
|
3546
|
+
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
3547
|
+
to call an additional method before loading the adapter:
|
3548
|
+
|
3549
|
+
```py
|
3550
|
+
pipeline = ... # load diffusers pipeline
|
3551
|
+
max_rank = ... # the highest rank among all LoRAs that you want to load
|
3552
|
+
# call *before* compiling and loading the LoRA adapter
|
3553
|
+
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
3554
|
+
pipeline.load_lora_weights(file_name)
|
3555
|
+
# optionally compile the model now
|
3556
|
+
```
|
3557
|
+
|
3558
|
+
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
3559
|
+
limitations to this technique, which are documented here:
|
3560
|
+
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
3488
3561
|
"""
|
3489
3562
|
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
3490
3563
|
raise ValueError(
|
@@ -3499,6 +3572,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
|
3499
3572
|
adapter_name=adapter_name,
|
3500
3573
|
_pipeline=_pipeline,
|
3501
3574
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
3575
|
+
hotswap=hotswap,
|
3502
3576
|
)
|
3503
3577
|
|
3504
3578
|
@classmethod
|
@@ -3549,10 +3623,10 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
|
3549
3623
|
safe_serialization=safe_serialization,
|
3550
3624
|
)
|
3551
3625
|
|
3552
|
-
# Copied from diffusers.loaders.lora_pipeline.
|
3626
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
3553
3627
|
def fuse_lora(
|
3554
3628
|
self,
|
3555
|
-
components: List[str] = ["transformer"
|
3629
|
+
components: List[str] = ["transformer"],
|
3556
3630
|
lora_scale: float = 1.0,
|
3557
3631
|
safe_fusing: bool = False,
|
3558
3632
|
adapter_names: Optional[List[str]] = None,
|
@@ -3590,11 +3664,15 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
|
3590
3664
|
```
|
3591
3665
|
"""
|
3592
3666
|
super().fuse_lora(
|
3593
|
-
components=components,
|
3667
|
+
components=components,
|
3668
|
+
lora_scale=lora_scale,
|
3669
|
+
safe_fusing=safe_fusing,
|
3670
|
+
adapter_names=adapter_names,
|
3671
|
+
**kwargs,
|
3594
3672
|
)
|
3595
3673
|
|
3596
|
-
# Copied from diffusers.loaders.lora_pipeline.
|
3597
|
-
def unfuse_lora(self, components: List[str] = ["transformer"
|
3674
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
3675
|
+
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
3598
3676
|
r"""
|
3599
3677
|
Reverses the effect of
|
3600
3678
|
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
@@ -3608,11 +3686,8 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
|
3608
3686
|
Args:
|
3609
3687
|
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
3610
3688
|
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
3611
|
-
unfuse_text_encoder (`bool`, defaults to `True`):
|
3612
|
-
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
3613
|
-
LoRA parameters then it won't have any effect.
|
3614
3689
|
"""
|
3615
|
-
super().unfuse_lora(components=components)
|
3690
|
+
super().unfuse_lora(components=components, **kwargs)
|
3616
3691
|
|
3617
3692
|
|
3618
3693
|
class SanaLoraLoaderMixin(LoraBaseMixin):
|
@@ -3775,7 +3850,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
|
3775
3850
|
@classmethod
|
3776
3851
|
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SanaTransformer2DModel
|
3777
3852
|
def load_lora_into_transformer(
|
3778
|
-
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
3853
|
+
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
|
3779
3854
|
):
|
3780
3855
|
"""
|
3781
3856
|
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
@@ -3793,6 +3868,29 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
|
3793
3868
|
low_cpu_mem_usage (`bool`, *optional*):
|
3794
3869
|
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
3795
3870
|
weights.
|
3871
|
+
hotswap : (`bool`, *optional*)
|
3872
|
+
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
3873
|
+
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
3874
|
+
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
3875
|
+
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
3876
|
+
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
3877
|
+
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
3878
|
+
|
3879
|
+
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
3880
|
+
to call an additional method before loading the adapter:
|
3881
|
+
|
3882
|
+
```py
|
3883
|
+
pipeline = ... # load diffusers pipeline
|
3884
|
+
max_rank = ... # the highest rank among all LoRAs that you want to load
|
3885
|
+
# call *before* compiling and loading the LoRA adapter
|
3886
|
+
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
3887
|
+
pipeline.load_lora_weights(file_name)
|
3888
|
+
# optionally compile the model now
|
3889
|
+
```
|
3890
|
+
|
3891
|
+
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
3892
|
+
limitations to this technique, which are documented here:
|
3893
|
+
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
3796
3894
|
"""
|
3797
3895
|
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
3798
3896
|
raise ValueError(
|
@@ -3807,6 +3905,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
|
3807
3905
|
adapter_name=adapter_name,
|
3808
3906
|
_pipeline=_pipeline,
|
3809
3907
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
3908
|
+
hotswap=hotswap,
|
3810
3909
|
)
|
3811
3910
|
|
3812
3911
|
@classmethod
|
@@ -3857,10 +3956,10 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
|
3857
3956
|
safe_serialization=safe_serialization,
|
3858
3957
|
)
|
3859
3958
|
|
3860
|
-
# Copied from diffusers.loaders.lora_pipeline.
|
3959
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
3861
3960
|
def fuse_lora(
|
3862
3961
|
self,
|
3863
|
-
components: List[str] = ["transformer"
|
3962
|
+
components: List[str] = ["transformer"],
|
3864
3963
|
lora_scale: float = 1.0,
|
3865
3964
|
safe_fusing: bool = False,
|
3866
3965
|
adapter_names: Optional[List[str]] = None,
|
@@ -3898,11 +3997,15 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
|
3898
3997
|
```
|
3899
3998
|
"""
|
3900
3999
|
super().fuse_lora(
|
3901
|
-
components=components,
|
4000
|
+
components=components,
|
4001
|
+
lora_scale=lora_scale,
|
4002
|
+
safe_fusing=safe_fusing,
|
4003
|
+
adapter_names=adapter_names,
|
4004
|
+
**kwargs,
|
3902
4005
|
)
|
3903
4006
|
|
3904
|
-
# Copied from diffusers.loaders.lora_pipeline.
|
3905
|
-
def unfuse_lora(self, components: List[str] = ["transformer"
|
4007
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
4008
|
+
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
3906
4009
|
r"""
|
3907
4010
|
Reverses the effect of
|
3908
4011
|
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
@@ -3916,11 +4019,8 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
|
3916
4019
|
Args:
|
3917
4020
|
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
3918
4021
|
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
3919
|
-
unfuse_text_encoder (`bool`, defaults to `True`):
|
3920
|
-
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
3921
|
-
LoRA parameters then it won't have any effect.
|
3922
4022
|
"""
|
3923
|
-
super().unfuse_lora(components=components)
|
4023
|
+
super().unfuse_lora(components=components, **kwargs)
|
3924
4024
|
|
3925
4025
|
|
3926
4026
|
class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
@@ -3933,7 +4033,6 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
|
3933
4033
|
|
3934
4034
|
@classmethod
|
3935
4035
|
@validate_hf_hub_args
|
3936
|
-
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
|
3937
4036
|
def lora_state_dict(
|
3938
4037
|
cls,
|
3939
4038
|
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
@@ -3944,7 +4043,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
|
3944
4043
|
|
3945
4044
|
<Tip warning={true}>
|
3946
4045
|
|
3947
|
-
We support loading
|
4046
|
+
We support loading original format HunyuanVideo LoRA checkpoints.
|
3948
4047
|
|
3949
4048
|
This function is experimental and might change in the future.
|
3950
4049
|
|
@@ -4027,6 +4126,10 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
|
4027
4126
|
logger.warning(warn_msg)
|
4028
4127
|
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
4029
4128
|
|
4129
|
+
is_original_hunyuan_video = any("img_attn_qkv" in k for k in state_dict)
|
4130
|
+
if is_original_hunyuan_video:
|
4131
|
+
state_dict = _convert_hunyuan_video_lora_to_diffusers(state_dict)
|
4132
|
+
|
4030
4133
|
return state_dict
|
4031
4134
|
|
4032
4135
|
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
@@ -4083,7 +4186,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
|
4083
4186
|
@classmethod
|
4084
4187
|
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HunyuanVideoTransformer3DModel
|
4085
4188
|
def load_lora_into_transformer(
|
4086
|
-
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
4189
|
+
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
|
4087
4190
|
):
|
4088
4191
|
"""
|
4089
4192
|
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
@@ -4101,6 +4204,29 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
|
4101
4204
|
low_cpu_mem_usage (`bool`, *optional*):
|
4102
4205
|
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
4103
4206
|
weights.
|
4207
|
+
hotswap : (`bool`, *optional*)
|
4208
|
+
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
4209
|
+
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
4210
|
+
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
4211
|
+
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
4212
|
+
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
4213
|
+
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
4214
|
+
|
4215
|
+
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
4216
|
+
to call an additional method before loading the adapter:
|
4217
|
+
|
4218
|
+
```py
|
4219
|
+
pipeline = ... # load diffusers pipeline
|
4220
|
+
max_rank = ... # the highest rank among all LoRAs that you want to load
|
4221
|
+
# call *before* compiling and loading the LoRA adapter
|
4222
|
+
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
4223
|
+
pipeline.load_lora_weights(file_name)
|
4224
|
+
# optionally compile the model now
|
4225
|
+
```
|
4226
|
+
|
4227
|
+
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
4228
|
+
limitations to this technique, which are documented here:
|
4229
|
+
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
4104
4230
|
"""
|
4105
4231
|
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
4106
4232
|
raise ValueError(
|
@@ -4115,6 +4241,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
|
4115
4241
|
adapter_name=adapter_name,
|
4116
4242
|
_pipeline=_pipeline,
|
4117
4243
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
4244
|
+
hotswap=hotswap,
|
4118
4245
|
)
|
4119
4246
|
|
4120
4247
|
@classmethod
|
@@ -4165,10 +4292,10 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
|
4165
4292
|
safe_serialization=safe_serialization,
|
4166
4293
|
)
|
4167
4294
|
|
4168
|
-
# Copied from diffusers.loaders.lora_pipeline.
|
4295
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
4169
4296
|
def fuse_lora(
|
4170
4297
|
self,
|
4171
|
-
components: List[str] = ["transformer"
|
4298
|
+
components: List[str] = ["transformer"],
|
4172
4299
|
lora_scale: float = 1.0,
|
4173
4300
|
safe_fusing: bool = False,
|
4174
4301
|
adapter_names: Optional[List[str]] = None,
|
@@ -4206,11 +4333,1049 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
|
4206
4333
|
```
|
4207
4334
|
"""
|
4208
4335
|
super().fuse_lora(
|
4209
|
-
components=components,
|
4336
|
+
components=components,
|
4337
|
+
lora_scale=lora_scale,
|
4338
|
+
safe_fusing=safe_fusing,
|
4339
|
+
adapter_names=adapter_names,
|
4340
|
+
**kwargs,
|
4210
4341
|
)
|
4211
4342
|
|
4212
|
-
# Copied from diffusers.loaders.lora_pipeline.
|
4213
|
-
def unfuse_lora(self, components: List[str] = ["transformer"
|
4343
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
4344
|
+
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
4345
|
+
r"""
|
4346
|
+
Reverses the effect of
|
4347
|
+
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
4348
|
+
|
4349
|
+
<Tip warning={true}>
|
4350
|
+
|
4351
|
+
This is an experimental API.
|
4352
|
+
|
4353
|
+
</Tip>
|
4354
|
+
|
4355
|
+
Args:
|
4356
|
+
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
4357
|
+
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
4358
|
+
"""
|
4359
|
+
super().unfuse_lora(components=components, **kwargs)
|
4360
|
+
|
4361
|
+
|
4362
|
+
class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
4363
|
+
r"""
|
4364
|
+
Load LoRA layers into [`Lumina2Transformer2DModel`]. Specific to [`Lumina2Text2ImgPipeline`].
|
4365
|
+
"""
|
4366
|
+
|
4367
|
+
_lora_loadable_modules = ["transformer"]
|
4368
|
+
transformer_name = TRANSFORMER_NAME
|
4369
|
+
|
4370
|
+
@classmethod
|
4371
|
+
@validate_hf_hub_args
|
4372
|
+
def lora_state_dict(
|
4373
|
+
cls,
|
4374
|
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
4375
|
+
**kwargs,
|
4376
|
+
):
|
4377
|
+
r"""
|
4378
|
+
Return state dict for lora weights and the network alphas.
|
4379
|
+
|
4380
|
+
<Tip warning={true}>
|
4381
|
+
|
4382
|
+
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
|
4383
|
+
|
4384
|
+
This function is experimental and might change in the future.
|
4385
|
+
|
4386
|
+
</Tip>
|
4387
|
+
|
4388
|
+
Parameters:
|
4389
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
4390
|
+
Can be either:
|
4391
|
+
|
4392
|
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
4393
|
+
the Hub.
|
4394
|
+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
4395
|
+
with [`ModelMixin.save_pretrained`].
|
4396
|
+
- A [torch state
|
4397
|
+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
4398
|
+
|
4399
|
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
4400
|
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
4401
|
+
is not used.
|
4402
|
+
force_download (`bool`, *optional*, defaults to `False`):
|
4403
|
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
4404
|
+
cached versions if they exist.
|
4405
|
+
|
4406
|
+
proxies (`Dict[str, str]`, *optional*):
|
4407
|
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
4408
|
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
4409
|
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
4410
|
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
4411
|
+
won't be downloaded from the Hub.
|
4412
|
+
token (`str` or *bool*, *optional*):
|
4413
|
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
4414
|
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
4415
|
+
revision (`str`, *optional*, defaults to `"main"`):
|
4416
|
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
4417
|
+
allowed by Git.
|
4418
|
+
subfolder (`str`, *optional*, defaults to `""`):
|
4419
|
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
4420
|
+
|
4421
|
+
"""
|
4422
|
+
# Load the main state dict first which has the LoRA layers for either of
|
4423
|
+
# transformer and text encoder or both.
|
4424
|
+
cache_dir = kwargs.pop("cache_dir", None)
|
4425
|
+
force_download = kwargs.pop("force_download", False)
|
4426
|
+
proxies = kwargs.pop("proxies", None)
|
4427
|
+
local_files_only = kwargs.pop("local_files_only", None)
|
4428
|
+
token = kwargs.pop("token", None)
|
4429
|
+
revision = kwargs.pop("revision", None)
|
4430
|
+
subfolder = kwargs.pop("subfolder", None)
|
4431
|
+
weight_name = kwargs.pop("weight_name", None)
|
4432
|
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
4433
|
+
|
4434
|
+
allow_pickle = False
|
4435
|
+
if use_safetensors is None:
|
4436
|
+
use_safetensors = True
|
4437
|
+
allow_pickle = True
|
4438
|
+
|
4439
|
+
user_agent = {
|
4440
|
+
"file_type": "attn_procs_weights",
|
4441
|
+
"framework": "pytorch",
|
4442
|
+
}
|
4443
|
+
|
4444
|
+
state_dict = _fetch_state_dict(
|
4445
|
+
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
4446
|
+
weight_name=weight_name,
|
4447
|
+
use_safetensors=use_safetensors,
|
4448
|
+
local_files_only=local_files_only,
|
4449
|
+
cache_dir=cache_dir,
|
4450
|
+
force_download=force_download,
|
4451
|
+
proxies=proxies,
|
4452
|
+
token=token,
|
4453
|
+
revision=revision,
|
4454
|
+
subfolder=subfolder,
|
4455
|
+
user_agent=user_agent,
|
4456
|
+
allow_pickle=allow_pickle,
|
4457
|
+
)
|
4458
|
+
|
4459
|
+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
4460
|
+
if is_dora_scale_present:
|
4461
|
+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
4462
|
+
logger.warning(warn_msg)
|
4463
|
+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
4464
|
+
|
4465
|
+
# conversion.
|
4466
|
+
non_diffusers = any(k.startswith("diffusion_model.") for k in state_dict)
|
4467
|
+
if non_diffusers:
|
4468
|
+
state_dict = _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict)
|
4469
|
+
|
4470
|
+
return state_dict
|
4471
|
+
|
4472
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
4473
|
+
def load_lora_weights(
|
4474
|
+
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
4475
|
+
):
|
4476
|
+
"""
|
4477
|
+
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
4478
|
+
`self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
|
4479
|
+
[`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
|
4480
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
|
4481
|
+
dict is loaded into `self.transformer`.
|
4482
|
+
|
4483
|
+
Parameters:
|
4484
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
4485
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
4486
|
+
adapter_name (`str`, *optional*):
|
4487
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
4488
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
4489
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
4490
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
4491
|
+
weights.
|
4492
|
+
kwargs (`dict`, *optional*):
|
4493
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
4494
|
+
"""
|
4495
|
+
if not USE_PEFT_BACKEND:
|
4496
|
+
raise ValueError("PEFT backend is required for this method.")
|
4497
|
+
|
4498
|
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
4499
|
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
4500
|
+
raise ValueError(
|
4501
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
4502
|
+
)
|
4503
|
+
|
4504
|
+
# if a dict is passed, copy it instead of modifying it inplace
|
4505
|
+
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
4506
|
+
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
4507
|
+
|
4508
|
+
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
4509
|
+
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
4510
|
+
|
4511
|
+
is_correct_format = all("lora" in key for key in state_dict.keys())
|
4512
|
+
if not is_correct_format:
|
4513
|
+
raise ValueError("Invalid LoRA checkpoint.")
|
4514
|
+
|
4515
|
+
self.load_lora_into_transformer(
|
4516
|
+
state_dict,
|
4517
|
+
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
4518
|
+
adapter_name=adapter_name,
|
4519
|
+
_pipeline=self,
|
4520
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
4521
|
+
)
|
4522
|
+
|
4523
|
+
@classmethod
|
4524
|
+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->Lumina2Transformer2DModel
|
4525
|
+
def load_lora_into_transformer(
|
4526
|
+
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
|
4527
|
+
):
|
4528
|
+
"""
|
4529
|
+
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
4530
|
+
|
4531
|
+
Parameters:
|
4532
|
+
state_dict (`dict`):
|
4533
|
+
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
4534
|
+
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
4535
|
+
encoder lora layers.
|
4536
|
+
transformer (`Lumina2Transformer2DModel`):
|
4537
|
+
The Transformer model to load the LoRA layers into.
|
4538
|
+
adapter_name (`str`, *optional*):
|
4539
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
4540
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
4541
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
4542
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
4543
|
+
weights.
|
4544
|
+
hotswap : (`bool`, *optional*)
|
4545
|
+
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
4546
|
+
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
4547
|
+
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
4548
|
+
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
4549
|
+
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
4550
|
+
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
4551
|
+
|
4552
|
+
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
4553
|
+
to call an additional method before loading the adapter:
|
4554
|
+
|
4555
|
+
```py
|
4556
|
+
pipeline = ... # load diffusers pipeline
|
4557
|
+
max_rank = ... # the highest rank among all LoRAs that you want to load
|
4558
|
+
# call *before* compiling and loading the LoRA adapter
|
4559
|
+
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
4560
|
+
pipeline.load_lora_weights(file_name)
|
4561
|
+
# optionally compile the model now
|
4562
|
+
```
|
4563
|
+
|
4564
|
+
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
4565
|
+
limitations to this technique, which are documented here:
|
4566
|
+
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
4567
|
+
"""
|
4568
|
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
4569
|
+
raise ValueError(
|
4570
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
4571
|
+
)
|
4572
|
+
|
4573
|
+
# Load the layers corresponding to transformer.
|
4574
|
+
logger.info(f"Loading {cls.transformer_name}.")
|
4575
|
+
transformer.load_lora_adapter(
|
4576
|
+
state_dict,
|
4577
|
+
network_alphas=None,
|
4578
|
+
adapter_name=adapter_name,
|
4579
|
+
_pipeline=_pipeline,
|
4580
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
4581
|
+
hotswap=hotswap,
|
4582
|
+
)
|
4583
|
+
|
4584
|
+
@classmethod
|
4585
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
|
4586
|
+
def save_lora_weights(
|
4587
|
+
cls,
|
4588
|
+
save_directory: Union[str, os.PathLike],
|
4589
|
+
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
4590
|
+
is_main_process: bool = True,
|
4591
|
+
weight_name: str = None,
|
4592
|
+
save_function: Callable = None,
|
4593
|
+
safe_serialization: bool = True,
|
4594
|
+
):
|
4595
|
+
r"""
|
4596
|
+
Save the LoRA parameters corresponding to the UNet and text encoder.
|
4597
|
+
|
4598
|
+
Arguments:
|
4599
|
+
save_directory (`str` or `os.PathLike`):
|
4600
|
+
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
4601
|
+
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
4602
|
+
State dict of the LoRA layers corresponding to the `transformer`.
|
4603
|
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
4604
|
+
Whether the process calling this is the main process or not. Useful during distributed training and you
|
4605
|
+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
4606
|
+
process to avoid race conditions.
|
4607
|
+
save_function (`Callable`):
|
4608
|
+
The function to use to save the state dictionary. Useful during distributed training when you need to
|
4609
|
+
replace `torch.save` with another method. Can be configured with the environment variable
|
4610
|
+
`DIFFUSERS_SAVE_MODE`.
|
4611
|
+
safe_serialization (`bool`, *optional*, defaults to `True`):
|
4612
|
+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
4613
|
+
"""
|
4614
|
+
state_dict = {}
|
4615
|
+
|
4616
|
+
if not transformer_lora_layers:
|
4617
|
+
raise ValueError("You must pass `transformer_lora_layers`.")
|
4618
|
+
|
4619
|
+
if transformer_lora_layers:
|
4620
|
+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
4621
|
+
|
4622
|
+
# Save the model
|
4623
|
+
cls.write_lora_layers(
|
4624
|
+
state_dict=state_dict,
|
4625
|
+
save_directory=save_directory,
|
4626
|
+
is_main_process=is_main_process,
|
4627
|
+
weight_name=weight_name,
|
4628
|
+
save_function=save_function,
|
4629
|
+
safe_serialization=safe_serialization,
|
4630
|
+
)
|
4631
|
+
|
4632
|
+
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
|
4633
|
+
def fuse_lora(
|
4634
|
+
self,
|
4635
|
+
components: List[str] = ["transformer"],
|
4636
|
+
lora_scale: float = 1.0,
|
4637
|
+
safe_fusing: bool = False,
|
4638
|
+
adapter_names: Optional[List[str]] = None,
|
4639
|
+
**kwargs,
|
4640
|
+
):
|
4641
|
+
r"""
|
4642
|
+
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
4643
|
+
|
4644
|
+
<Tip warning={true}>
|
4645
|
+
|
4646
|
+
This is an experimental API.
|
4647
|
+
|
4648
|
+
</Tip>
|
4649
|
+
|
4650
|
+
Args:
|
4651
|
+
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
4652
|
+
lora_scale (`float`, defaults to 1.0):
|
4653
|
+
Controls how much to influence the outputs with the LoRA parameters.
|
4654
|
+
safe_fusing (`bool`, defaults to `False`):
|
4655
|
+
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
4656
|
+
adapter_names (`List[str]`, *optional*):
|
4657
|
+
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
4658
|
+
|
4659
|
+
Example:
|
4660
|
+
|
4661
|
+
```py
|
4662
|
+
from diffusers import DiffusionPipeline
|
4663
|
+
import torch
|
4664
|
+
|
4665
|
+
pipeline = DiffusionPipeline.from_pretrained(
|
4666
|
+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
4667
|
+
).to("cuda")
|
4668
|
+
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
4669
|
+
pipeline.fuse_lora(lora_scale=0.7)
|
4670
|
+
```
|
4671
|
+
"""
|
4672
|
+
super().fuse_lora(
|
4673
|
+
components=components,
|
4674
|
+
lora_scale=lora_scale,
|
4675
|
+
safe_fusing=safe_fusing,
|
4676
|
+
adapter_names=adapter_names,
|
4677
|
+
**kwargs,
|
4678
|
+
)
|
4679
|
+
|
4680
|
+
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
|
4681
|
+
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
4682
|
+
r"""
|
4683
|
+
Reverses the effect of
|
4684
|
+
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
4685
|
+
|
4686
|
+
<Tip warning={true}>
|
4687
|
+
|
4688
|
+
This is an experimental API.
|
4689
|
+
|
4690
|
+
</Tip>
|
4691
|
+
|
4692
|
+
Args:
|
4693
|
+
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
4694
|
+
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
4695
|
+
"""
|
4696
|
+
super().unfuse_lora(components=components, **kwargs)
|
4697
|
+
|
4698
|
+
|
4699
|
+
class WanLoraLoaderMixin(LoraBaseMixin):
|
4700
|
+
r"""
|
4701
|
+
Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`WanPipeline`] and `[WanImageToVideoPipeline`].
|
4702
|
+
"""
|
4703
|
+
|
4704
|
+
_lora_loadable_modules = ["transformer"]
|
4705
|
+
transformer_name = TRANSFORMER_NAME
|
4706
|
+
|
4707
|
+
@classmethod
|
4708
|
+
@validate_hf_hub_args
|
4709
|
+
def lora_state_dict(
|
4710
|
+
cls,
|
4711
|
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
4712
|
+
**kwargs,
|
4713
|
+
):
|
4714
|
+
r"""
|
4715
|
+
Return state dict for lora weights and the network alphas.
|
4716
|
+
|
4717
|
+
<Tip warning={true}>
|
4718
|
+
|
4719
|
+
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
|
4720
|
+
|
4721
|
+
This function is experimental and might change in the future.
|
4722
|
+
|
4723
|
+
</Tip>
|
4724
|
+
|
4725
|
+
Parameters:
|
4726
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
4727
|
+
Can be either:
|
4728
|
+
|
4729
|
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
4730
|
+
the Hub.
|
4731
|
+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
4732
|
+
with [`ModelMixin.save_pretrained`].
|
4733
|
+
- A [torch state
|
4734
|
+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
4735
|
+
|
4736
|
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
4737
|
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
4738
|
+
is not used.
|
4739
|
+
force_download (`bool`, *optional*, defaults to `False`):
|
4740
|
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
4741
|
+
cached versions if they exist.
|
4742
|
+
|
4743
|
+
proxies (`Dict[str, str]`, *optional*):
|
4744
|
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
4745
|
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
4746
|
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
4747
|
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
4748
|
+
won't be downloaded from the Hub.
|
4749
|
+
token (`str` or *bool*, *optional*):
|
4750
|
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
4751
|
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
4752
|
+
revision (`str`, *optional*, defaults to `"main"`):
|
4753
|
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
4754
|
+
allowed by Git.
|
4755
|
+
subfolder (`str`, *optional*, defaults to `""`):
|
4756
|
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
4757
|
+
|
4758
|
+
"""
|
4759
|
+
# Load the main state dict first which has the LoRA layers for either of
|
4760
|
+
# transformer and text encoder or both.
|
4761
|
+
cache_dir = kwargs.pop("cache_dir", None)
|
4762
|
+
force_download = kwargs.pop("force_download", False)
|
4763
|
+
proxies = kwargs.pop("proxies", None)
|
4764
|
+
local_files_only = kwargs.pop("local_files_only", None)
|
4765
|
+
token = kwargs.pop("token", None)
|
4766
|
+
revision = kwargs.pop("revision", None)
|
4767
|
+
subfolder = kwargs.pop("subfolder", None)
|
4768
|
+
weight_name = kwargs.pop("weight_name", None)
|
4769
|
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
4770
|
+
|
4771
|
+
allow_pickle = False
|
4772
|
+
if use_safetensors is None:
|
4773
|
+
use_safetensors = True
|
4774
|
+
allow_pickle = True
|
4775
|
+
|
4776
|
+
user_agent = {
|
4777
|
+
"file_type": "attn_procs_weights",
|
4778
|
+
"framework": "pytorch",
|
4779
|
+
}
|
4780
|
+
|
4781
|
+
state_dict = _fetch_state_dict(
|
4782
|
+
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
4783
|
+
weight_name=weight_name,
|
4784
|
+
use_safetensors=use_safetensors,
|
4785
|
+
local_files_only=local_files_only,
|
4786
|
+
cache_dir=cache_dir,
|
4787
|
+
force_download=force_download,
|
4788
|
+
proxies=proxies,
|
4789
|
+
token=token,
|
4790
|
+
revision=revision,
|
4791
|
+
subfolder=subfolder,
|
4792
|
+
user_agent=user_agent,
|
4793
|
+
allow_pickle=allow_pickle,
|
4794
|
+
)
|
4795
|
+
if any(k.startswith("diffusion_model.") for k in state_dict):
|
4796
|
+
state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict)
|
4797
|
+
|
4798
|
+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
4799
|
+
if is_dora_scale_present:
|
4800
|
+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
4801
|
+
logger.warning(warn_msg)
|
4802
|
+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
4803
|
+
|
4804
|
+
return state_dict
|
4805
|
+
|
4806
|
+
@classmethod
|
4807
|
+
def _maybe_expand_t2v_lora_for_i2v(
|
4808
|
+
cls,
|
4809
|
+
transformer: torch.nn.Module,
|
4810
|
+
state_dict,
|
4811
|
+
):
|
4812
|
+
if transformer.config.image_dim is None:
|
4813
|
+
return state_dict
|
4814
|
+
|
4815
|
+
if any(k.startswith("transformer.blocks.") for k in state_dict):
|
4816
|
+
num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict})
|
4817
|
+
is_i2v_lora = any("add_k_proj" in k for k in state_dict) and any("add_v_proj" in k for k in state_dict)
|
4818
|
+
|
4819
|
+
if is_i2v_lora:
|
4820
|
+
return state_dict
|
4821
|
+
|
4822
|
+
for i in range(num_blocks):
|
4823
|
+
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
|
4824
|
+
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like(
|
4825
|
+
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"]
|
4826
|
+
)
|
4827
|
+
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like(
|
4828
|
+
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"]
|
4829
|
+
)
|
4830
|
+
|
4831
|
+
return state_dict
|
4832
|
+
|
4833
|
+
def load_lora_weights(
|
4834
|
+
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
4835
|
+
):
|
4836
|
+
"""
|
4837
|
+
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
4838
|
+
`self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
|
4839
|
+
[`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
|
4840
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
|
4841
|
+
dict is loaded into `self.transformer`.
|
4842
|
+
|
4843
|
+
Parameters:
|
4844
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
4845
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
4846
|
+
adapter_name (`str`, *optional*):
|
4847
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
4848
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
4849
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
4850
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
4851
|
+
weights.
|
4852
|
+
kwargs (`dict`, *optional*):
|
4853
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
4854
|
+
"""
|
4855
|
+
if not USE_PEFT_BACKEND:
|
4856
|
+
raise ValueError("PEFT backend is required for this method.")
|
4857
|
+
|
4858
|
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
4859
|
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
4860
|
+
raise ValueError(
|
4861
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
4862
|
+
)
|
4863
|
+
|
4864
|
+
# if a dict is passed, copy it instead of modifying it inplace
|
4865
|
+
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
4866
|
+
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
4867
|
+
|
4868
|
+
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
4869
|
+
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
4870
|
+
# convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers
|
4871
|
+
state_dict = self._maybe_expand_t2v_lora_for_i2v(
|
4872
|
+
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
4873
|
+
state_dict=state_dict,
|
4874
|
+
)
|
4875
|
+
is_correct_format = all("lora" in key for key in state_dict.keys())
|
4876
|
+
if not is_correct_format:
|
4877
|
+
raise ValueError("Invalid LoRA checkpoint.")
|
4878
|
+
|
4879
|
+
self.load_lora_into_transformer(
|
4880
|
+
state_dict,
|
4881
|
+
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
4882
|
+
adapter_name=adapter_name,
|
4883
|
+
_pipeline=self,
|
4884
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
4885
|
+
)
|
4886
|
+
|
4887
|
+
@classmethod
|
4888
|
+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel
|
4889
|
+
def load_lora_into_transformer(
|
4890
|
+
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
|
4891
|
+
):
|
4892
|
+
"""
|
4893
|
+
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
4894
|
+
|
4895
|
+
Parameters:
|
4896
|
+
state_dict (`dict`):
|
4897
|
+
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
4898
|
+
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
4899
|
+
encoder lora layers.
|
4900
|
+
transformer (`WanTransformer3DModel`):
|
4901
|
+
The Transformer model to load the LoRA layers into.
|
4902
|
+
adapter_name (`str`, *optional*):
|
4903
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
4904
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
4905
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
4906
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
4907
|
+
weights.
|
4908
|
+
hotswap : (`bool`, *optional*)
|
4909
|
+
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
4910
|
+
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
4911
|
+
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
4912
|
+
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
4913
|
+
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
4914
|
+
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
4915
|
+
|
4916
|
+
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
4917
|
+
to call an additional method before loading the adapter:
|
4918
|
+
|
4919
|
+
```py
|
4920
|
+
pipeline = ... # load diffusers pipeline
|
4921
|
+
max_rank = ... # the highest rank among all LoRAs that you want to load
|
4922
|
+
# call *before* compiling and loading the LoRA adapter
|
4923
|
+
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
4924
|
+
pipeline.load_lora_weights(file_name)
|
4925
|
+
# optionally compile the model now
|
4926
|
+
```
|
4927
|
+
|
4928
|
+
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
4929
|
+
limitations to this technique, which are documented here:
|
4930
|
+
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
4931
|
+
"""
|
4932
|
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
4933
|
+
raise ValueError(
|
4934
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
4935
|
+
)
|
4936
|
+
|
4937
|
+
# Load the layers corresponding to transformer.
|
4938
|
+
logger.info(f"Loading {cls.transformer_name}.")
|
4939
|
+
transformer.load_lora_adapter(
|
4940
|
+
state_dict,
|
4941
|
+
network_alphas=None,
|
4942
|
+
adapter_name=adapter_name,
|
4943
|
+
_pipeline=_pipeline,
|
4944
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
4945
|
+
hotswap=hotswap,
|
4946
|
+
)
|
4947
|
+
|
4948
|
+
@classmethod
|
4949
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
|
4950
|
+
def save_lora_weights(
|
4951
|
+
cls,
|
4952
|
+
save_directory: Union[str, os.PathLike],
|
4953
|
+
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
4954
|
+
is_main_process: bool = True,
|
4955
|
+
weight_name: str = None,
|
4956
|
+
save_function: Callable = None,
|
4957
|
+
safe_serialization: bool = True,
|
4958
|
+
):
|
4959
|
+
r"""
|
4960
|
+
Save the LoRA parameters corresponding to the UNet and text encoder.
|
4961
|
+
|
4962
|
+
Arguments:
|
4963
|
+
save_directory (`str` or `os.PathLike`):
|
4964
|
+
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
4965
|
+
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
4966
|
+
State dict of the LoRA layers corresponding to the `transformer`.
|
4967
|
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
4968
|
+
Whether the process calling this is the main process or not. Useful during distributed training and you
|
4969
|
+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
4970
|
+
process to avoid race conditions.
|
4971
|
+
save_function (`Callable`):
|
4972
|
+
The function to use to save the state dictionary. Useful during distributed training when you need to
|
4973
|
+
replace `torch.save` with another method. Can be configured with the environment variable
|
4974
|
+
`DIFFUSERS_SAVE_MODE`.
|
4975
|
+
safe_serialization (`bool`, *optional*, defaults to `True`):
|
4976
|
+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
4977
|
+
"""
|
4978
|
+
state_dict = {}
|
4979
|
+
|
4980
|
+
if not transformer_lora_layers:
|
4981
|
+
raise ValueError("You must pass `transformer_lora_layers`.")
|
4982
|
+
|
4983
|
+
if transformer_lora_layers:
|
4984
|
+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
4985
|
+
|
4986
|
+
# Save the model
|
4987
|
+
cls.write_lora_layers(
|
4988
|
+
state_dict=state_dict,
|
4989
|
+
save_directory=save_directory,
|
4990
|
+
is_main_process=is_main_process,
|
4991
|
+
weight_name=weight_name,
|
4992
|
+
save_function=save_function,
|
4993
|
+
safe_serialization=safe_serialization,
|
4994
|
+
)
|
4995
|
+
|
4996
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
4997
|
+
def fuse_lora(
|
4998
|
+
self,
|
4999
|
+
components: List[str] = ["transformer"],
|
5000
|
+
lora_scale: float = 1.0,
|
5001
|
+
safe_fusing: bool = False,
|
5002
|
+
adapter_names: Optional[List[str]] = None,
|
5003
|
+
**kwargs,
|
5004
|
+
):
|
5005
|
+
r"""
|
5006
|
+
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
5007
|
+
|
5008
|
+
<Tip warning={true}>
|
5009
|
+
|
5010
|
+
This is an experimental API.
|
5011
|
+
|
5012
|
+
</Tip>
|
5013
|
+
|
5014
|
+
Args:
|
5015
|
+
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
5016
|
+
lora_scale (`float`, defaults to 1.0):
|
5017
|
+
Controls how much to influence the outputs with the LoRA parameters.
|
5018
|
+
safe_fusing (`bool`, defaults to `False`):
|
5019
|
+
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
5020
|
+
adapter_names (`List[str]`, *optional*):
|
5021
|
+
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
5022
|
+
|
5023
|
+
Example:
|
5024
|
+
|
5025
|
+
```py
|
5026
|
+
from diffusers import DiffusionPipeline
|
5027
|
+
import torch
|
5028
|
+
|
5029
|
+
pipeline = DiffusionPipeline.from_pretrained(
|
5030
|
+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
5031
|
+
).to("cuda")
|
5032
|
+
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
5033
|
+
pipeline.fuse_lora(lora_scale=0.7)
|
5034
|
+
```
|
5035
|
+
"""
|
5036
|
+
super().fuse_lora(
|
5037
|
+
components=components,
|
5038
|
+
lora_scale=lora_scale,
|
5039
|
+
safe_fusing=safe_fusing,
|
5040
|
+
adapter_names=adapter_names,
|
5041
|
+
**kwargs,
|
5042
|
+
)
|
5043
|
+
|
5044
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
5045
|
+
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
5046
|
+
r"""
|
5047
|
+
Reverses the effect of
|
5048
|
+
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
5049
|
+
|
5050
|
+
<Tip warning={true}>
|
5051
|
+
|
5052
|
+
This is an experimental API.
|
5053
|
+
|
5054
|
+
</Tip>
|
5055
|
+
|
5056
|
+
Args:
|
5057
|
+
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
5058
|
+
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
5059
|
+
"""
|
5060
|
+
super().unfuse_lora(components=components, **kwargs)
|
5061
|
+
|
5062
|
+
|
5063
|
+
class CogView4LoraLoaderMixin(LoraBaseMixin):
|
5064
|
+
r"""
|
5065
|
+
Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`CogView4Pipeline`].
|
5066
|
+
"""
|
5067
|
+
|
5068
|
+
_lora_loadable_modules = ["transformer"]
|
5069
|
+
transformer_name = TRANSFORMER_NAME
|
5070
|
+
|
5071
|
+
@classmethod
|
5072
|
+
@validate_hf_hub_args
|
5073
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
|
5074
|
+
def lora_state_dict(
|
5075
|
+
cls,
|
5076
|
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
5077
|
+
**kwargs,
|
5078
|
+
):
|
5079
|
+
r"""
|
5080
|
+
Return state dict for lora weights and the network alphas.
|
5081
|
+
|
5082
|
+
<Tip warning={true}>
|
5083
|
+
|
5084
|
+
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
|
5085
|
+
|
5086
|
+
This function is experimental and might change in the future.
|
5087
|
+
|
5088
|
+
</Tip>
|
5089
|
+
|
5090
|
+
Parameters:
|
5091
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
5092
|
+
Can be either:
|
5093
|
+
|
5094
|
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
5095
|
+
the Hub.
|
5096
|
+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
5097
|
+
with [`ModelMixin.save_pretrained`].
|
5098
|
+
- A [torch state
|
5099
|
+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
5100
|
+
|
5101
|
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
5102
|
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
5103
|
+
is not used.
|
5104
|
+
force_download (`bool`, *optional*, defaults to `False`):
|
5105
|
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
5106
|
+
cached versions if they exist.
|
5107
|
+
|
5108
|
+
proxies (`Dict[str, str]`, *optional*):
|
5109
|
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
5110
|
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
5111
|
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
5112
|
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
5113
|
+
won't be downloaded from the Hub.
|
5114
|
+
token (`str` or *bool*, *optional*):
|
5115
|
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
5116
|
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
5117
|
+
revision (`str`, *optional*, defaults to `"main"`):
|
5118
|
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
5119
|
+
allowed by Git.
|
5120
|
+
subfolder (`str`, *optional*, defaults to `""`):
|
5121
|
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
5122
|
+
|
5123
|
+
"""
|
5124
|
+
# Load the main state dict first which has the LoRA layers for either of
|
5125
|
+
# transformer and text encoder or both.
|
5126
|
+
cache_dir = kwargs.pop("cache_dir", None)
|
5127
|
+
force_download = kwargs.pop("force_download", False)
|
5128
|
+
proxies = kwargs.pop("proxies", None)
|
5129
|
+
local_files_only = kwargs.pop("local_files_only", None)
|
5130
|
+
token = kwargs.pop("token", None)
|
5131
|
+
revision = kwargs.pop("revision", None)
|
5132
|
+
subfolder = kwargs.pop("subfolder", None)
|
5133
|
+
weight_name = kwargs.pop("weight_name", None)
|
5134
|
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
5135
|
+
|
5136
|
+
allow_pickle = False
|
5137
|
+
if use_safetensors is None:
|
5138
|
+
use_safetensors = True
|
5139
|
+
allow_pickle = True
|
5140
|
+
|
5141
|
+
user_agent = {
|
5142
|
+
"file_type": "attn_procs_weights",
|
5143
|
+
"framework": "pytorch",
|
5144
|
+
}
|
5145
|
+
|
5146
|
+
state_dict = _fetch_state_dict(
|
5147
|
+
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
5148
|
+
weight_name=weight_name,
|
5149
|
+
use_safetensors=use_safetensors,
|
5150
|
+
local_files_only=local_files_only,
|
5151
|
+
cache_dir=cache_dir,
|
5152
|
+
force_download=force_download,
|
5153
|
+
proxies=proxies,
|
5154
|
+
token=token,
|
5155
|
+
revision=revision,
|
5156
|
+
subfolder=subfolder,
|
5157
|
+
user_agent=user_agent,
|
5158
|
+
allow_pickle=allow_pickle,
|
5159
|
+
)
|
5160
|
+
|
5161
|
+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
5162
|
+
if is_dora_scale_present:
|
5163
|
+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
5164
|
+
logger.warning(warn_msg)
|
5165
|
+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
5166
|
+
|
5167
|
+
return state_dict
|
5168
|
+
|
5169
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
5170
|
+
def load_lora_weights(
|
5171
|
+
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
5172
|
+
):
|
5173
|
+
"""
|
5174
|
+
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
5175
|
+
`self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
|
5176
|
+
[`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
|
5177
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
|
5178
|
+
dict is loaded into `self.transformer`.
|
5179
|
+
|
5180
|
+
Parameters:
|
5181
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
5182
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
5183
|
+
adapter_name (`str`, *optional*):
|
5184
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
5185
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
5186
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
5187
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
5188
|
+
weights.
|
5189
|
+
kwargs (`dict`, *optional*):
|
5190
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
5191
|
+
"""
|
5192
|
+
if not USE_PEFT_BACKEND:
|
5193
|
+
raise ValueError("PEFT backend is required for this method.")
|
5194
|
+
|
5195
|
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
5196
|
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
5197
|
+
raise ValueError(
|
5198
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
5199
|
+
)
|
5200
|
+
|
5201
|
+
# if a dict is passed, copy it instead of modifying it inplace
|
5202
|
+
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
5203
|
+
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
5204
|
+
|
5205
|
+
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
5206
|
+
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
5207
|
+
|
5208
|
+
is_correct_format = all("lora" in key for key in state_dict.keys())
|
5209
|
+
if not is_correct_format:
|
5210
|
+
raise ValueError("Invalid LoRA checkpoint.")
|
5211
|
+
|
5212
|
+
self.load_lora_into_transformer(
|
5213
|
+
state_dict,
|
5214
|
+
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
5215
|
+
adapter_name=adapter_name,
|
5216
|
+
_pipeline=self,
|
5217
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
5218
|
+
)
|
5219
|
+
|
5220
|
+
@classmethod
|
5221
|
+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogView4Transformer2DModel
|
5222
|
+
def load_lora_into_transformer(
|
5223
|
+
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
|
5224
|
+
):
|
5225
|
+
"""
|
5226
|
+
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
5227
|
+
|
5228
|
+
Parameters:
|
5229
|
+
state_dict (`dict`):
|
5230
|
+
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
5231
|
+
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
5232
|
+
encoder lora layers.
|
5233
|
+
transformer (`CogView4Transformer2DModel`):
|
5234
|
+
The Transformer model to load the LoRA layers into.
|
5235
|
+
adapter_name (`str`, *optional*):
|
5236
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
5237
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
5238
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
5239
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
5240
|
+
weights.
|
5241
|
+
hotswap : (`bool`, *optional*)
|
5242
|
+
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
5243
|
+
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
5244
|
+
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
5245
|
+
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
5246
|
+
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
5247
|
+
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
5248
|
+
|
5249
|
+
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
5250
|
+
to call an additional method before loading the adapter:
|
5251
|
+
|
5252
|
+
```py
|
5253
|
+
pipeline = ... # load diffusers pipeline
|
5254
|
+
max_rank = ... # the highest rank among all LoRAs that you want to load
|
5255
|
+
# call *before* compiling and loading the LoRA adapter
|
5256
|
+
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
5257
|
+
pipeline.load_lora_weights(file_name)
|
5258
|
+
# optionally compile the model now
|
5259
|
+
```
|
5260
|
+
|
5261
|
+
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
5262
|
+
limitations to this technique, which are documented here:
|
5263
|
+
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
5264
|
+
"""
|
5265
|
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
5266
|
+
raise ValueError(
|
5267
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
5268
|
+
)
|
5269
|
+
|
5270
|
+
# Load the layers corresponding to transformer.
|
5271
|
+
logger.info(f"Loading {cls.transformer_name}.")
|
5272
|
+
transformer.load_lora_adapter(
|
5273
|
+
state_dict,
|
5274
|
+
network_alphas=None,
|
5275
|
+
adapter_name=adapter_name,
|
5276
|
+
_pipeline=_pipeline,
|
5277
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
5278
|
+
hotswap=hotswap,
|
5279
|
+
)
|
5280
|
+
|
5281
|
+
@classmethod
|
5282
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
|
5283
|
+
def save_lora_weights(
|
5284
|
+
cls,
|
5285
|
+
save_directory: Union[str, os.PathLike],
|
5286
|
+
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
5287
|
+
is_main_process: bool = True,
|
5288
|
+
weight_name: str = None,
|
5289
|
+
save_function: Callable = None,
|
5290
|
+
safe_serialization: bool = True,
|
5291
|
+
):
|
5292
|
+
r"""
|
5293
|
+
Save the LoRA parameters corresponding to the UNet and text encoder.
|
5294
|
+
|
5295
|
+
Arguments:
|
5296
|
+
save_directory (`str` or `os.PathLike`):
|
5297
|
+
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
5298
|
+
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
5299
|
+
State dict of the LoRA layers corresponding to the `transformer`.
|
5300
|
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
5301
|
+
Whether the process calling this is the main process or not. Useful during distributed training and you
|
5302
|
+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
5303
|
+
process to avoid race conditions.
|
5304
|
+
save_function (`Callable`):
|
5305
|
+
The function to use to save the state dictionary. Useful during distributed training when you need to
|
5306
|
+
replace `torch.save` with another method. Can be configured with the environment variable
|
5307
|
+
`DIFFUSERS_SAVE_MODE`.
|
5308
|
+
safe_serialization (`bool`, *optional*, defaults to `True`):
|
5309
|
+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
5310
|
+
"""
|
5311
|
+
state_dict = {}
|
5312
|
+
|
5313
|
+
if not transformer_lora_layers:
|
5314
|
+
raise ValueError("You must pass `transformer_lora_layers`.")
|
5315
|
+
|
5316
|
+
if transformer_lora_layers:
|
5317
|
+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
5318
|
+
|
5319
|
+
# Save the model
|
5320
|
+
cls.write_lora_layers(
|
5321
|
+
state_dict=state_dict,
|
5322
|
+
save_directory=save_directory,
|
5323
|
+
is_main_process=is_main_process,
|
5324
|
+
weight_name=weight_name,
|
5325
|
+
save_function=save_function,
|
5326
|
+
safe_serialization=safe_serialization,
|
5327
|
+
)
|
5328
|
+
|
5329
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
5330
|
+
def fuse_lora(
|
5331
|
+
self,
|
5332
|
+
components: List[str] = ["transformer"],
|
5333
|
+
lora_scale: float = 1.0,
|
5334
|
+
safe_fusing: bool = False,
|
5335
|
+
adapter_names: Optional[List[str]] = None,
|
5336
|
+
**kwargs,
|
5337
|
+
):
|
5338
|
+
r"""
|
5339
|
+
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
5340
|
+
|
5341
|
+
<Tip warning={true}>
|
5342
|
+
|
5343
|
+
This is an experimental API.
|
5344
|
+
|
5345
|
+
</Tip>
|
5346
|
+
|
5347
|
+
Args:
|
5348
|
+
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
5349
|
+
lora_scale (`float`, defaults to 1.0):
|
5350
|
+
Controls how much to influence the outputs with the LoRA parameters.
|
5351
|
+
safe_fusing (`bool`, defaults to `False`):
|
5352
|
+
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
5353
|
+
adapter_names (`List[str]`, *optional*):
|
5354
|
+
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
5355
|
+
|
5356
|
+
Example:
|
5357
|
+
|
5358
|
+
```py
|
5359
|
+
from diffusers import DiffusionPipeline
|
5360
|
+
import torch
|
5361
|
+
|
5362
|
+
pipeline = DiffusionPipeline.from_pretrained(
|
5363
|
+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
5364
|
+
).to("cuda")
|
5365
|
+
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
5366
|
+
pipeline.fuse_lora(lora_scale=0.7)
|
5367
|
+
```
|
5368
|
+
"""
|
5369
|
+
super().fuse_lora(
|
5370
|
+
components=components,
|
5371
|
+
lora_scale=lora_scale,
|
5372
|
+
safe_fusing=safe_fusing,
|
5373
|
+
adapter_names=adapter_names,
|
5374
|
+
**kwargs,
|
5375
|
+
)
|
5376
|
+
|
5377
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
5378
|
+
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
4214
5379
|
r"""
|
4215
5380
|
Reverses the effect of
|
4216
5381
|
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
@@ -4224,11 +5389,8 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
|
4224
5389
|
Args:
|
4225
5390
|
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
4226
5391
|
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
4227
|
-
unfuse_text_encoder (`bool`, defaults to `True`):
|
4228
|
-
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
4229
|
-
LoRA parameters then it won't have any effect.
|
4230
5392
|
"""
|
4231
|
-
super().unfuse_lora(components=components)
|
5393
|
+
super().unfuse_lora(components=components, **kwargs)
|
4232
5394
|
|
4233
5395
|
|
4234
5396
|
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|