diffusers 0.27.1__py3-none-any.whl → 0.32.2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diffusers/__init__.py +233 -6
- diffusers/callbacks.py +209 -0
- diffusers/commands/env.py +102 -6
- diffusers/configuration_utils.py +45 -16
- diffusers/dependency_versions_table.py +4 -3
- diffusers/image_processor.py +434 -110
- diffusers/loaders/__init__.py +42 -9
- diffusers/loaders/ip_adapter.py +626 -36
- diffusers/loaders/lora_base.py +900 -0
- diffusers/loaders/lora_conversion_utils.py +991 -125
- diffusers/loaders/lora_pipeline.py +3812 -0
- diffusers/loaders/peft.py +571 -7
- diffusers/loaders/single_file.py +405 -173
- diffusers/loaders/single_file_model.py +385 -0
- diffusers/loaders/single_file_utils.py +1783 -713
- diffusers/loaders/textual_inversion.py +41 -23
- diffusers/loaders/transformer_flux.py +181 -0
- diffusers/loaders/transformer_sd3.py +89 -0
- diffusers/loaders/unet.py +464 -540
- diffusers/loaders/unet_loader_utils.py +163 -0
- diffusers/models/__init__.py +76 -7
- diffusers/models/activations.py +65 -10
- diffusers/models/adapter.py +53 -53
- diffusers/models/attention.py +605 -18
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +4304 -687
- diffusers/models/autoencoders/__init__.py +8 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +15 -17
- diffusers/models/autoencoders/autoencoder_dc.py +620 -0
- diffusers/models/autoencoders/autoencoder_kl.py +110 -28
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1482 -0
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +19 -24
- diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
- diffusers/models/autoencoders/autoencoder_tiny.py +21 -18
- diffusers/models/autoencoders/consistency_decoder_vae.py +45 -20
- diffusers/models/autoencoders/vae.py +41 -29
- diffusers/models/autoencoders/vq_model.py +182 -0
- diffusers/models/controlnet.py +47 -800
- diffusers/models/controlnet_flux.py +70 -0
- diffusers/models/controlnet_sd3.py +68 -0
- diffusers/models/controlnet_sparsectrl.py +116 -0
- diffusers/models/controlnets/__init__.py +23 -0
- diffusers/models/controlnets/controlnet.py +872 -0
- diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +9 -9
- diffusers/models/controlnets/controlnet_flux.py +536 -0
- diffusers/models/controlnets/controlnet_hunyuan.py +401 -0
- diffusers/models/controlnets/controlnet_sd3.py +489 -0
- diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
- diffusers/models/controlnets/controlnet_union.py +832 -0
- diffusers/models/controlnets/controlnet_xs.py +1946 -0
- diffusers/models/controlnets/multicontrolnet.py +183 -0
- diffusers/models/downsampling.py +85 -18
- diffusers/models/embeddings.py +1856 -158
- diffusers/models/embeddings_flax.py +23 -9
- diffusers/models/model_loading_utils.py +480 -0
- diffusers/models/modeling_flax_pytorch_utils.py +2 -1
- diffusers/models/modeling_flax_utils.py +2 -7
- diffusers/models/modeling_outputs.py +14 -0
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +611 -146
- diffusers/models/normalization.py +361 -20
- diffusers/models/resnet.py +18 -23
- diffusers/models/transformers/__init__.py +16 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +544 -0
- diffusers/models/transformers/cogvideox_transformer_3d.py +542 -0
- diffusers/models/transformers/dit_transformer_2d.py +240 -0
- diffusers/models/transformers/dual_transformer_2d.py +9 -8
- diffusers/models/transformers/hunyuan_transformer_2d.py +578 -0
- diffusers/models/transformers/latte_transformer_3d.py +327 -0
- diffusers/models/transformers/lumina_nextdit2d.py +340 -0
- diffusers/models/transformers/pixart_transformer_2d.py +445 -0
- diffusers/models/transformers/prior_transformer.py +13 -13
- diffusers/models/transformers/sana_transformer.py +488 -0
- diffusers/models/transformers/stable_audio_transformer.py +458 -0
- diffusers/models/transformers/t5_film_transformer.py +17 -19
- diffusers/models/transformers/transformer_2d.py +297 -187
- diffusers/models/transformers/transformer_allegro.py +422 -0
- diffusers/models/transformers/transformer_cogview3plus.py +386 -0
- diffusers/models/transformers/transformer_flux.py +593 -0
- diffusers/models/transformers/transformer_hunyuan_video.py +791 -0
- diffusers/models/transformers/transformer_ltx.py +469 -0
- diffusers/models/transformers/transformer_mochi.py +499 -0
- diffusers/models/transformers/transformer_sd3.py +461 -0
- diffusers/models/transformers/transformer_temporal.py +21 -19
- diffusers/models/unets/unet_1d.py +8 -8
- diffusers/models/unets/unet_1d_blocks.py +31 -31
- diffusers/models/unets/unet_2d.py +17 -10
- diffusers/models/unets/unet_2d_blocks.py +225 -149
- diffusers/models/unets/unet_2d_condition.py +41 -40
- diffusers/models/unets/unet_2d_condition_flax.py +6 -5
- diffusers/models/unets/unet_3d_blocks.py +192 -1057
- diffusers/models/unets/unet_3d_condition.py +22 -27
- diffusers/models/unets/unet_i2vgen_xl.py +22 -18
- diffusers/models/unets/unet_kandinsky3.py +2 -2
- diffusers/models/unets/unet_motion_model.py +1413 -89
- diffusers/models/unets/unet_spatio_temporal_condition.py +40 -16
- diffusers/models/unets/unet_stable_cascade.py +19 -18
- diffusers/models/unets/uvit_2d.py +2 -2
- diffusers/models/upsampling.py +95 -26
- diffusers/models/vq_model.py +12 -164
- diffusers/optimization.py +1 -1
- diffusers/pipelines/__init__.py +202 -3
- diffusers/pipelines/allegro/__init__.py +48 -0
- diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
- diffusers/pipelines/allegro/pipeline_output.py +23 -0
- diffusers/pipelines/amused/pipeline_amused.py +12 -12
- diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
- diffusers/pipelines/animatediff/__init__.py +8 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +122 -109
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1106 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1288 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1010 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +236 -180
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
- diffusers/pipelines/animatediff/pipeline_output.py +3 -2
- diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +58 -39
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +121 -36
- diffusers/pipelines/aura_flow/__init__.py +48 -0
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +584 -0
- diffusers/pipelines/auto_pipeline.py +196 -28
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +6 -6
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/cogvideo/__init__.py +54 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +772 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +885 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +851 -0
- diffusers/pipelines/cogvideo/pipeline_output.py +20 -0
- diffusers/pipelines/cogview3/__init__.py +47 -0
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
- diffusers/pipelines/cogview3/pipeline_output.py +21 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +6 -6
- diffusers/pipelines/controlnet/__init__.py +86 -80
- diffusers/pipelines/controlnet/multicontrolnet.py +7 -182
- diffusers/pipelines/controlnet/pipeline_controlnet.py +134 -87
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +93 -77
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +88 -197
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +136 -90
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +176 -80
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +125 -89
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
- diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1060 -0
- diffusers/pipelines/controlnet_sd3/__init__.py +57 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +1133 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
- diffusers/pipelines/controlnet_xs/__init__.py +68 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +916 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1111 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +16 -30
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +20 -35
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +23 -41
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +22 -38
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +25 -41
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +19 -34
- diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
- diffusers/pipelines/deepfloyd_if/watermark.py +1 -1
- diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +70 -30
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +48 -25
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +21 -20
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +27 -29
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +33 -27
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +33 -23
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +36 -30
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +102 -69
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/dit/pipeline_dit.py +7 -4
- diffusers/pipelines/flux/__init__.py +69 -0
- diffusers/pipelines/flux/modeling_flux.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +957 -0
- diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
- diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
- diffusers/pipelines/flux/pipeline_output.py +37 -0
- diffusers/pipelines/free_init_utils.py +41 -38
- diffusers/pipelines/free_noise_utils.py +596 -0
- diffusers/pipelines/hunyuan_video/__init__.py +48 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
- diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
- diffusers/pipelines/hunyuandit/__init__.py +48 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +916 -0
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +32 -29
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +34 -31
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +22 -35
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +26 -37
- diffusers/pipelines/kolors/__init__.py +54 -0
- diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1250 -0
- diffusers/pipelines/kolors/pipeline_output.py +21 -0
- diffusers/pipelines/kolors/text_encoder.py +889 -0
- diffusers/pipelines/kolors/tokenizer.py +338 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +82 -62
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +77 -60
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +12 -12
- diffusers/pipelines/latte/__init__.py +48 -0
- diffusers/pipelines/latte/pipeline_latte.py +881 -0
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +80 -74
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +85 -76
- diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
- diffusers/pipelines/ltx/__init__.py +50 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
- diffusers/pipelines/ltx/pipeline_output.py +20 -0
- diffusers/pipelines/lumina/__init__.py +48 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +890 -0
- diffusers/pipelines/marigold/__init__.py +50 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +576 -0
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
- diffusers/pipelines/mochi/__init__.py +48 -0
- diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
- diffusers/pipelines/mochi/pipeline_output.py +20 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
- diffusers/pipelines/pag/__init__.py +80 -0
- diffusers/pipelines/pag/pag_utils.py +243 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1328 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1610 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +969 -0
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +865 -0
- diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
- diffusers/pipelines/pag/pipeline_pag_sd.py +1062 -0
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +994 -0
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +866 -0
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1345 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1544 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1776 -0
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
- diffusers/pipelines/pia/pipeline_pia.py +74 -164
- diffusers/pipelines/pipeline_flax_utils.py +5 -10
- diffusers/pipelines/pipeline_loading_utils.py +515 -53
- diffusers/pipelines/pipeline_utils.py +411 -222
- diffusers/pipelines/pixart_alpha/__init__.py +8 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +76 -93
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +873 -0
- diffusers/pipelines/sana/__init__.py +47 -0
- diffusers/pipelines/sana/pipeline_output.py +21 -0
- diffusers/pipelines/sana/pipeline_sana.py +884 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +27 -23
- diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_audio/__init__.py +50 -0
- diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +756 -0
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +71 -25
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +35 -34
- diffusers/pipelines/stable_diffusion/__init__.py +0 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +20 -11
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +145 -79
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +43 -28
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +100 -68
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +109 -201
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +131 -32
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +247 -87
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +30 -29
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +35 -27
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +49 -42
- diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
- diffusers/pipelines/stable_diffusion_3/__init__.py +54 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_output.py +21 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +1140 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +1036 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1250 -0
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +29 -20
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +59 -58
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +31 -25
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +38 -22
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -24
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -23
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +107 -67
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +316 -69
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +98 -30
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +121 -83
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +161 -105
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +142 -218
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -29
- diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +69 -39
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +105 -74
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +29 -49
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +32 -93
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +37 -25
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +54 -40
- diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
- diffusers/pipelines/unidiffuser/modeling_uvit.py +12 -12
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +29 -28
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +6 -8
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +15 -14
- diffusers/{models/dual_transformer_2d.py → quantizers/__init__.py} +2 -6
- diffusers/quantizers/auto.py +139 -0
- diffusers/quantizers/base.py +233 -0
- diffusers/quantizers/bitsandbytes/__init__.py +2 -0
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
- diffusers/quantizers/bitsandbytes/utils.py +306 -0
- diffusers/quantizers/gguf/__init__.py +1 -0
- diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
- diffusers/quantizers/gguf/utils.py +456 -0
- diffusers/quantizers/quantization_config.py +669 -0
- diffusers/quantizers/torchao/__init__.py +15 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +292 -0
- diffusers/schedulers/__init__.py +12 -2
- diffusers/schedulers/deprecated/__init__.py +1 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
- diffusers/schedulers/scheduling_amused.py +5 -5
- diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
- diffusers/schedulers/scheduling_consistency_models.py +23 -25
- diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
- diffusers/schedulers/scheduling_ddim.py +27 -26
- diffusers/schedulers/scheduling_ddim_cogvideox.py +452 -0
- diffusers/schedulers/scheduling_ddim_flax.py +2 -1
- diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
- diffusers/schedulers/scheduling_ddim_parallel.py +32 -31
- diffusers/schedulers/scheduling_ddpm.py +27 -30
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +33 -36
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
- diffusers/schedulers/scheduling_deis_multistep.py +150 -50
- diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +221 -84
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +158 -52
- diffusers/schedulers/scheduling_dpmsolver_sde.py +153 -34
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +275 -86
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +81 -57
- diffusers/schedulers/scheduling_edm_euler.py +62 -39
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +30 -29
- diffusers/schedulers/scheduling_euler_discrete.py +255 -74
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +458 -0
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +320 -0
- diffusers/schedulers/scheduling_heun_discrete.py +174 -46
- diffusers/schedulers/scheduling_ipndm.py +9 -9
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +138 -29
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +132 -26
- diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
- diffusers/schedulers/scheduling_lcm.py +23 -29
- diffusers/schedulers/scheduling_lms_discrete.py +105 -28
- diffusers/schedulers/scheduling_pndm.py +20 -20
- diffusers/schedulers/scheduling_repaint.py +21 -21
- diffusers/schedulers/scheduling_sasolver.py +157 -60
- diffusers/schedulers/scheduling_sde_ve.py +19 -19
- diffusers/schedulers/scheduling_tcd.py +41 -36
- diffusers/schedulers/scheduling_unclip.py +19 -16
- diffusers/schedulers/scheduling_unipc_multistep.py +243 -47
- diffusers/schedulers/scheduling_utils.py +12 -5
- diffusers/schedulers/scheduling_utils_flax.py +1 -3
- diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
- diffusers/training_utils.py +214 -30
- diffusers/utils/__init__.py +17 -1
- diffusers/utils/constants.py +3 -0
- diffusers/utils/doc_utils.py +1 -0
- diffusers/utils/dummy_pt_objects.py +592 -7
- diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
- diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +1001 -71
- diffusers/utils/dynamic_modules_utils.py +34 -29
- diffusers/utils/export_utils.py +50 -6
- diffusers/utils/hub_utils.py +131 -17
- diffusers/utils/import_utils.py +210 -8
- diffusers/utils/loading_utils.py +118 -5
- diffusers/utils/logging.py +4 -2
- diffusers/utils/peft_utils.py +37 -7
- diffusers/utils/state_dict_utils.py +13 -2
- diffusers/utils/testing_utils.py +193 -11
- diffusers/utils/torch_utils.py +4 -0
- diffusers/video_processor.py +113 -0
- {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/METADATA +82 -91
- diffusers-0.32.2.dist-info/RECORD +550 -0
- {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/WHEEL +1 -1
- diffusers/loaders/autoencoder.py +0 -146
- diffusers/loaders/controlnet.py +0 -136
- diffusers/loaders/lora.py +0 -1349
- diffusers/models/prior_transformer.py +0 -12
- diffusers/models/t5_film_transformer.py +0 -70
- diffusers/models/transformer_2d.py +0 -25
- diffusers/models/transformer_temporal.py +0 -34
- diffusers/models/unet_1d.py +0 -26
- diffusers/models/unet_1d_blocks.py +0 -203
- diffusers/models/unet_2d.py +0 -27
- diffusers/models/unet_2d_blocks.py +0 -375
- diffusers/models/unet_2d_condition.py +0 -25
- diffusers-0.27.1.dist-info/RECORD +0 -399
- {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/LICENSE +0 -0
- {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/entry_points.txt +0 -0
- {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,3812 @@
|
|
1
|
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import os
|
16
|
+
from typing import Callable, Dict, List, Optional, Union
|
17
|
+
|
18
|
+
import torch
|
19
|
+
from huggingface_hub.utils import validate_hf_hub_args
|
20
|
+
|
21
|
+
from ..utils import (
|
22
|
+
USE_PEFT_BACKEND,
|
23
|
+
deprecate,
|
24
|
+
get_submodule_by_name,
|
25
|
+
is_peft_available,
|
26
|
+
is_peft_version,
|
27
|
+
is_torch_version,
|
28
|
+
is_transformers_available,
|
29
|
+
is_transformers_version,
|
30
|
+
logging,
|
31
|
+
)
|
32
|
+
from .lora_base import ( # noqa
|
33
|
+
LORA_WEIGHT_NAME,
|
34
|
+
LORA_WEIGHT_NAME_SAFE,
|
35
|
+
LoraBaseMixin,
|
36
|
+
_fetch_state_dict,
|
37
|
+
_load_lora_into_text_encoder,
|
38
|
+
)
|
39
|
+
from .lora_conversion_utils import (
|
40
|
+
_convert_bfl_flux_control_lora_to_diffusers,
|
41
|
+
_convert_hunyuan_video_lora_to_diffusers,
|
42
|
+
_convert_kohya_flux_lora_to_diffusers,
|
43
|
+
_convert_non_diffusers_lora_to_diffusers,
|
44
|
+
_convert_xlabs_flux_lora_to_diffusers,
|
45
|
+
_maybe_map_sgm_blocks_to_diffusers,
|
46
|
+
)
|
47
|
+
|
48
|
+
|
49
|
+
_LOW_CPU_MEM_USAGE_DEFAULT_LORA = False
|
50
|
+
if is_torch_version(">=", "1.9.0"):
|
51
|
+
if (
|
52
|
+
is_peft_available()
|
53
|
+
and is_peft_version(">=", "0.13.1")
|
54
|
+
and is_transformers_available()
|
55
|
+
and is_transformers_version(">", "4.45.2")
|
56
|
+
):
|
57
|
+
_LOW_CPU_MEM_USAGE_DEFAULT_LORA = True
|
58
|
+
|
59
|
+
|
60
|
+
logger = logging.get_logger(__name__)
|
61
|
+
|
62
|
+
TEXT_ENCODER_NAME = "text_encoder"
|
63
|
+
UNET_NAME = "unet"
|
64
|
+
TRANSFORMER_NAME = "transformer"
|
65
|
+
|
66
|
+
_MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX = {"x_embedder": "in_channels"}
|
67
|
+
|
68
|
+
|
69
|
+
class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
70
|
+
r"""
|
71
|
+
Load LoRA layers into Stable Diffusion [`UNet2DConditionModel`] and
|
72
|
+
[`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
|
73
|
+
"""
|
74
|
+
|
75
|
+
_lora_loadable_modules = ["unet", "text_encoder"]
|
76
|
+
unet_name = UNET_NAME
|
77
|
+
text_encoder_name = TEXT_ENCODER_NAME
|
78
|
+
|
79
|
+
def load_lora_weights(
|
80
|
+
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
81
|
+
):
|
82
|
+
"""
|
83
|
+
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
|
84
|
+
`self.text_encoder`.
|
85
|
+
|
86
|
+
All kwargs are forwarded to `self.lora_state_dict`.
|
87
|
+
|
88
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
|
89
|
+
loaded.
|
90
|
+
|
91
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is
|
92
|
+
loaded into `self.unet`.
|
93
|
+
|
94
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state
|
95
|
+
dict is loaded into `self.text_encoder`.
|
96
|
+
|
97
|
+
Parameters:
|
98
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
99
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
100
|
+
adapter_name (`str`, *optional*):
|
101
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
102
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
103
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
104
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
105
|
+
weights.
|
106
|
+
kwargs (`dict`, *optional*):
|
107
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
108
|
+
"""
|
109
|
+
if not USE_PEFT_BACKEND:
|
110
|
+
raise ValueError("PEFT backend is required for this method.")
|
111
|
+
|
112
|
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
113
|
+
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
114
|
+
raise ValueError(
|
115
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
116
|
+
)
|
117
|
+
|
118
|
+
# if a dict is passed, copy it instead of modifying it inplace
|
119
|
+
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
120
|
+
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
121
|
+
|
122
|
+
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
123
|
+
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
124
|
+
|
125
|
+
is_correct_format = all("lora" in key for key in state_dict.keys())
|
126
|
+
if not is_correct_format:
|
127
|
+
raise ValueError("Invalid LoRA checkpoint.")
|
128
|
+
|
129
|
+
self.load_lora_into_unet(
|
130
|
+
state_dict,
|
131
|
+
network_alphas=network_alphas,
|
132
|
+
unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
|
133
|
+
adapter_name=adapter_name,
|
134
|
+
_pipeline=self,
|
135
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
136
|
+
)
|
137
|
+
self.load_lora_into_text_encoder(
|
138
|
+
state_dict,
|
139
|
+
network_alphas=network_alphas,
|
140
|
+
text_encoder=getattr(self, self.text_encoder_name)
|
141
|
+
if not hasattr(self, "text_encoder")
|
142
|
+
else self.text_encoder,
|
143
|
+
lora_scale=self.lora_scale,
|
144
|
+
adapter_name=adapter_name,
|
145
|
+
_pipeline=self,
|
146
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
147
|
+
)
|
148
|
+
|
149
|
+
@classmethod
|
150
|
+
@validate_hf_hub_args
|
151
|
+
def lora_state_dict(
|
152
|
+
cls,
|
153
|
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
154
|
+
**kwargs,
|
155
|
+
):
|
156
|
+
r"""
|
157
|
+
Return state dict for lora weights and the network alphas.
|
158
|
+
|
159
|
+
<Tip warning={true}>
|
160
|
+
|
161
|
+
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
|
162
|
+
|
163
|
+
This function is experimental and might change in the future.
|
164
|
+
|
165
|
+
</Tip>
|
166
|
+
|
167
|
+
Parameters:
|
168
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
169
|
+
Can be either:
|
170
|
+
|
171
|
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
172
|
+
the Hub.
|
173
|
+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
174
|
+
with [`ModelMixin.save_pretrained`].
|
175
|
+
- A [torch state
|
176
|
+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
177
|
+
|
178
|
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
179
|
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
180
|
+
is not used.
|
181
|
+
force_download (`bool`, *optional*, defaults to `False`):
|
182
|
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
183
|
+
cached versions if they exist.
|
184
|
+
|
185
|
+
proxies (`Dict[str, str]`, *optional*):
|
186
|
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
187
|
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
188
|
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
189
|
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
190
|
+
won't be downloaded from the Hub.
|
191
|
+
token (`str` or *bool*, *optional*):
|
192
|
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
193
|
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
194
|
+
revision (`str`, *optional*, defaults to `"main"`):
|
195
|
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
196
|
+
allowed by Git.
|
197
|
+
subfolder (`str`, *optional*, defaults to `""`):
|
198
|
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
199
|
+
weight_name (`str`, *optional*, defaults to None):
|
200
|
+
Name of the serialized state dict file.
|
201
|
+
"""
|
202
|
+
# Load the main state dict first which has the LoRA layers for either of
|
203
|
+
# UNet and text encoder or both.
|
204
|
+
cache_dir = kwargs.pop("cache_dir", None)
|
205
|
+
force_download = kwargs.pop("force_download", False)
|
206
|
+
proxies = kwargs.pop("proxies", None)
|
207
|
+
local_files_only = kwargs.pop("local_files_only", None)
|
208
|
+
token = kwargs.pop("token", None)
|
209
|
+
revision = kwargs.pop("revision", None)
|
210
|
+
subfolder = kwargs.pop("subfolder", None)
|
211
|
+
weight_name = kwargs.pop("weight_name", None)
|
212
|
+
unet_config = kwargs.pop("unet_config", None)
|
213
|
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
214
|
+
|
215
|
+
allow_pickle = False
|
216
|
+
if use_safetensors is None:
|
217
|
+
use_safetensors = True
|
218
|
+
allow_pickle = True
|
219
|
+
|
220
|
+
user_agent = {
|
221
|
+
"file_type": "attn_procs_weights",
|
222
|
+
"framework": "pytorch",
|
223
|
+
}
|
224
|
+
|
225
|
+
state_dict = _fetch_state_dict(
|
226
|
+
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
227
|
+
weight_name=weight_name,
|
228
|
+
use_safetensors=use_safetensors,
|
229
|
+
local_files_only=local_files_only,
|
230
|
+
cache_dir=cache_dir,
|
231
|
+
force_download=force_download,
|
232
|
+
proxies=proxies,
|
233
|
+
token=token,
|
234
|
+
revision=revision,
|
235
|
+
subfolder=subfolder,
|
236
|
+
user_agent=user_agent,
|
237
|
+
allow_pickle=allow_pickle,
|
238
|
+
)
|
239
|
+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
240
|
+
if is_dora_scale_present:
|
241
|
+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
242
|
+
logger.warning(warn_msg)
|
243
|
+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
244
|
+
|
245
|
+
network_alphas = None
|
246
|
+
# TODO: replace it with a method from `state_dict_utils`
|
247
|
+
if all(
|
248
|
+
(
|
249
|
+
k.startswith("lora_te_")
|
250
|
+
or k.startswith("lora_unet_")
|
251
|
+
or k.startswith("lora_te1_")
|
252
|
+
or k.startswith("lora_te2_")
|
253
|
+
)
|
254
|
+
for k in state_dict.keys()
|
255
|
+
):
|
256
|
+
# Map SDXL blocks correctly.
|
257
|
+
if unet_config is not None:
|
258
|
+
# use unet config to remap block numbers
|
259
|
+
state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
|
260
|
+
state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict)
|
261
|
+
|
262
|
+
return state_dict, network_alphas
|
263
|
+
|
264
|
+
@classmethod
|
265
|
+
def load_lora_into_unet(
|
266
|
+
cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
267
|
+
):
|
268
|
+
"""
|
269
|
+
This will load the LoRA layers specified in `state_dict` into `unet`.
|
270
|
+
|
271
|
+
Parameters:
|
272
|
+
state_dict (`dict`):
|
273
|
+
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
274
|
+
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
275
|
+
encoder lora layers.
|
276
|
+
network_alphas (`Dict[str, float]`):
|
277
|
+
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
278
|
+
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
279
|
+
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
280
|
+
unet (`UNet2DConditionModel`):
|
281
|
+
The UNet model to load the LoRA layers into.
|
282
|
+
adapter_name (`str`, *optional*):
|
283
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
284
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
285
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
286
|
+
Speed up model loading only loading the pretrained LoRA weights and not initializing the random
|
287
|
+
weights.
|
288
|
+
"""
|
289
|
+
if not USE_PEFT_BACKEND:
|
290
|
+
raise ValueError("PEFT backend is required for this method.")
|
291
|
+
|
292
|
+
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
293
|
+
raise ValueError(
|
294
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
295
|
+
)
|
296
|
+
|
297
|
+
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
298
|
+
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
|
299
|
+
# their prefixes.
|
300
|
+
keys = list(state_dict.keys())
|
301
|
+
only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys)
|
302
|
+
if not only_text_encoder:
|
303
|
+
# Load the layers corresponding to UNet.
|
304
|
+
logger.info(f"Loading {cls.unet_name}.")
|
305
|
+
unet.load_lora_adapter(
|
306
|
+
state_dict,
|
307
|
+
prefix=cls.unet_name,
|
308
|
+
network_alphas=network_alphas,
|
309
|
+
adapter_name=adapter_name,
|
310
|
+
_pipeline=_pipeline,
|
311
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
312
|
+
)
|
313
|
+
|
314
|
+
@classmethod
|
315
|
+
def load_lora_into_text_encoder(
|
316
|
+
cls,
|
317
|
+
state_dict,
|
318
|
+
network_alphas,
|
319
|
+
text_encoder,
|
320
|
+
prefix=None,
|
321
|
+
lora_scale=1.0,
|
322
|
+
adapter_name=None,
|
323
|
+
_pipeline=None,
|
324
|
+
low_cpu_mem_usage=False,
|
325
|
+
):
|
326
|
+
"""
|
327
|
+
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
328
|
+
|
329
|
+
Parameters:
|
330
|
+
state_dict (`dict`):
|
331
|
+
A standard state dict containing the lora layer parameters. The key should be prefixed with an
|
332
|
+
additional `text_encoder` to distinguish between unet lora layers.
|
333
|
+
network_alphas (`Dict[str, float]`):
|
334
|
+
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
335
|
+
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
336
|
+
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
337
|
+
text_encoder (`CLIPTextModel`):
|
338
|
+
The text encoder model to load the LoRA layers into.
|
339
|
+
prefix (`str`):
|
340
|
+
Expected prefix of the `text_encoder` in the `state_dict`.
|
341
|
+
lora_scale (`float`):
|
342
|
+
How much to scale the output of the lora linear layer before it is added with the output of the regular
|
343
|
+
lora layer.
|
344
|
+
adapter_name (`str`, *optional*):
|
345
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
346
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
347
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
348
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
349
|
+
weights.
|
350
|
+
"""
|
351
|
+
_load_lora_into_text_encoder(
|
352
|
+
state_dict=state_dict,
|
353
|
+
network_alphas=network_alphas,
|
354
|
+
lora_scale=lora_scale,
|
355
|
+
text_encoder=text_encoder,
|
356
|
+
prefix=prefix,
|
357
|
+
text_encoder_name=cls.text_encoder_name,
|
358
|
+
adapter_name=adapter_name,
|
359
|
+
_pipeline=_pipeline,
|
360
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
361
|
+
)
|
362
|
+
|
363
|
+
@classmethod
|
364
|
+
def save_lora_weights(
|
365
|
+
cls,
|
366
|
+
save_directory: Union[str, os.PathLike],
|
367
|
+
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
368
|
+
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
|
369
|
+
is_main_process: bool = True,
|
370
|
+
weight_name: str = None,
|
371
|
+
save_function: Callable = None,
|
372
|
+
safe_serialization: bool = True,
|
373
|
+
):
|
374
|
+
r"""
|
375
|
+
Save the LoRA parameters corresponding to the UNet and text encoder.
|
376
|
+
|
377
|
+
Arguments:
|
378
|
+
save_directory (`str` or `os.PathLike`):
|
379
|
+
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
380
|
+
unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
381
|
+
State dict of the LoRA layers corresponding to the `unet`.
|
382
|
+
text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
383
|
+
State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
|
384
|
+
encoder LoRA state dict because it comes from 🤗 Transformers.
|
385
|
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
386
|
+
Whether the process calling this is the main process or not. Useful during distributed training and you
|
387
|
+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
388
|
+
process to avoid race conditions.
|
389
|
+
save_function (`Callable`):
|
390
|
+
The function to use to save the state dictionary. Useful during distributed training when you need to
|
391
|
+
replace `torch.save` with another method. Can be configured with the environment variable
|
392
|
+
`DIFFUSERS_SAVE_MODE`.
|
393
|
+
safe_serialization (`bool`, *optional*, defaults to `True`):
|
394
|
+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
395
|
+
"""
|
396
|
+
state_dict = {}
|
397
|
+
|
398
|
+
if not (unet_lora_layers or text_encoder_lora_layers):
|
399
|
+
raise ValueError("You must pass at least one of `unet_lora_layers` and `text_encoder_lora_layers`.")
|
400
|
+
|
401
|
+
if unet_lora_layers:
|
402
|
+
state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name))
|
403
|
+
|
404
|
+
if text_encoder_lora_layers:
|
405
|
+
state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
|
406
|
+
|
407
|
+
# Save the model
|
408
|
+
cls.write_lora_layers(
|
409
|
+
state_dict=state_dict,
|
410
|
+
save_directory=save_directory,
|
411
|
+
is_main_process=is_main_process,
|
412
|
+
weight_name=weight_name,
|
413
|
+
save_function=save_function,
|
414
|
+
safe_serialization=safe_serialization,
|
415
|
+
)
|
416
|
+
|
417
|
+
def fuse_lora(
|
418
|
+
self,
|
419
|
+
components: List[str] = ["unet", "text_encoder"],
|
420
|
+
lora_scale: float = 1.0,
|
421
|
+
safe_fusing: bool = False,
|
422
|
+
adapter_names: Optional[List[str]] = None,
|
423
|
+
**kwargs,
|
424
|
+
):
|
425
|
+
r"""
|
426
|
+
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
427
|
+
|
428
|
+
<Tip warning={true}>
|
429
|
+
|
430
|
+
This is an experimental API.
|
431
|
+
|
432
|
+
</Tip>
|
433
|
+
|
434
|
+
Args:
|
435
|
+
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
436
|
+
lora_scale (`float`, defaults to 1.0):
|
437
|
+
Controls how much to influence the outputs with the LoRA parameters.
|
438
|
+
safe_fusing (`bool`, defaults to `False`):
|
439
|
+
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
440
|
+
adapter_names (`List[str]`, *optional*):
|
441
|
+
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
442
|
+
|
443
|
+
Example:
|
444
|
+
|
445
|
+
```py
|
446
|
+
from diffusers import DiffusionPipeline
|
447
|
+
import torch
|
448
|
+
|
449
|
+
pipeline = DiffusionPipeline.from_pretrained(
|
450
|
+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
451
|
+
).to("cuda")
|
452
|
+
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
453
|
+
pipeline.fuse_lora(lora_scale=0.7)
|
454
|
+
```
|
455
|
+
"""
|
456
|
+
super().fuse_lora(
|
457
|
+
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
458
|
+
)
|
459
|
+
|
460
|
+
def unfuse_lora(self, components: List[str] = ["unet", "text_encoder"], **kwargs):
|
461
|
+
r"""
|
462
|
+
Reverses the effect of
|
463
|
+
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
464
|
+
|
465
|
+
<Tip warning={true}>
|
466
|
+
|
467
|
+
This is an experimental API.
|
468
|
+
|
469
|
+
</Tip>
|
470
|
+
|
471
|
+
Args:
|
472
|
+
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
473
|
+
unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
474
|
+
unfuse_text_encoder (`bool`, defaults to `True`):
|
475
|
+
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
476
|
+
LoRA parameters then it won't have any effect.
|
477
|
+
"""
|
478
|
+
super().unfuse_lora(components=components)
|
479
|
+
|
480
|
+
|
481
|
+
class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
482
|
+
r"""
|
483
|
+
Load LoRA layers into Stable Diffusion XL [`UNet2DConditionModel`],
|
484
|
+
[`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), and
|
485
|
+
[`CLIPTextModelWithProjection`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection).
|
486
|
+
"""
|
487
|
+
|
488
|
+
_lora_loadable_modules = ["unet", "text_encoder", "text_encoder_2"]
|
489
|
+
unet_name = UNET_NAME
|
490
|
+
text_encoder_name = TEXT_ENCODER_NAME
|
491
|
+
|
492
|
+
def load_lora_weights(
|
493
|
+
self,
|
494
|
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
495
|
+
adapter_name: Optional[str] = None,
|
496
|
+
**kwargs,
|
497
|
+
):
|
498
|
+
"""
|
499
|
+
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
|
500
|
+
`self.text_encoder`.
|
501
|
+
|
502
|
+
All kwargs are forwarded to `self.lora_state_dict`.
|
503
|
+
|
504
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
|
505
|
+
loaded.
|
506
|
+
|
507
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is
|
508
|
+
loaded into `self.unet`.
|
509
|
+
|
510
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state
|
511
|
+
dict is loaded into `self.text_encoder`.
|
512
|
+
|
513
|
+
Parameters:
|
514
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
515
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
516
|
+
adapter_name (`str`, *optional*):
|
517
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
518
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
519
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
520
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
521
|
+
weights.
|
522
|
+
kwargs (`dict`, *optional*):
|
523
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
524
|
+
"""
|
525
|
+
if not USE_PEFT_BACKEND:
|
526
|
+
raise ValueError("PEFT backend is required for this method.")
|
527
|
+
|
528
|
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
529
|
+
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
530
|
+
raise ValueError(
|
531
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
532
|
+
)
|
533
|
+
|
534
|
+
# We could have accessed the unet config from `lora_state_dict()` too. We pass
|
535
|
+
# it here explicitly to be able to tell that it's coming from an SDXL
|
536
|
+
# pipeline.
|
537
|
+
|
538
|
+
# if a dict is passed, copy it instead of modifying it inplace
|
539
|
+
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
540
|
+
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
541
|
+
|
542
|
+
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
543
|
+
state_dict, network_alphas = self.lora_state_dict(
|
544
|
+
pretrained_model_name_or_path_or_dict,
|
545
|
+
unet_config=self.unet.config,
|
546
|
+
**kwargs,
|
547
|
+
)
|
548
|
+
|
549
|
+
is_correct_format = all("lora" in key for key in state_dict.keys())
|
550
|
+
if not is_correct_format:
|
551
|
+
raise ValueError("Invalid LoRA checkpoint.")
|
552
|
+
|
553
|
+
self.load_lora_into_unet(
|
554
|
+
state_dict,
|
555
|
+
network_alphas=network_alphas,
|
556
|
+
unet=self.unet,
|
557
|
+
adapter_name=adapter_name,
|
558
|
+
_pipeline=self,
|
559
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
560
|
+
)
|
561
|
+
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
562
|
+
if len(text_encoder_state_dict) > 0:
|
563
|
+
self.load_lora_into_text_encoder(
|
564
|
+
text_encoder_state_dict,
|
565
|
+
network_alphas=network_alphas,
|
566
|
+
text_encoder=self.text_encoder,
|
567
|
+
prefix="text_encoder",
|
568
|
+
lora_scale=self.lora_scale,
|
569
|
+
adapter_name=adapter_name,
|
570
|
+
_pipeline=self,
|
571
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
572
|
+
)
|
573
|
+
|
574
|
+
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
|
575
|
+
if len(text_encoder_2_state_dict) > 0:
|
576
|
+
self.load_lora_into_text_encoder(
|
577
|
+
text_encoder_2_state_dict,
|
578
|
+
network_alphas=network_alphas,
|
579
|
+
text_encoder=self.text_encoder_2,
|
580
|
+
prefix="text_encoder_2",
|
581
|
+
lora_scale=self.lora_scale,
|
582
|
+
adapter_name=adapter_name,
|
583
|
+
_pipeline=self,
|
584
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
585
|
+
)
|
586
|
+
|
587
|
+
@classmethod
|
588
|
+
@validate_hf_hub_args
|
589
|
+
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.lora_state_dict
|
590
|
+
def lora_state_dict(
|
591
|
+
cls,
|
592
|
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
593
|
+
**kwargs,
|
594
|
+
):
|
595
|
+
r"""
|
596
|
+
Return state dict for lora weights and the network alphas.
|
597
|
+
|
598
|
+
<Tip warning={true}>
|
599
|
+
|
600
|
+
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
|
601
|
+
|
602
|
+
This function is experimental and might change in the future.
|
603
|
+
|
604
|
+
</Tip>
|
605
|
+
|
606
|
+
Parameters:
|
607
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
608
|
+
Can be either:
|
609
|
+
|
610
|
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
611
|
+
the Hub.
|
612
|
+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
613
|
+
with [`ModelMixin.save_pretrained`].
|
614
|
+
- A [torch state
|
615
|
+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
616
|
+
|
617
|
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
618
|
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
619
|
+
is not used.
|
620
|
+
force_download (`bool`, *optional*, defaults to `False`):
|
621
|
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
622
|
+
cached versions if they exist.
|
623
|
+
|
624
|
+
proxies (`Dict[str, str]`, *optional*):
|
625
|
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
626
|
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
627
|
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
628
|
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
629
|
+
won't be downloaded from the Hub.
|
630
|
+
token (`str` or *bool*, *optional*):
|
631
|
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
632
|
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
633
|
+
revision (`str`, *optional*, defaults to `"main"`):
|
634
|
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
635
|
+
allowed by Git.
|
636
|
+
subfolder (`str`, *optional*, defaults to `""`):
|
637
|
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
638
|
+
weight_name (`str`, *optional*, defaults to None):
|
639
|
+
Name of the serialized state dict file.
|
640
|
+
"""
|
641
|
+
# Load the main state dict first which has the LoRA layers for either of
|
642
|
+
# UNet and text encoder or both.
|
643
|
+
cache_dir = kwargs.pop("cache_dir", None)
|
644
|
+
force_download = kwargs.pop("force_download", False)
|
645
|
+
proxies = kwargs.pop("proxies", None)
|
646
|
+
local_files_only = kwargs.pop("local_files_only", None)
|
647
|
+
token = kwargs.pop("token", None)
|
648
|
+
revision = kwargs.pop("revision", None)
|
649
|
+
subfolder = kwargs.pop("subfolder", None)
|
650
|
+
weight_name = kwargs.pop("weight_name", None)
|
651
|
+
unet_config = kwargs.pop("unet_config", None)
|
652
|
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
653
|
+
|
654
|
+
allow_pickle = False
|
655
|
+
if use_safetensors is None:
|
656
|
+
use_safetensors = True
|
657
|
+
allow_pickle = True
|
658
|
+
|
659
|
+
user_agent = {
|
660
|
+
"file_type": "attn_procs_weights",
|
661
|
+
"framework": "pytorch",
|
662
|
+
}
|
663
|
+
|
664
|
+
state_dict = _fetch_state_dict(
|
665
|
+
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
666
|
+
weight_name=weight_name,
|
667
|
+
use_safetensors=use_safetensors,
|
668
|
+
local_files_only=local_files_only,
|
669
|
+
cache_dir=cache_dir,
|
670
|
+
force_download=force_download,
|
671
|
+
proxies=proxies,
|
672
|
+
token=token,
|
673
|
+
revision=revision,
|
674
|
+
subfolder=subfolder,
|
675
|
+
user_agent=user_agent,
|
676
|
+
allow_pickle=allow_pickle,
|
677
|
+
)
|
678
|
+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
679
|
+
if is_dora_scale_present:
|
680
|
+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
681
|
+
logger.warning(warn_msg)
|
682
|
+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
683
|
+
|
684
|
+
network_alphas = None
|
685
|
+
# TODO: replace it with a method from `state_dict_utils`
|
686
|
+
if all(
|
687
|
+
(
|
688
|
+
k.startswith("lora_te_")
|
689
|
+
or k.startswith("lora_unet_")
|
690
|
+
or k.startswith("lora_te1_")
|
691
|
+
or k.startswith("lora_te2_")
|
692
|
+
)
|
693
|
+
for k in state_dict.keys()
|
694
|
+
):
|
695
|
+
# Map SDXL blocks correctly.
|
696
|
+
if unet_config is not None:
|
697
|
+
# use unet config to remap block numbers
|
698
|
+
state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
|
699
|
+
state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict)
|
700
|
+
|
701
|
+
return state_dict, network_alphas
|
702
|
+
|
703
|
+
@classmethod
|
704
|
+
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_unet
|
705
|
+
def load_lora_into_unet(
|
706
|
+
cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
707
|
+
):
|
708
|
+
"""
|
709
|
+
This will load the LoRA layers specified in `state_dict` into `unet`.
|
710
|
+
|
711
|
+
Parameters:
|
712
|
+
state_dict (`dict`):
|
713
|
+
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
714
|
+
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
715
|
+
encoder lora layers.
|
716
|
+
network_alphas (`Dict[str, float]`):
|
717
|
+
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
718
|
+
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
719
|
+
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
720
|
+
unet (`UNet2DConditionModel`):
|
721
|
+
The UNet model to load the LoRA layers into.
|
722
|
+
adapter_name (`str`, *optional*):
|
723
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
724
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
725
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
726
|
+
Speed up model loading only loading the pretrained LoRA weights and not initializing the random
|
727
|
+
weights.
|
728
|
+
"""
|
729
|
+
if not USE_PEFT_BACKEND:
|
730
|
+
raise ValueError("PEFT backend is required for this method.")
|
731
|
+
|
732
|
+
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
733
|
+
raise ValueError(
|
734
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
735
|
+
)
|
736
|
+
|
737
|
+
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
738
|
+
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
|
739
|
+
# their prefixes.
|
740
|
+
keys = list(state_dict.keys())
|
741
|
+
only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys)
|
742
|
+
if not only_text_encoder:
|
743
|
+
# Load the layers corresponding to UNet.
|
744
|
+
logger.info(f"Loading {cls.unet_name}.")
|
745
|
+
unet.load_lora_adapter(
|
746
|
+
state_dict,
|
747
|
+
prefix=cls.unet_name,
|
748
|
+
network_alphas=network_alphas,
|
749
|
+
adapter_name=adapter_name,
|
750
|
+
_pipeline=_pipeline,
|
751
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
752
|
+
)
|
753
|
+
|
754
|
+
@classmethod
|
755
|
+
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
|
756
|
+
def load_lora_into_text_encoder(
|
757
|
+
cls,
|
758
|
+
state_dict,
|
759
|
+
network_alphas,
|
760
|
+
text_encoder,
|
761
|
+
prefix=None,
|
762
|
+
lora_scale=1.0,
|
763
|
+
adapter_name=None,
|
764
|
+
_pipeline=None,
|
765
|
+
low_cpu_mem_usage=False,
|
766
|
+
):
|
767
|
+
"""
|
768
|
+
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
769
|
+
|
770
|
+
Parameters:
|
771
|
+
state_dict (`dict`):
|
772
|
+
A standard state dict containing the lora layer parameters. The key should be prefixed with an
|
773
|
+
additional `text_encoder` to distinguish between unet lora layers.
|
774
|
+
network_alphas (`Dict[str, float]`):
|
775
|
+
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
776
|
+
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
777
|
+
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
778
|
+
text_encoder (`CLIPTextModel`):
|
779
|
+
The text encoder model to load the LoRA layers into.
|
780
|
+
prefix (`str`):
|
781
|
+
Expected prefix of the `text_encoder` in the `state_dict`.
|
782
|
+
lora_scale (`float`):
|
783
|
+
How much to scale the output of the lora linear layer before it is added with the output of the regular
|
784
|
+
lora layer.
|
785
|
+
adapter_name (`str`, *optional*):
|
786
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
787
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
788
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
789
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
790
|
+
weights.
|
791
|
+
"""
|
792
|
+
_load_lora_into_text_encoder(
|
793
|
+
state_dict=state_dict,
|
794
|
+
network_alphas=network_alphas,
|
795
|
+
lora_scale=lora_scale,
|
796
|
+
text_encoder=text_encoder,
|
797
|
+
prefix=prefix,
|
798
|
+
text_encoder_name=cls.text_encoder_name,
|
799
|
+
adapter_name=adapter_name,
|
800
|
+
_pipeline=_pipeline,
|
801
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
802
|
+
)
|
803
|
+
|
804
|
+
@classmethod
|
805
|
+
def save_lora_weights(
|
806
|
+
cls,
|
807
|
+
save_directory: Union[str, os.PathLike],
|
808
|
+
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
809
|
+
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
810
|
+
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
811
|
+
is_main_process: bool = True,
|
812
|
+
weight_name: str = None,
|
813
|
+
save_function: Callable = None,
|
814
|
+
safe_serialization: bool = True,
|
815
|
+
):
|
816
|
+
r"""
|
817
|
+
Save the LoRA parameters corresponding to the UNet and text encoder.
|
818
|
+
|
819
|
+
Arguments:
|
820
|
+
save_directory (`str` or `os.PathLike`):
|
821
|
+
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
822
|
+
unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
823
|
+
State dict of the LoRA layers corresponding to the `unet`.
|
824
|
+
text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
825
|
+
State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
|
826
|
+
encoder LoRA state dict because it comes from 🤗 Transformers.
|
827
|
+
text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
828
|
+
State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text
|
829
|
+
encoder LoRA state dict because it comes from 🤗 Transformers.
|
830
|
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
831
|
+
Whether the process calling this is the main process or not. Useful during distributed training and you
|
832
|
+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
833
|
+
process to avoid race conditions.
|
834
|
+
save_function (`Callable`):
|
835
|
+
The function to use to save the state dictionary. Useful during distributed training when you need to
|
836
|
+
replace `torch.save` with another method. Can be configured with the environment variable
|
837
|
+
`DIFFUSERS_SAVE_MODE`.
|
838
|
+
safe_serialization (`bool`, *optional*, defaults to `True`):
|
839
|
+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
840
|
+
"""
|
841
|
+
state_dict = {}
|
842
|
+
|
843
|
+
if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
|
844
|
+
raise ValueError(
|
845
|
+
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`."
|
846
|
+
)
|
847
|
+
|
848
|
+
if unet_lora_layers:
|
849
|
+
state_dict.update(cls.pack_weights(unet_lora_layers, "unet"))
|
850
|
+
|
851
|
+
if text_encoder_lora_layers:
|
852
|
+
state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))
|
853
|
+
|
854
|
+
if text_encoder_2_lora_layers:
|
855
|
+
state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
|
856
|
+
|
857
|
+
cls.write_lora_layers(
|
858
|
+
state_dict=state_dict,
|
859
|
+
save_directory=save_directory,
|
860
|
+
is_main_process=is_main_process,
|
861
|
+
weight_name=weight_name,
|
862
|
+
save_function=save_function,
|
863
|
+
safe_serialization=safe_serialization,
|
864
|
+
)
|
865
|
+
|
866
|
+
def fuse_lora(
|
867
|
+
self,
|
868
|
+
components: List[str] = ["unet", "text_encoder", "text_encoder_2"],
|
869
|
+
lora_scale: float = 1.0,
|
870
|
+
safe_fusing: bool = False,
|
871
|
+
adapter_names: Optional[List[str]] = None,
|
872
|
+
**kwargs,
|
873
|
+
):
|
874
|
+
r"""
|
875
|
+
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
876
|
+
|
877
|
+
<Tip warning={true}>
|
878
|
+
|
879
|
+
This is an experimental API.
|
880
|
+
|
881
|
+
</Tip>
|
882
|
+
|
883
|
+
Args:
|
884
|
+
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
885
|
+
lora_scale (`float`, defaults to 1.0):
|
886
|
+
Controls how much to influence the outputs with the LoRA parameters.
|
887
|
+
safe_fusing (`bool`, defaults to `False`):
|
888
|
+
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
889
|
+
adapter_names (`List[str]`, *optional*):
|
890
|
+
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
891
|
+
|
892
|
+
Example:
|
893
|
+
|
894
|
+
```py
|
895
|
+
from diffusers import DiffusionPipeline
|
896
|
+
import torch
|
897
|
+
|
898
|
+
pipeline = DiffusionPipeline.from_pretrained(
|
899
|
+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
900
|
+
).to("cuda")
|
901
|
+
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
902
|
+
pipeline.fuse_lora(lora_scale=0.7)
|
903
|
+
```
|
904
|
+
"""
|
905
|
+
super().fuse_lora(
|
906
|
+
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
907
|
+
)
|
908
|
+
|
909
|
+
def unfuse_lora(self, components: List[str] = ["unet", "text_encoder", "text_encoder_2"], **kwargs):
|
910
|
+
r"""
|
911
|
+
Reverses the effect of
|
912
|
+
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
913
|
+
|
914
|
+
<Tip warning={true}>
|
915
|
+
|
916
|
+
This is an experimental API.
|
917
|
+
|
918
|
+
</Tip>
|
919
|
+
|
920
|
+
Args:
|
921
|
+
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
922
|
+
unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
923
|
+
unfuse_text_encoder (`bool`, defaults to `True`):
|
924
|
+
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
925
|
+
LoRA parameters then it won't have any effect.
|
926
|
+
"""
|
927
|
+
super().unfuse_lora(components=components)
|
928
|
+
|
929
|
+
|
930
|
+
class SD3LoraLoaderMixin(LoraBaseMixin):
|
931
|
+
r"""
|
932
|
+
Load LoRA layers into [`SD3Transformer2DModel`],
|
933
|
+
[`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), and
|
934
|
+
[`CLIPTextModelWithProjection`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection).
|
935
|
+
|
936
|
+
Specific to [`StableDiffusion3Pipeline`].
|
937
|
+
"""
|
938
|
+
|
939
|
+
_lora_loadable_modules = ["transformer", "text_encoder", "text_encoder_2"]
|
940
|
+
transformer_name = TRANSFORMER_NAME
|
941
|
+
text_encoder_name = TEXT_ENCODER_NAME
|
942
|
+
|
943
|
+
@classmethod
|
944
|
+
@validate_hf_hub_args
|
945
|
+
def lora_state_dict(
|
946
|
+
cls,
|
947
|
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
948
|
+
**kwargs,
|
949
|
+
):
|
950
|
+
r"""
|
951
|
+
Return state dict for lora weights and the network alphas.
|
952
|
+
|
953
|
+
<Tip warning={true}>
|
954
|
+
|
955
|
+
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
|
956
|
+
|
957
|
+
This function is experimental and might change in the future.
|
958
|
+
|
959
|
+
</Tip>
|
960
|
+
|
961
|
+
Parameters:
|
962
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
963
|
+
Can be either:
|
964
|
+
|
965
|
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
966
|
+
the Hub.
|
967
|
+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
968
|
+
with [`ModelMixin.save_pretrained`].
|
969
|
+
- A [torch state
|
970
|
+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
971
|
+
|
972
|
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
973
|
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
974
|
+
is not used.
|
975
|
+
force_download (`bool`, *optional*, defaults to `False`):
|
976
|
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
977
|
+
cached versions if they exist.
|
978
|
+
|
979
|
+
proxies (`Dict[str, str]`, *optional*):
|
980
|
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
981
|
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
982
|
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
983
|
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
984
|
+
won't be downloaded from the Hub.
|
985
|
+
token (`str` or *bool*, *optional*):
|
986
|
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
987
|
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
988
|
+
revision (`str`, *optional*, defaults to `"main"`):
|
989
|
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
990
|
+
allowed by Git.
|
991
|
+
subfolder (`str`, *optional*, defaults to `""`):
|
992
|
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
993
|
+
|
994
|
+
"""
|
995
|
+
# Load the main state dict first which has the LoRA layers for either of
|
996
|
+
# transformer and text encoder or both.
|
997
|
+
cache_dir = kwargs.pop("cache_dir", None)
|
998
|
+
force_download = kwargs.pop("force_download", False)
|
999
|
+
proxies = kwargs.pop("proxies", None)
|
1000
|
+
local_files_only = kwargs.pop("local_files_only", None)
|
1001
|
+
token = kwargs.pop("token", None)
|
1002
|
+
revision = kwargs.pop("revision", None)
|
1003
|
+
subfolder = kwargs.pop("subfolder", None)
|
1004
|
+
weight_name = kwargs.pop("weight_name", None)
|
1005
|
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
1006
|
+
|
1007
|
+
allow_pickle = False
|
1008
|
+
if use_safetensors is None:
|
1009
|
+
use_safetensors = True
|
1010
|
+
allow_pickle = True
|
1011
|
+
|
1012
|
+
user_agent = {
|
1013
|
+
"file_type": "attn_procs_weights",
|
1014
|
+
"framework": "pytorch",
|
1015
|
+
}
|
1016
|
+
|
1017
|
+
state_dict = _fetch_state_dict(
|
1018
|
+
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
1019
|
+
weight_name=weight_name,
|
1020
|
+
use_safetensors=use_safetensors,
|
1021
|
+
local_files_only=local_files_only,
|
1022
|
+
cache_dir=cache_dir,
|
1023
|
+
force_download=force_download,
|
1024
|
+
proxies=proxies,
|
1025
|
+
token=token,
|
1026
|
+
revision=revision,
|
1027
|
+
subfolder=subfolder,
|
1028
|
+
user_agent=user_agent,
|
1029
|
+
allow_pickle=allow_pickle,
|
1030
|
+
)
|
1031
|
+
|
1032
|
+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
1033
|
+
if is_dora_scale_present:
|
1034
|
+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
1035
|
+
logger.warning(warn_msg)
|
1036
|
+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
1037
|
+
|
1038
|
+
return state_dict
|
1039
|
+
|
1040
|
+
def load_lora_weights(
|
1041
|
+
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
1042
|
+
):
|
1043
|
+
"""
|
1044
|
+
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
|
1045
|
+
`self.text_encoder`.
|
1046
|
+
|
1047
|
+
All kwargs are forwarded to `self.lora_state_dict`.
|
1048
|
+
|
1049
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
|
1050
|
+
loaded.
|
1051
|
+
|
1052
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
|
1053
|
+
dict is loaded into `self.transformer`.
|
1054
|
+
|
1055
|
+
Parameters:
|
1056
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
1057
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
1058
|
+
adapter_name (`str`, *optional*):
|
1059
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
1060
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
1061
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
1062
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
1063
|
+
weights.
|
1064
|
+
kwargs (`dict`, *optional*):
|
1065
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
1066
|
+
"""
|
1067
|
+
if not USE_PEFT_BACKEND:
|
1068
|
+
raise ValueError("PEFT backend is required for this method.")
|
1069
|
+
|
1070
|
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
1071
|
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
1072
|
+
raise ValueError(
|
1073
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
1074
|
+
)
|
1075
|
+
|
1076
|
+
# if a dict is passed, copy it instead of modifying it inplace
|
1077
|
+
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
1078
|
+
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
1079
|
+
|
1080
|
+
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
1081
|
+
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
1082
|
+
|
1083
|
+
is_correct_format = all("lora" in key for key in state_dict.keys())
|
1084
|
+
if not is_correct_format:
|
1085
|
+
raise ValueError("Invalid LoRA checkpoint.")
|
1086
|
+
|
1087
|
+
transformer_state_dict = {k: v for k, v in state_dict.items() if "transformer." in k}
|
1088
|
+
if len(transformer_state_dict) > 0:
|
1089
|
+
self.load_lora_into_transformer(
|
1090
|
+
state_dict,
|
1091
|
+
transformer=getattr(self, self.transformer_name)
|
1092
|
+
if not hasattr(self, "transformer")
|
1093
|
+
else self.transformer,
|
1094
|
+
adapter_name=adapter_name,
|
1095
|
+
_pipeline=self,
|
1096
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
1097
|
+
)
|
1098
|
+
|
1099
|
+
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
1100
|
+
if len(text_encoder_state_dict) > 0:
|
1101
|
+
self.load_lora_into_text_encoder(
|
1102
|
+
text_encoder_state_dict,
|
1103
|
+
network_alphas=None,
|
1104
|
+
text_encoder=self.text_encoder,
|
1105
|
+
prefix="text_encoder",
|
1106
|
+
lora_scale=self.lora_scale,
|
1107
|
+
adapter_name=adapter_name,
|
1108
|
+
_pipeline=self,
|
1109
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
1110
|
+
)
|
1111
|
+
|
1112
|
+
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
|
1113
|
+
if len(text_encoder_2_state_dict) > 0:
|
1114
|
+
self.load_lora_into_text_encoder(
|
1115
|
+
text_encoder_2_state_dict,
|
1116
|
+
network_alphas=None,
|
1117
|
+
text_encoder=self.text_encoder_2,
|
1118
|
+
prefix="text_encoder_2",
|
1119
|
+
lora_scale=self.lora_scale,
|
1120
|
+
adapter_name=adapter_name,
|
1121
|
+
_pipeline=self,
|
1122
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
1123
|
+
)
|
1124
|
+
|
1125
|
+
@classmethod
|
1126
|
+
def load_lora_into_transformer(
|
1127
|
+
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
1128
|
+
):
|
1129
|
+
"""
|
1130
|
+
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
1131
|
+
|
1132
|
+
Parameters:
|
1133
|
+
state_dict (`dict`):
|
1134
|
+
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
1135
|
+
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
1136
|
+
encoder lora layers.
|
1137
|
+
transformer (`SD3Transformer2DModel`):
|
1138
|
+
The Transformer model to load the LoRA layers into.
|
1139
|
+
adapter_name (`str`, *optional*):
|
1140
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
1141
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
1142
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
1143
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
1144
|
+
weights.
|
1145
|
+
"""
|
1146
|
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
1147
|
+
raise ValueError(
|
1148
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
1149
|
+
)
|
1150
|
+
|
1151
|
+
# Load the layers corresponding to transformer.
|
1152
|
+
logger.info(f"Loading {cls.transformer_name}.")
|
1153
|
+
transformer.load_lora_adapter(
|
1154
|
+
state_dict,
|
1155
|
+
network_alphas=None,
|
1156
|
+
adapter_name=adapter_name,
|
1157
|
+
_pipeline=_pipeline,
|
1158
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
1159
|
+
)
|
1160
|
+
|
1161
|
+
@classmethod
|
1162
|
+
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
|
1163
|
+
def load_lora_into_text_encoder(
|
1164
|
+
cls,
|
1165
|
+
state_dict,
|
1166
|
+
network_alphas,
|
1167
|
+
text_encoder,
|
1168
|
+
prefix=None,
|
1169
|
+
lora_scale=1.0,
|
1170
|
+
adapter_name=None,
|
1171
|
+
_pipeline=None,
|
1172
|
+
low_cpu_mem_usage=False,
|
1173
|
+
):
|
1174
|
+
"""
|
1175
|
+
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
1176
|
+
|
1177
|
+
Parameters:
|
1178
|
+
state_dict (`dict`):
|
1179
|
+
A standard state dict containing the lora layer parameters. The key should be prefixed with an
|
1180
|
+
additional `text_encoder` to distinguish between unet lora layers.
|
1181
|
+
network_alphas (`Dict[str, float]`):
|
1182
|
+
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
1183
|
+
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
1184
|
+
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
1185
|
+
text_encoder (`CLIPTextModel`):
|
1186
|
+
The text encoder model to load the LoRA layers into.
|
1187
|
+
prefix (`str`):
|
1188
|
+
Expected prefix of the `text_encoder` in the `state_dict`.
|
1189
|
+
lora_scale (`float`):
|
1190
|
+
How much to scale the output of the lora linear layer before it is added with the output of the regular
|
1191
|
+
lora layer.
|
1192
|
+
adapter_name (`str`, *optional*):
|
1193
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
1194
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
1195
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
1196
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
1197
|
+
weights.
|
1198
|
+
"""
|
1199
|
+
_load_lora_into_text_encoder(
|
1200
|
+
state_dict=state_dict,
|
1201
|
+
network_alphas=network_alphas,
|
1202
|
+
lora_scale=lora_scale,
|
1203
|
+
text_encoder=text_encoder,
|
1204
|
+
prefix=prefix,
|
1205
|
+
text_encoder_name=cls.text_encoder_name,
|
1206
|
+
adapter_name=adapter_name,
|
1207
|
+
_pipeline=_pipeline,
|
1208
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
1209
|
+
)
|
1210
|
+
|
1211
|
+
@classmethod
|
1212
|
+
def save_lora_weights(
|
1213
|
+
cls,
|
1214
|
+
save_directory: Union[str, os.PathLike],
|
1215
|
+
transformer_lora_layers: Dict[str, torch.nn.Module] = None,
|
1216
|
+
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
1217
|
+
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
1218
|
+
is_main_process: bool = True,
|
1219
|
+
weight_name: str = None,
|
1220
|
+
save_function: Callable = None,
|
1221
|
+
safe_serialization: bool = True,
|
1222
|
+
):
|
1223
|
+
r"""
|
1224
|
+
Save the LoRA parameters corresponding to the UNet and text encoder.
|
1225
|
+
|
1226
|
+
Arguments:
|
1227
|
+
save_directory (`str` or `os.PathLike`):
|
1228
|
+
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
1229
|
+
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
1230
|
+
State dict of the LoRA layers corresponding to the `transformer`.
|
1231
|
+
text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
1232
|
+
State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
|
1233
|
+
encoder LoRA state dict because it comes from 🤗 Transformers.
|
1234
|
+
text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
1235
|
+
State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text
|
1236
|
+
encoder LoRA state dict because it comes from 🤗 Transformers.
|
1237
|
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
1238
|
+
Whether the process calling this is the main process or not. Useful during distributed training and you
|
1239
|
+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
1240
|
+
process to avoid race conditions.
|
1241
|
+
save_function (`Callable`):
|
1242
|
+
The function to use to save the state dictionary. Useful during distributed training when you need to
|
1243
|
+
replace `torch.save` with another method. Can be configured with the environment variable
|
1244
|
+
`DIFFUSERS_SAVE_MODE`.
|
1245
|
+
safe_serialization (`bool`, *optional*, defaults to `True`):
|
1246
|
+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
1247
|
+
"""
|
1248
|
+
state_dict = {}
|
1249
|
+
|
1250
|
+
if not (transformer_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
|
1251
|
+
raise ValueError(
|
1252
|
+
"You must pass at least one of `transformer_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`."
|
1253
|
+
)
|
1254
|
+
|
1255
|
+
if transformer_lora_layers:
|
1256
|
+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
1257
|
+
|
1258
|
+
if text_encoder_lora_layers:
|
1259
|
+
state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))
|
1260
|
+
|
1261
|
+
if text_encoder_2_lora_layers:
|
1262
|
+
state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
|
1263
|
+
|
1264
|
+
# Save the model
|
1265
|
+
cls.write_lora_layers(
|
1266
|
+
state_dict=state_dict,
|
1267
|
+
save_directory=save_directory,
|
1268
|
+
is_main_process=is_main_process,
|
1269
|
+
weight_name=weight_name,
|
1270
|
+
save_function=save_function,
|
1271
|
+
safe_serialization=safe_serialization,
|
1272
|
+
)
|
1273
|
+
|
1274
|
+
def fuse_lora(
|
1275
|
+
self,
|
1276
|
+
components: List[str] = ["transformer", "text_encoder", "text_encoder_2"],
|
1277
|
+
lora_scale: float = 1.0,
|
1278
|
+
safe_fusing: bool = False,
|
1279
|
+
adapter_names: Optional[List[str]] = None,
|
1280
|
+
**kwargs,
|
1281
|
+
):
|
1282
|
+
r"""
|
1283
|
+
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
1284
|
+
|
1285
|
+
<Tip warning={true}>
|
1286
|
+
|
1287
|
+
This is an experimental API.
|
1288
|
+
|
1289
|
+
</Tip>
|
1290
|
+
|
1291
|
+
Args:
|
1292
|
+
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
1293
|
+
lora_scale (`float`, defaults to 1.0):
|
1294
|
+
Controls how much to influence the outputs with the LoRA parameters.
|
1295
|
+
safe_fusing (`bool`, defaults to `False`):
|
1296
|
+
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
1297
|
+
adapter_names (`List[str]`, *optional*):
|
1298
|
+
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
1299
|
+
|
1300
|
+
Example:
|
1301
|
+
|
1302
|
+
```py
|
1303
|
+
from diffusers import DiffusionPipeline
|
1304
|
+
import torch
|
1305
|
+
|
1306
|
+
pipeline = DiffusionPipeline.from_pretrained(
|
1307
|
+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
1308
|
+
).to("cuda")
|
1309
|
+
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
1310
|
+
pipeline.fuse_lora(lora_scale=0.7)
|
1311
|
+
```
|
1312
|
+
"""
|
1313
|
+
super().fuse_lora(
|
1314
|
+
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
1315
|
+
)
|
1316
|
+
|
1317
|
+
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], **kwargs):
|
1318
|
+
r"""
|
1319
|
+
Reverses the effect of
|
1320
|
+
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
1321
|
+
|
1322
|
+
<Tip warning={true}>
|
1323
|
+
|
1324
|
+
This is an experimental API.
|
1325
|
+
|
1326
|
+
</Tip>
|
1327
|
+
|
1328
|
+
Args:
|
1329
|
+
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
1330
|
+
unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
1331
|
+
unfuse_text_encoder (`bool`, defaults to `True`):
|
1332
|
+
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
1333
|
+
LoRA parameters then it won't have any effect.
|
1334
|
+
"""
|
1335
|
+
super().unfuse_lora(components=components)
|
1336
|
+
|
1337
|
+
|
1338
|
+
class FluxLoraLoaderMixin(LoraBaseMixin):
|
1339
|
+
r"""
|
1340
|
+
Load LoRA layers into [`FluxTransformer2DModel`],
|
1341
|
+
[`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
|
1342
|
+
|
1343
|
+
Specific to [`StableDiffusion3Pipeline`].
|
1344
|
+
"""
|
1345
|
+
|
1346
|
+
_lora_loadable_modules = ["transformer", "text_encoder"]
|
1347
|
+
transformer_name = TRANSFORMER_NAME
|
1348
|
+
text_encoder_name = TEXT_ENCODER_NAME
|
1349
|
+
_control_lora_supported_norm_keys = ["norm_q", "norm_k", "norm_added_q", "norm_added_k"]
|
1350
|
+
|
1351
|
+
@classmethod
|
1352
|
+
@validate_hf_hub_args
|
1353
|
+
def lora_state_dict(
|
1354
|
+
cls,
|
1355
|
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
1356
|
+
return_alphas: bool = False,
|
1357
|
+
**kwargs,
|
1358
|
+
):
|
1359
|
+
r"""
|
1360
|
+
Return state dict for lora weights and the network alphas.
|
1361
|
+
|
1362
|
+
<Tip warning={true}>
|
1363
|
+
|
1364
|
+
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
|
1365
|
+
|
1366
|
+
This function is experimental and might change in the future.
|
1367
|
+
|
1368
|
+
</Tip>
|
1369
|
+
|
1370
|
+
Parameters:
|
1371
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
1372
|
+
Can be either:
|
1373
|
+
|
1374
|
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
1375
|
+
the Hub.
|
1376
|
+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
1377
|
+
with [`ModelMixin.save_pretrained`].
|
1378
|
+
- A [torch state
|
1379
|
+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
1380
|
+
|
1381
|
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
1382
|
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
1383
|
+
is not used.
|
1384
|
+
force_download (`bool`, *optional*, defaults to `False`):
|
1385
|
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
1386
|
+
cached versions if they exist.
|
1387
|
+
|
1388
|
+
proxies (`Dict[str, str]`, *optional*):
|
1389
|
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
1390
|
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
1391
|
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
1392
|
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
1393
|
+
won't be downloaded from the Hub.
|
1394
|
+
token (`str` or *bool*, *optional*):
|
1395
|
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
1396
|
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
1397
|
+
revision (`str`, *optional*, defaults to `"main"`):
|
1398
|
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
1399
|
+
allowed by Git.
|
1400
|
+
subfolder (`str`, *optional*, defaults to `""`):
|
1401
|
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
1402
|
+
|
1403
|
+
"""
|
1404
|
+
# Load the main state dict first which has the LoRA layers for either of
|
1405
|
+
# transformer and text encoder or both.
|
1406
|
+
cache_dir = kwargs.pop("cache_dir", None)
|
1407
|
+
force_download = kwargs.pop("force_download", False)
|
1408
|
+
proxies = kwargs.pop("proxies", None)
|
1409
|
+
local_files_only = kwargs.pop("local_files_only", None)
|
1410
|
+
token = kwargs.pop("token", None)
|
1411
|
+
revision = kwargs.pop("revision", None)
|
1412
|
+
subfolder = kwargs.pop("subfolder", None)
|
1413
|
+
weight_name = kwargs.pop("weight_name", None)
|
1414
|
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
1415
|
+
|
1416
|
+
allow_pickle = False
|
1417
|
+
if use_safetensors is None:
|
1418
|
+
use_safetensors = True
|
1419
|
+
allow_pickle = True
|
1420
|
+
|
1421
|
+
user_agent = {
|
1422
|
+
"file_type": "attn_procs_weights",
|
1423
|
+
"framework": "pytorch",
|
1424
|
+
}
|
1425
|
+
|
1426
|
+
state_dict = _fetch_state_dict(
|
1427
|
+
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
1428
|
+
weight_name=weight_name,
|
1429
|
+
use_safetensors=use_safetensors,
|
1430
|
+
local_files_only=local_files_only,
|
1431
|
+
cache_dir=cache_dir,
|
1432
|
+
force_download=force_download,
|
1433
|
+
proxies=proxies,
|
1434
|
+
token=token,
|
1435
|
+
revision=revision,
|
1436
|
+
subfolder=subfolder,
|
1437
|
+
user_agent=user_agent,
|
1438
|
+
allow_pickle=allow_pickle,
|
1439
|
+
)
|
1440
|
+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
1441
|
+
if is_dora_scale_present:
|
1442
|
+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
1443
|
+
logger.warning(warn_msg)
|
1444
|
+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
1445
|
+
|
1446
|
+
# TODO (sayakpaul): to a follow-up to clean and try to unify the conditions.
|
1447
|
+
is_kohya = any(".lora_down.weight" in k for k in state_dict)
|
1448
|
+
if is_kohya:
|
1449
|
+
state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict)
|
1450
|
+
# Kohya already takes care of scaling the LoRA parameters with alpha.
|
1451
|
+
return (state_dict, None) if return_alphas else state_dict
|
1452
|
+
|
1453
|
+
is_xlabs = any("processor" in k for k in state_dict)
|
1454
|
+
if is_xlabs:
|
1455
|
+
state_dict = _convert_xlabs_flux_lora_to_diffusers(state_dict)
|
1456
|
+
# xlabs doesn't use `alpha`.
|
1457
|
+
return (state_dict, None) if return_alphas else state_dict
|
1458
|
+
|
1459
|
+
is_bfl_control = any("query_norm.scale" in k for k in state_dict)
|
1460
|
+
if is_bfl_control:
|
1461
|
+
state_dict = _convert_bfl_flux_control_lora_to_diffusers(state_dict)
|
1462
|
+
return (state_dict, None) if return_alphas else state_dict
|
1463
|
+
|
1464
|
+
# For state dicts like
|
1465
|
+
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
|
1466
|
+
keys = list(state_dict.keys())
|
1467
|
+
network_alphas = {}
|
1468
|
+
for k in keys:
|
1469
|
+
if "alpha" in k:
|
1470
|
+
alpha_value = state_dict.get(k)
|
1471
|
+
if (torch.is_tensor(alpha_value) and torch.is_floating_point(alpha_value)) or isinstance(
|
1472
|
+
alpha_value, float
|
1473
|
+
):
|
1474
|
+
network_alphas[k] = state_dict.pop(k)
|
1475
|
+
else:
|
1476
|
+
raise ValueError(
|
1477
|
+
f"The alpha key ({k}) seems to be incorrect. If you think this error is unexpected, please open as issue."
|
1478
|
+
)
|
1479
|
+
|
1480
|
+
if return_alphas:
|
1481
|
+
return state_dict, network_alphas
|
1482
|
+
else:
|
1483
|
+
return state_dict
|
1484
|
+
|
1485
|
+
def load_lora_weights(
|
1486
|
+
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
1487
|
+
):
|
1488
|
+
"""
|
1489
|
+
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
1490
|
+
`self.text_encoder`.
|
1491
|
+
|
1492
|
+
All kwargs are forwarded to `self.lora_state_dict`.
|
1493
|
+
|
1494
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
|
1495
|
+
loaded.
|
1496
|
+
|
1497
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
|
1498
|
+
dict is loaded into `self.transformer`.
|
1499
|
+
|
1500
|
+
Parameters:
|
1501
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
1502
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
1503
|
+
kwargs (`dict`, *optional*):
|
1504
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
1505
|
+
adapter_name (`str`, *optional*):
|
1506
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
1507
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
1508
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
1509
|
+
`Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
1510
|
+
weights.
|
1511
|
+
"""
|
1512
|
+
if not USE_PEFT_BACKEND:
|
1513
|
+
raise ValueError("PEFT backend is required for this method.")
|
1514
|
+
|
1515
|
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
1516
|
+
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
1517
|
+
raise ValueError(
|
1518
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
1519
|
+
)
|
1520
|
+
|
1521
|
+
# if a dict is passed, copy it instead of modifying it inplace
|
1522
|
+
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
1523
|
+
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
1524
|
+
|
1525
|
+
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
1526
|
+
state_dict, network_alphas = self.lora_state_dict(
|
1527
|
+
pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs
|
1528
|
+
)
|
1529
|
+
|
1530
|
+
has_lora_keys = any("lora" in key for key in state_dict.keys())
|
1531
|
+
|
1532
|
+
# Flux Control LoRAs also have norm keys
|
1533
|
+
has_norm_keys = any(
|
1534
|
+
norm_key in key for key in state_dict.keys() for norm_key in self._control_lora_supported_norm_keys
|
1535
|
+
)
|
1536
|
+
|
1537
|
+
if not (has_lora_keys or has_norm_keys):
|
1538
|
+
raise ValueError("Invalid LoRA checkpoint.")
|
1539
|
+
|
1540
|
+
transformer_lora_state_dict = {
|
1541
|
+
k: state_dict.pop(k) for k in list(state_dict.keys()) if "transformer." in k and "lora" in k
|
1542
|
+
}
|
1543
|
+
transformer_norm_state_dict = {
|
1544
|
+
k: state_dict.pop(k)
|
1545
|
+
for k in list(state_dict.keys())
|
1546
|
+
if "transformer." in k and any(norm_key in k for norm_key in self._control_lora_supported_norm_keys)
|
1547
|
+
}
|
1548
|
+
|
1549
|
+
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
|
1550
|
+
has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_(
|
1551
|
+
transformer, transformer_lora_state_dict, transformer_norm_state_dict
|
1552
|
+
)
|
1553
|
+
|
1554
|
+
if has_param_with_expanded_shape:
|
1555
|
+
logger.info(
|
1556
|
+
"The LoRA weights contain parameters that have different shapes that expected by the transformer. "
|
1557
|
+
"As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. "
|
1558
|
+
"To get a comprehensive list of parameter names that were modified, enable debug logging."
|
1559
|
+
)
|
1560
|
+
transformer_lora_state_dict = self._maybe_expand_lora_state_dict(
|
1561
|
+
transformer=transformer, lora_state_dict=transformer_lora_state_dict
|
1562
|
+
)
|
1563
|
+
|
1564
|
+
if len(transformer_lora_state_dict) > 0:
|
1565
|
+
self.load_lora_into_transformer(
|
1566
|
+
transformer_lora_state_dict,
|
1567
|
+
network_alphas=network_alphas,
|
1568
|
+
transformer=transformer,
|
1569
|
+
adapter_name=adapter_name,
|
1570
|
+
_pipeline=self,
|
1571
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
1572
|
+
)
|
1573
|
+
|
1574
|
+
if len(transformer_norm_state_dict) > 0:
|
1575
|
+
transformer._transformer_norm_layers = self._load_norm_into_transformer(
|
1576
|
+
transformer_norm_state_dict,
|
1577
|
+
transformer=transformer,
|
1578
|
+
discard_original_layers=False,
|
1579
|
+
)
|
1580
|
+
|
1581
|
+
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
1582
|
+
if len(text_encoder_state_dict) > 0:
|
1583
|
+
self.load_lora_into_text_encoder(
|
1584
|
+
text_encoder_state_dict,
|
1585
|
+
network_alphas=network_alphas,
|
1586
|
+
text_encoder=self.text_encoder,
|
1587
|
+
prefix="text_encoder",
|
1588
|
+
lora_scale=self.lora_scale,
|
1589
|
+
adapter_name=adapter_name,
|
1590
|
+
_pipeline=self,
|
1591
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
1592
|
+
)
|
1593
|
+
|
1594
|
+
@classmethod
|
1595
|
+
def load_lora_into_transformer(
|
1596
|
+
cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
1597
|
+
):
|
1598
|
+
"""
|
1599
|
+
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
1600
|
+
|
1601
|
+
Parameters:
|
1602
|
+
state_dict (`dict`):
|
1603
|
+
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
1604
|
+
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
1605
|
+
encoder lora layers.
|
1606
|
+
network_alphas (`Dict[str, float]`):
|
1607
|
+
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
1608
|
+
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
1609
|
+
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
1610
|
+
transformer (`FluxTransformer2DModel`):
|
1611
|
+
The Transformer model to load the LoRA layers into.
|
1612
|
+
adapter_name (`str`, *optional*):
|
1613
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
1614
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
1615
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
1616
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
1617
|
+
weights.
|
1618
|
+
"""
|
1619
|
+
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
1620
|
+
raise ValueError(
|
1621
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
1622
|
+
)
|
1623
|
+
|
1624
|
+
# Load the layers corresponding to transformer.
|
1625
|
+
keys = list(state_dict.keys())
|
1626
|
+
transformer_present = any(key.startswith(cls.transformer_name) for key in keys)
|
1627
|
+
if transformer_present:
|
1628
|
+
logger.info(f"Loading {cls.transformer_name}.")
|
1629
|
+
transformer.load_lora_adapter(
|
1630
|
+
state_dict,
|
1631
|
+
network_alphas=network_alphas,
|
1632
|
+
adapter_name=adapter_name,
|
1633
|
+
_pipeline=_pipeline,
|
1634
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
1635
|
+
)
|
1636
|
+
|
1637
|
+
@classmethod
|
1638
|
+
def _load_norm_into_transformer(
|
1639
|
+
cls,
|
1640
|
+
state_dict,
|
1641
|
+
transformer,
|
1642
|
+
prefix=None,
|
1643
|
+
discard_original_layers=False,
|
1644
|
+
) -> Dict[str, torch.Tensor]:
|
1645
|
+
# Remove prefix if present
|
1646
|
+
prefix = prefix or cls.transformer_name
|
1647
|
+
for key in list(state_dict.keys()):
|
1648
|
+
if key.split(".")[0] == prefix:
|
1649
|
+
state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key)
|
1650
|
+
|
1651
|
+
# Find invalid keys
|
1652
|
+
transformer_state_dict = transformer.state_dict()
|
1653
|
+
transformer_keys = set(transformer_state_dict.keys())
|
1654
|
+
state_dict_keys = set(state_dict.keys())
|
1655
|
+
extra_keys = list(state_dict_keys - transformer_keys)
|
1656
|
+
|
1657
|
+
if extra_keys:
|
1658
|
+
logger.warning(
|
1659
|
+
f"Unsupported keys found in state dict when trying to load normalization layers into the transformer. The following keys will be ignored:\n{extra_keys}."
|
1660
|
+
)
|
1661
|
+
|
1662
|
+
for key in extra_keys:
|
1663
|
+
state_dict.pop(key)
|
1664
|
+
|
1665
|
+
# Save the layers that are going to be overwritten so that unload_lora_weights can work as expected
|
1666
|
+
overwritten_layers_state_dict = {}
|
1667
|
+
if not discard_original_layers:
|
1668
|
+
for key in state_dict.keys():
|
1669
|
+
overwritten_layers_state_dict[key] = transformer_state_dict[key].clone()
|
1670
|
+
|
1671
|
+
logger.info(
|
1672
|
+
"The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will directly update the state_dict of the transformer "
|
1673
|
+
'as opposed to the LoRA layers that will co-exist separately until the "fuse_lora()" method is called. That is to say, the normalization layers will always be directly '
|
1674
|
+
"fused into the transformer and can only be unfused if `discard_original_layers=True` is passed. This might also have implications when dealing with multiple LoRAs. "
|
1675
|
+
"If you notice something unexpected, please open an issue: https://github.com/huggingface/diffusers/issues."
|
1676
|
+
)
|
1677
|
+
|
1678
|
+
# We can't load with strict=True because the current state_dict does not contain all the transformer keys
|
1679
|
+
incompatible_keys = transformer.load_state_dict(state_dict, strict=False)
|
1680
|
+
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
1681
|
+
|
1682
|
+
# We shouldn't expect to see the supported norm keys here being present in the unexpected keys.
|
1683
|
+
if unexpected_keys:
|
1684
|
+
if any(norm_key in k for k in unexpected_keys for norm_key in cls._control_lora_supported_norm_keys):
|
1685
|
+
raise ValueError(
|
1686
|
+
f"Found {unexpected_keys} as unexpected keys while trying to load norm layers into the transformer."
|
1687
|
+
)
|
1688
|
+
|
1689
|
+
return overwritten_layers_state_dict
|
1690
|
+
|
1691
|
+
@classmethod
|
1692
|
+
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
|
1693
|
+
def load_lora_into_text_encoder(
|
1694
|
+
cls,
|
1695
|
+
state_dict,
|
1696
|
+
network_alphas,
|
1697
|
+
text_encoder,
|
1698
|
+
prefix=None,
|
1699
|
+
lora_scale=1.0,
|
1700
|
+
adapter_name=None,
|
1701
|
+
_pipeline=None,
|
1702
|
+
low_cpu_mem_usage=False,
|
1703
|
+
):
|
1704
|
+
"""
|
1705
|
+
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
1706
|
+
|
1707
|
+
Parameters:
|
1708
|
+
state_dict (`dict`):
|
1709
|
+
A standard state dict containing the lora layer parameters. The key should be prefixed with an
|
1710
|
+
additional `text_encoder` to distinguish between unet lora layers.
|
1711
|
+
network_alphas (`Dict[str, float]`):
|
1712
|
+
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
1713
|
+
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
1714
|
+
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
1715
|
+
text_encoder (`CLIPTextModel`):
|
1716
|
+
The text encoder model to load the LoRA layers into.
|
1717
|
+
prefix (`str`):
|
1718
|
+
Expected prefix of the `text_encoder` in the `state_dict`.
|
1719
|
+
lora_scale (`float`):
|
1720
|
+
How much to scale the output of the lora linear layer before it is added with the output of the regular
|
1721
|
+
lora layer.
|
1722
|
+
adapter_name (`str`, *optional*):
|
1723
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
1724
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
1725
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
1726
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
1727
|
+
weights.
|
1728
|
+
"""
|
1729
|
+
_load_lora_into_text_encoder(
|
1730
|
+
state_dict=state_dict,
|
1731
|
+
network_alphas=network_alphas,
|
1732
|
+
lora_scale=lora_scale,
|
1733
|
+
text_encoder=text_encoder,
|
1734
|
+
prefix=prefix,
|
1735
|
+
text_encoder_name=cls.text_encoder_name,
|
1736
|
+
adapter_name=adapter_name,
|
1737
|
+
_pipeline=_pipeline,
|
1738
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
1739
|
+
)
|
1740
|
+
|
1741
|
+
@classmethod
|
1742
|
+
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights with unet->transformer
|
1743
|
+
def save_lora_weights(
|
1744
|
+
cls,
|
1745
|
+
save_directory: Union[str, os.PathLike],
|
1746
|
+
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
1747
|
+
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
|
1748
|
+
is_main_process: bool = True,
|
1749
|
+
weight_name: str = None,
|
1750
|
+
save_function: Callable = None,
|
1751
|
+
safe_serialization: bool = True,
|
1752
|
+
):
|
1753
|
+
r"""
|
1754
|
+
Save the LoRA parameters corresponding to the UNet and text encoder.
|
1755
|
+
|
1756
|
+
Arguments:
|
1757
|
+
save_directory (`str` or `os.PathLike`):
|
1758
|
+
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
1759
|
+
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
1760
|
+
State dict of the LoRA layers corresponding to the `transformer`.
|
1761
|
+
text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
1762
|
+
State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
|
1763
|
+
encoder LoRA state dict because it comes from 🤗 Transformers.
|
1764
|
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
1765
|
+
Whether the process calling this is the main process or not. Useful during distributed training and you
|
1766
|
+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
1767
|
+
process to avoid race conditions.
|
1768
|
+
save_function (`Callable`):
|
1769
|
+
The function to use to save the state dictionary. Useful during distributed training when you need to
|
1770
|
+
replace `torch.save` with another method. Can be configured with the environment variable
|
1771
|
+
`DIFFUSERS_SAVE_MODE`.
|
1772
|
+
safe_serialization (`bool`, *optional*, defaults to `True`):
|
1773
|
+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
1774
|
+
"""
|
1775
|
+
state_dict = {}
|
1776
|
+
|
1777
|
+
if not (transformer_lora_layers or text_encoder_lora_layers):
|
1778
|
+
raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.")
|
1779
|
+
|
1780
|
+
if transformer_lora_layers:
|
1781
|
+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
1782
|
+
|
1783
|
+
if text_encoder_lora_layers:
|
1784
|
+
state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
|
1785
|
+
|
1786
|
+
# Save the model
|
1787
|
+
cls.write_lora_layers(
|
1788
|
+
state_dict=state_dict,
|
1789
|
+
save_directory=save_directory,
|
1790
|
+
is_main_process=is_main_process,
|
1791
|
+
weight_name=weight_name,
|
1792
|
+
save_function=save_function,
|
1793
|
+
safe_serialization=safe_serialization,
|
1794
|
+
)
|
1795
|
+
|
1796
|
+
def fuse_lora(
|
1797
|
+
self,
|
1798
|
+
components: List[str] = ["transformer"],
|
1799
|
+
lora_scale: float = 1.0,
|
1800
|
+
safe_fusing: bool = False,
|
1801
|
+
adapter_names: Optional[List[str]] = None,
|
1802
|
+
**kwargs,
|
1803
|
+
):
|
1804
|
+
r"""
|
1805
|
+
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
1806
|
+
|
1807
|
+
<Tip warning={true}>
|
1808
|
+
|
1809
|
+
This is an experimental API.
|
1810
|
+
|
1811
|
+
</Tip>
|
1812
|
+
|
1813
|
+
Args:
|
1814
|
+
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
1815
|
+
lora_scale (`float`, defaults to 1.0):
|
1816
|
+
Controls how much to influence the outputs with the LoRA parameters.
|
1817
|
+
safe_fusing (`bool`, defaults to `False`):
|
1818
|
+
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
1819
|
+
adapter_names (`List[str]`, *optional*):
|
1820
|
+
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
1821
|
+
|
1822
|
+
Example:
|
1823
|
+
|
1824
|
+
```py
|
1825
|
+
from diffusers import DiffusionPipeline
|
1826
|
+
import torch
|
1827
|
+
|
1828
|
+
pipeline = DiffusionPipeline.from_pretrained(
|
1829
|
+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
1830
|
+
).to("cuda")
|
1831
|
+
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
1832
|
+
pipeline.fuse_lora(lora_scale=0.7)
|
1833
|
+
```
|
1834
|
+
"""
|
1835
|
+
|
1836
|
+
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
|
1837
|
+
if (
|
1838
|
+
hasattr(transformer, "_transformer_norm_layers")
|
1839
|
+
and isinstance(transformer._transformer_norm_layers, dict)
|
1840
|
+
and len(transformer._transformer_norm_layers.keys()) > 0
|
1841
|
+
):
|
1842
|
+
logger.info(
|
1843
|
+
"The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will be directly updated the state_dict of the transformer "
|
1844
|
+
"as opposed to the LoRA layers that will co-exist separately until the 'fuse_lora()' method is called. That is to say, the normalization layers will always be directly "
|
1845
|
+
"fused into the transformer and can only be unfused if `discard_original_layers=True` is passed."
|
1846
|
+
)
|
1847
|
+
|
1848
|
+
super().fuse_lora(
|
1849
|
+
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
1850
|
+
)
|
1851
|
+
|
1852
|
+
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
|
1853
|
+
r"""
|
1854
|
+
Reverses the effect of
|
1855
|
+
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
1856
|
+
|
1857
|
+
<Tip warning={true}>
|
1858
|
+
|
1859
|
+
This is an experimental API.
|
1860
|
+
|
1861
|
+
</Tip>
|
1862
|
+
|
1863
|
+
Args:
|
1864
|
+
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
1865
|
+
"""
|
1866
|
+
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
|
1867
|
+
if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers:
|
1868
|
+
transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
|
1869
|
+
|
1870
|
+
super().unfuse_lora(components=components)
|
1871
|
+
|
1872
|
+
# We override this here account for `_transformer_norm_layers` and `_overwritten_params`.
|
1873
|
+
def unload_lora_weights(self, reset_to_overwritten_params=False):
|
1874
|
+
"""
|
1875
|
+
Unloads the LoRA parameters.
|
1876
|
+
|
1877
|
+
Args:
|
1878
|
+
reset_to_overwritten_params (`bool`, defaults to `False`): Whether to reset the LoRA-loaded modules
|
1879
|
+
to their original params. Refer to the [Flux
|
1880
|
+
documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux) to learn more.
|
1881
|
+
|
1882
|
+
Examples:
|
1883
|
+
|
1884
|
+
```python
|
1885
|
+
>>> # Assuming `pipeline` is already loaded with the LoRA parameters.
|
1886
|
+
>>> pipeline.unload_lora_weights()
|
1887
|
+
>>> ...
|
1888
|
+
```
|
1889
|
+
"""
|
1890
|
+
super().unload_lora_weights()
|
1891
|
+
|
1892
|
+
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
|
1893
|
+
if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers:
|
1894
|
+
transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
|
1895
|
+
transformer._transformer_norm_layers = None
|
1896
|
+
|
1897
|
+
if reset_to_overwritten_params and getattr(transformer, "_overwritten_params", None) is not None:
|
1898
|
+
overwritten_params = transformer._overwritten_params
|
1899
|
+
module_names = set()
|
1900
|
+
|
1901
|
+
for param_name in overwritten_params:
|
1902
|
+
if param_name.endswith(".weight"):
|
1903
|
+
module_names.add(param_name.replace(".weight", ""))
|
1904
|
+
|
1905
|
+
for name, module in transformer.named_modules():
|
1906
|
+
if isinstance(module, torch.nn.Linear) and name in module_names:
|
1907
|
+
module_weight = module.weight.data
|
1908
|
+
module_bias = module.bias.data if module.bias is not None else None
|
1909
|
+
bias = module_bias is not None
|
1910
|
+
|
1911
|
+
parent_module_name, _, current_module_name = name.rpartition(".")
|
1912
|
+
parent_module = transformer.get_submodule(parent_module_name)
|
1913
|
+
|
1914
|
+
current_param_weight = overwritten_params[f"{name}.weight"]
|
1915
|
+
in_features, out_features = current_param_weight.shape[1], current_param_weight.shape[0]
|
1916
|
+
with torch.device("meta"):
|
1917
|
+
original_module = torch.nn.Linear(
|
1918
|
+
in_features,
|
1919
|
+
out_features,
|
1920
|
+
bias=bias,
|
1921
|
+
dtype=module_weight.dtype,
|
1922
|
+
)
|
1923
|
+
|
1924
|
+
tmp_state_dict = {"weight": current_param_weight}
|
1925
|
+
if module_bias is not None:
|
1926
|
+
tmp_state_dict.update({"bias": overwritten_params[f"{name}.bias"]})
|
1927
|
+
original_module.load_state_dict(tmp_state_dict, assign=True, strict=True)
|
1928
|
+
setattr(parent_module, current_module_name, original_module)
|
1929
|
+
|
1930
|
+
del tmp_state_dict
|
1931
|
+
|
1932
|
+
if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX:
|
1933
|
+
attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name]
|
1934
|
+
new_value = int(current_param_weight.shape[1])
|
1935
|
+
old_value = getattr(transformer.config, attribute_name)
|
1936
|
+
setattr(transformer.config, attribute_name, new_value)
|
1937
|
+
logger.info(
|
1938
|
+
f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}."
|
1939
|
+
)
|
1940
|
+
|
1941
|
+
@classmethod
|
1942
|
+
def _maybe_expand_transformer_param_shape_or_error_(
|
1943
|
+
cls,
|
1944
|
+
transformer: torch.nn.Module,
|
1945
|
+
lora_state_dict=None,
|
1946
|
+
norm_state_dict=None,
|
1947
|
+
prefix=None,
|
1948
|
+
) -> bool:
|
1949
|
+
"""
|
1950
|
+
Control LoRA expands the shape of the input layer from (3072, 64) to (3072, 128). This method handles that and
|
1951
|
+
generalizes things a bit so that any parameter that needs expansion receives appropriate treatement.
|
1952
|
+
"""
|
1953
|
+
state_dict = {}
|
1954
|
+
if lora_state_dict is not None:
|
1955
|
+
state_dict.update(lora_state_dict)
|
1956
|
+
if norm_state_dict is not None:
|
1957
|
+
state_dict.update(norm_state_dict)
|
1958
|
+
|
1959
|
+
# Remove prefix if present
|
1960
|
+
prefix = prefix or cls.transformer_name
|
1961
|
+
for key in list(state_dict.keys()):
|
1962
|
+
if key.split(".")[0] == prefix:
|
1963
|
+
state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key)
|
1964
|
+
|
1965
|
+
# Expand transformer parameter shapes if they don't match lora
|
1966
|
+
has_param_with_shape_update = False
|
1967
|
+
overwritten_params = {}
|
1968
|
+
|
1969
|
+
is_peft_loaded = getattr(transformer, "peft_config", None) is not None
|
1970
|
+
for name, module in transformer.named_modules():
|
1971
|
+
if isinstance(module, torch.nn.Linear):
|
1972
|
+
module_weight = module.weight.data
|
1973
|
+
module_bias = module.bias.data if module.bias is not None else None
|
1974
|
+
bias = module_bias is not None
|
1975
|
+
|
1976
|
+
lora_base_name = name.replace(".base_layer", "") if is_peft_loaded else name
|
1977
|
+
lora_A_weight_name = f"{lora_base_name}.lora_A.weight"
|
1978
|
+
lora_B_weight_name = f"{lora_base_name}.lora_B.weight"
|
1979
|
+
if lora_A_weight_name not in state_dict:
|
1980
|
+
continue
|
1981
|
+
|
1982
|
+
in_features = state_dict[lora_A_weight_name].shape[1]
|
1983
|
+
out_features = state_dict[lora_B_weight_name].shape[0]
|
1984
|
+
|
1985
|
+
# Model maybe loaded with different quantization schemes which may flatten the params.
|
1986
|
+
# `bitsandbytes`, for example, flatten the weights when using 4bit. 8bit bnb models
|
1987
|
+
# preserve weight shape.
|
1988
|
+
module_weight_shape = cls._calculate_module_shape(model=transformer, base_module=module)
|
1989
|
+
|
1990
|
+
# This means there's no need for an expansion in the params, so we simply skip.
|
1991
|
+
if tuple(module_weight_shape) == (out_features, in_features):
|
1992
|
+
continue
|
1993
|
+
|
1994
|
+
# TODO (sayakpaul): We still need to consider if the module we're expanding is
|
1995
|
+
# quantized and handle it accordingly if that is the case.
|
1996
|
+
module_out_features, module_in_features = module_weight.shape
|
1997
|
+
debug_message = ""
|
1998
|
+
if in_features > module_in_features:
|
1999
|
+
debug_message += (
|
2000
|
+
f'Expanding the nn.Linear input/output features for module="{name}" because the provided LoRA '
|
2001
|
+
f"checkpoint contains higher number of features than expected. The number of input_features will be "
|
2002
|
+
f"expanded from {module_in_features} to {in_features}"
|
2003
|
+
)
|
2004
|
+
if out_features > module_out_features:
|
2005
|
+
debug_message += (
|
2006
|
+
", and the number of output features will be "
|
2007
|
+
f"expanded from {module_out_features} to {out_features}."
|
2008
|
+
)
|
2009
|
+
else:
|
2010
|
+
debug_message += "."
|
2011
|
+
if debug_message:
|
2012
|
+
logger.debug(debug_message)
|
2013
|
+
|
2014
|
+
if out_features > module_out_features or in_features > module_in_features:
|
2015
|
+
has_param_with_shape_update = True
|
2016
|
+
parent_module_name, _, current_module_name = name.rpartition(".")
|
2017
|
+
parent_module = transformer.get_submodule(parent_module_name)
|
2018
|
+
|
2019
|
+
with torch.device("meta"):
|
2020
|
+
expanded_module = torch.nn.Linear(
|
2021
|
+
in_features, out_features, bias=bias, dtype=module_weight.dtype
|
2022
|
+
)
|
2023
|
+
# Only weights are expanded and biases are not. This is because only the input dimensions
|
2024
|
+
# are changed while the output dimensions remain the same. The shape of the weight tensor
|
2025
|
+
# is (out_features, in_features), while the shape of bias tensor is (out_features,), which
|
2026
|
+
# explains the reason why only weights are expanded.
|
2027
|
+
new_weight = torch.zeros_like(
|
2028
|
+
expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype
|
2029
|
+
)
|
2030
|
+
slices = tuple(slice(0, dim) for dim in module_weight.shape)
|
2031
|
+
new_weight[slices] = module_weight
|
2032
|
+
tmp_state_dict = {"weight": new_weight}
|
2033
|
+
if module_bias is not None:
|
2034
|
+
tmp_state_dict["bias"] = module_bias
|
2035
|
+
expanded_module.load_state_dict(tmp_state_dict, strict=True, assign=True)
|
2036
|
+
|
2037
|
+
setattr(parent_module, current_module_name, expanded_module)
|
2038
|
+
|
2039
|
+
del tmp_state_dict
|
2040
|
+
|
2041
|
+
if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX:
|
2042
|
+
attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name]
|
2043
|
+
new_value = int(expanded_module.weight.data.shape[1])
|
2044
|
+
old_value = getattr(transformer.config, attribute_name)
|
2045
|
+
setattr(transformer.config, attribute_name, new_value)
|
2046
|
+
logger.info(
|
2047
|
+
f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}."
|
2048
|
+
)
|
2049
|
+
|
2050
|
+
# For `unload_lora_weights()`.
|
2051
|
+
# TODO: this could lead to more memory overhead if the number of overwritten params
|
2052
|
+
# are large. Should be revisited later and tackled through a `discard_original_layers` arg.
|
2053
|
+
overwritten_params[f"{current_module_name}.weight"] = module_weight
|
2054
|
+
if module_bias is not None:
|
2055
|
+
overwritten_params[f"{current_module_name}.bias"] = module_bias
|
2056
|
+
|
2057
|
+
if len(overwritten_params) > 0:
|
2058
|
+
transformer._overwritten_params = overwritten_params
|
2059
|
+
|
2060
|
+
return has_param_with_shape_update
|
2061
|
+
|
2062
|
+
@classmethod
|
2063
|
+
def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
|
2064
|
+
expanded_module_names = set()
|
2065
|
+
transformer_state_dict = transformer.state_dict()
|
2066
|
+
prefix = f"{cls.transformer_name}."
|
2067
|
+
|
2068
|
+
lora_module_names = [
|
2069
|
+
key[: -len(".lora_A.weight")] for key in lora_state_dict if key.endswith(".lora_A.weight")
|
2070
|
+
]
|
2071
|
+
lora_module_names = [name[len(prefix) :] for name in lora_module_names if name.startswith(prefix)]
|
2072
|
+
lora_module_names = sorted(set(lora_module_names))
|
2073
|
+
transformer_module_names = sorted({name for name, _ in transformer.named_modules()})
|
2074
|
+
unexpected_modules = set(lora_module_names) - set(transformer_module_names)
|
2075
|
+
if unexpected_modules:
|
2076
|
+
logger.debug(f"Found unexpected modules: {unexpected_modules}. These will be ignored.")
|
2077
|
+
|
2078
|
+
is_peft_loaded = getattr(transformer, "peft_config", None) is not None
|
2079
|
+
for k in lora_module_names:
|
2080
|
+
if k in unexpected_modules:
|
2081
|
+
continue
|
2082
|
+
|
2083
|
+
base_param_name = (
|
2084
|
+
f"{k.replace(prefix, '')}.base_layer.weight"
|
2085
|
+
if is_peft_loaded and f"{k.replace(prefix, '')}.base_layer.weight" in transformer_state_dict
|
2086
|
+
else f"{k.replace(prefix, '')}.weight"
|
2087
|
+
)
|
2088
|
+
base_weight_param = transformer_state_dict[base_param_name]
|
2089
|
+
lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"]
|
2090
|
+
|
2091
|
+
# TODO (sayakpaul): Handle the cases when we actually need to expand when using quantization.
|
2092
|
+
base_module_shape = cls._calculate_module_shape(model=transformer, base_weight_param_name=base_param_name)
|
2093
|
+
|
2094
|
+
if base_module_shape[1] > lora_A_param.shape[1]:
|
2095
|
+
shape = (lora_A_param.shape[0], base_weight_param.shape[1])
|
2096
|
+
expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device)
|
2097
|
+
expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param)
|
2098
|
+
lora_state_dict[f"{prefix}{k}.lora_A.weight"] = expanded_state_dict_weight
|
2099
|
+
expanded_module_names.add(k)
|
2100
|
+
elif base_module_shape[1] < lora_A_param.shape[1]:
|
2101
|
+
raise NotImplementedError(
|
2102
|
+
f"This LoRA param ({k}.lora_A.weight) has an incompatible shape {lora_A_param.shape}. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new."
|
2103
|
+
)
|
2104
|
+
|
2105
|
+
if expanded_module_names:
|
2106
|
+
logger.info(
|
2107
|
+
f"The following LoRA modules were zero padded to match the state dict of {cls.transformer_name}: {expanded_module_names}. Please open an issue if you think this was unexpected - https://github.com/huggingface/diffusers/issues/new."
|
2108
|
+
)
|
2109
|
+
|
2110
|
+
return lora_state_dict
|
2111
|
+
|
2112
|
+
@staticmethod
|
2113
|
+
def _calculate_module_shape(
|
2114
|
+
model: "torch.nn.Module",
|
2115
|
+
base_module: "torch.nn.Linear" = None,
|
2116
|
+
base_weight_param_name: str = None,
|
2117
|
+
) -> "torch.Size":
|
2118
|
+
def _get_weight_shape(weight: torch.Tensor):
|
2119
|
+
return weight.quant_state.shape if weight.__class__.__name__ == "Params4bit" else weight.shape
|
2120
|
+
|
2121
|
+
if base_module is not None:
|
2122
|
+
return _get_weight_shape(base_module.weight)
|
2123
|
+
elif base_weight_param_name is not None:
|
2124
|
+
if not base_weight_param_name.endswith(".weight"):
|
2125
|
+
raise ValueError(
|
2126
|
+
f"Invalid `base_weight_param_name` passed as it does not end with '.weight' {base_weight_param_name=}."
|
2127
|
+
)
|
2128
|
+
module_path = base_weight_param_name.rsplit(".weight", 1)[0]
|
2129
|
+
submodule = get_submodule_by_name(model, module_path)
|
2130
|
+
return _get_weight_shape(submodule.weight)
|
2131
|
+
|
2132
|
+
raise ValueError("Either `base_module` or `base_weight_param_name` must be provided.")
|
2133
|
+
|
2134
|
+
|
2135
|
+
# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
|
2136
|
+
# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support.
|
2137
|
+
class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
2138
|
+
_lora_loadable_modules = ["transformer", "text_encoder"]
|
2139
|
+
transformer_name = TRANSFORMER_NAME
|
2140
|
+
text_encoder_name = TEXT_ENCODER_NAME
|
2141
|
+
|
2142
|
+
@classmethod
|
2143
|
+
# Copied from diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin.load_lora_into_transformer with FluxTransformer2DModel->UVit2DModel
|
2144
|
+
def load_lora_into_transformer(
|
2145
|
+
cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
2146
|
+
):
|
2147
|
+
"""
|
2148
|
+
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
2149
|
+
|
2150
|
+
Parameters:
|
2151
|
+
state_dict (`dict`):
|
2152
|
+
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
2153
|
+
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
2154
|
+
encoder lora layers.
|
2155
|
+
network_alphas (`Dict[str, float]`):
|
2156
|
+
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
2157
|
+
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
2158
|
+
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
2159
|
+
transformer (`UVit2DModel`):
|
2160
|
+
The Transformer model to load the LoRA layers into.
|
2161
|
+
adapter_name (`str`, *optional*):
|
2162
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
2163
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
2164
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
2165
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
2166
|
+
weights.
|
2167
|
+
"""
|
2168
|
+
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
2169
|
+
raise ValueError(
|
2170
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
2171
|
+
)
|
2172
|
+
|
2173
|
+
# Load the layers corresponding to transformer.
|
2174
|
+
keys = list(state_dict.keys())
|
2175
|
+
transformer_present = any(key.startswith(cls.transformer_name) for key in keys)
|
2176
|
+
if transformer_present:
|
2177
|
+
logger.info(f"Loading {cls.transformer_name}.")
|
2178
|
+
transformer.load_lora_adapter(
|
2179
|
+
state_dict,
|
2180
|
+
network_alphas=network_alphas,
|
2181
|
+
adapter_name=adapter_name,
|
2182
|
+
_pipeline=_pipeline,
|
2183
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
2184
|
+
)
|
2185
|
+
|
2186
|
+
@classmethod
|
2187
|
+
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
|
2188
|
+
def load_lora_into_text_encoder(
|
2189
|
+
cls,
|
2190
|
+
state_dict,
|
2191
|
+
network_alphas,
|
2192
|
+
text_encoder,
|
2193
|
+
prefix=None,
|
2194
|
+
lora_scale=1.0,
|
2195
|
+
adapter_name=None,
|
2196
|
+
_pipeline=None,
|
2197
|
+
low_cpu_mem_usage=False,
|
2198
|
+
):
|
2199
|
+
"""
|
2200
|
+
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
2201
|
+
|
2202
|
+
Parameters:
|
2203
|
+
state_dict (`dict`):
|
2204
|
+
A standard state dict containing the lora layer parameters. The key should be prefixed with an
|
2205
|
+
additional `text_encoder` to distinguish between unet lora layers.
|
2206
|
+
network_alphas (`Dict[str, float]`):
|
2207
|
+
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
2208
|
+
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
2209
|
+
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
2210
|
+
text_encoder (`CLIPTextModel`):
|
2211
|
+
The text encoder model to load the LoRA layers into.
|
2212
|
+
prefix (`str`):
|
2213
|
+
Expected prefix of the `text_encoder` in the `state_dict`.
|
2214
|
+
lora_scale (`float`):
|
2215
|
+
How much to scale the output of the lora linear layer before it is added with the output of the regular
|
2216
|
+
lora layer.
|
2217
|
+
adapter_name (`str`, *optional*):
|
2218
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
2219
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
2220
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
2221
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
2222
|
+
weights.
|
2223
|
+
"""
|
2224
|
+
_load_lora_into_text_encoder(
|
2225
|
+
state_dict=state_dict,
|
2226
|
+
network_alphas=network_alphas,
|
2227
|
+
lora_scale=lora_scale,
|
2228
|
+
text_encoder=text_encoder,
|
2229
|
+
prefix=prefix,
|
2230
|
+
text_encoder_name=cls.text_encoder_name,
|
2231
|
+
adapter_name=adapter_name,
|
2232
|
+
_pipeline=_pipeline,
|
2233
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
2234
|
+
)
|
2235
|
+
|
2236
|
+
@classmethod
|
2237
|
+
def save_lora_weights(
|
2238
|
+
cls,
|
2239
|
+
save_directory: Union[str, os.PathLike],
|
2240
|
+
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
|
2241
|
+
transformer_lora_layers: Dict[str, torch.nn.Module] = None,
|
2242
|
+
is_main_process: bool = True,
|
2243
|
+
weight_name: str = None,
|
2244
|
+
save_function: Callable = None,
|
2245
|
+
safe_serialization: bool = True,
|
2246
|
+
):
|
2247
|
+
r"""
|
2248
|
+
Save the LoRA parameters corresponding to the UNet and text encoder.
|
2249
|
+
|
2250
|
+
Arguments:
|
2251
|
+
save_directory (`str` or `os.PathLike`):
|
2252
|
+
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
2253
|
+
unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
2254
|
+
State dict of the LoRA layers corresponding to the `unet`.
|
2255
|
+
text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
2256
|
+
State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
|
2257
|
+
encoder LoRA state dict because it comes from 🤗 Transformers.
|
2258
|
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
2259
|
+
Whether the process calling this is the main process or not. Useful during distributed training and you
|
2260
|
+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
2261
|
+
process to avoid race conditions.
|
2262
|
+
save_function (`Callable`):
|
2263
|
+
The function to use to save the state dictionary. Useful during distributed training when you need to
|
2264
|
+
replace `torch.save` with another method. Can be configured with the environment variable
|
2265
|
+
`DIFFUSERS_SAVE_MODE`.
|
2266
|
+
safe_serialization (`bool`, *optional*, defaults to `True`):
|
2267
|
+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
2268
|
+
"""
|
2269
|
+
state_dict = {}
|
2270
|
+
|
2271
|
+
if not (transformer_lora_layers or text_encoder_lora_layers):
|
2272
|
+
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
|
2273
|
+
|
2274
|
+
if transformer_lora_layers:
|
2275
|
+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
2276
|
+
|
2277
|
+
if text_encoder_lora_layers:
|
2278
|
+
state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
|
2279
|
+
|
2280
|
+
# Save the model
|
2281
|
+
cls.write_lora_layers(
|
2282
|
+
state_dict=state_dict,
|
2283
|
+
save_directory=save_directory,
|
2284
|
+
is_main_process=is_main_process,
|
2285
|
+
weight_name=weight_name,
|
2286
|
+
save_function=save_function,
|
2287
|
+
safe_serialization=safe_serialization,
|
2288
|
+
)
|
2289
|
+
|
2290
|
+
|
2291
|
+
class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
2292
|
+
r"""
|
2293
|
+
Load LoRA layers into [`CogVideoXTransformer3DModel`]. Specific to [`CogVideoXPipeline`].
|
2294
|
+
"""
|
2295
|
+
|
2296
|
+
_lora_loadable_modules = ["transformer"]
|
2297
|
+
transformer_name = TRANSFORMER_NAME
|
2298
|
+
|
2299
|
+
@classmethod
|
2300
|
+
@validate_hf_hub_args
|
2301
|
+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
|
2302
|
+
def lora_state_dict(
|
2303
|
+
cls,
|
2304
|
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
2305
|
+
**kwargs,
|
2306
|
+
):
|
2307
|
+
r"""
|
2308
|
+
Return state dict for lora weights and the network alphas.
|
2309
|
+
|
2310
|
+
<Tip warning={true}>
|
2311
|
+
|
2312
|
+
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
|
2313
|
+
|
2314
|
+
This function is experimental and might change in the future.
|
2315
|
+
|
2316
|
+
</Tip>
|
2317
|
+
|
2318
|
+
Parameters:
|
2319
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
2320
|
+
Can be either:
|
2321
|
+
|
2322
|
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
2323
|
+
the Hub.
|
2324
|
+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
2325
|
+
with [`ModelMixin.save_pretrained`].
|
2326
|
+
- A [torch state
|
2327
|
+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
2328
|
+
|
2329
|
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
2330
|
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
2331
|
+
is not used.
|
2332
|
+
force_download (`bool`, *optional*, defaults to `False`):
|
2333
|
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
2334
|
+
cached versions if they exist.
|
2335
|
+
|
2336
|
+
proxies (`Dict[str, str]`, *optional*):
|
2337
|
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
2338
|
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
2339
|
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
2340
|
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
2341
|
+
won't be downloaded from the Hub.
|
2342
|
+
token (`str` or *bool*, *optional*):
|
2343
|
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
2344
|
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
2345
|
+
revision (`str`, *optional*, defaults to `"main"`):
|
2346
|
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
2347
|
+
allowed by Git.
|
2348
|
+
subfolder (`str`, *optional*, defaults to `""`):
|
2349
|
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
2350
|
+
|
2351
|
+
"""
|
2352
|
+
# Load the main state dict first which has the LoRA layers for either of
|
2353
|
+
# transformer and text encoder or both.
|
2354
|
+
cache_dir = kwargs.pop("cache_dir", None)
|
2355
|
+
force_download = kwargs.pop("force_download", False)
|
2356
|
+
proxies = kwargs.pop("proxies", None)
|
2357
|
+
local_files_only = kwargs.pop("local_files_only", None)
|
2358
|
+
token = kwargs.pop("token", None)
|
2359
|
+
revision = kwargs.pop("revision", None)
|
2360
|
+
subfolder = kwargs.pop("subfolder", None)
|
2361
|
+
weight_name = kwargs.pop("weight_name", None)
|
2362
|
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
2363
|
+
|
2364
|
+
allow_pickle = False
|
2365
|
+
if use_safetensors is None:
|
2366
|
+
use_safetensors = True
|
2367
|
+
allow_pickle = True
|
2368
|
+
|
2369
|
+
user_agent = {
|
2370
|
+
"file_type": "attn_procs_weights",
|
2371
|
+
"framework": "pytorch",
|
2372
|
+
}
|
2373
|
+
|
2374
|
+
state_dict = _fetch_state_dict(
|
2375
|
+
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
2376
|
+
weight_name=weight_name,
|
2377
|
+
use_safetensors=use_safetensors,
|
2378
|
+
local_files_only=local_files_only,
|
2379
|
+
cache_dir=cache_dir,
|
2380
|
+
force_download=force_download,
|
2381
|
+
proxies=proxies,
|
2382
|
+
token=token,
|
2383
|
+
revision=revision,
|
2384
|
+
subfolder=subfolder,
|
2385
|
+
user_agent=user_agent,
|
2386
|
+
allow_pickle=allow_pickle,
|
2387
|
+
)
|
2388
|
+
|
2389
|
+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
2390
|
+
if is_dora_scale_present:
|
2391
|
+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
2392
|
+
logger.warning(warn_msg)
|
2393
|
+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
2394
|
+
|
2395
|
+
return state_dict
|
2396
|
+
|
2397
|
+
def load_lora_weights(
|
2398
|
+
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
2399
|
+
):
|
2400
|
+
"""
|
2401
|
+
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
2402
|
+
`self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
|
2403
|
+
[`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
|
2404
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
|
2405
|
+
dict is loaded into `self.transformer`.
|
2406
|
+
|
2407
|
+
Parameters:
|
2408
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
2409
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
2410
|
+
adapter_name (`str`, *optional*):
|
2411
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
2412
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
2413
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
2414
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
2415
|
+
weights.
|
2416
|
+
kwargs (`dict`, *optional*):
|
2417
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
2418
|
+
"""
|
2419
|
+
if not USE_PEFT_BACKEND:
|
2420
|
+
raise ValueError("PEFT backend is required for this method.")
|
2421
|
+
|
2422
|
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
2423
|
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
2424
|
+
raise ValueError(
|
2425
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
2426
|
+
)
|
2427
|
+
|
2428
|
+
# if a dict is passed, copy it instead of modifying it inplace
|
2429
|
+
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
2430
|
+
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
2431
|
+
|
2432
|
+
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
2433
|
+
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
2434
|
+
|
2435
|
+
is_correct_format = all("lora" in key for key in state_dict.keys())
|
2436
|
+
if not is_correct_format:
|
2437
|
+
raise ValueError("Invalid LoRA checkpoint.")
|
2438
|
+
|
2439
|
+
self.load_lora_into_transformer(
|
2440
|
+
state_dict,
|
2441
|
+
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
2442
|
+
adapter_name=adapter_name,
|
2443
|
+
_pipeline=self,
|
2444
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
2445
|
+
)
|
2446
|
+
|
2447
|
+
@classmethod
|
2448
|
+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogVideoXTransformer3DModel
|
2449
|
+
def load_lora_into_transformer(
|
2450
|
+
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
2451
|
+
):
|
2452
|
+
"""
|
2453
|
+
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
2454
|
+
|
2455
|
+
Parameters:
|
2456
|
+
state_dict (`dict`):
|
2457
|
+
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
2458
|
+
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
2459
|
+
encoder lora layers.
|
2460
|
+
transformer (`CogVideoXTransformer3DModel`):
|
2461
|
+
The Transformer model to load the LoRA layers into.
|
2462
|
+
adapter_name (`str`, *optional*):
|
2463
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
2464
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
2465
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
2466
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
2467
|
+
weights.
|
2468
|
+
"""
|
2469
|
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
2470
|
+
raise ValueError(
|
2471
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
2472
|
+
)
|
2473
|
+
|
2474
|
+
# Load the layers corresponding to transformer.
|
2475
|
+
logger.info(f"Loading {cls.transformer_name}.")
|
2476
|
+
transformer.load_lora_adapter(
|
2477
|
+
state_dict,
|
2478
|
+
network_alphas=None,
|
2479
|
+
adapter_name=adapter_name,
|
2480
|
+
_pipeline=_pipeline,
|
2481
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
2482
|
+
)
|
2483
|
+
|
2484
|
+
@classmethod
|
2485
|
+
# Adapted from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights without support for text encoder
|
2486
|
+
def save_lora_weights(
|
2487
|
+
cls,
|
2488
|
+
save_directory: Union[str, os.PathLike],
|
2489
|
+
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
2490
|
+
is_main_process: bool = True,
|
2491
|
+
weight_name: str = None,
|
2492
|
+
save_function: Callable = None,
|
2493
|
+
safe_serialization: bool = True,
|
2494
|
+
):
|
2495
|
+
r"""
|
2496
|
+
Save the LoRA parameters corresponding to the UNet and text encoder.
|
2497
|
+
|
2498
|
+
Arguments:
|
2499
|
+
save_directory (`str` or `os.PathLike`):
|
2500
|
+
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
2501
|
+
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
2502
|
+
State dict of the LoRA layers corresponding to the `transformer`.
|
2503
|
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
2504
|
+
Whether the process calling this is the main process or not. Useful during distributed training and you
|
2505
|
+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
2506
|
+
process to avoid race conditions.
|
2507
|
+
save_function (`Callable`):
|
2508
|
+
The function to use to save the state dictionary. Useful during distributed training when you need to
|
2509
|
+
replace `torch.save` with another method. Can be configured with the environment variable
|
2510
|
+
`DIFFUSERS_SAVE_MODE`.
|
2511
|
+
safe_serialization (`bool`, *optional*, defaults to `True`):
|
2512
|
+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
2513
|
+
"""
|
2514
|
+
state_dict = {}
|
2515
|
+
|
2516
|
+
if not transformer_lora_layers:
|
2517
|
+
raise ValueError("You must pass `transformer_lora_layers`.")
|
2518
|
+
|
2519
|
+
if transformer_lora_layers:
|
2520
|
+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
2521
|
+
|
2522
|
+
# Save the model
|
2523
|
+
cls.write_lora_layers(
|
2524
|
+
state_dict=state_dict,
|
2525
|
+
save_directory=save_directory,
|
2526
|
+
is_main_process=is_main_process,
|
2527
|
+
weight_name=weight_name,
|
2528
|
+
save_function=save_function,
|
2529
|
+
safe_serialization=safe_serialization,
|
2530
|
+
)
|
2531
|
+
|
2532
|
+
def fuse_lora(
|
2533
|
+
self,
|
2534
|
+
components: List[str] = ["transformer"],
|
2535
|
+
lora_scale: float = 1.0,
|
2536
|
+
safe_fusing: bool = False,
|
2537
|
+
adapter_names: Optional[List[str]] = None,
|
2538
|
+
**kwargs,
|
2539
|
+
):
|
2540
|
+
r"""
|
2541
|
+
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
2542
|
+
|
2543
|
+
<Tip warning={true}>
|
2544
|
+
|
2545
|
+
This is an experimental API.
|
2546
|
+
|
2547
|
+
</Tip>
|
2548
|
+
|
2549
|
+
Args:
|
2550
|
+
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
2551
|
+
lora_scale (`float`, defaults to 1.0):
|
2552
|
+
Controls how much to influence the outputs with the LoRA parameters.
|
2553
|
+
safe_fusing (`bool`, defaults to `False`):
|
2554
|
+
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
2555
|
+
adapter_names (`List[str]`, *optional*):
|
2556
|
+
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
2557
|
+
|
2558
|
+
Example:
|
2559
|
+
|
2560
|
+
```py
|
2561
|
+
from diffusers import DiffusionPipeline
|
2562
|
+
import torch
|
2563
|
+
|
2564
|
+
pipeline = DiffusionPipeline.from_pretrained(
|
2565
|
+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
2566
|
+
).to("cuda")
|
2567
|
+
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
2568
|
+
pipeline.fuse_lora(lora_scale=0.7)
|
2569
|
+
```
|
2570
|
+
"""
|
2571
|
+
super().fuse_lora(
|
2572
|
+
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
2573
|
+
)
|
2574
|
+
|
2575
|
+
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
2576
|
+
r"""
|
2577
|
+
Reverses the effect of
|
2578
|
+
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
2579
|
+
|
2580
|
+
<Tip warning={true}>
|
2581
|
+
|
2582
|
+
This is an experimental API.
|
2583
|
+
|
2584
|
+
</Tip>
|
2585
|
+
|
2586
|
+
Args:
|
2587
|
+
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
2588
|
+
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
2589
|
+
"""
|
2590
|
+
super().unfuse_lora(components=components)
|
2591
|
+
|
2592
|
+
|
2593
|
+
class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
2594
|
+
r"""
|
2595
|
+
Load LoRA layers into [`MochiTransformer3DModel`]. Specific to [`MochiPipeline`].
|
2596
|
+
"""
|
2597
|
+
|
2598
|
+
_lora_loadable_modules = ["transformer"]
|
2599
|
+
transformer_name = TRANSFORMER_NAME
|
2600
|
+
|
2601
|
+
@classmethod
|
2602
|
+
@validate_hf_hub_args
|
2603
|
+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
|
2604
|
+
def lora_state_dict(
|
2605
|
+
cls,
|
2606
|
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
2607
|
+
**kwargs,
|
2608
|
+
):
|
2609
|
+
r"""
|
2610
|
+
Return state dict for lora weights and the network alphas.
|
2611
|
+
|
2612
|
+
<Tip warning={true}>
|
2613
|
+
|
2614
|
+
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
|
2615
|
+
|
2616
|
+
This function is experimental and might change in the future.
|
2617
|
+
|
2618
|
+
</Tip>
|
2619
|
+
|
2620
|
+
Parameters:
|
2621
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
2622
|
+
Can be either:
|
2623
|
+
|
2624
|
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
2625
|
+
the Hub.
|
2626
|
+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
2627
|
+
with [`ModelMixin.save_pretrained`].
|
2628
|
+
- A [torch state
|
2629
|
+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
2630
|
+
|
2631
|
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
2632
|
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
2633
|
+
is not used.
|
2634
|
+
force_download (`bool`, *optional*, defaults to `False`):
|
2635
|
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
2636
|
+
cached versions if they exist.
|
2637
|
+
|
2638
|
+
proxies (`Dict[str, str]`, *optional*):
|
2639
|
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
2640
|
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
2641
|
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
2642
|
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
2643
|
+
won't be downloaded from the Hub.
|
2644
|
+
token (`str` or *bool*, *optional*):
|
2645
|
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
2646
|
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
2647
|
+
revision (`str`, *optional*, defaults to `"main"`):
|
2648
|
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
2649
|
+
allowed by Git.
|
2650
|
+
subfolder (`str`, *optional*, defaults to `""`):
|
2651
|
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
2652
|
+
|
2653
|
+
"""
|
2654
|
+
# Load the main state dict first which has the LoRA layers for either of
|
2655
|
+
# transformer and text encoder or both.
|
2656
|
+
cache_dir = kwargs.pop("cache_dir", None)
|
2657
|
+
force_download = kwargs.pop("force_download", False)
|
2658
|
+
proxies = kwargs.pop("proxies", None)
|
2659
|
+
local_files_only = kwargs.pop("local_files_only", None)
|
2660
|
+
token = kwargs.pop("token", None)
|
2661
|
+
revision = kwargs.pop("revision", None)
|
2662
|
+
subfolder = kwargs.pop("subfolder", None)
|
2663
|
+
weight_name = kwargs.pop("weight_name", None)
|
2664
|
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
2665
|
+
|
2666
|
+
allow_pickle = False
|
2667
|
+
if use_safetensors is None:
|
2668
|
+
use_safetensors = True
|
2669
|
+
allow_pickle = True
|
2670
|
+
|
2671
|
+
user_agent = {
|
2672
|
+
"file_type": "attn_procs_weights",
|
2673
|
+
"framework": "pytorch",
|
2674
|
+
}
|
2675
|
+
|
2676
|
+
state_dict = _fetch_state_dict(
|
2677
|
+
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
2678
|
+
weight_name=weight_name,
|
2679
|
+
use_safetensors=use_safetensors,
|
2680
|
+
local_files_only=local_files_only,
|
2681
|
+
cache_dir=cache_dir,
|
2682
|
+
force_download=force_download,
|
2683
|
+
proxies=proxies,
|
2684
|
+
token=token,
|
2685
|
+
revision=revision,
|
2686
|
+
subfolder=subfolder,
|
2687
|
+
user_agent=user_agent,
|
2688
|
+
allow_pickle=allow_pickle,
|
2689
|
+
)
|
2690
|
+
|
2691
|
+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
2692
|
+
if is_dora_scale_present:
|
2693
|
+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
2694
|
+
logger.warning(warn_msg)
|
2695
|
+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
2696
|
+
|
2697
|
+
return state_dict
|
2698
|
+
|
2699
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
2700
|
+
def load_lora_weights(
|
2701
|
+
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
2702
|
+
):
|
2703
|
+
"""
|
2704
|
+
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
2705
|
+
`self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
|
2706
|
+
[`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
|
2707
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
|
2708
|
+
dict is loaded into `self.transformer`.
|
2709
|
+
|
2710
|
+
Parameters:
|
2711
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
2712
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
2713
|
+
adapter_name (`str`, *optional*):
|
2714
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
2715
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
2716
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
2717
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
2718
|
+
weights.
|
2719
|
+
kwargs (`dict`, *optional*):
|
2720
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
2721
|
+
"""
|
2722
|
+
if not USE_PEFT_BACKEND:
|
2723
|
+
raise ValueError("PEFT backend is required for this method.")
|
2724
|
+
|
2725
|
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
2726
|
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
2727
|
+
raise ValueError(
|
2728
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
2729
|
+
)
|
2730
|
+
|
2731
|
+
# if a dict is passed, copy it instead of modifying it inplace
|
2732
|
+
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
2733
|
+
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
2734
|
+
|
2735
|
+
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
2736
|
+
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
2737
|
+
|
2738
|
+
is_correct_format = all("lora" in key for key in state_dict.keys())
|
2739
|
+
if not is_correct_format:
|
2740
|
+
raise ValueError("Invalid LoRA checkpoint.")
|
2741
|
+
|
2742
|
+
self.load_lora_into_transformer(
|
2743
|
+
state_dict,
|
2744
|
+
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
2745
|
+
adapter_name=adapter_name,
|
2746
|
+
_pipeline=self,
|
2747
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
2748
|
+
)
|
2749
|
+
|
2750
|
+
@classmethod
|
2751
|
+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->MochiTransformer3DModel
|
2752
|
+
def load_lora_into_transformer(
|
2753
|
+
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
2754
|
+
):
|
2755
|
+
"""
|
2756
|
+
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
2757
|
+
|
2758
|
+
Parameters:
|
2759
|
+
state_dict (`dict`):
|
2760
|
+
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
2761
|
+
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
2762
|
+
encoder lora layers.
|
2763
|
+
transformer (`MochiTransformer3DModel`):
|
2764
|
+
The Transformer model to load the LoRA layers into.
|
2765
|
+
adapter_name (`str`, *optional*):
|
2766
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
2767
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
2768
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
2769
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
2770
|
+
weights.
|
2771
|
+
"""
|
2772
|
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
2773
|
+
raise ValueError(
|
2774
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
2775
|
+
)
|
2776
|
+
|
2777
|
+
# Load the layers corresponding to transformer.
|
2778
|
+
logger.info(f"Loading {cls.transformer_name}.")
|
2779
|
+
transformer.load_lora_adapter(
|
2780
|
+
state_dict,
|
2781
|
+
network_alphas=None,
|
2782
|
+
adapter_name=adapter_name,
|
2783
|
+
_pipeline=_pipeline,
|
2784
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
2785
|
+
)
|
2786
|
+
|
2787
|
+
@classmethod
|
2788
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
|
2789
|
+
def save_lora_weights(
|
2790
|
+
cls,
|
2791
|
+
save_directory: Union[str, os.PathLike],
|
2792
|
+
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
2793
|
+
is_main_process: bool = True,
|
2794
|
+
weight_name: str = None,
|
2795
|
+
save_function: Callable = None,
|
2796
|
+
safe_serialization: bool = True,
|
2797
|
+
):
|
2798
|
+
r"""
|
2799
|
+
Save the LoRA parameters corresponding to the UNet and text encoder.
|
2800
|
+
|
2801
|
+
Arguments:
|
2802
|
+
save_directory (`str` or `os.PathLike`):
|
2803
|
+
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
2804
|
+
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
2805
|
+
State dict of the LoRA layers corresponding to the `transformer`.
|
2806
|
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
2807
|
+
Whether the process calling this is the main process or not. Useful during distributed training and you
|
2808
|
+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
2809
|
+
process to avoid race conditions.
|
2810
|
+
save_function (`Callable`):
|
2811
|
+
The function to use to save the state dictionary. Useful during distributed training when you need to
|
2812
|
+
replace `torch.save` with another method. Can be configured with the environment variable
|
2813
|
+
`DIFFUSERS_SAVE_MODE`.
|
2814
|
+
safe_serialization (`bool`, *optional*, defaults to `True`):
|
2815
|
+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
2816
|
+
"""
|
2817
|
+
state_dict = {}
|
2818
|
+
|
2819
|
+
if not transformer_lora_layers:
|
2820
|
+
raise ValueError("You must pass `transformer_lora_layers`.")
|
2821
|
+
|
2822
|
+
if transformer_lora_layers:
|
2823
|
+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
2824
|
+
|
2825
|
+
# Save the model
|
2826
|
+
cls.write_lora_layers(
|
2827
|
+
state_dict=state_dict,
|
2828
|
+
save_directory=save_directory,
|
2829
|
+
is_main_process=is_main_process,
|
2830
|
+
weight_name=weight_name,
|
2831
|
+
save_function=save_function,
|
2832
|
+
safe_serialization=safe_serialization,
|
2833
|
+
)
|
2834
|
+
|
2835
|
+
def fuse_lora(
|
2836
|
+
self,
|
2837
|
+
components: List[str] = ["transformer"],
|
2838
|
+
lora_scale: float = 1.0,
|
2839
|
+
safe_fusing: bool = False,
|
2840
|
+
adapter_names: Optional[List[str]] = None,
|
2841
|
+
**kwargs,
|
2842
|
+
):
|
2843
|
+
r"""
|
2844
|
+
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
2845
|
+
|
2846
|
+
<Tip warning={true}>
|
2847
|
+
|
2848
|
+
This is an experimental API.
|
2849
|
+
|
2850
|
+
</Tip>
|
2851
|
+
|
2852
|
+
Args:
|
2853
|
+
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
2854
|
+
lora_scale (`float`, defaults to 1.0):
|
2855
|
+
Controls how much to influence the outputs with the LoRA parameters.
|
2856
|
+
safe_fusing (`bool`, defaults to `False`):
|
2857
|
+
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
2858
|
+
adapter_names (`List[str]`, *optional*):
|
2859
|
+
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
2860
|
+
|
2861
|
+
Example:
|
2862
|
+
|
2863
|
+
```py
|
2864
|
+
from diffusers import DiffusionPipeline
|
2865
|
+
import torch
|
2866
|
+
|
2867
|
+
pipeline = DiffusionPipeline.from_pretrained(
|
2868
|
+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
2869
|
+
).to("cuda")
|
2870
|
+
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
2871
|
+
pipeline.fuse_lora(lora_scale=0.7)
|
2872
|
+
```
|
2873
|
+
"""
|
2874
|
+
super().fuse_lora(
|
2875
|
+
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
2876
|
+
)
|
2877
|
+
|
2878
|
+
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
2879
|
+
r"""
|
2880
|
+
Reverses the effect of
|
2881
|
+
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
2882
|
+
|
2883
|
+
<Tip warning={true}>
|
2884
|
+
|
2885
|
+
This is an experimental API.
|
2886
|
+
|
2887
|
+
</Tip>
|
2888
|
+
|
2889
|
+
Args:
|
2890
|
+
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
2891
|
+
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
2892
|
+
"""
|
2893
|
+
super().unfuse_lora(components=components)
|
2894
|
+
|
2895
|
+
|
2896
|
+
class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
2897
|
+
r"""
|
2898
|
+
Load LoRA layers into [`LTXVideoTransformer3DModel`]. Specific to [`LTXPipeline`].
|
2899
|
+
"""
|
2900
|
+
|
2901
|
+
_lora_loadable_modules = ["transformer"]
|
2902
|
+
transformer_name = TRANSFORMER_NAME
|
2903
|
+
|
2904
|
+
@classmethod
|
2905
|
+
@validate_hf_hub_args
|
2906
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
|
2907
|
+
def lora_state_dict(
|
2908
|
+
cls,
|
2909
|
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
2910
|
+
**kwargs,
|
2911
|
+
):
|
2912
|
+
r"""
|
2913
|
+
Return state dict for lora weights and the network alphas.
|
2914
|
+
|
2915
|
+
<Tip warning={true}>
|
2916
|
+
|
2917
|
+
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
|
2918
|
+
|
2919
|
+
This function is experimental and might change in the future.
|
2920
|
+
|
2921
|
+
</Tip>
|
2922
|
+
|
2923
|
+
Parameters:
|
2924
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
2925
|
+
Can be either:
|
2926
|
+
|
2927
|
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
2928
|
+
the Hub.
|
2929
|
+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
2930
|
+
with [`ModelMixin.save_pretrained`].
|
2931
|
+
- A [torch state
|
2932
|
+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
2933
|
+
|
2934
|
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
2935
|
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
2936
|
+
is not used.
|
2937
|
+
force_download (`bool`, *optional*, defaults to `False`):
|
2938
|
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
2939
|
+
cached versions if they exist.
|
2940
|
+
|
2941
|
+
proxies (`Dict[str, str]`, *optional*):
|
2942
|
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
2943
|
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
2944
|
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
2945
|
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
2946
|
+
won't be downloaded from the Hub.
|
2947
|
+
token (`str` or *bool*, *optional*):
|
2948
|
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
2949
|
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
2950
|
+
revision (`str`, *optional*, defaults to `"main"`):
|
2951
|
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
2952
|
+
allowed by Git.
|
2953
|
+
subfolder (`str`, *optional*, defaults to `""`):
|
2954
|
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
2955
|
+
|
2956
|
+
"""
|
2957
|
+
# Load the main state dict first which has the LoRA layers for either of
|
2958
|
+
# transformer and text encoder or both.
|
2959
|
+
cache_dir = kwargs.pop("cache_dir", None)
|
2960
|
+
force_download = kwargs.pop("force_download", False)
|
2961
|
+
proxies = kwargs.pop("proxies", None)
|
2962
|
+
local_files_only = kwargs.pop("local_files_only", None)
|
2963
|
+
token = kwargs.pop("token", None)
|
2964
|
+
revision = kwargs.pop("revision", None)
|
2965
|
+
subfolder = kwargs.pop("subfolder", None)
|
2966
|
+
weight_name = kwargs.pop("weight_name", None)
|
2967
|
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
2968
|
+
|
2969
|
+
allow_pickle = False
|
2970
|
+
if use_safetensors is None:
|
2971
|
+
use_safetensors = True
|
2972
|
+
allow_pickle = True
|
2973
|
+
|
2974
|
+
user_agent = {
|
2975
|
+
"file_type": "attn_procs_weights",
|
2976
|
+
"framework": "pytorch",
|
2977
|
+
}
|
2978
|
+
|
2979
|
+
state_dict = _fetch_state_dict(
|
2980
|
+
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
2981
|
+
weight_name=weight_name,
|
2982
|
+
use_safetensors=use_safetensors,
|
2983
|
+
local_files_only=local_files_only,
|
2984
|
+
cache_dir=cache_dir,
|
2985
|
+
force_download=force_download,
|
2986
|
+
proxies=proxies,
|
2987
|
+
token=token,
|
2988
|
+
revision=revision,
|
2989
|
+
subfolder=subfolder,
|
2990
|
+
user_agent=user_agent,
|
2991
|
+
allow_pickle=allow_pickle,
|
2992
|
+
)
|
2993
|
+
|
2994
|
+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
2995
|
+
if is_dora_scale_present:
|
2996
|
+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
2997
|
+
logger.warning(warn_msg)
|
2998
|
+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
2999
|
+
|
3000
|
+
return state_dict
|
3001
|
+
|
3002
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
3003
|
+
def load_lora_weights(
|
3004
|
+
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
3005
|
+
):
|
3006
|
+
"""
|
3007
|
+
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
3008
|
+
`self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
|
3009
|
+
[`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
|
3010
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
|
3011
|
+
dict is loaded into `self.transformer`.
|
3012
|
+
|
3013
|
+
Parameters:
|
3014
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
3015
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
3016
|
+
adapter_name (`str`, *optional*):
|
3017
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
3018
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
3019
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
3020
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
3021
|
+
weights.
|
3022
|
+
kwargs (`dict`, *optional*):
|
3023
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
3024
|
+
"""
|
3025
|
+
if not USE_PEFT_BACKEND:
|
3026
|
+
raise ValueError("PEFT backend is required for this method.")
|
3027
|
+
|
3028
|
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
3029
|
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
3030
|
+
raise ValueError(
|
3031
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
3032
|
+
)
|
3033
|
+
|
3034
|
+
# if a dict is passed, copy it instead of modifying it inplace
|
3035
|
+
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
3036
|
+
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
3037
|
+
|
3038
|
+
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
3039
|
+
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
3040
|
+
|
3041
|
+
is_correct_format = all("lora" in key for key in state_dict.keys())
|
3042
|
+
if not is_correct_format:
|
3043
|
+
raise ValueError("Invalid LoRA checkpoint.")
|
3044
|
+
|
3045
|
+
self.load_lora_into_transformer(
|
3046
|
+
state_dict,
|
3047
|
+
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
3048
|
+
adapter_name=adapter_name,
|
3049
|
+
_pipeline=self,
|
3050
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
3051
|
+
)
|
3052
|
+
|
3053
|
+
@classmethod
|
3054
|
+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->LTXVideoTransformer3DModel
|
3055
|
+
def load_lora_into_transformer(
|
3056
|
+
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
3057
|
+
):
|
3058
|
+
"""
|
3059
|
+
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
3060
|
+
|
3061
|
+
Parameters:
|
3062
|
+
state_dict (`dict`):
|
3063
|
+
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
3064
|
+
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
3065
|
+
encoder lora layers.
|
3066
|
+
transformer (`LTXVideoTransformer3DModel`):
|
3067
|
+
The Transformer model to load the LoRA layers into.
|
3068
|
+
adapter_name (`str`, *optional*):
|
3069
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
3070
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
3071
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
3072
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
3073
|
+
weights.
|
3074
|
+
"""
|
3075
|
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
3076
|
+
raise ValueError(
|
3077
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
3078
|
+
)
|
3079
|
+
|
3080
|
+
# Load the layers corresponding to transformer.
|
3081
|
+
logger.info(f"Loading {cls.transformer_name}.")
|
3082
|
+
transformer.load_lora_adapter(
|
3083
|
+
state_dict,
|
3084
|
+
network_alphas=None,
|
3085
|
+
adapter_name=adapter_name,
|
3086
|
+
_pipeline=_pipeline,
|
3087
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
3088
|
+
)
|
3089
|
+
|
3090
|
+
@classmethod
|
3091
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
|
3092
|
+
def save_lora_weights(
|
3093
|
+
cls,
|
3094
|
+
save_directory: Union[str, os.PathLike],
|
3095
|
+
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
3096
|
+
is_main_process: bool = True,
|
3097
|
+
weight_name: str = None,
|
3098
|
+
save_function: Callable = None,
|
3099
|
+
safe_serialization: bool = True,
|
3100
|
+
):
|
3101
|
+
r"""
|
3102
|
+
Save the LoRA parameters corresponding to the UNet and text encoder.
|
3103
|
+
|
3104
|
+
Arguments:
|
3105
|
+
save_directory (`str` or `os.PathLike`):
|
3106
|
+
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
3107
|
+
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
3108
|
+
State dict of the LoRA layers corresponding to the `transformer`.
|
3109
|
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
3110
|
+
Whether the process calling this is the main process or not. Useful during distributed training and you
|
3111
|
+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
3112
|
+
process to avoid race conditions.
|
3113
|
+
save_function (`Callable`):
|
3114
|
+
The function to use to save the state dictionary. Useful during distributed training when you need to
|
3115
|
+
replace `torch.save` with another method. Can be configured with the environment variable
|
3116
|
+
`DIFFUSERS_SAVE_MODE`.
|
3117
|
+
safe_serialization (`bool`, *optional*, defaults to `True`):
|
3118
|
+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
3119
|
+
"""
|
3120
|
+
state_dict = {}
|
3121
|
+
|
3122
|
+
if not transformer_lora_layers:
|
3123
|
+
raise ValueError("You must pass `transformer_lora_layers`.")
|
3124
|
+
|
3125
|
+
if transformer_lora_layers:
|
3126
|
+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
3127
|
+
|
3128
|
+
# Save the model
|
3129
|
+
cls.write_lora_layers(
|
3130
|
+
state_dict=state_dict,
|
3131
|
+
save_directory=save_directory,
|
3132
|
+
is_main_process=is_main_process,
|
3133
|
+
weight_name=weight_name,
|
3134
|
+
save_function=save_function,
|
3135
|
+
safe_serialization=safe_serialization,
|
3136
|
+
)
|
3137
|
+
|
3138
|
+
def fuse_lora(
|
3139
|
+
self,
|
3140
|
+
components: List[str] = ["transformer"],
|
3141
|
+
lora_scale: float = 1.0,
|
3142
|
+
safe_fusing: bool = False,
|
3143
|
+
adapter_names: Optional[List[str]] = None,
|
3144
|
+
**kwargs,
|
3145
|
+
):
|
3146
|
+
r"""
|
3147
|
+
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
3148
|
+
|
3149
|
+
<Tip warning={true}>
|
3150
|
+
|
3151
|
+
This is an experimental API.
|
3152
|
+
|
3153
|
+
</Tip>
|
3154
|
+
|
3155
|
+
Args:
|
3156
|
+
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
3157
|
+
lora_scale (`float`, defaults to 1.0):
|
3158
|
+
Controls how much to influence the outputs with the LoRA parameters.
|
3159
|
+
safe_fusing (`bool`, defaults to `False`):
|
3160
|
+
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
3161
|
+
adapter_names (`List[str]`, *optional*):
|
3162
|
+
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
3163
|
+
|
3164
|
+
Example:
|
3165
|
+
|
3166
|
+
```py
|
3167
|
+
from diffusers import DiffusionPipeline
|
3168
|
+
import torch
|
3169
|
+
|
3170
|
+
pipeline = DiffusionPipeline.from_pretrained(
|
3171
|
+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
3172
|
+
).to("cuda")
|
3173
|
+
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
3174
|
+
pipeline.fuse_lora(lora_scale=0.7)
|
3175
|
+
```
|
3176
|
+
"""
|
3177
|
+
super().fuse_lora(
|
3178
|
+
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
3179
|
+
)
|
3180
|
+
|
3181
|
+
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
3182
|
+
r"""
|
3183
|
+
Reverses the effect of
|
3184
|
+
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
3185
|
+
|
3186
|
+
<Tip warning={true}>
|
3187
|
+
|
3188
|
+
This is an experimental API.
|
3189
|
+
|
3190
|
+
</Tip>
|
3191
|
+
|
3192
|
+
Args:
|
3193
|
+
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
3194
|
+
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
3195
|
+
"""
|
3196
|
+
super().unfuse_lora(components=components)
|
3197
|
+
|
3198
|
+
|
3199
|
+
class SanaLoraLoaderMixin(LoraBaseMixin):
|
3200
|
+
r"""
|
3201
|
+
Load LoRA layers into [`SanaTransformer2DModel`]. Specific to [`SanaPipeline`].
|
3202
|
+
"""
|
3203
|
+
|
3204
|
+
_lora_loadable_modules = ["transformer"]
|
3205
|
+
transformer_name = TRANSFORMER_NAME
|
3206
|
+
|
3207
|
+
@classmethod
|
3208
|
+
@validate_hf_hub_args
|
3209
|
+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
|
3210
|
+
def lora_state_dict(
|
3211
|
+
cls,
|
3212
|
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
3213
|
+
**kwargs,
|
3214
|
+
):
|
3215
|
+
r"""
|
3216
|
+
Return state dict for lora weights and the network alphas.
|
3217
|
+
|
3218
|
+
<Tip warning={true}>
|
3219
|
+
|
3220
|
+
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
|
3221
|
+
|
3222
|
+
This function is experimental and might change in the future.
|
3223
|
+
|
3224
|
+
</Tip>
|
3225
|
+
|
3226
|
+
Parameters:
|
3227
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
3228
|
+
Can be either:
|
3229
|
+
|
3230
|
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
3231
|
+
the Hub.
|
3232
|
+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
3233
|
+
with [`ModelMixin.save_pretrained`].
|
3234
|
+
- A [torch state
|
3235
|
+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
3236
|
+
|
3237
|
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
3238
|
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
3239
|
+
is not used.
|
3240
|
+
force_download (`bool`, *optional*, defaults to `False`):
|
3241
|
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
3242
|
+
cached versions if they exist.
|
3243
|
+
|
3244
|
+
proxies (`Dict[str, str]`, *optional*):
|
3245
|
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
3246
|
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
3247
|
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
3248
|
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
3249
|
+
won't be downloaded from the Hub.
|
3250
|
+
token (`str` or *bool*, *optional*):
|
3251
|
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
3252
|
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
3253
|
+
revision (`str`, *optional*, defaults to `"main"`):
|
3254
|
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
3255
|
+
allowed by Git.
|
3256
|
+
subfolder (`str`, *optional*, defaults to `""`):
|
3257
|
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
3258
|
+
|
3259
|
+
"""
|
3260
|
+
# Load the main state dict first which has the LoRA layers for either of
|
3261
|
+
# transformer and text encoder or both.
|
3262
|
+
cache_dir = kwargs.pop("cache_dir", None)
|
3263
|
+
force_download = kwargs.pop("force_download", False)
|
3264
|
+
proxies = kwargs.pop("proxies", None)
|
3265
|
+
local_files_only = kwargs.pop("local_files_only", None)
|
3266
|
+
token = kwargs.pop("token", None)
|
3267
|
+
revision = kwargs.pop("revision", None)
|
3268
|
+
subfolder = kwargs.pop("subfolder", None)
|
3269
|
+
weight_name = kwargs.pop("weight_name", None)
|
3270
|
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
3271
|
+
|
3272
|
+
allow_pickle = False
|
3273
|
+
if use_safetensors is None:
|
3274
|
+
use_safetensors = True
|
3275
|
+
allow_pickle = True
|
3276
|
+
|
3277
|
+
user_agent = {
|
3278
|
+
"file_type": "attn_procs_weights",
|
3279
|
+
"framework": "pytorch",
|
3280
|
+
}
|
3281
|
+
|
3282
|
+
state_dict = _fetch_state_dict(
|
3283
|
+
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
3284
|
+
weight_name=weight_name,
|
3285
|
+
use_safetensors=use_safetensors,
|
3286
|
+
local_files_only=local_files_only,
|
3287
|
+
cache_dir=cache_dir,
|
3288
|
+
force_download=force_download,
|
3289
|
+
proxies=proxies,
|
3290
|
+
token=token,
|
3291
|
+
revision=revision,
|
3292
|
+
subfolder=subfolder,
|
3293
|
+
user_agent=user_agent,
|
3294
|
+
allow_pickle=allow_pickle,
|
3295
|
+
)
|
3296
|
+
|
3297
|
+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
3298
|
+
if is_dora_scale_present:
|
3299
|
+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
3300
|
+
logger.warning(warn_msg)
|
3301
|
+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
3302
|
+
|
3303
|
+
return state_dict
|
3304
|
+
|
3305
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
3306
|
+
def load_lora_weights(
|
3307
|
+
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
3308
|
+
):
|
3309
|
+
"""
|
3310
|
+
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
3311
|
+
`self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
|
3312
|
+
[`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
|
3313
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
|
3314
|
+
dict is loaded into `self.transformer`.
|
3315
|
+
|
3316
|
+
Parameters:
|
3317
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
3318
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
3319
|
+
adapter_name (`str`, *optional*):
|
3320
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
3321
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
3322
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
3323
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
3324
|
+
weights.
|
3325
|
+
kwargs (`dict`, *optional*):
|
3326
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
3327
|
+
"""
|
3328
|
+
if not USE_PEFT_BACKEND:
|
3329
|
+
raise ValueError("PEFT backend is required for this method.")
|
3330
|
+
|
3331
|
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
3332
|
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
3333
|
+
raise ValueError(
|
3334
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
3335
|
+
)
|
3336
|
+
|
3337
|
+
# if a dict is passed, copy it instead of modifying it inplace
|
3338
|
+
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
3339
|
+
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
3340
|
+
|
3341
|
+
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
3342
|
+
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
3343
|
+
|
3344
|
+
is_correct_format = all("lora" in key for key in state_dict.keys())
|
3345
|
+
if not is_correct_format:
|
3346
|
+
raise ValueError("Invalid LoRA checkpoint.")
|
3347
|
+
|
3348
|
+
self.load_lora_into_transformer(
|
3349
|
+
state_dict,
|
3350
|
+
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
3351
|
+
adapter_name=adapter_name,
|
3352
|
+
_pipeline=self,
|
3353
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
3354
|
+
)
|
3355
|
+
|
3356
|
+
@classmethod
|
3357
|
+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SanaTransformer2DModel
|
3358
|
+
def load_lora_into_transformer(
|
3359
|
+
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
3360
|
+
):
|
3361
|
+
"""
|
3362
|
+
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
3363
|
+
|
3364
|
+
Parameters:
|
3365
|
+
state_dict (`dict`):
|
3366
|
+
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
3367
|
+
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
3368
|
+
encoder lora layers.
|
3369
|
+
transformer (`SanaTransformer2DModel`):
|
3370
|
+
The Transformer model to load the LoRA layers into.
|
3371
|
+
adapter_name (`str`, *optional*):
|
3372
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
3373
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
3374
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
3375
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
3376
|
+
weights.
|
3377
|
+
"""
|
3378
|
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
3379
|
+
raise ValueError(
|
3380
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
3381
|
+
)
|
3382
|
+
|
3383
|
+
# Load the layers corresponding to transformer.
|
3384
|
+
logger.info(f"Loading {cls.transformer_name}.")
|
3385
|
+
transformer.load_lora_adapter(
|
3386
|
+
state_dict,
|
3387
|
+
network_alphas=None,
|
3388
|
+
adapter_name=adapter_name,
|
3389
|
+
_pipeline=_pipeline,
|
3390
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
3391
|
+
)
|
3392
|
+
|
3393
|
+
@classmethod
|
3394
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
|
3395
|
+
def save_lora_weights(
|
3396
|
+
cls,
|
3397
|
+
save_directory: Union[str, os.PathLike],
|
3398
|
+
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
3399
|
+
is_main_process: bool = True,
|
3400
|
+
weight_name: str = None,
|
3401
|
+
save_function: Callable = None,
|
3402
|
+
safe_serialization: bool = True,
|
3403
|
+
):
|
3404
|
+
r"""
|
3405
|
+
Save the LoRA parameters corresponding to the UNet and text encoder.
|
3406
|
+
|
3407
|
+
Arguments:
|
3408
|
+
save_directory (`str` or `os.PathLike`):
|
3409
|
+
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
3410
|
+
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
3411
|
+
State dict of the LoRA layers corresponding to the `transformer`.
|
3412
|
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
3413
|
+
Whether the process calling this is the main process or not. Useful during distributed training and you
|
3414
|
+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
3415
|
+
process to avoid race conditions.
|
3416
|
+
save_function (`Callable`):
|
3417
|
+
The function to use to save the state dictionary. Useful during distributed training when you need to
|
3418
|
+
replace `torch.save` with another method. Can be configured with the environment variable
|
3419
|
+
`DIFFUSERS_SAVE_MODE`.
|
3420
|
+
safe_serialization (`bool`, *optional*, defaults to `True`):
|
3421
|
+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
3422
|
+
"""
|
3423
|
+
state_dict = {}
|
3424
|
+
|
3425
|
+
if not transformer_lora_layers:
|
3426
|
+
raise ValueError("You must pass `transformer_lora_layers`.")
|
3427
|
+
|
3428
|
+
if transformer_lora_layers:
|
3429
|
+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
3430
|
+
|
3431
|
+
# Save the model
|
3432
|
+
cls.write_lora_layers(
|
3433
|
+
state_dict=state_dict,
|
3434
|
+
save_directory=save_directory,
|
3435
|
+
is_main_process=is_main_process,
|
3436
|
+
weight_name=weight_name,
|
3437
|
+
save_function=save_function,
|
3438
|
+
safe_serialization=safe_serialization,
|
3439
|
+
)
|
3440
|
+
|
3441
|
+
def fuse_lora(
|
3442
|
+
self,
|
3443
|
+
components: List[str] = ["transformer"],
|
3444
|
+
lora_scale: float = 1.0,
|
3445
|
+
safe_fusing: bool = False,
|
3446
|
+
adapter_names: Optional[List[str]] = None,
|
3447
|
+
**kwargs,
|
3448
|
+
):
|
3449
|
+
r"""
|
3450
|
+
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
3451
|
+
|
3452
|
+
<Tip warning={true}>
|
3453
|
+
|
3454
|
+
This is an experimental API.
|
3455
|
+
|
3456
|
+
</Tip>
|
3457
|
+
|
3458
|
+
Args:
|
3459
|
+
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
3460
|
+
lora_scale (`float`, defaults to 1.0):
|
3461
|
+
Controls how much to influence the outputs with the LoRA parameters.
|
3462
|
+
safe_fusing (`bool`, defaults to `False`):
|
3463
|
+
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
3464
|
+
adapter_names (`List[str]`, *optional*):
|
3465
|
+
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
3466
|
+
|
3467
|
+
Example:
|
3468
|
+
|
3469
|
+
```py
|
3470
|
+
from diffusers import DiffusionPipeline
|
3471
|
+
import torch
|
3472
|
+
|
3473
|
+
pipeline = DiffusionPipeline.from_pretrained(
|
3474
|
+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
3475
|
+
).to("cuda")
|
3476
|
+
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
3477
|
+
pipeline.fuse_lora(lora_scale=0.7)
|
3478
|
+
```
|
3479
|
+
"""
|
3480
|
+
super().fuse_lora(
|
3481
|
+
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
3482
|
+
)
|
3483
|
+
|
3484
|
+
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
3485
|
+
r"""
|
3486
|
+
Reverses the effect of
|
3487
|
+
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
3488
|
+
|
3489
|
+
<Tip warning={true}>
|
3490
|
+
|
3491
|
+
This is an experimental API.
|
3492
|
+
|
3493
|
+
</Tip>
|
3494
|
+
|
3495
|
+
Args:
|
3496
|
+
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
3497
|
+
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
3498
|
+
"""
|
3499
|
+
super().unfuse_lora(components=components)
|
3500
|
+
|
3501
|
+
|
3502
|
+
class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
3503
|
+
r"""
|
3504
|
+
Load LoRA layers into [`HunyuanVideoTransformer3DModel`]. Specific to [`HunyuanVideoPipeline`].
|
3505
|
+
"""
|
3506
|
+
|
3507
|
+
_lora_loadable_modules = ["transformer"]
|
3508
|
+
transformer_name = TRANSFORMER_NAME
|
3509
|
+
|
3510
|
+
@classmethod
|
3511
|
+
@validate_hf_hub_args
|
3512
|
+
def lora_state_dict(
|
3513
|
+
cls,
|
3514
|
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
3515
|
+
**kwargs,
|
3516
|
+
):
|
3517
|
+
r"""
|
3518
|
+
Return state dict for lora weights and the network alphas.
|
3519
|
+
|
3520
|
+
<Tip warning={true}>
|
3521
|
+
|
3522
|
+
We support loading original format HunyuanVideo LoRA checkpoints.
|
3523
|
+
|
3524
|
+
This function is experimental and might change in the future.
|
3525
|
+
|
3526
|
+
</Tip>
|
3527
|
+
|
3528
|
+
Parameters:
|
3529
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
3530
|
+
Can be either:
|
3531
|
+
|
3532
|
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
3533
|
+
the Hub.
|
3534
|
+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
3535
|
+
with [`ModelMixin.save_pretrained`].
|
3536
|
+
- A [torch state
|
3537
|
+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
3538
|
+
|
3539
|
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
3540
|
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
3541
|
+
is not used.
|
3542
|
+
force_download (`bool`, *optional*, defaults to `False`):
|
3543
|
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
3544
|
+
cached versions if they exist.
|
3545
|
+
|
3546
|
+
proxies (`Dict[str, str]`, *optional*):
|
3547
|
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
3548
|
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
3549
|
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
3550
|
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
3551
|
+
won't be downloaded from the Hub.
|
3552
|
+
token (`str` or *bool*, *optional*):
|
3553
|
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
3554
|
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
3555
|
+
revision (`str`, *optional*, defaults to `"main"`):
|
3556
|
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
3557
|
+
allowed by Git.
|
3558
|
+
subfolder (`str`, *optional*, defaults to `""`):
|
3559
|
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
3560
|
+
|
3561
|
+
"""
|
3562
|
+
# Load the main state dict first which has the LoRA layers for either of
|
3563
|
+
# transformer and text encoder or both.
|
3564
|
+
cache_dir = kwargs.pop("cache_dir", None)
|
3565
|
+
force_download = kwargs.pop("force_download", False)
|
3566
|
+
proxies = kwargs.pop("proxies", None)
|
3567
|
+
local_files_only = kwargs.pop("local_files_only", None)
|
3568
|
+
token = kwargs.pop("token", None)
|
3569
|
+
revision = kwargs.pop("revision", None)
|
3570
|
+
subfolder = kwargs.pop("subfolder", None)
|
3571
|
+
weight_name = kwargs.pop("weight_name", None)
|
3572
|
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
3573
|
+
|
3574
|
+
allow_pickle = False
|
3575
|
+
if use_safetensors is None:
|
3576
|
+
use_safetensors = True
|
3577
|
+
allow_pickle = True
|
3578
|
+
|
3579
|
+
user_agent = {
|
3580
|
+
"file_type": "attn_procs_weights",
|
3581
|
+
"framework": "pytorch",
|
3582
|
+
}
|
3583
|
+
|
3584
|
+
state_dict = _fetch_state_dict(
|
3585
|
+
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
3586
|
+
weight_name=weight_name,
|
3587
|
+
use_safetensors=use_safetensors,
|
3588
|
+
local_files_only=local_files_only,
|
3589
|
+
cache_dir=cache_dir,
|
3590
|
+
force_download=force_download,
|
3591
|
+
proxies=proxies,
|
3592
|
+
token=token,
|
3593
|
+
revision=revision,
|
3594
|
+
subfolder=subfolder,
|
3595
|
+
user_agent=user_agent,
|
3596
|
+
allow_pickle=allow_pickle,
|
3597
|
+
)
|
3598
|
+
|
3599
|
+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
3600
|
+
if is_dora_scale_present:
|
3601
|
+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
3602
|
+
logger.warning(warn_msg)
|
3603
|
+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
3604
|
+
|
3605
|
+
is_original_hunyuan_video = any("img_attn_qkv" in k for k in state_dict)
|
3606
|
+
if is_original_hunyuan_video:
|
3607
|
+
state_dict = _convert_hunyuan_video_lora_to_diffusers(state_dict)
|
3608
|
+
|
3609
|
+
return state_dict
|
3610
|
+
|
3611
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
3612
|
+
def load_lora_weights(
|
3613
|
+
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
3614
|
+
):
|
3615
|
+
"""
|
3616
|
+
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
3617
|
+
`self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
|
3618
|
+
[`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
|
3619
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
|
3620
|
+
dict is loaded into `self.transformer`.
|
3621
|
+
|
3622
|
+
Parameters:
|
3623
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
3624
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
3625
|
+
adapter_name (`str`, *optional*):
|
3626
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
3627
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
3628
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
3629
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
3630
|
+
weights.
|
3631
|
+
kwargs (`dict`, *optional*):
|
3632
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
3633
|
+
"""
|
3634
|
+
if not USE_PEFT_BACKEND:
|
3635
|
+
raise ValueError("PEFT backend is required for this method.")
|
3636
|
+
|
3637
|
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
3638
|
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
3639
|
+
raise ValueError(
|
3640
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
3641
|
+
)
|
3642
|
+
|
3643
|
+
# if a dict is passed, copy it instead of modifying it inplace
|
3644
|
+
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
3645
|
+
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
3646
|
+
|
3647
|
+
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
3648
|
+
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
3649
|
+
|
3650
|
+
is_correct_format = all("lora" in key for key in state_dict.keys())
|
3651
|
+
if not is_correct_format:
|
3652
|
+
raise ValueError("Invalid LoRA checkpoint.")
|
3653
|
+
|
3654
|
+
self.load_lora_into_transformer(
|
3655
|
+
state_dict,
|
3656
|
+
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
3657
|
+
adapter_name=adapter_name,
|
3658
|
+
_pipeline=self,
|
3659
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
3660
|
+
)
|
3661
|
+
|
3662
|
+
@classmethod
|
3663
|
+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HunyuanVideoTransformer3DModel
|
3664
|
+
def load_lora_into_transformer(
|
3665
|
+
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
3666
|
+
):
|
3667
|
+
"""
|
3668
|
+
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
3669
|
+
|
3670
|
+
Parameters:
|
3671
|
+
state_dict (`dict`):
|
3672
|
+
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
3673
|
+
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
3674
|
+
encoder lora layers.
|
3675
|
+
transformer (`HunyuanVideoTransformer3DModel`):
|
3676
|
+
The Transformer model to load the LoRA layers into.
|
3677
|
+
adapter_name (`str`, *optional*):
|
3678
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
3679
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
3680
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
3681
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
3682
|
+
weights.
|
3683
|
+
"""
|
3684
|
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
3685
|
+
raise ValueError(
|
3686
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
3687
|
+
)
|
3688
|
+
|
3689
|
+
# Load the layers corresponding to transformer.
|
3690
|
+
logger.info(f"Loading {cls.transformer_name}.")
|
3691
|
+
transformer.load_lora_adapter(
|
3692
|
+
state_dict,
|
3693
|
+
network_alphas=None,
|
3694
|
+
adapter_name=adapter_name,
|
3695
|
+
_pipeline=_pipeline,
|
3696
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
3697
|
+
)
|
3698
|
+
|
3699
|
+
@classmethod
|
3700
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
|
3701
|
+
def save_lora_weights(
|
3702
|
+
cls,
|
3703
|
+
save_directory: Union[str, os.PathLike],
|
3704
|
+
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
3705
|
+
is_main_process: bool = True,
|
3706
|
+
weight_name: str = None,
|
3707
|
+
save_function: Callable = None,
|
3708
|
+
safe_serialization: bool = True,
|
3709
|
+
):
|
3710
|
+
r"""
|
3711
|
+
Save the LoRA parameters corresponding to the UNet and text encoder.
|
3712
|
+
|
3713
|
+
Arguments:
|
3714
|
+
save_directory (`str` or `os.PathLike`):
|
3715
|
+
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
3716
|
+
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
3717
|
+
State dict of the LoRA layers corresponding to the `transformer`.
|
3718
|
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
3719
|
+
Whether the process calling this is the main process or not. Useful during distributed training and you
|
3720
|
+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
3721
|
+
process to avoid race conditions.
|
3722
|
+
save_function (`Callable`):
|
3723
|
+
The function to use to save the state dictionary. Useful during distributed training when you need to
|
3724
|
+
replace `torch.save` with another method. Can be configured with the environment variable
|
3725
|
+
`DIFFUSERS_SAVE_MODE`.
|
3726
|
+
safe_serialization (`bool`, *optional*, defaults to `True`):
|
3727
|
+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
3728
|
+
"""
|
3729
|
+
state_dict = {}
|
3730
|
+
|
3731
|
+
if not transformer_lora_layers:
|
3732
|
+
raise ValueError("You must pass `transformer_lora_layers`.")
|
3733
|
+
|
3734
|
+
if transformer_lora_layers:
|
3735
|
+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
3736
|
+
|
3737
|
+
# Save the model
|
3738
|
+
cls.write_lora_layers(
|
3739
|
+
state_dict=state_dict,
|
3740
|
+
save_directory=save_directory,
|
3741
|
+
is_main_process=is_main_process,
|
3742
|
+
weight_name=weight_name,
|
3743
|
+
save_function=save_function,
|
3744
|
+
safe_serialization=safe_serialization,
|
3745
|
+
)
|
3746
|
+
|
3747
|
+
def fuse_lora(
|
3748
|
+
self,
|
3749
|
+
components: List[str] = ["transformer"],
|
3750
|
+
lora_scale: float = 1.0,
|
3751
|
+
safe_fusing: bool = False,
|
3752
|
+
adapter_names: Optional[List[str]] = None,
|
3753
|
+
**kwargs,
|
3754
|
+
):
|
3755
|
+
r"""
|
3756
|
+
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
3757
|
+
|
3758
|
+
<Tip warning={true}>
|
3759
|
+
|
3760
|
+
This is an experimental API.
|
3761
|
+
|
3762
|
+
</Tip>
|
3763
|
+
|
3764
|
+
Args:
|
3765
|
+
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
3766
|
+
lora_scale (`float`, defaults to 1.0):
|
3767
|
+
Controls how much to influence the outputs with the LoRA parameters.
|
3768
|
+
safe_fusing (`bool`, defaults to `False`):
|
3769
|
+
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
3770
|
+
adapter_names (`List[str]`, *optional*):
|
3771
|
+
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
3772
|
+
|
3773
|
+
Example:
|
3774
|
+
|
3775
|
+
```py
|
3776
|
+
from diffusers import DiffusionPipeline
|
3777
|
+
import torch
|
3778
|
+
|
3779
|
+
pipeline = DiffusionPipeline.from_pretrained(
|
3780
|
+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
3781
|
+
).to("cuda")
|
3782
|
+
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
3783
|
+
pipeline.fuse_lora(lora_scale=0.7)
|
3784
|
+
```
|
3785
|
+
"""
|
3786
|
+
super().fuse_lora(
|
3787
|
+
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
3788
|
+
)
|
3789
|
+
|
3790
|
+
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
3791
|
+
r"""
|
3792
|
+
Reverses the effect of
|
3793
|
+
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
3794
|
+
|
3795
|
+
<Tip warning={true}>
|
3796
|
+
|
3797
|
+
This is an experimental API.
|
3798
|
+
|
3799
|
+
</Tip>
|
3800
|
+
|
3801
|
+
Args:
|
3802
|
+
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
3803
|
+
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
3804
|
+
"""
|
3805
|
+
super().unfuse_lora(components=components)
|
3806
|
+
|
3807
|
+
|
3808
|
+
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
3809
|
+
def __init__(self, *args, **kwargs):
|
3810
|
+
deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."
|
3811
|
+
deprecate("LoraLoaderMixin", "1.0.0", deprecation_message)
|
3812
|
+
super().__init__(*args, **kwargs)
|