diffusers 0.33.0__py3-none-any.whl → 0.34.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +48 -1
- diffusers/commands/__init__.py +1 -1
- diffusers/commands/diffusers_cli.py +1 -1
- diffusers/commands/env.py +1 -1
- diffusers/commands/fp16_safetensors.py +1 -1
- diffusers/dependency_versions_check.py +1 -1
- diffusers/dependency_versions_table.py +1 -1
- diffusers/experimental/rl/value_guided_sampling.py +1 -1
- diffusers/hooks/faster_cache.py +2 -2
- diffusers/hooks/group_offloading.py +128 -29
- diffusers/hooks/hooks.py +2 -2
- diffusers/hooks/layerwise_casting.py +3 -3
- diffusers/hooks/pyramid_attention_broadcast.py +1 -1
- diffusers/image_processor.py +7 -2
- diffusers/loaders/__init__.py +4 -0
- diffusers/loaders/ip_adapter.py +5 -14
- diffusers/loaders/lora_base.py +212 -111
- diffusers/loaders/lora_conversion_utils.py +275 -34
- diffusers/loaders/lora_pipeline.py +1554 -819
- diffusers/loaders/peft.py +52 -109
- diffusers/loaders/single_file.py +2 -2
- diffusers/loaders/single_file_model.py +20 -4
- diffusers/loaders/single_file_utils.py +225 -5
- diffusers/loaders/textual_inversion.py +3 -2
- diffusers/loaders/transformer_flux.py +1 -1
- diffusers/loaders/transformer_sd3.py +2 -2
- diffusers/loaders/unet.py +2 -16
- diffusers/loaders/unet_loader_utils.py +1 -1
- diffusers/loaders/utils.py +1 -1
- diffusers/models/__init__.py +15 -1
- diffusers/models/activations.py +5 -5
- diffusers/models/adapter.py +2 -3
- diffusers/models/attention.py +4 -4
- diffusers/models/attention_flax.py +10 -10
- diffusers/models/attention_processor.py +14 -10
- diffusers/models/auto_model.py +47 -10
- diffusers/models/autoencoders/__init__.py +1 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
- diffusers/models/autoencoders/autoencoder_dc.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
- diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1108 -0
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_wan.py +256 -22
- diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
- diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
- diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
- diffusers/models/autoencoders/vae.py +13 -2
- diffusers/models/autoencoders/vq_model.py +2 -2
- diffusers/models/cache_utils.py +1 -1
- diffusers/models/controlnet.py +1 -1
- diffusers/models/controlnet_flux.py +1 -1
- diffusers/models/controlnet_sd3.py +1 -1
- diffusers/models/controlnet_sparsectrl.py +1 -1
- diffusers/models/controlnets/__init__.py +1 -0
- diffusers/models/controlnets/controlnet.py +3 -3
- diffusers/models/controlnets/controlnet_flax.py +1 -1
- diffusers/models/controlnets/controlnet_flux.py +16 -15
- diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
- diffusers/models/controlnets/controlnet_sana.py +290 -0
- diffusers/models/controlnets/controlnet_sd3.py +1 -1
- diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
- diffusers/models/controlnets/controlnet_union.py +1 -1
- diffusers/models/controlnets/controlnet_xs.py +7 -7
- diffusers/models/controlnets/multicontrolnet.py +4 -5
- diffusers/models/controlnets/multicontrolnet_union.py +5 -6
- diffusers/models/downsampling.py +2 -2
- diffusers/models/embeddings.py +10 -12
- diffusers/models/embeddings_flax.py +2 -2
- diffusers/models/lora.py +3 -3
- diffusers/models/modeling_utils.py +44 -14
- diffusers/models/normalization.py +4 -4
- diffusers/models/resnet.py +2 -2
- diffusers/models/resnet_flax.py +1 -1
- diffusers/models/transformers/__init__.py +5 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
- diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
- diffusers/models/transformers/consisid_transformer_3d.py +1 -1
- diffusers/models/transformers/dit_transformer_2d.py +2 -2
- diffusers/models/transformers/dual_transformer_2d.py +1 -1
- diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
- diffusers/models/transformers/latte_transformer_3d.py +4 -5
- diffusers/models/transformers/lumina_nextdit2d.py +2 -2
- diffusers/models/transformers/pixart_transformer_2d.py +3 -3
- diffusers/models/transformers/prior_transformer.py +1 -1
- diffusers/models/transformers/sana_transformer.py +8 -3
- diffusers/models/transformers/stable_audio_transformer.py +5 -9
- diffusers/models/transformers/t5_film_transformer.py +3 -3
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +1 -1
- diffusers/models/transformers/transformer_chroma.py +742 -0
- diffusers/models/transformers/transformer_cogview3plus.py +5 -10
- diffusers/models/transformers/transformer_cogview4.py +317 -25
- diffusers/models/transformers/transformer_cosmos.py +579 -0
- diffusers/models/transformers/transformer_flux.py +9 -11
- diffusers/models/transformers/transformer_hidream_image.py +942 -0
- diffusers/models/transformers/transformer_hunyuan_video.py +6 -8
- diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
- diffusers/models/transformers/transformer_ltx.py +2 -2
- diffusers/models/transformers/transformer_lumina2.py +1 -1
- diffusers/models/transformers/transformer_mochi.py +1 -1
- diffusers/models/transformers/transformer_omnigen.py +2 -2
- diffusers/models/transformers/transformer_sd3.py +7 -7
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/transformers/transformer_wan.py +24 -8
- diffusers/models/transformers/transformer_wan_vace.py +393 -0
- diffusers/models/unets/unet_1d.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +1 -1
- diffusers/models/unets/unet_2d_blocks.py +1 -1
- diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
- diffusers/models/unets/unet_2d_condition.py +2 -2
- diffusers/models/unets/unet_2d_condition_flax.py +2 -2
- diffusers/models/unets/unet_3d_blocks.py +1 -1
- diffusers/models/unets/unet_3d_condition.py +3 -3
- diffusers/models/unets/unet_i2vgen_xl.py +3 -3
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +2 -2
- diffusers/models/unets/unet_stable_cascade.py +1 -1
- diffusers/models/upsampling.py +2 -2
- diffusers/models/vae_flax.py +2 -2
- diffusers/models/vq_model.py +1 -1
- diffusers/pipelines/__init__.py +37 -6
- diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
- diffusers/pipelines/amused/pipeline_amused.py +7 -6
- diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
- diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
- diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +23 -13
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
- diffusers/pipelines/auto_pipeline.py +6 -7
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
- diffusers/pipelines/chroma/__init__.py +49 -0
- diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
- diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
- diffusers/pipelines/chroma/pipeline_output.py +21 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +8 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +8 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +8 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +8 -8
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
- diffusers/pipelines/cogview4/pipeline_cogview4.py +7 -7
- diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
- diffusers/pipelines/consisid/consisid_utils.py +2 -2
- diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +8 -8
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +14 -14
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +5 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +13 -13
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
- diffusers/pipelines/cosmos/__init__.py +54 -0
- diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
- diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
- diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
- diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
- diffusers/pipelines/cosmos/pipeline_output.py +40 -0
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
- diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
- diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
- diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
- diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
- diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
- diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
- diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
- diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +1 -1
- diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
- diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
- diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
- diffusers/pipelines/flux/modeling_flux.py +1 -1
- diffusers/pipelines/flux/pipeline_flux.py +10 -17
- diffusers/pipelines/flux/pipeline_flux_control.py +6 -6
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -6
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +6 -6
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +6 -6
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +30 -22
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +2 -1
- diffusers/pipelines/flux/pipeline_flux_fill.py +6 -6
- diffusers/pipelines/flux/pipeline_flux_img2img.py +39 -6
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +11 -6
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
- diffusers/pipelines/free_init_utils.py +2 -2
- diffusers/pipelines/free_noise_utils.py +3 -3
- diffusers/pipelines/hidream_image/__init__.py +47 -0
- diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
- diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
- diffusers/pipelines/hunyuan_video/__init__.py +2 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +8 -8
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
- diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
- diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
- diffusers/pipelines/kolors/text_encoder.py +3 -3
- diffusers/pipelines/kolors/tokenizer.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
- diffusers/pipelines/latte/pipeline_latte.py +12 -12
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
- diffusers/pipelines/ltx/__init__.py +4 -0
- diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +51 -6
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +107 -29
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +50 -6
- diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
- diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
- diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
- diffusers/pipelines/mochi/pipeline_mochi.py +6 -6
- diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
- diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
- diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
- diffusers/pipelines/onnx_utils.py +15 -2
- diffusers/pipelines/pag/pag_utils.py +2 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
- diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
- diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
- diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
- diffusers/pipelines/pia/pipeline_pia.py +8 -6
- diffusers/pipelines/pipeline_flax_utils.py +3 -4
- diffusers/pipelines/pipeline_loading_utils.py +89 -13
- diffusers/pipelines/pipeline_utils.py +105 -33
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +11 -11
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +11 -11
- diffusers/pipelines/sana/__init__.py +4 -0
- diffusers/pipelines/sana/pipeline_sana.py +23 -21
- diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
- diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
- diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
- diffusers/pipelines/shap_e/camera.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
- diffusers/pipelines/shap_e/renderer.py +3 -3
- diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
- diffusers/pipelines/stable_diffusion/__init__.py +0 -7
- diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +9 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +8 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
- diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
- diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -7
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
- diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
- diffusers/pipelines/unclip/text_proj.py +2 -2
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
- diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
- diffusers/pipelines/visualcloze/__init__.py +52 -0
- diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
- diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
- diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
- diffusers/pipelines/wan/__init__.py +2 -0
- diffusers/pipelines/wan/pipeline_wan.py +17 -12
- diffusers/pipelines/wan/pipeline_wan_i2v.py +42 -20
- diffusers/pipelines/wan/pipeline_wan_vace.py +976 -0
- diffusers/pipelines/wan/pipeline_wan_video2video.py +18 -18
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
- diffusers/quantizers/__init__.py +179 -1
- diffusers/quantizers/base.py +6 -1
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
- diffusers/quantizers/bitsandbytes/utils.py +10 -7
- diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
- diffusers/quantizers/gguf/utils.py +16 -13
- diffusers/quantizers/quantization_config.py +18 -16
- diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +5 -1
- diffusers/schedulers/__init__.py +3 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
- diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
- diffusers/schedulers/scheduling_consistency_models.py +1 -1
- diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
- diffusers/schedulers/scheduling_ddim.py +8 -8
- diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
- diffusers/schedulers/scheduling_ddim_flax.py +6 -6
- diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
- diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
- diffusers/schedulers/scheduling_ddpm.py +9 -9
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
- diffusers/schedulers/scheduling_deis_multistep.py +8 -8
- diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -12
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
- diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +13 -13
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
- diffusers/schedulers/scheduling_edm_euler.py +20 -11
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
- diffusers/schedulers/scheduling_euler_discrete.py +3 -3
- diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
- diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
- diffusers/schedulers/scheduling_heun_discrete.py +2 -2
- diffusers/schedulers/scheduling_ipndm.py +2 -2
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
- diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
- diffusers/schedulers/scheduling_lcm.py +3 -3
- diffusers/schedulers/scheduling_lms_discrete.py +2 -2
- diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
- diffusers/schedulers/scheduling_pndm.py +4 -4
- diffusers/schedulers/scheduling_pndm_flax.py +4 -4
- diffusers/schedulers/scheduling_repaint.py +9 -9
- diffusers/schedulers/scheduling_sasolver.py +15 -15
- diffusers/schedulers/scheduling_scm.py +1 -1
- diffusers/schedulers/scheduling_sde_ve.py +1 -1
- diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
- diffusers/schedulers/scheduling_tcd.py +3 -3
- diffusers/schedulers/scheduling_unclip.py +5 -5
- diffusers/schedulers/scheduling_unipc_multistep.py +11 -11
- diffusers/schedulers/scheduling_utils.py +1 -1
- diffusers/schedulers/scheduling_utils_flax.py +1 -1
- diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
- diffusers/training_utils.py +13 -5
- diffusers/utils/__init__.py +5 -0
- diffusers/utils/accelerate_utils.py +1 -1
- diffusers/utils/doc_utils.py +1 -1
- diffusers/utils/dummy_pt_objects.py +120 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +225 -0
- diffusers/utils/dynamic_modules_utils.py +21 -3
- diffusers/utils/export_utils.py +1 -1
- diffusers/utils/import_utils.py +81 -18
- diffusers/utils/logging.py +1 -1
- diffusers/utils/outputs.py +2 -1
- diffusers/utils/peft_utils.py +91 -8
- diffusers/utils/state_dict_utils.py +20 -3
- diffusers/utils/testing_utils.py +59 -7
- diffusers/utils/torch_utils.py +25 -5
- diffusers/video_processor.py +2 -2
- {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/METADATA +3 -3
- diffusers-0.34.0.dist-info/RECORD +639 -0
- diffusers-0.33.0.dist-info/RECORD +0 -608
- {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/LICENSE +0 -0
- {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/WHEEL +0 -0
- {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -37,12 +37,16 @@ from .lora_base import ( # noqa
|
|
37
37
|
LoraBaseMixin,
|
38
38
|
_fetch_state_dict,
|
39
39
|
_load_lora_into_text_encoder,
|
40
|
+
_pack_dict_with_prefix,
|
40
41
|
)
|
41
42
|
from .lora_conversion_utils import (
|
42
43
|
_convert_bfl_flux_control_lora_to_diffusers,
|
43
44
|
_convert_hunyuan_video_lora_to_diffusers,
|
44
45
|
_convert_kohya_flux_lora_to_diffusers,
|
46
|
+
_convert_musubi_wan_lora_to_diffusers,
|
47
|
+
_convert_non_diffusers_hidream_lora_to_diffusers,
|
45
48
|
_convert_non_diffusers_lora_to_diffusers,
|
49
|
+
_convert_non_diffusers_ltxv_lora_to_diffusers,
|
46
50
|
_convert_non_diffusers_lumina2_lora_to_diffusers,
|
47
51
|
_convert_non_diffusers_wan_lora_to_diffusers,
|
48
52
|
_convert_xlabs_flux_lora_to_diffusers,
|
@@ -78,30 +82,36 @@ def _maybe_dequantize_weight_for_expanded_lora(model, module):
|
|
78
82
|
from ..quantizers.gguf.utils import dequantize_gguf_tensor
|
79
83
|
|
80
84
|
is_bnb_4bit_quantized = module.weight.__class__.__name__ == "Params4bit"
|
85
|
+
is_bnb_8bit_quantized = module.weight.__class__.__name__ == "Int8Params"
|
81
86
|
is_gguf_quantized = module.weight.__class__.__name__ == "GGUFParameter"
|
82
87
|
|
83
88
|
if is_bnb_4bit_quantized and not is_bitsandbytes_available():
|
84
89
|
raise ValueError(
|
85
90
|
"The checkpoint seems to have been quantized with `bitsandbytes` (4bits). Install `bitsandbytes` to load quantized checkpoints."
|
86
91
|
)
|
92
|
+
if is_bnb_8bit_quantized and not is_bitsandbytes_available():
|
93
|
+
raise ValueError(
|
94
|
+
"The checkpoint seems to have been quantized with `bitsandbytes` (8bits). Install `bitsandbytes` to load quantized checkpoints."
|
95
|
+
)
|
87
96
|
if is_gguf_quantized and not is_gguf_available():
|
88
97
|
raise ValueError(
|
89
98
|
"The checkpoint seems to have been quantized with `gguf`. Install `gguf` to load quantized checkpoints."
|
90
99
|
)
|
91
100
|
|
92
101
|
weight_on_cpu = False
|
93
|
-
if
|
102
|
+
if module.weight.device.type == "cpu":
|
94
103
|
weight_on_cpu = True
|
95
104
|
|
96
|
-
if
|
105
|
+
device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
|
106
|
+
if is_bnb_4bit_quantized or is_bnb_8bit_quantized:
|
97
107
|
module_weight = dequantize_bnb_weight(
|
98
|
-
module.weight.
|
99
|
-
state=module.weight.quant_state,
|
108
|
+
module.weight.to(device) if weight_on_cpu else module.weight,
|
109
|
+
state=module.weight.quant_state if is_bnb_4bit_quantized else module.state,
|
100
110
|
dtype=model.dtype,
|
101
111
|
).data
|
102
112
|
elif is_gguf_quantized:
|
103
113
|
module_weight = dequantize_gguf_tensor(
|
104
|
-
module.weight.
|
114
|
+
module.weight.to(device) if weight_on_cpu else module.weight,
|
105
115
|
)
|
106
116
|
module_weight = module_weight.to(model.dtype)
|
107
117
|
else:
|
@@ -126,7 +136,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
126
136
|
def load_lora_weights(
|
127
137
|
self,
|
128
138
|
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
129
|
-
adapter_name=None,
|
139
|
+
adapter_name: Optional[str] = None,
|
130
140
|
hotswap: bool = False,
|
131
141
|
**kwargs,
|
132
142
|
):
|
@@ -153,7 +163,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
153
163
|
low_cpu_mem_usage (`bool`, *optional*):
|
154
164
|
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
155
165
|
weights.
|
156
|
-
hotswap
|
166
|
+
hotswap (`bool`, *optional*):
|
157
167
|
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
158
168
|
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
159
169
|
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
@@ -193,7 +203,8 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
193
203
|
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
194
204
|
|
195
205
|
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
196
|
-
|
206
|
+
kwargs["return_lora_metadata"] = True
|
207
|
+
state_dict, network_alphas, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
197
208
|
|
198
209
|
is_correct_format = all("lora" in key for key in state_dict.keys())
|
199
210
|
if not is_correct_format:
|
@@ -204,6 +215,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
204
215
|
network_alphas=network_alphas,
|
205
216
|
unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
|
206
217
|
adapter_name=adapter_name,
|
218
|
+
metadata=metadata,
|
207
219
|
_pipeline=self,
|
208
220
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
209
221
|
hotswap=hotswap,
|
@@ -217,6 +229,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
217
229
|
lora_scale=self.lora_scale,
|
218
230
|
adapter_name=adapter_name,
|
219
231
|
_pipeline=self,
|
232
|
+
metadata=metadata,
|
220
233
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
221
234
|
hotswap=hotswap,
|
222
235
|
)
|
@@ -273,6 +286,8 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
273
286
|
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
274
287
|
weight_name (`str`, *optional*, defaults to None):
|
275
288
|
Name of the serialized state dict file.
|
289
|
+
return_lora_metadata (`bool`, *optional*, defaults to False):
|
290
|
+
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
|
276
291
|
"""
|
277
292
|
# Load the main state dict first which has the LoRA layers for either of
|
278
293
|
# UNet and text encoder or both.
|
@@ -286,18 +301,16 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
286
301
|
weight_name = kwargs.pop("weight_name", None)
|
287
302
|
unet_config = kwargs.pop("unet_config", None)
|
288
303
|
use_safetensors = kwargs.pop("use_safetensors", None)
|
304
|
+
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
|
289
305
|
|
290
306
|
allow_pickle = False
|
291
307
|
if use_safetensors is None:
|
292
308
|
use_safetensors = True
|
293
309
|
allow_pickle = True
|
294
310
|
|
295
|
-
user_agent = {
|
296
|
-
"file_type": "attn_procs_weights",
|
297
|
-
"framework": "pytorch",
|
298
|
-
}
|
311
|
+
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
|
299
312
|
|
300
|
-
state_dict = _fetch_state_dict(
|
313
|
+
state_dict, metadata = _fetch_state_dict(
|
301
314
|
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
302
315
|
weight_name=weight_name,
|
303
316
|
use_safetensors=use_safetensors,
|
@@ -334,7 +347,8 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
334
347
|
state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
|
335
348
|
state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict)
|
336
349
|
|
337
|
-
|
350
|
+
out = (state_dict, network_alphas, metadata) if return_lora_metadata else (state_dict, network_alphas)
|
351
|
+
return out
|
338
352
|
|
339
353
|
@classmethod
|
340
354
|
def load_lora_into_unet(
|
@@ -346,6 +360,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
346
360
|
_pipeline=None,
|
347
361
|
low_cpu_mem_usage=False,
|
348
362
|
hotswap: bool = False,
|
363
|
+
metadata=None,
|
349
364
|
):
|
350
365
|
"""
|
351
366
|
This will load the LoRA layers specified in `state_dict` into `unet`.
|
@@ -367,29 +382,11 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
367
382
|
low_cpu_mem_usage (`bool`, *optional*):
|
368
383
|
Speed up model loading only loading the pretrained LoRA weights and not initializing the random
|
369
384
|
weights.
|
370
|
-
hotswap
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
376
|
-
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
377
|
-
|
378
|
-
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
379
|
-
to call an additional method before loading the adapter:
|
380
|
-
|
381
|
-
```py
|
382
|
-
pipeline = ... # load diffusers pipeline
|
383
|
-
max_rank = ... # the highest rank among all LoRAs that you want to load
|
384
|
-
# call *before* compiling and loading the LoRA adapter
|
385
|
-
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
386
|
-
pipeline.load_lora_weights(file_name)
|
387
|
-
# optionally compile the model now
|
388
|
-
```
|
389
|
-
|
390
|
-
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
391
|
-
limitations to this technique, which are documented here:
|
392
|
-
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
385
|
+
hotswap (`bool`, *optional*):
|
386
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
387
|
+
metadata (`dict`):
|
388
|
+
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
|
389
|
+
from the state dict.
|
393
390
|
"""
|
394
391
|
if not USE_PEFT_BACKEND:
|
395
392
|
raise ValueError("PEFT backend is required for this method.")
|
@@ -408,6 +405,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
408
405
|
prefix=cls.unet_name,
|
409
406
|
network_alphas=network_alphas,
|
410
407
|
adapter_name=adapter_name,
|
408
|
+
metadata=metadata,
|
411
409
|
_pipeline=_pipeline,
|
412
410
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
413
411
|
hotswap=hotswap,
|
@@ -425,6 +423,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
425
423
|
_pipeline=None,
|
426
424
|
low_cpu_mem_usage=False,
|
427
425
|
hotswap: bool = False,
|
426
|
+
metadata=None,
|
428
427
|
):
|
429
428
|
"""
|
430
429
|
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
@@ -450,29 +449,11 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
450
449
|
low_cpu_mem_usage (`bool`, *optional*):
|
451
450
|
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
452
451
|
weights.
|
453
|
-
hotswap
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
459
|
-
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
460
|
-
|
461
|
-
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
462
|
-
to call an additional method before loading the adapter:
|
463
|
-
|
464
|
-
```py
|
465
|
-
pipeline = ... # load diffusers pipeline
|
466
|
-
max_rank = ... # the highest rank among all LoRAs that you want to load
|
467
|
-
# call *before* compiling and loading the LoRA adapter
|
468
|
-
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
469
|
-
pipeline.load_lora_weights(file_name)
|
470
|
-
# optionally compile the model now
|
471
|
-
```
|
472
|
-
|
473
|
-
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
474
|
-
limitations to this technique, which are documented here:
|
475
|
-
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
452
|
+
hotswap (`bool`, *optional*):
|
453
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
454
|
+
metadata (`dict`):
|
455
|
+
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
|
456
|
+
from the state dict.
|
476
457
|
"""
|
477
458
|
_load_lora_into_text_encoder(
|
478
459
|
state_dict=state_dict,
|
@@ -482,6 +463,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
482
463
|
prefix=prefix,
|
483
464
|
text_encoder_name=cls.text_encoder_name,
|
484
465
|
adapter_name=adapter_name,
|
466
|
+
metadata=metadata,
|
485
467
|
_pipeline=_pipeline,
|
486
468
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
487
469
|
hotswap=hotswap,
|
@@ -497,6 +479,8 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
497
479
|
weight_name: str = None,
|
498
480
|
save_function: Callable = None,
|
499
481
|
safe_serialization: bool = True,
|
482
|
+
unet_lora_adapter_metadata=None,
|
483
|
+
text_encoder_lora_adapter_metadata=None,
|
500
484
|
):
|
501
485
|
r"""
|
502
486
|
Save the LoRA parameters corresponding to the UNet and text encoder.
|
@@ -519,8 +503,13 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
519
503
|
`DIFFUSERS_SAVE_MODE`.
|
520
504
|
safe_serialization (`bool`, *optional*, defaults to `True`):
|
521
505
|
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
506
|
+
unet_lora_adapter_metadata:
|
507
|
+
LoRA adapter metadata associated with the unet to be serialized with the state dict.
|
508
|
+
text_encoder_lora_adapter_metadata:
|
509
|
+
LoRA adapter metadata associated with the text encoder to be serialized with the state dict.
|
522
510
|
"""
|
523
511
|
state_dict = {}
|
512
|
+
lora_adapter_metadata = {}
|
524
513
|
|
525
514
|
if not (unet_lora_layers or text_encoder_lora_layers):
|
526
515
|
raise ValueError("You must pass at least one of `unet_lora_layers` and `text_encoder_lora_layers`.")
|
@@ -531,6 +520,14 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
531
520
|
if text_encoder_lora_layers:
|
532
521
|
state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
|
533
522
|
|
523
|
+
if unet_lora_adapter_metadata:
|
524
|
+
lora_adapter_metadata.update(_pack_dict_with_prefix(unet_lora_adapter_metadata, cls.unet_name))
|
525
|
+
|
526
|
+
if text_encoder_lora_adapter_metadata:
|
527
|
+
lora_adapter_metadata.update(
|
528
|
+
_pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name)
|
529
|
+
)
|
530
|
+
|
534
531
|
# Save the model
|
535
532
|
cls.write_lora_layers(
|
536
533
|
state_dict=state_dict,
|
@@ -539,6 +536,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
539
536
|
weight_name=weight_name,
|
540
537
|
save_function=save_function,
|
541
538
|
safe_serialization=safe_serialization,
|
539
|
+
lora_adapter_metadata=lora_adapter_metadata,
|
542
540
|
)
|
543
541
|
|
544
542
|
def fuse_lora(
|
@@ -624,6 +622,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
624
622
|
self,
|
625
623
|
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
626
624
|
adapter_name: Optional[str] = None,
|
625
|
+
hotswap: bool = False,
|
627
626
|
**kwargs,
|
628
627
|
):
|
629
628
|
"""
|
@@ -650,6 +649,8 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
650
649
|
low_cpu_mem_usage (`bool`, *optional*):
|
651
650
|
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
652
651
|
weights.
|
652
|
+
hotswap (`bool`, *optional*):
|
653
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
653
654
|
kwargs (`dict`, *optional*):
|
654
655
|
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
655
656
|
"""
|
@@ -671,7 +672,8 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
671
672
|
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
672
673
|
|
673
674
|
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
674
|
-
|
675
|
+
kwargs["return_lora_metadata"] = True
|
676
|
+
state_dict, network_alphas, metadata = self.lora_state_dict(
|
675
677
|
pretrained_model_name_or_path_or_dict,
|
676
678
|
unet_config=self.unet.config,
|
677
679
|
**kwargs,
|
@@ -686,8 +688,10 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
686
688
|
network_alphas=network_alphas,
|
687
689
|
unet=self.unet,
|
688
690
|
adapter_name=adapter_name,
|
691
|
+
metadata=metadata,
|
689
692
|
_pipeline=self,
|
690
693
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
694
|
+
hotswap=hotswap,
|
691
695
|
)
|
692
696
|
self.load_lora_into_text_encoder(
|
693
697
|
state_dict,
|
@@ -696,8 +700,10 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
696
700
|
prefix=self.text_encoder_name,
|
697
701
|
lora_scale=self.lora_scale,
|
698
702
|
adapter_name=adapter_name,
|
703
|
+
metadata=metadata,
|
699
704
|
_pipeline=self,
|
700
705
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
706
|
+
hotswap=hotswap,
|
701
707
|
)
|
702
708
|
self.load_lora_into_text_encoder(
|
703
709
|
state_dict,
|
@@ -706,8 +712,10 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
706
712
|
prefix=f"{self.text_encoder_name}_2",
|
707
713
|
lora_scale=self.lora_scale,
|
708
714
|
adapter_name=adapter_name,
|
715
|
+
metadata=metadata,
|
709
716
|
_pipeline=self,
|
710
717
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
718
|
+
hotswap=hotswap,
|
711
719
|
)
|
712
720
|
|
713
721
|
@classmethod
|
@@ -763,6 +771,8 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
763
771
|
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
764
772
|
weight_name (`str`, *optional*, defaults to None):
|
765
773
|
Name of the serialized state dict file.
|
774
|
+
return_lora_metadata (`bool`, *optional*, defaults to False):
|
775
|
+
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
|
766
776
|
"""
|
767
777
|
# Load the main state dict first which has the LoRA layers for either of
|
768
778
|
# UNet and text encoder or both.
|
@@ -776,18 +786,16 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
776
786
|
weight_name = kwargs.pop("weight_name", None)
|
777
787
|
unet_config = kwargs.pop("unet_config", None)
|
778
788
|
use_safetensors = kwargs.pop("use_safetensors", None)
|
789
|
+
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
|
779
790
|
|
780
791
|
allow_pickle = False
|
781
792
|
if use_safetensors is None:
|
782
793
|
use_safetensors = True
|
783
794
|
allow_pickle = True
|
784
795
|
|
785
|
-
user_agent = {
|
786
|
-
"file_type": "attn_procs_weights",
|
787
|
-
"framework": "pytorch",
|
788
|
-
}
|
796
|
+
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
|
789
797
|
|
790
|
-
state_dict = _fetch_state_dict(
|
798
|
+
state_dict, metadata = _fetch_state_dict(
|
791
799
|
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
792
800
|
weight_name=weight_name,
|
793
801
|
use_safetensors=use_safetensors,
|
@@ -824,7 +832,8 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
824
832
|
state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
|
825
833
|
state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict)
|
826
834
|
|
827
|
-
|
835
|
+
out = (state_dict, network_alphas, metadata) if return_lora_metadata else (state_dict, network_alphas)
|
836
|
+
return out
|
828
837
|
|
829
838
|
@classmethod
|
830
839
|
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_unet
|
@@ -837,6 +846,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
837
846
|
_pipeline=None,
|
838
847
|
low_cpu_mem_usage=False,
|
839
848
|
hotswap: bool = False,
|
849
|
+
metadata=None,
|
840
850
|
):
|
841
851
|
"""
|
842
852
|
This will load the LoRA layers specified in `state_dict` into `unet`.
|
@@ -858,29 +868,11 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
858
868
|
low_cpu_mem_usage (`bool`, *optional*):
|
859
869
|
Speed up model loading only loading the pretrained LoRA weights and not initializing the random
|
860
870
|
weights.
|
861
|
-
hotswap
|
862
|
-
|
863
|
-
|
864
|
-
|
865
|
-
|
866
|
-
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
867
|
-
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
868
|
-
|
869
|
-
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
870
|
-
to call an additional method before loading the adapter:
|
871
|
-
|
872
|
-
```py
|
873
|
-
pipeline = ... # load diffusers pipeline
|
874
|
-
max_rank = ... # the highest rank among all LoRAs that you want to load
|
875
|
-
# call *before* compiling and loading the LoRA adapter
|
876
|
-
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
877
|
-
pipeline.load_lora_weights(file_name)
|
878
|
-
# optionally compile the model now
|
879
|
-
```
|
880
|
-
|
881
|
-
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
882
|
-
limitations to this technique, which are documented here:
|
883
|
-
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
871
|
+
hotswap (`bool`, *optional*):
|
872
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
873
|
+
metadata (`dict`):
|
874
|
+
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
|
875
|
+
from the state dict.
|
884
876
|
"""
|
885
877
|
if not USE_PEFT_BACKEND:
|
886
878
|
raise ValueError("PEFT backend is required for this method.")
|
@@ -899,6 +891,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
899
891
|
prefix=cls.unet_name,
|
900
892
|
network_alphas=network_alphas,
|
901
893
|
adapter_name=adapter_name,
|
894
|
+
metadata=metadata,
|
902
895
|
_pipeline=_pipeline,
|
903
896
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
904
897
|
hotswap=hotswap,
|
@@ -917,6 +910,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
917
910
|
_pipeline=None,
|
918
911
|
low_cpu_mem_usage=False,
|
919
912
|
hotswap: bool = False,
|
913
|
+
metadata=None,
|
920
914
|
):
|
921
915
|
"""
|
922
916
|
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
@@ -942,29 +936,11 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
942
936
|
low_cpu_mem_usage (`bool`, *optional*):
|
943
937
|
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
944
938
|
weights.
|
945
|
-
hotswap
|
946
|
-
|
947
|
-
|
948
|
-
|
949
|
-
|
950
|
-
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
951
|
-
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
952
|
-
|
953
|
-
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
954
|
-
to call an additional method before loading the adapter:
|
955
|
-
|
956
|
-
```py
|
957
|
-
pipeline = ... # load diffusers pipeline
|
958
|
-
max_rank = ... # the highest rank among all LoRAs that you want to load
|
959
|
-
# call *before* compiling and loading the LoRA adapter
|
960
|
-
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
961
|
-
pipeline.load_lora_weights(file_name)
|
962
|
-
# optionally compile the model now
|
963
|
-
```
|
964
|
-
|
965
|
-
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
966
|
-
limitations to this technique, which are documented here:
|
967
|
-
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
939
|
+
hotswap (`bool`, *optional*):
|
940
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
941
|
+
metadata (`dict`):
|
942
|
+
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
|
943
|
+
from the state dict.
|
968
944
|
"""
|
969
945
|
_load_lora_into_text_encoder(
|
970
946
|
state_dict=state_dict,
|
@@ -974,6 +950,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
974
950
|
prefix=prefix,
|
975
951
|
text_encoder_name=cls.text_encoder_name,
|
976
952
|
adapter_name=adapter_name,
|
953
|
+
metadata=metadata,
|
977
954
|
_pipeline=_pipeline,
|
978
955
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
979
956
|
hotswap=hotswap,
|
@@ -990,6 +967,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
990
967
|
weight_name: str = None,
|
991
968
|
save_function: Callable = None,
|
992
969
|
safe_serialization: bool = True,
|
970
|
+
unet_lora_adapter_metadata=None,
|
971
|
+
text_encoder_lora_adapter_metadata=None,
|
972
|
+
text_encoder_2_lora_adapter_metadata=None,
|
993
973
|
):
|
994
974
|
r"""
|
995
975
|
Save the LoRA parameters corresponding to the UNet and text encoder.
|
@@ -1015,8 +995,15 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
1015
995
|
`DIFFUSERS_SAVE_MODE`.
|
1016
996
|
safe_serialization (`bool`, *optional*, defaults to `True`):
|
1017
997
|
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
998
|
+
unet_lora_adapter_metadata:
|
999
|
+
LoRA adapter metadata associated with the unet to be serialized with the state dict.
|
1000
|
+
text_encoder_lora_adapter_metadata:
|
1001
|
+
LoRA adapter metadata associated with the text encoder to be serialized with the state dict.
|
1002
|
+
text_encoder_2_lora_adapter_metadata:
|
1003
|
+
LoRA adapter metadata associated with the second text encoder to be serialized with the state dict.
|
1018
1004
|
"""
|
1019
1005
|
state_dict = {}
|
1006
|
+
lora_adapter_metadata = {}
|
1020
1007
|
|
1021
1008
|
if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
|
1022
1009
|
raise ValueError(
|
@@ -1032,6 +1019,19 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
1032
1019
|
if text_encoder_2_lora_layers:
|
1033
1020
|
state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
|
1034
1021
|
|
1022
|
+
if unet_lora_adapter_metadata is not None:
|
1023
|
+
lora_adapter_metadata.update(_pack_dict_with_prefix(unet_lora_adapter_metadata, cls.unet_name))
|
1024
|
+
|
1025
|
+
if text_encoder_lora_adapter_metadata:
|
1026
|
+
lora_adapter_metadata.update(
|
1027
|
+
_pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name)
|
1028
|
+
)
|
1029
|
+
|
1030
|
+
if text_encoder_2_lora_adapter_metadata:
|
1031
|
+
lora_adapter_metadata.update(
|
1032
|
+
_pack_dict_with_prefix(text_encoder_2_lora_adapter_metadata, "text_encoder_2")
|
1033
|
+
)
|
1034
|
+
|
1035
1035
|
cls.write_lora_layers(
|
1036
1036
|
state_dict=state_dict,
|
1037
1037
|
save_directory=save_directory,
|
@@ -1039,6 +1039,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
1039
1039
|
weight_name=weight_name,
|
1040
1040
|
save_function=save_function,
|
1041
1041
|
safe_serialization=safe_serialization,
|
1042
|
+
lora_adapter_metadata=lora_adapter_metadata,
|
1042
1043
|
)
|
1043
1044
|
|
1044
1045
|
def fuse_lora(
|
@@ -1172,6 +1173,8 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1172
1173
|
allowed by Git.
|
1173
1174
|
subfolder (`str`, *optional*, defaults to `""`):
|
1174
1175
|
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
1176
|
+
return_lora_metadata (`bool`, *optional*, defaults to False):
|
1177
|
+
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
|
1175
1178
|
|
1176
1179
|
"""
|
1177
1180
|
# Load the main state dict first which has the LoRA layers for either of
|
@@ -1185,18 +1188,16 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1185
1188
|
subfolder = kwargs.pop("subfolder", None)
|
1186
1189
|
weight_name = kwargs.pop("weight_name", None)
|
1187
1190
|
use_safetensors = kwargs.pop("use_safetensors", None)
|
1191
|
+
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
|
1188
1192
|
|
1189
1193
|
allow_pickle = False
|
1190
1194
|
if use_safetensors is None:
|
1191
1195
|
use_safetensors = True
|
1192
1196
|
allow_pickle = True
|
1193
1197
|
|
1194
|
-
user_agent = {
|
1195
|
-
"file_type": "attn_procs_weights",
|
1196
|
-
"framework": "pytorch",
|
1197
|
-
}
|
1198
|
+
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
|
1198
1199
|
|
1199
|
-
state_dict = _fetch_state_dict(
|
1200
|
+
state_dict, metadata = _fetch_state_dict(
|
1200
1201
|
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
1201
1202
|
weight_name=weight_name,
|
1202
1203
|
use_safetensors=use_safetensors,
|
@@ -1217,7 +1218,8 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1217
1218
|
logger.warning(warn_msg)
|
1218
1219
|
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
1219
1220
|
|
1220
|
-
|
1221
|
+
out = (state_dict, metadata) if return_lora_metadata else state_dict
|
1222
|
+
return out
|
1221
1223
|
|
1222
1224
|
def load_lora_weights(
|
1223
1225
|
self,
|
@@ -1247,29 +1249,8 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1247
1249
|
low_cpu_mem_usage (`bool`, *optional*):
|
1248
1250
|
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
1249
1251
|
weights.
|
1250
|
-
hotswap
|
1251
|
-
|
1252
|
-
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
1253
|
-
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
1254
|
-
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
1255
|
-
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
1256
|
-
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
1257
|
-
|
1258
|
-
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
1259
|
-
to call an additional method before loading the adapter:
|
1260
|
-
|
1261
|
-
```py
|
1262
|
-
pipeline = ... # load diffusers pipeline
|
1263
|
-
max_rank = ... # the highest rank among all LoRAs that you want to load
|
1264
|
-
# call *before* compiling and loading the LoRA adapter
|
1265
|
-
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
1266
|
-
pipeline.load_lora_weights(file_name)
|
1267
|
-
# optionally compile the model now
|
1268
|
-
```
|
1269
|
-
|
1270
|
-
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
1271
|
-
limitations to this technique, which are documented here:
|
1272
|
-
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
1252
|
+
hotswap (`bool`, *optional*):
|
1253
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
1273
1254
|
kwargs (`dict`, *optional*):
|
1274
1255
|
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
1275
1256
|
"""
|
@@ -1287,7 +1268,8 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1287
1268
|
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
1288
1269
|
|
1289
1270
|
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
1290
|
-
|
1271
|
+
kwargs["return_lora_metadata"] = True
|
1272
|
+
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
1291
1273
|
|
1292
1274
|
is_correct_format = all("lora" in key for key in state_dict.keys())
|
1293
1275
|
if not is_correct_format:
|
@@ -1297,6 +1279,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1297
1279
|
state_dict,
|
1298
1280
|
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
1299
1281
|
adapter_name=adapter_name,
|
1282
|
+
metadata=metadata,
|
1300
1283
|
_pipeline=self,
|
1301
1284
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
1302
1285
|
hotswap=hotswap,
|
@@ -1308,6 +1291,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1308
1291
|
prefix=self.text_encoder_name,
|
1309
1292
|
lora_scale=self.lora_scale,
|
1310
1293
|
adapter_name=adapter_name,
|
1294
|
+
metadata=metadata,
|
1311
1295
|
_pipeline=self,
|
1312
1296
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
1313
1297
|
hotswap=hotswap,
|
@@ -1319,6 +1303,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1319
1303
|
prefix=f"{self.text_encoder_name}_2",
|
1320
1304
|
lora_scale=self.lora_scale,
|
1321
1305
|
adapter_name=adapter_name,
|
1306
|
+
metadata=metadata,
|
1322
1307
|
_pipeline=self,
|
1323
1308
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
1324
1309
|
hotswap=hotswap,
|
@@ -1326,7 +1311,14 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1326
1311
|
|
1327
1312
|
@classmethod
|
1328
1313
|
def load_lora_into_transformer(
|
1329
|
-
cls,
|
1314
|
+
cls,
|
1315
|
+
state_dict,
|
1316
|
+
transformer,
|
1317
|
+
adapter_name=None,
|
1318
|
+
_pipeline=None,
|
1319
|
+
low_cpu_mem_usage=False,
|
1320
|
+
hotswap: bool = False,
|
1321
|
+
metadata=None,
|
1330
1322
|
):
|
1331
1323
|
"""
|
1332
1324
|
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
@@ -1344,29 +1336,11 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1344
1336
|
low_cpu_mem_usage (`bool`, *optional*):
|
1345
1337
|
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
1346
1338
|
weights.
|
1347
|
-
hotswap
|
1348
|
-
|
1349
|
-
|
1350
|
-
|
1351
|
-
|
1352
|
-
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
1353
|
-
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
1354
|
-
|
1355
|
-
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
1356
|
-
to call an additional method before loading the adapter:
|
1357
|
-
|
1358
|
-
```py
|
1359
|
-
pipeline = ... # load diffusers pipeline
|
1360
|
-
max_rank = ... # the highest rank among all LoRAs that you want to load
|
1361
|
-
# call *before* compiling and loading the LoRA adapter
|
1362
|
-
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
1363
|
-
pipeline.load_lora_weights(file_name)
|
1364
|
-
# optionally compile the model now
|
1365
|
-
```
|
1366
|
-
|
1367
|
-
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
1368
|
-
limitations to this technique, which are documented here:
|
1369
|
-
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
1339
|
+
hotswap (`bool`, *optional*):
|
1340
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
1341
|
+
metadata (`dict`):
|
1342
|
+
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
|
1343
|
+
from the state dict.
|
1370
1344
|
"""
|
1371
1345
|
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
1372
1346
|
raise ValueError(
|
@@ -1379,6 +1353,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1379
1353
|
state_dict,
|
1380
1354
|
network_alphas=None,
|
1381
1355
|
adapter_name=adapter_name,
|
1356
|
+
metadata=metadata,
|
1382
1357
|
_pipeline=_pipeline,
|
1383
1358
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
1384
1359
|
hotswap=hotswap,
|
@@ -1397,6 +1372,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1397
1372
|
_pipeline=None,
|
1398
1373
|
low_cpu_mem_usage=False,
|
1399
1374
|
hotswap: bool = False,
|
1375
|
+
metadata=None,
|
1400
1376
|
):
|
1401
1377
|
"""
|
1402
1378
|
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
@@ -1422,29 +1398,11 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1422
1398
|
low_cpu_mem_usage (`bool`, *optional*):
|
1423
1399
|
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
1424
1400
|
weights.
|
1425
|
-
hotswap
|
1426
|
-
|
1427
|
-
|
1428
|
-
|
1429
|
-
|
1430
|
-
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
1431
|
-
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
1432
|
-
|
1433
|
-
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
1434
|
-
to call an additional method before loading the adapter:
|
1435
|
-
|
1436
|
-
```py
|
1437
|
-
pipeline = ... # load diffusers pipeline
|
1438
|
-
max_rank = ... # the highest rank among all LoRAs that you want to load
|
1439
|
-
# call *before* compiling and loading the LoRA adapter
|
1440
|
-
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
1441
|
-
pipeline.load_lora_weights(file_name)
|
1442
|
-
# optionally compile the model now
|
1443
|
-
```
|
1444
|
-
|
1445
|
-
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
1446
|
-
limitations to this technique, which are documented here:
|
1447
|
-
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
1401
|
+
hotswap (`bool`, *optional*):
|
1402
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
1403
|
+
metadata (`dict`):
|
1404
|
+
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
|
1405
|
+
from the state dict.
|
1448
1406
|
"""
|
1449
1407
|
_load_lora_into_text_encoder(
|
1450
1408
|
state_dict=state_dict,
|
@@ -1454,6 +1412,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1454
1412
|
prefix=prefix,
|
1455
1413
|
text_encoder_name=cls.text_encoder_name,
|
1456
1414
|
adapter_name=adapter_name,
|
1415
|
+
metadata=metadata,
|
1457
1416
|
_pipeline=_pipeline,
|
1458
1417
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
1459
1418
|
hotswap=hotswap,
|
@@ -1471,6 +1430,9 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1471
1430
|
weight_name: str = None,
|
1472
1431
|
save_function: Callable = None,
|
1473
1432
|
safe_serialization: bool = True,
|
1433
|
+
transformer_lora_adapter_metadata=None,
|
1434
|
+
text_encoder_lora_adapter_metadata=None,
|
1435
|
+
text_encoder_2_lora_adapter_metadata=None,
|
1474
1436
|
):
|
1475
1437
|
r"""
|
1476
1438
|
Save the LoRA parameters corresponding to the UNet and text encoder.
|
@@ -1496,8 +1458,15 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1496
1458
|
`DIFFUSERS_SAVE_MODE`.
|
1497
1459
|
safe_serialization (`bool`, *optional*, defaults to `True`):
|
1498
1460
|
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
1461
|
+
transformer_lora_adapter_metadata:
|
1462
|
+
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
|
1463
|
+
text_encoder_lora_adapter_metadata:
|
1464
|
+
LoRA adapter metadata associated with the text encoder to be serialized with the state dict.
|
1465
|
+
text_encoder_2_lora_adapter_metadata:
|
1466
|
+
LoRA adapter metadata associated with the second text encoder to be serialized with the state dict.
|
1499
1467
|
"""
|
1500
1468
|
state_dict = {}
|
1469
|
+
lora_adapter_metadata = {}
|
1501
1470
|
|
1502
1471
|
if not (transformer_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
|
1503
1472
|
raise ValueError(
|
@@ -1513,6 +1482,21 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1513
1482
|
if text_encoder_2_lora_layers:
|
1514
1483
|
state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
|
1515
1484
|
|
1485
|
+
if transformer_lora_adapter_metadata is not None:
|
1486
|
+
lora_adapter_metadata.update(
|
1487
|
+
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
|
1488
|
+
)
|
1489
|
+
|
1490
|
+
if text_encoder_lora_adapter_metadata:
|
1491
|
+
lora_adapter_metadata.update(
|
1492
|
+
_pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name)
|
1493
|
+
)
|
1494
|
+
|
1495
|
+
if text_encoder_2_lora_adapter_metadata:
|
1496
|
+
lora_adapter_metadata.update(
|
1497
|
+
_pack_dict_with_prefix(text_encoder_2_lora_adapter_metadata, "text_encoder_2")
|
1498
|
+
)
|
1499
|
+
|
1516
1500
|
cls.write_lora_layers(
|
1517
1501
|
state_dict=state_dict,
|
1518
1502
|
save_directory=save_directory,
|
@@ -1520,6 +1504,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1520
1504
|
weight_name=weight_name,
|
1521
1505
|
save_function=save_function,
|
1522
1506
|
safe_serialization=safe_serialization,
|
1507
|
+
lora_adapter_metadata=lora_adapter_metadata,
|
1523
1508
|
)
|
1524
1509
|
|
1525
1510
|
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.fuse_lora with unet->transformer
|
@@ -1592,25 +1577,20 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1592
1577
|
super().unfuse_lora(components=components, **kwargs)
|
1593
1578
|
|
1594
1579
|
|
1595
|
-
class
|
1580
|
+
class AuraFlowLoraLoaderMixin(LoraBaseMixin):
|
1596
1581
|
r"""
|
1597
|
-
Load LoRA layers into [`
|
1598
|
-
[`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
|
1599
|
-
|
1600
|
-
Specific to [`StableDiffusion3Pipeline`].
|
1582
|
+
Load LoRA layers into [`AuraFlowTransformer2DModel`] Specific to [`AuraFlowPipeline`].
|
1601
1583
|
"""
|
1602
1584
|
|
1603
|
-
_lora_loadable_modules = ["transformer"
|
1585
|
+
_lora_loadable_modules = ["transformer"]
|
1604
1586
|
transformer_name = TRANSFORMER_NAME
|
1605
|
-
text_encoder_name = TEXT_ENCODER_NAME
|
1606
|
-
_control_lora_supported_norm_keys = ["norm_q", "norm_k", "norm_added_q", "norm_added_k"]
|
1607
1587
|
|
1608
1588
|
@classmethod
|
1609
1589
|
@validate_hf_hub_args
|
1590
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
|
1610
1591
|
def lora_state_dict(
|
1611
1592
|
cls,
|
1612
1593
|
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
1613
|
-
return_alphas: bool = False,
|
1614
1594
|
**kwargs,
|
1615
1595
|
):
|
1616
1596
|
r"""
|
@@ -1656,6 +1636,8 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1656
1636
|
allowed by Git.
|
1657
1637
|
subfolder (`str`, *optional*, defaults to `""`):
|
1658
1638
|
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
1639
|
+
return_lora_metadata (`bool`, *optional*, defaults to False):
|
1640
|
+
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
|
1659
1641
|
|
1660
1642
|
"""
|
1661
1643
|
# Load the main state dict first which has the LoRA layers for either of
|
@@ -1669,18 +1651,16 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1669
1651
|
subfolder = kwargs.pop("subfolder", None)
|
1670
1652
|
weight_name = kwargs.pop("weight_name", None)
|
1671
1653
|
use_safetensors = kwargs.pop("use_safetensors", None)
|
1654
|
+
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
|
1672
1655
|
|
1673
1656
|
allow_pickle = False
|
1674
1657
|
if use_safetensors is None:
|
1675
1658
|
use_safetensors = True
|
1676
1659
|
allow_pickle = True
|
1677
1660
|
|
1678
|
-
user_agent = {
|
1679
|
-
"file_type": "attn_procs_weights",
|
1680
|
-
"framework": "pytorch",
|
1681
|
-
}
|
1661
|
+
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
|
1682
1662
|
|
1683
|
-
state_dict = _fetch_state_dict(
|
1663
|
+
state_dict, metadata = _fetch_state_dict(
|
1684
1664
|
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
1685
1665
|
weight_name=weight_name,
|
1686
1666
|
use_safetensors=use_safetensors,
|
@@ -1694,101 +1674,453 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1694
1674
|
user_agent=user_agent,
|
1695
1675
|
allow_pickle=allow_pickle,
|
1696
1676
|
)
|
1677
|
+
|
1697
1678
|
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
1698
1679
|
if is_dora_scale_present:
|
1699
1680
|
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
1700
1681
|
logger.warning(warn_msg)
|
1701
1682
|
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
1702
1683
|
|
1703
|
-
|
1704
|
-
|
1705
|
-
if is_kohya:
|
1706
|
-
state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict)
|
1707
|
-
# Kohya already takes care of scaling the LoRA parameters with alpha.
|
1708
|
-
return (state_dict, None) if return_alphas else state_dict
|
1709
|
-
|
1710
|
-
is_xlabs = any("processor" in k for k in state_dict)
|
1711
|
-
if is_xlabs:
|
1712
|
-
state_dict = _convert_xlabs_flux_lora_to_diffusers(state_dict)
|
1713
|
-
# xlabs doesn't use `alpha`.
|
1714
|
-
return (state_dict, None) if return_alphas else state_dict
|
1715
|
-
|
1716
|
-
is_bfl_control = any("query_norm.scale" in k for k in state_dict)
|
1717
|
-
if is_bfl_control:
|
1718
|
-
state_dict = _convert_bfl_flux_control_lora_to_diffusers(state_dict)
|
1719
|
-
return (state_dict, None) if return_alphas else state_dict
|
1720
|
-
|
1721
|
-
# For state dicts like
|
1722
|
-
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
|
1723
|
-
keys = list(state_dict.keys())
|
1724
|
-
network_alphas = {}
|
1725
|
-
for k in keys:
|
1726
|
-
if "alpha" in k:
|
1727
|
-
alpha_value = state_dict.get(k)
|
1728
|
-
if (torch.is_tensor(alpha_value) and torch.is_floating_point(alpha_value)) or isinstance(
|
1729
|
-
alpha_value, float
|
1730
|
-
):
|
1731
|
-
network_alphas[k] = state_dict.pop(k)
|
1732
|
-
else:
|
1733
|
-
raise ValueError(
|
1734
|
-
f"The alpha key ({k}) seems to be incorrect. If you think this error is unexpected, please open as issue."
|
1735
|
-
)
|
1736
|
-
|
1737
|
-
if return_alphas:
|
1738
|
-
return state_dict, network_alphas
|
1739
|
-
else:
|
1740
|
-
return state_dict
|
1684
|
+
out = (state_dict, metadata) if return_lora_metadata else state_dict
|
1685
|
+
return out
|
1741
1686
|
|
1687
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
1742
1688
|
def load_lora_weights(
|
1743
1689
|
self,
|
1744
1690
|
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
1745
|
-
adapter_name=None,
|
1691
|
+
adapter_name: Optional[str] = None,
|
1746
1692
|
hotswap: bool = False,
|
1747
1693
|
**kwargs,
|
1748
1694
|
):
|
1749
1695
|
"""
|
1750
1696
|
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
1751
|
-
`self.text_encoder`.
|
1752
|
-
|
1753
|
-
All kwargs are forwarded to `self.lora_state_dict`.
|
1754
|
-
|
1755
|
-
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
|
1756
|
-
loaded.
|
1757
|
-
|
1697
|
+
`self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
|
1698
|
+
[`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
|
1758
1699
|
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
|
1759
1700
|
dict is loaded into `self.transformer`.
|
1760
1701
|
|
1761
1702
|
Parameters:
|
1762
1703
|
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
1763
1704
|
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
1764
|
-
kwargs (`dict`, *optional*):
|
1765
|
-
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
1766
1705
|
adapter_name (`str`, *optional*):
|
1767
1706
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
1768
1707
|
`default_{i}` where i is the total number of adapters being loaded.
|
1769
1708
|
low_cpu_mem_usage (`bool`, *optional*):
|
1770
|
-
|
1709
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
1771
1710
|
weights.
|
1772
|
-
hotswap
|
1773
|
-
|
1774
|
-
|
1775
|
-
|
1776
|
-
|
1777
|
-
|
1778
|
-
|
1779
|
-
|
1780
|
-
|
1781
|
-
|
1782
|
-
|
1783
|
-
|
1784
|
-
|
1785
|
-
|
1786
|
-
|
1787
|
-
|
1788
|
-
|
1789
|
-
|
1790
|
-
|
1791
|
-
|
1711
|
+
hotswap (`bool`, *optional*):
|
1712
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
1713
|
+
kwargs (`dict`, *optional*):
|
1714
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
1715
|
+
"""
|
1716
|
+
if not USE_PEFT_BACKEND:
|
1717
|
+
raise ValueError("PEFT backend is required for this method.")
|
1718
|
+
|
1719
|
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
1720
|
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
1721
|
+
raise ValueError(
|
1722
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
1723
|
+
)
|
1724
|
+
|
1725
|
+
# if a dict is passed, copy it instead of modifying it inplace
|
1726
|
+
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
1727
|
+
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
1728
|
+
|
1729
|
+
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
1730
|
+
kwargs["return_lora_metadata"] = True
|
1731
|
+
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
1732
|
+
|
1733
|
+
is_correct_format = all("lora" in key for key in state_dict.keys())
|
1734
|
+
if not is_correct_format:
|
1735
|
+
raise ValueError("Invalid LoRA checkpoint.")
|
1736
|
+
|
1737
|
+
self.load_lora_into_transformer(
|
1738
|
+
state_dict,
|
1739
|
+
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
1740
|
+
adapter_name=adapter_name,
|
1741
|
+
metadata=metadata,
|
1742
|
+
_pipeline=self,
|
1743
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
1744
|
+
hotswap=hotswap,
|
1745
|
+
)
|
1746
|
+
|
1747
|
+
@classmethod
|
1748
|
+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->AuraFlowTransformer2DModel
|
1749
|
+
def load_lora_into_transformer(
|
1750
|
+
cls,
|
1751
|
+
state_dict,
|
1752
|
+
transformer,
|
1753
|
+
adapter_name=None,
|
1754
|
+
_pipeline=None,
|
1755
|
+
low_cpu_mem_usage=False,
|
1756
|
+
hotswap: bool = False,
|
1757
|
+
metadata=None,
|
1758
|
+
):
|
1759
|
+
"""
|
1760
|
+
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
1761
|
+
|
1762
|
+
Parameters:
|
1763
|
+
state_dict (`dict`):
|
1764
|
+
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
1765
|
+
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
1766
|
+
encoder lora layers.
|
1767
|
+
transformer (`AuraFlowTransformer2DModel`):
|
1768
|
+
The Transformer model to load the LoRA layers into.
|
1769
|
+
adapter_name (`str`, *optional*):
|
1770
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
1771
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
1772
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
1773
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
1774
|
+
weights.
|
1775
|
+
hotswap (`bool`, *optional*):
|
1776
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
1777
|
+
metadata (`dict`):
|
1778
|
+
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
|
1779
|
+
from the state dict.
|
1780
|
+
"""
|
1781
|
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
1782
|
+
raise ValueError(
|
1783
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
1784
|
+
)
|
1785
|
+
|
1786
|
+
# Load the layers corresponding to transformer.
|
1787
|
+
logger.info(f"Loading {cls.transformer_name}.")
|
1788
|
+
transformer.load_lora_adapter(
|
1789
|
+
state_dict,
|
1790
|
+
network_alphas=None,
|
1791
|
+
adapter_name=adapter_name,
|
1792
|
+
metadata=metadata,
|
1793
|
+
_pipeline=_pipeline,
|
1794
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
1795
|
+
hotswap=hotswap,
|
1796
|
+
)
|
1797
|
+
|
1798
|
+
@classmethod
|
1799
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
|
1800
|
+
def save_lora_weights(
|
1801
|
+
cls,
|
1802
|
+
save_directory: Union[str, os.PathLike],
|
1803
|
+
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
1804
|
+
is_main_process: bool = True,
|
1805
|
+
weight_name: str = None,
|
1806
|
+
save_function: Callable = None,
|
1807
|
+
safe_serialization: bool = True,
|
1808
|
+
transformer_lora_adapter_metadata: Optional[dict] = None,
|
1809
|
+
):
|
1810
|
+
r"""
|
1811
|
+
Save the LoRA parameters corresponding to the transformer.
|
1812
|
+
|
1813
|
+
Arguments:
|
1814
|
+
save_directory (`str` or `os.PathLike`):
|
1815
|
+
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
1816
|
+
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
1817
|
+
State dict of the LoRA layers corresponding to the `transformer`.
|
1818
|
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
1819
|
+
Whether the process calling this is the main process or not. Useful during distributed training and you
|
1820
|
+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
1821
|
+
process to avoid race conditions.
|
1822
|
+
save_function (`Callable`):
|
1823
|
+
The function to use to save the state dictionary. Useful during distributed training when you need to
|
1824
|
+
replace `torch.save` with another method. Can be configured with the environment variable
|
1825
|
+
`DIFFUSERS_SAVE_MODE`.
|
1826
|
+
safe_serialization (`bool`, *optional*, defaults to `True`):
|
1827
|
+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
1828
|
+
transformer_lora_adapter_metadata:
|
1829
|
+
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
|
1830
|
+
"""
|
1831
|
+
state_dict = {}
|
1832
|
+
lora_adapter_metadata = {}
|
1833
|
+
|
1834
|
+
if not transformer_lora_layers:
|
1835
|
+
raise ValueError("You must pass `transformer_lora_layers`.")
|
1836
|
+
|
1837
|
+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
1838
|
+
|
1839
|
+
if transformer_lora_adapter_metadata is not None:
|
1840
|
+
lora_adapter_metadata.update(
|
1841
|
+
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
|
1842
|
+
)
|
1843
|
+
|
1844
|
+
# Save the model
|
1845
|
+
cls.write_lora_layers(
|
1846
|
+
state_dict=state_dict,
|
1847
|
+
save_directory=save_directory,
|
1848
|
+
is_main_process=is_main_process,
|
1849
|
+
weight_name=weight_name,
|
1850
|
+
save_function=save_function,
|
1851
|
+
safe_serialization=safe_serialization,
|
1852
|
+
lora_adapter_metadata=lora_adapter_metadata,
|
1853
|
+
)
|
1854
|
+
|
1855
|
+
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
|
1856
|
+
def fuse_lora(
|
1857
|
+
self,
|
1858
|
+
components: List[str] = ["transformer"],
|
1859
|
+
lora_scale: float = 1.0,
|
1860
|
+
safe_fusing: bool = False,
|
1861
|
+
adapter_names: Optional[List[str]] = None,
|
1862
|
+
**kwargs,
|
1863
|
+
):
|
1864
|
+
r"""
|
1865
|
+
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
1866
|
+
|
1867
|
+
<Tip warning={true}>
|
1868
|
+
|
1869
|
+
This is an experimental API.
|
1870
|
+
|
1871
|
+
</Tip>
|
1872
|
+
|
1873
|
+
Args:
|
1874
|
+
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
1875
|
+
lora_scale (`float`, defaults to 1.0):
|
1876
|
+
Controls how much to influence the outputs with the LoRA parameters.
|
1877
|
+
safe_fusing (`bool`, defaults to `False`):
|
1878
|
+
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
1879
|
+
adapter_names (`List[str]`, *optional*):
|
1880
|
+
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
1881
|
+
|
1882
|
+
Example:
|
1883
|
+
|
1884
|
+
```py
|
1885
|
+
from diffusers import DiffusionPipeline
|
1886
|
+
import torch
|
1887
|
+
|
1888
|
+
pipeline = DiffusionPipeline.from_pretrained(
|
1889
|
+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
1890
|
+
).to("cuda")
|
1891
|
+
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
1892
|
+
pipeline.fuse_lora(lora_scale=0.7)
|
1893
|
+
```
|
1894
|
+
"""
|
1895
|
+
super().fuse_lora(
|
1896
|
+
components=components,
|
1897
|
+
lora_scale=lora_scale,
|
1898
|
+
safe_fusing=safe_fusing,
|
1899
|
+
adapter_names=adapter_names,
|
1900
|
+
**kwargs,
|
1901
|
+
)
|
1902
|
+
|
1903
|
+
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
|
1904
|
+
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
|
1905
|
+
r"""
|
1906
|
+
Reverses the effect of
|
1907
|
+
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
1908
|
+
|
1909
|
+
<Tip warning={true}>
|
1910
|
+
|
1911
|
+
This is an experimental API.
|
1912
|
+
|
1913
|
+
</Tip>
|
1914
|
+
|
1915
|
+
Args:
|
1916
|
+
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
1917
|
+
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
1918
|
+
"""
|
1919
|
+
super().unfuse_lora(components=components, **kwargs)
|
1920
|
+
|
1921
|
+
|
1922
|
+
class FluxLoraLoaderMixin(LoraBaseMixin):
|
1923
|
+
r"""
|
1924
|
+
Load LoRA layers into [`FluxTransformer2DModel`],
|
1925
|
+
[`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
|
1926
|
+
|
1927
|
+
Specific to [`StableDiffusion3Pipeline`].
|
1928
|
+
"""
|
1929
|
+
|
1930
|
+
_lora_loadable_modules = ["transformer", "text_encoder"]
|
1931
|
+
transformer_name = TRANSFORMER_NAME
|
1932
|
+
text_encoder_name = TEXT_ENCODER_NAME
|
1933
|
+
_control_lora_supported_norm_keys = ["norm_q", "norm_k", "norm_added_q", "norm_added_k"]
|
1934
|
+
|
1935
|
+
@classmethod
|
1936
|
+
@validate_hf_hub_args
|
1937
|
+
def lora_state_dict(
|
1938
|
+
cls,
|
1939
|
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
1940
|
+
return_alphas: bool = False,
|
1941
|
+
**kwargs,
|
1942
|
+
):
|
1943
|
+
r"""
|
1944
|
+
Return state dict for lora weights and the network alphas.
|
1945
|
+
|
1946
|
+
<Tip warning={true}>
|
1947
|
+
|
1948
|
+
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
|
1949
|
+
|
1950
|
+
This function is experimental and might change in the future.
|
1951
|
+
|
1952
|
+
</Tip>
|
1953
|
+
|
1954
|
+
Parameters:
|
1955
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
1956
|
+
Can be either:
|
1957
|
+
|
1958
|
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
1959
|
+
the Hub.
|
1960
|
+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
1961
|
+
with [`ModelMixin.save_pretrained`].
|
1962
|
+
- A [torch state
|
1963
|
+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
1964
|
+
|
1965
|
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
1966
|
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
1967
|
+
is not used.
|
1968
|
+
force_download (`bool`, *optional*, defaults to `False`):
|
1969
|
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
1970
|
+
cached versions if they exist.
|
1971
|
+
|
1972
|
+
proxies (`Dict[str, str]`, *optional*):
|
1973
|
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
1974
|
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
1975
|
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
1976
|
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
1977
|
+
won't be downloaded from the Hub.
|
1978
|
+
token (`str` or *bool*, *optional*):
|
1979
|
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
1980
|
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
1981
|
+
revision (`str`, *optional*, defaults to `"main"`):
|
1982
|
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
1983
|
+
allowed by Git.
|
1984
|
+
subfolder (`str`, *optional*, defaults to `""`):
|
1985
|
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
1986
|
+
return_lora_metadata (`bool`, *optional*, defaults to False):
|
1987
|
+
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
|
1988
|
+
"""
|
1989
|
+
# Load the main state dict first which has the LoRA layers for either of
|
1990
|
+
# transformer and text encoder or both.
|
1991
|
+
cache_dir = kwargs.pop("cache_dir", None)
|
1992
|
+
force_download = kwargs.pop("force_download", False)
|
1993
|
+
proxies = kwargs.pop("proxies", None)
|
1994
|
+
local_files_only = kwargs.pop("local_files_only", None)
|
1995
|
+
token = kwargs.pop("token", None)
|
1996
|
+
revision = kwargs.pop("revision", None)
|
1997
|
+
subfolder = kwargs.pop("subfolder", None)
|
1998
|
+
weight_name = kwargs.pop("weight_name", None)
|
1999
|
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
2000
|
+
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
|
2001
|
+
|
2002
|
+
allow_pickle = False
|
2003
|
+
if use_safetensors is None:
|
2004
|
+
use_safetensors = True
|
2005
|
+
allow_pickle = True
|
2006
|
+
|
2007
|
+
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
|
2008
|
+
|
2009
|
+
state_dict, metadata = _fetch_state_dict(
|
2010
|
+
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
2011
|
+
weight_name=weight_name,
|
2012
|
+
use_safetensors=use_safetensors,
|
2013
|
+
local_files_only=local_files_only,
|
2014
|
+
cache_dir=cache_dir,
|
2015
|
+
force_download=force_download,
|
2016
|
+
proxies=proxies,
|
2017
|
+
token=token,
|
2018
|
+
revision=revision,
|
2019
|
+
subfolder=subfolder,
|
2020
|
+
user_agent=user_agent,
|
2021
|
+
allow_pickle=allow_pickle,
|
2022
|
+
)
|
2023
|
+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
2024
|
+
if is_dora_scale_present:
|
2025
|
+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
2026
|
+
logger.warning(warn_msg)
|
2027
|
+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
2028
|
+
|
2029
|
+
# TODO (sayakpaul): to a follow-up to clean and try to unify the conditions.
|
2030
|
+
is_kohya = any(".lora_down.weight" in k for k in state_dict)
|
2031
|
+
if is_kohya:
|
2032
|
+
state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict)
|
2033
|
+
# Kohya already takes care of scaling the LoRA parameters with alpha.
|
2034
|
+
return cls._prepare_outputs(
|
2035
|
+
state_dict,
|
2036
|
+
metadata=metadata,
|
2037
|
+
alphas=None,
|
2038
|
+
return_alphas=return_alphas,
|
2039
|
+
return_metadata=return_lora_metadata,
|
2040
|
+
)
|
2041
|
+
|
2042
|
+
is_xlabs = any("processor" in k for k in state_dict)
|
2043
|
+
if is_xlabs:
|
2044
|
+
state_dict = _convert_xlabs_flux_lora_to_diffusers(state_dict)
|
2045
|
+
# xlabs doesn't use `alpha`.
|
2046
|
+
return cls._prepare_outputs(
|
2047
|
+
state_dict,
|
2048
|
+
metadata=metadata,
|
2049
|
+
alphas=None,
|
2050
|
+
return_alphas=return_alphas,
|
2051
|
+
return_metadata=return_lora_metadata,
|
2052
|
+
)
|
2053
|
+
|
2054
|
+
is_bfl_control = any("query_norm.scale" in k for k in state_dict)
|
2055
|
+
if is_bfl_control:
|
2056
|
+
state_dict = _convert_bfl_flux_control_lora_to_diffusers(state_dict)
|
2057
|
+
return cls._prepare_outputs(
|
2058
|
+
state_dict,
|
2059
|
+
metadata=metadata,
|
2060
|
+
alphas=None,
|
2061
|
+
return_alphas=return_alphas,
|
2062
|
+
return_metadata=return_lora_metadata,
|
2063
|
+
)
|
2064
|
+
|
2065
|
+
# For state dicts like
|
2066
|
+
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
|
2067
|
+
keys = list(state_dict.keys())
|
2068
|
+
network_alphas = {}
|
2069
|
+
for k in keys:
|
2070
|
+
if "alpha" in k:
|
2071
|
+
alpha_value = state_dict.get(k)
|
2072
|
+
if (torch.is_tensor(alpha_value) and torch.is_floating_point(alpha_value)) or isinstance(
|
2073
|
+
alpha_value, float
|
2074
|
+
):
|
2075
|
+
network_alphas[k] = state_dict.pop(k)
|
2076
|
+
else:
|
2077
|
+
raise ValueError(
|
2078
|
+
f"The alpha key ({k}) seems to be incorrect. If you think this error is unexpected, please open as issue."
|
2079
|
+
)
|
2080
|
+
|
2081
|
+
if return_alphas or return_lora_metadata:
|
2082
|
+
return cls._prepare_outputs(
|
2083
|
+
state_dict,
|
2084
|
+
metadata=metadata,
|
2085
|
+
alphas=network_alphas,
|
2086
|
+
return_alphas=return_alphas,
|
2087
|
+
return_metadata=return_lora_metadata,
|
2088
|
+
)
|
2089
|
+
else:
|
2090
|
+
return state_dict
|
2091
|
+
|
2092
|
+
def load_lora_weights(
|
2093
|
+
self,
|
2094
|
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
2095
|
+
adapter_name: Optional[str] = None,
|
2096
|
+
hotswap: bool = False,
|
2097
|
+
**kwargs,
|
2098
|
+
):
|
2099
|
+
"""
|
2100
|
+
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
2101
|
+
`self.text_encoder`.
|
2102
|
+
|
2103
|
+
All kwargs are forwarded to `self.lora_state_dict`.
|
2104
|
+
|
2105
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
|
2106
|
+
loaded.
|
2107
|
+
|
2108
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
|
2109
|
+
dict is loaded into `self.transformer`.
|
2110
|
+
|
2111
|
+
Parameters:
|
2112
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
2113
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
2114
|
+
adapter_name (`str`, *optional*):
|
2115
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
2116
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
2117
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
2118
|
+
`Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
2119
|
+
weights.
|
2120
|
+
hotswap (`bool`, *optional*):
|
2121
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
2122
|
+
kwargs (`dict`, *optional*):
|
2123
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
1792
2124
|
"""
|
1793
2125
|
if not USE_PEFT_BACKEND:
|
1794
2126
|
raise ValueError("PEFT backend is required for this method.")
|
@@ -1804,7 +2136,8 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1804
2136
|
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
1805
2137
|
|
1806
2138
|
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
1807
|
-
|
2139
|
+
kwargs["return_lora_metadata"] = True
|
2140
|
+
state_dict, network_alphas, metadata = self.lora_state_dict(
|
1808
2141
|
pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs
|
1809
2142
|
)
|
1810
2143
|
|
@@ -1855,6 +2188,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1855
2188
|
network_alphas=network_alphas,
|
1856
2189
|
transformer=transformer,
|
1857
2190
|
adapter_name=adapter_name,
|
2191
|
+
metadata=metadata,
|
1858
2192
|
_pipeline=self,
|
1859
2193
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
1860
2194
|
hotswap=hotswap,
|
@@ -1874,6 +2208,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1874
2208
|
prefix=self.text_encoder_name,
|
1875
2209
|
lora_scale=self.lora_scale,
|
1876
2210
|
adapter_name=adapter_name,
|
2211
|
+
metadata=metadata,
|
1877
2212
|
_pipeline=self,
|
1878
2213
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
1879
2214
|
hotswap=hotswap,
|
@@ -1886,6 +2221,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1886
2221
|
network_alphas,
|
1887
2222
|
transformer,
|
1888
2223
|
adapter_name=None,
|
2224
|
+
metadata=None,
|
1889
2225
|
_pipeline=None,
|
1890
2226
|
low_cpu_mem_usage=False,
|
1891
2227
|
hotswap: bool = False,
|
@@ -1910,29 +2246,11 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1910
2246
|
low_cpu_mem_usage (`bool`, *optional*):
|
1911
2247
|
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
1912
2248
|
weights.
|
1913
|
-
hotswap
|
1914
|
-
|
1915
|
-
|
1916
|
-
|
1917
|
-
|
1918
|
-
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
1919
|
-
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
1920
|
-
|
1921
|
-
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
1922
|
-
to call an additional method before loading the adapter:
|
1923
|
-
|
1924
|
-
```py
|
1925
|
-
pipeline = ... # load diffusers pipeline
|
1926
|
-
max_rank = ... # the highest rank among all LoRAs that you want to load
|
1927
|
-
# call *before* compiling and loading the LoRA adapter
|
1928
|
-
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
1929
|
-
pipeline.load_lora_weights(file_name)
|
1930
|
-
# optionally compile the model now
|
1931
|
-
```
|
1932
|
-
|
1933
|
-
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
1934
|
-
limitations to this technique, which are documented here:
|
1935
|
-
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
2249
|
+
hotswap (`bool`, *optional*):
|
2250
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
2251
|
+
metadata (`dict`):
|
2252
|
+
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
|
2253
|
+
from the state dict.
|
1936
2254
|
"""
|
1937
2255
|
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
1938
2256
|
raise ValueError(
|
@@ -1945,6 +2263,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1945
2263
|
state_dict,
|
1946
2264
|
network_alphas=network_alphas,
|
1947
2265
|
adapter_name=adapter_name,
|
2266
|
+
metadata=metadata,
|
1948
2267
|
_pipeline=_pipeline,
|
1949
2268
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
1950
2269
|
hotswap=hotswap,
|
@@ -1962,7 +2281,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1962
2281
|
prefix = prefix or cls.transformer_name
|
1963
2282
|
for key in list(state_dict.keys()):
|
1964
2283
|
if key.split(".")[0] == prefix:
|
1965
|
-
state_dict[key
|
2284
|
+
state_dict[key.removeprefix(f"{prefix}.")] = state_dict.pop(key)
|
1966
2285
|
|
1967
2286
|
# Find invalid keys
|
1968
2287
|
transformer_state_dict = transformer.state_dict()
|
@@ -2017,6 +2336,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
2017
2336
|
_pipeline=None,
|
2018
2337
|
low_cpu_mem_usage=False,
|
2019
2338
|
hotswap: bool = False,
|
2339
|
+
metadata=None,
|
2020
2340
|
):
|
2021
2341
|
"""
|
2022
2342
|
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
@@ -2040,31 +2360,13 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
2040
2360
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
2041
2361
|
`default_{i}` where i is the total number of adapters being loaded.
|
2042
2362
|
low_cpu_mem_usage (`bool`, *optional*):
|
2043
|
-
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
2044
|
-
weights.
|
2045
|
-
hotswap
|
2046
|
-
|
2047
|
-
|
2048
|
-
|
2049
|
-
|
2050
|
-
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
2051
|
-
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
2052
|
-
|
2053
|
-
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
2054
|
-
to call an additional method before loading the adapter:
|
2055
|
-
|
2056
|
-
```py
|
2057
|
-
pipeline = ... # load diffusers pipeline
|
2058
|
-
max_rank = ... # the highest rank among all LoRAs that you want to load
|
2059
|
-
# call *before* compiling and loading the LoRA adapter
|
2060
|
-
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
2061
|
-
pipeline.load_lora_weights(file_name)
|
2062
|
-
# optionally compile the model now
|
2063
|
-
```
|
2064
|
-
|
2065
|
-
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
2066
|
-
limitations to this technique, which are documented here:
|
2067
|
-
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
2363
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
2364
|
+
weights.
|
2365
|
+
hotswap (`bool`, *optional*):
|
2366
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
2367
|
+
metadata (`dict`):
|
2368
|
+
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
|
2369
|
+
from the state dict.
|
2068
2370
|
"""
|
2069
2371
|
_load_lora_into_text_encoder(
|
2070
2372
|
state_dict=state_dict,
|
@@ -2074,6 +2376,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
2074
2376
|
prefix=prefix,
|
2075
2377
|
text_encoder_name=cls.text_encoder_name,
|
2076
2378
|
adapter_name=adapter_name,
|
2379
|
+
metadata=metadata,
|
2077
2380
|
_pipeline=_pipeline,
|
2078
2381
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
2079
2382
|
hotswap=hotswap,
|
@@ -2090,6 +2393,8 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
2090
2393
|
weight_name: str = None,
|
2091
2394
|
save_function: Callable = None,
|
2092
2395
|
safe_serialization: bool = True,
|
2396
|
+
transformer_lora_adapter_metadata=None,
|
2397
|
+
text_encoder_lora_adapter_metadata=None,
|
2093
2398
|
):
|
2094
2399
|
r"""
|
2095
2400
|
Save the LoRA parameters corresponding to the UNet and text encoder.
|
@@ -2112,8 +2417,13 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
2112
2417
|
`DIFFUSERS_SAVE_MODE`.
|
2113
2418
|
safe_serialization (`bool`, *optional*, defaults to `True`):
|
2114
2419
|
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
2420
|
+
transformer_lora_adapter_metadata:
|
2421
|
+
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
|
2422
|
+
text_encoder_lora_adapter_metadata:
|
2423
|
+
LoRA adapter metadata associated with the text encoder to be serialized with the state dict.
|
2115
2424
|
"""
|
2116
2425
|
state_dict = {}
|
2426
|
+
lora_adapter_metadata = {}
|
2117
2427
|
|
2118
2428
|
if not (transformer_lora_layers or text_encoder_lora_layers):
|
2119
2429
|
raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.")
|
@@ -2124,6 +2434,16 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
2124
2434
|
if text_encoder_lora_layers:
|
2125
2435
|
state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
|
2126
2436
|
|
2437
|
+
if transformer_lora_adapter_metadata:
|
2438
|
+
lora_adapter_metadata.update(
|
2439
|
+
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
|
2440
|
+
)
|
2441
|
+
|
2442
|
+
if text_encoder_lora_adapter_metadata:
|
2443
|
+
lora_adapter_metadata.update(
|
2444
|
+
_pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name)
|
2445
|
+
)
|
2446
|
+
|
2127
2447
|
# Save the model
|
2128
2448
|
cls.write_lora_layers(
|
2129
2449
|
state_dict=state_dict,
|
@@ -2132,6 +2452,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
2132
2452
|
weight_name=weight_name,
|
2133
2453
|
save_function=save_function,
|
2134
2454
|
safe_serialization=safe_serialization,
|
2455
|
+
lora_adapter_metadata=lora_adapter_metadata,
|
2135
2456
|
)
|
2136
2457
|
|
2137
2458
|
def fuse_lora(
|
@@ -2293,7 +2614,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
2293
2614
|
) -> bool:
|
2294
2615
|
"""
|
2295
2616
|
Control LoRA expands the shape of the input layer from (3072, 64) to (3072, 128). This method handles that and
|
2296
|
-
generalizes things a bit so that any parameter that needs expansion receives appropriate
|
2617
|
+
generalizes things a bit so that any parameter that needs expansion receives appropriate treatment.
|
2297
2618
|
"""
|
2298
2619
|
state_dict = {}
|
2299
2620
|
if lora_state_dict is not None:
|
@@ -2305,7 +2626,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
2305
2626
|
prefix = prefix or cls.transformer_name
|
2306
2627
|
for key in list(state_dict.keys()):
|
2307
2628
|
if key.split(".")[0] == prefix:
|
2308
|
-
state_dict[key
|
2629
|
+
state_dict[key.removeprefix(f"{prefix}.")] = state_dict.pop(key)
|
2309
2630
|
|
2310
2631
|
# Expand transformer parameter shapes if they don't match lora
|
2311
2632
|
has_param_with_shape_update = False
|
@@ -2423,14 +2744,13 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
2423
2744
|
if unexpected_modules:
|
2424
2745
|
logger.debug(f"Found unexpected modules: {unexpected_modules}. These will be ignored.")
|
2425
2746
|
|
2426
|
-
is_peft_loaded = getattr(transformer, "peft_config", None) is not None
|
2427
2747
|
for k in lora_module_names:
|
2428
2748
|
if k in unexpected_modules:
|
2429
2749
|
continue
|
2430
2750
|
|
2431
2751
|
base_param_name = (
|
2432
2752
|
f"{k.replace(prefix, '')}.base_layer.weight"
|
2433
|
-
if
|
2753
|
+
if f"{k.replace(prefix, '')}.base_layer.weight" in transformer_state_dict
|
2434
2754
|
else f"{k.replace(prefix, '')}.weight"
|
2435
2755
|
)
|
2436
2756
|
base_weight_param = transformer_state_dict[base_param_name]
|
@@ -2484,6 +2804,15 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
2484
2804
|
|
2485
2805
|
raise ValueError("Either `base_module` or `base_weight_param_name` must be provided.")
|
2486
2806
|
|
2807
|
+
@staticmethod
|
2808
|
+
def _prepare_outputs(state_dict, metadata, alphas=None, return_alphas=False, return_metadata=False):
|
2809
|
+
outputs = [state_dict]
|
2810
|
+
if return_alphas:
|
2811
|
+
outputs.append(alphas)
|
2812
|
+
if return_metadata:
|
2813
|
+
outputs.append(metadata)
|
2814
|
+
return tuple(outputs) if (return_alphas or return_metadata) else state_dict
|
2815
|
+
|
2487
2816
|
|
2488
2817
|
# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
|
2489
2818
|
# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support.
|
@@ -2500,6 +2829,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
|
2500
2829
|
network_alphas,
|
2501
2830
|
transformer,
|
2502
2831
|
adapter_name=None,
|
2832
|
+
metadata=None,
|
2503
2833
|
_pipeline=None,
|
2504
2834
|
low_cpu_mem_usage=False,
|
2505
2835
|
hotswap: bool = False,
|
@@ -2524,143 +2854,380 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
|
2524
2854
|
low_cpu_mem_usage (`bool`, *optional*):
|
2525
2855
|
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
2526
2856
|
weights.
|
2527
|
-
hotswap
|
2528
|
-
|
2529
|
-
|
2530
|
-
|
2531
|
-
|
2532
|
-
|
2533
|
-
|
2534
|
-
|
2535
|
-
|
2536
|
-
|
2537
|
-
|
2538
|
-
|
2539
|
-
|
2540
|
-
|
2541
|
-
|
2542
|
-
|
2543
|
-
|
2544
|
-
|
2545
|
-
|
2546
|
-
|
2547
|
-
|
2548
|
-
|
2549
|
-
|
2857
|
+
hotswap (`bool`, *optional*):
|
2858
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
2859
|
+
metadata (`dict`):
|
2860
|
+
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
|
2861
|
+
from the state dict.
|
2862
|
+
"""
|
2863
|
+
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
2864
|
+
raise ValueError(
|
2865
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
2866
|
+
)
|
2867
|
+
|
2868
|
+
# Load the layers corresponding to transformer.
|
2869
|
+
logger.info(f"Loading {cls.transformer_name}.")
|
2870
|
+
transformer.load_lora_adapter(
|
2871
|
+
state_dict,
|
2872
|
+
network_alphas=network_alphas,
|
2873
|
+
adapter_name=adapter_name,
|
2874
|
+
metadata=metadata,
|
2875
|
+
_pipeline=_pipeline,
|
2876
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
2877
|
+
hotswap=hotswap,
|
2878
|
+
)
|
2879
|
+
|
2880
|
+
@classmethod
|
2881
|
+
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
|
2882
|
+
def load_lora_into_text_encoder(
|
2883
|
+
cls,
|
2884
|
+
state_dict,
|
2885
|
+
network_alphas,
|
2886
|
+
text_encoder,
|
2887
|
+
prefix=None,
|
2888
|
+
lora_scale=1.0,
|
2889
|
+
adapter_name=None,
|
2890
|
+
_pipeline=None,
|
2891
|
+
low_cpu_mem_usage=False,
|
2892
|
+
hotswap: bool = False,
|
2893
|
+
metadata=None,
|
2894
|
+
):
|
2895
|
+
"""
|
2896
|
+
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
2897
|
+
|
2898
|
+
Parameters:
|
2899
|
+
state_dict (`dict`):
|
2900
|
+
A standard state dict containing the lora layer parameters. The key should be prefixed with an
|
2901
|
+
additional `text_encoder` to distinguish between unet lora layers.
|
2902
|
+
network_alphas (`Dict[str, float]`):
|
2903
|
+
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
2904
|
+
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
2905
|
+
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
2906
|
+
text_encoder (`CLIPTextModel`):
|
2907
|
+
The text encoder model to load the LoRA layers into.
|
2908
|
+
prefix (`str`):
|
2909
|
+
Expected prefix of the `text_encoder` in the `state_dict`.
|
2910
|
+
lora_scale (`float`):
|
2911
|
+
How much to scale the output of the lora linear layer before it is added with the output of the regular
|
2912
|
+
lora layer.
|
2913
|
+
adapter_name (`str`, *optional*):
|
2914
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
2915
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
2916
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
2917
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
2918
|
+
weights.
|
2919
|
+
hotswap (`bool`, *optional*):
|
2920
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
2921
|
+
metadata (`dict`):
|
2922
|
+
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
|
2923
|
+
from the state dict.
|
2924
|
+
"""
|
2925
|
+
_load_lora_into_text_encoder(
|
2926
|
+
state_dict=state_dict,
|
2927
|
+
network_alphas=network_alphas,
|
2928
|
+
lora_scale=lora_scale,
|
2929
|
+
text_encoder=text_encoder,
|
2930
|
+
prefix=prefix,
|
2931
|
+
text_encoder_name=cls.text_encoder_name,
|
2932
|
+
adapter_name=adapter_name,
|
2933
|
+
metadata=metadata,
|
2934
|
+
_pipeline=_pipeline,
|
2935
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
2936
|
+
hotswap=hotswap,
|
2937
|
+
)
|
2938
|
+
|
2939
|
+
@classmethod
|
2940
|
+
def save_lora_weights(
|
2941
|
+
cls,
|
2942
|
+
save_directory: Union[str, os.PathLike],
|
2943
|
+
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
|
2944
|
+
transformer_lora_layers: Dict[str, torch.nn.Module] = None,
|
2945
|
+
is_main_process: bool = True,
|
2946
|
+
weight_name: str = None,
|
2947
|
+
save_function: Callable = None,
|
2948
|
+
safe_serialization: bool = True,
|
2949
|
+
):
|
2950
|
+
r"""
|
2951
|
+
Save the LoRA parameters corresponding to the UNet and text encoder.
|
2952
|
+
|
2953
|
+
Arguments:
|
2954
|
+
save_directory (`str` or `os.PathLike`):
|
2955
|
+
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
2956
|
+
unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
2957
|
+
State dict of the LoRA layers corresponding to the `unet`.
|
2958
|
+
text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
2959
|
+
State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
|
2960
|
+
encoder LoRA state dict because it comes from 🤗 Transformers.
|
2961
|
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
2962
|
+
Whether the process calling this is the main process or not. Useful during distributed training and you
|
2963
|
+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
2964
|
+
process to avoid race conditions.
|
2965
|
+
save_function (`Callable`):
|
2966
|
+
The function to use to save the state dictionary. Useful during distributed training when you need to
|
2967
|
+
replace `torch.save` with another method. Can be configured with the environment variable
|
2968
|
+
`DIFFUSERS_SAVE_MODE`.
|
2969
|
+
safe_serialization (`bool`, *optional*, defaults to `True`):
|
2970
|
+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
2971
|
+
"""
|
2972
|
+
state_dict = {}
|
2973
|
+
|
2974
|
+
if not (transformer_lora_layers or text_encoder_lora_layers):
|
2975
|
+
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
|
2976
|
+
|
2977
|
+
if transformer_lora_layers:
|
2978
|
+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
2979
|
+
|
2980
|
+
if text_encoder_lora_layers:
|
2981
|
+
state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
|
2982
|
+
|
2983
|
+
# Save the model
|
2984
|
+
cls.write_lora_layers(
|
2985
|
+
state_dict=state_dict,
|
2986
|
+
save_directory=save_directory,
|
2987
|
+
is_main_process=is_main_process,
|
2988
|
+
weight_name=weight_name,
|
2989
|
+
save_function=save_function,
|
2990
|
+
safe_serialization=safe_serialization,
|
2991
|
+
)
|
2992
|
+
|
2993
|
+
|
2994
|
+
class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
2995
|
+
r"""
|
2996
|
+
Load LoRA layers into [`CogVideoXTransformer3DModel`]. Specific to [`CogVideoXPipeline`].
|
2997
|
+
"""
|
2998
|
+
|
2999
|
+
_lora_loadable_modules = ["transformer"]
|
3000
|
+
transformer_name = TRANSFORMER_NAME
|
3001
|
+
|
3002
|
+
@classmethod
|
3003
|
+
@validate_hf_hub_args
|
3004
|
+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
|
3005
|
+
def lora_state_dict(
|
3006
|
+
cls,
|
3007
|
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
3008
|
+
**kwargs,
|
3009
|
+
):
|
3010
|
+
r"""
|
3011
|
+
Return state dict for lora weights and the network alphas.
|
3012
|
+
|
3013
|
+
<Tip warning={true}>
|
3014
|
+
|
3015
|
+
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
|
3016
|
+
|
3017
|
+
This function is experimental and might change in the future.
|
3018
|
+
|
3019
|
+
</Tip>
|
3020
|
+
|
3021
|
+
Parameters:
|
3022
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
3023
|
+
Can be either:
|
3024
|
+
|
3025
|
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
3026
|
+
the Hub.
|
3027
|
+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
3028
|
+
with [`ModelMixin.save_pretrained`].
|
3029
|
+
- A [torch state
|
3030
|
+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
3031
|
+
|
3032
|
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
3033
|
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
3034
|
+
is not used.
|
3035
|
+
force_download (`bool`, *optional*, defaults to `False`):
|
3036
|
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
3037
|
+
cached versions if they exist.
|
3038
|
+
|
3039
|
+
proxies (`Dict[str, str]`, *optional*):
|
3040
|
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
3041
|
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
3042
|
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
3043
|
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
3044
|
+
won't be downloaded from the Hub.
|
3045
|
+
token (`str` or *bool*, *optional*):
|
3046
|
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
3047
|
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
3048
|
+
revision (`str`, *optional*, defaults to `"main"`):
|
3049
|
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
3050
|
+
allowed by Git.
|
3051
|
+
subfolder (`str`, *optional*, defaults to `""`):
|
3052
|
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
3053
|
+
return_lora_metadata (`bool`, *optional*, defaults to False):
|
3054
|
+
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
|
3055
|
+
|
3056
|
+
"""
|
3057
|
+
# Load the main state dict first which has the LoRA layers for either of
|
3058
|
+
# transformer and text encoder or both.
|
3059
|
+
cache_dir = kwargs.pop("cache_dir", None)
|
3060
|
+
force_download = kwargs.pop("force_download", False)
|
3061
|
+
proxies = kwargs.pop("proxies", None)
|
3062
|
+
local_files_only = kwargs.pop("local_files_only", None)
|
3063
|
+
token = kwargs.pop("token", None)
|
3064
|
+
revision = kwargs.pop("revision", None)
|
3065
|
+
subfolder = kwargs.pop("subfolder", None)
|
3066
|
+
weight_name = kwargs.pop("weight_name", None)
|
3067
|
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
3068
|
+
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
|
3069
|
+
|
3070
|
+
allow_pickle = False
|
3071
|
+
if use_safetensors is None:
|
3072
|
+
use_safetensors = True
|
3073
|
+
allow_pickle = True
|
3074
|
+
|
3075
|
+
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
|
3076
|
+
|
3077
|
+
state_dict, metadata = _fetch_state_dict(
|
3078
|
+
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
3079
|
+
weight_name=weight_name,
|
3080
|
+
use_safetensors=use_safetensors,
|
3081
|
+
local_files_only=local_files_only,
|
3082
|
+
cache_dir=cache_dir,
|
3083
|
+
force_download=force_download,
|
3084
|
+
proxies=proxies,
|
3085
|
+
token=token,
|
3086
|
+
revision=revision,
|
3087
|
+
subfolder=subfolder,
|
3088
|
+
user_agent=user_agent,
|
3089
|
+
allow_pickle=allow_pickle,
|
3090
|
+
)
|
3091
|
+
|
3092
|
+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
3093
|
+
if is_dora_scale_present:
|
3094
|
+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
3095
|
+
logger.warning(warn_msg)
|
3096
|
+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
3097
|
+
|
3098
|
+
out = (state_dict, metadata) if return_lora_metadata else state_dict
|
3099
|
+
return out
|
3100
|
+
|
3101
|
+
def load_lora_weights(
|
3102
|
+
self,
|
3103
|
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
3104
|
+
adapter_name: Optional[str] = None,
|
3105
|
+
hotswap: bool = False,
|
3106
|
+
**kwargs,
|
3107
|
+
):
|
3108
|
+
"""
|
3109
|
+
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
3110
|
+
`self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
|
3111
|
+
[`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
|
3112
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
|
3113
|
+
dict is loaded into `self.transformer`.
|
3114
|
+
|
3115
|
+
Parameters:
|
3116
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
3117
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
3118
|
+
adapter_name (`str`, *optional*):
|
3119
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
3120
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
3121
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
3122
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
3123
|
+
weights.
|
3124
|
+
hotswap (`bool`, *optional*):
|
3125
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
3126
|
+
kwargs (`dict`, *optional*):
|
3127
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
2550
3128
|
"""
|
2551
|
-
if
|
3129
|
+
if not USE_PEFT_BACKEND:
|
3130
|
+
raise ValueError("PEFT backend is required for this method.")
|
3131
|
+
|
3132
|
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
3133
|
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
2552
3134
|
raise ValueError(
|
2553
3135
|
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
2554
3136
|
)
|
2555
3137
|
|
2556
|
-
#
|
2557
|
-
|
2558
|
-
|
3138
|
+
# if a dict is passed, copy it instead of modifying it inplace
|
3139
|
+
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
3140
|
+
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
3141
|
+
|
3142
|
+
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
3143
|
+
kwargs["return_lora_metadata"] = True
|
3144
|
+
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
3145
|
+
|
3146
|
+
is_correct_format = all("lora" in key for key in state_dict.keys())
|
3147
|
+
if not is_correct_format:
|
3148
|
+
raise ValueError("Invalid LoRA checkpoint.")
|
3149
|
+
|
3150
|
+
self.load_lora_into_transformer(
|
2559
3151
|
state_dict,
|
2560
|
-
|
3152
|
+
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
2561
3153
|
adapter_name=adapter_name,
|
2562
|
-
|
3154
|
+
metadata=metadata,
|
3155
|
+
_pipeline=self,
|
2563
3156
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
2564
3157
|
hotswap=hotswap,
|
2565
3158
|
)
|
2566
3159
|
|
2567
3160
|
@classmethod
|
2568
|
-
# Copied from diffusers.loaders.lora_pipeline.
|
2569
|
-
def
|
3161
|
+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogVideoXTransformer3DModel
|
3162
|
+
def load_lora_into_transformer(
|
2570
3163
|
cls,
|
2571
3164
|
state_dict,
|
2572
|
-
|
2573
|
-
text_encoder,
|
2574
|
-
prefix=None,
|
2575
|
-
lora_scale=1.0,
|
3165
|
+
transformer,
|
2576
3166
|
adapter_name=None,
|
2577
3167
|
_pipeline=None,
|
2578
3168
|
low_cpu_mem_usage=False,
|
2579
3169
|
hotswap: bool = False,
|
3170
|
+
metadata=None,
|
2580
3171
|
):
|
2581
3172
|
"""
|
2582
|
-
This will load the LoRA layers specified in `state_dict` into `
|
3173
|
+
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
2583
3174
|
|
2584
3175
|
Parameters:
|
2585
3176
|
state_dict (`dict`):
|
2586
|
-
A standard state dict containing the lora layer parameters. The
|
2587
|
-
additional `
|
2588
|
-
|
2589
|
-
|
2590
|
-
|
2591
|
-
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
2592
|
-
text_encoder (`CLIPTextModel`):
|
2593
|
-
The text encoder model to load the LoRA layers into.
|
2594
|
-
prefix (`str`):
|
2595
|
-
Expected prefix of the `text_encoder` in the `state_dict`.
|
2596
|
-
lora_scale (`float`):
|
2597
|
-
How much to scale the output of the lora linear layer before it is added with the output of the regular
|
2598
|
-
lora layer.
|
3177
|
+
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
3178
|
+
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
3179
|
+
encoder lora layers.
|
3180
|
+
transformer (`CogVideoXTransformer3DModel`):
|
3181
|
+
The Transformer model to load the LoRA layers into.
|
2599
3182
|
adapter_name (`str`, *optional*):
|
2600
3183
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
2601
3184
|
`default_{i}` where i is the total number of adapters being loaded.
|
2602
3185
|
low_cpu_mem_usage (`bool`, *optional*):
|
2603
3186
|
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
2604
3187
|
weights.
|
2605
|
-
hotswap
|
2606
|
-
|
2607
|
-
|
2608
|
-
|
2609
|
-
|
2610
|
-
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
2611
|
-
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
2612
|
-
|
2613
|
-
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
2614
|
-
to call an additional method before loading the adapter:
|
2615
|
-
|
2616
|
-
```py
|
2617
|
-
pipeline = ... # load diffusers pipeline
|
2618
|
-
max_rank = ... # the highest rank among all LoRAs that you want to load
|
2619
|
-
# call *before* compiling and loading the LoRA adapter
|
2620
|
-
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
2621
|
-
pipeline.load_lora_weights(file_name)
|
2622
|
-
# optionally compile the model now
|
2623
|
-
```
|
2624
|
-
|
2625
|
-
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
2626
|
-
limitations to this technique, which are documented here:
|
2627
|
-
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
3188
|
+
hotswap (`bool`, *optional*):
|
3189
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
3190
|
+
metadata (`dict`):
|
3191
|
+
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
|
3192
|
+
from the state dict.
|
2628
3193
|
"""
|
2629
|
-
|
2630
|
-
|
2631
|
-
|
2632
|
-
|
2633
|
-
|
2634
|
-
|
2635
|
-
|
3194
|
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
3195
|
+
raise ValueError(
|
3196
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
3197
|
+
)
|
3198
|
+
|
3199
|
+
# Load the layers corresponding to transformer.
|
3200
|
+
logger.info(f"Loading {cls.transformer_name}.")
|
3201
|
+
transformer.load_lora_adapter(
|
3202
|
+
state_dict,
|
3203
|
+
network_alphas=None,
|
2636
3204
|
adapter_name=adapter_name,
|
3205
|
+
metadata=metadata,
|
2637
3206
|
_pipeline=_pipeline,
|
2638
3207
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
2639
3208
|
hotswap=hotswap,
|
2640
3209
|
)
|
2641
3210
|
|
2642
3211
|
@classmethod
|
3212
|
+
# Adapted from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights without support for text encoder
|
2643
3213
|
def save_lora_weights(
|
2644
3214
|
cls,
|
2645
3215
|
save_directory: Union[str, os.PathLike],
|
2646
|
-
|
2647
|
-
transformer_lora_layers: Dict[str, torch.nn.Module] = None,
|
3216
|
+
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
2648
3217
|
is_main_process: bool = True,
|
2649
3218
|
weight_name: str = None,
|
2650
3219
|
save_function: Callable = None,
|
2651
3220
|
safe_serialization: bool = True,
|
3221
|
+
transformer_lora_adapter_metadata: Optional[dict] = None,
|
2652
3222
|
):
|
2653
3223
|
r"""
|
2654
|
-
Save the LoRA parameters corresponding to the
|
3224
|
+
Save the LoRA parameters corresponding to the transformer.
|
2655
3225
|
|
2656
3226
|
Arguments:
|
2657
3227
|
save_directory (`str` or `os.PathLike`):
|
2658
3228
|
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
2659
|
-
|
2660
|
-
State dict of the LoRA layers corresponding to the `
|
2661
|
-
text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
2662
|
-
State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
|
2663
|
-
encoder LoRA state dict because it comes from 🤗 Transformers.
|
3229
|
+
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
3230
|
+
State dict of the LoRA layers corresponding to the `transformer`.
|
2664
3231
|
is_main_process (`bool`, *optional*, defaults to `True`):
|
2665
3232
|
Whether the process calling this is the main process or not. Useful during distributed training and you
|
2666
3233
|
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
@@ -2671,17 +3238,21 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
|
2671
3238
|
`DIFFUSERS_SAVE_MODE`.
|
2672
3239
|
safe_serialization (`bool`, *optional*, defaults to `True`):
|
2673
3240
|
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
3241
|
+
transformer_lora_adapter_metadata:
|
3242
|
+
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
|
2674
3243
|
"""
|
2675
3244
|
state_dict = {}
|
3245
|
+
lora_adapter_metadata = {}
|
2676
3246
|
|
2677
|
-
if not
|
2678
|
-
raise ValueError("You must pass
|
3247
|
+
if not transformer_lora_layers:
|
3248
|
+
raise ValueError("You must pass `transformer_lora_layers`.")
|
2679
3249
|
|
2680
|
-
|
2681
|
-
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
3250
|
+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
2682
3251
|
|
2683
|
-
if
|
2684
|
-
|
3252
|
+
if transformer_lora_adapter_metadata is not None:
|
3253
|
+
lora_adapter_metadata.update(
|
3254
|
+
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
|
3255
|
+
)
|
2685
3256
|
|
2686
3257
|
# Save the model
|
2687
3258
|
cls.write_lora_layers(
|
@@ -2691,12 +3262,77 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
|
2691
3262
|
weight_name=weight_name,
|
2692
3263
|
save_function=save_function,
|
2693
3264
|
safe_serialization=safe_serialization,
|
3265
|
+
lora_adapter_metadata=lora_adapter_metadata,
|
2694
3266
|
)
|
2695
3267
|
|
3268
|
+
def fuse_lora(
|
3269
|
+
self,
|
3270
|
+
components: List[str] = ["transformer"],
|
3271
|
+
lora_scale: float = 1.0,
|
3272
|
+
safe_fusing: bool = False,
|
3273
|
+
adapter_names: Optional[List[str]] = None,
|
3274
|
+
**kwargs,
|
3275
|
+
):
|
3276
|
+
r"""
|
3277
|
+
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
3278
|
+
|
3279
|
+
<Tip warning={true}>
|
2696
3280
|
|
2697
|
-
|
3281
|
+
This is an experimental API.
|
3282
|
+
|
3283
|
+
</Tip>
|
3284
|
+
|
3285
|
+
Args:
|
3286
|
+
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
3287
|
+
lora_scale (`float`, defaults to 1.0):
|
3288
|
+
Controls how much to influence the outputs with the LoRA parameters.
|
3289
|
+
safe_fusing (`bool`, defaults to `False`):
|
3290
|
+
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
3291
|
+
adapter_names (`List[str]`, *optional*):
|
3292
|
+
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
3293
|
+
|
3294
|
+
Example:
|
3295
|
+
|
3296
|
+
```py
|
3297
|
+
from diffusers import DiffusionPipeline
|
3298
|
+
import torch
|
3299
|
+
|
3300
|
+
pipeline = DiffusionPipeline.from_pretrained(
|
3301
|
+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
3302
|
+
).to("cuda")
|
3303
|
+
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
3304
|
+
pipeline.fuse_lora(lora_scale=0.7)
|
3305
|
+
```
|
3306
|
+
"""
|
3307
|
+
super().fuse_lora(
|
3308
|
+
components=components,
|
3309
|
+
lora_scale=lora_scale,
|
3310
|
+
safe_fusing=safe_fusing,
|
3311
|
+
adapter_names=adapter_names,
|
3312
|
+
**kwargs,
|
3313
|
+
)
|
3314
|
+
|
3315
|
+
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
3316
|
+
r"""
|
3317
|
+
Reverses the effect of
|
3318
|
+
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
3319
|
+
|
3320
|
+
<Tip warning={true}>
|
3321
|
+
|
3322
|
+
This is an experimental API.
|
3323
|
+
|
3324
|
+
</Tip>
|
3325
|
+
|
3326
|
+
Args:
|
3327
|
+
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
3328
|
+
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
3329
|
+
"""
|
3330
|
+
super().unfuse_lora(components=components, **kwargs)
|
3331
|
+
|
3332
|
+
|
3333
|
+
class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
2698
3334
|
r"""
|
2699
|
-
Load LoRA layers into [`
|
3335
|
+
Load LoRA layers into [`MochiTransformer3DModel`]. Specific to [`MochiPipeline`].
|
2700
3336
|
"""
|
2701
3337
|
|
2702
3338
|
_lora_loadable_modules = ["transformer"]
|
@@ -2753,6 +3389,8 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
|
2753
3389
|
allowed by Git.
|
2754
3390
|
subfolder (`str`, *optional*, defaults to `""`):
|
2755
3391
|
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
3392
|
+
return_lora_metadata (`bool`, *optional*, defaults to False):
|
3393
|
+
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
|
2756
3394
|
|
2757
3395
|
"""
|
2758
3396
|
# Load the main state dict first which has the LoRA layers for either of
|
@@ -2766,18 +3404,16 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
|
2766
3404
|
subfolder = kwargs.pop("subfolder", None)
|
2767
3405
|
weight_name = kwargs.pop("weight_name", None)
|
2768
3406
|
use_safetensors = kwargs.pop("use_safetensors", None)
|
3407
|
+
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
|
2769
3408
|
|
2770
3409
|
allow_pickle = False
|
2771
3410
|
if use_safetensors is None:
|
2772
3411
|
use_safetensors = True
|
2773
3412
|
allow_pickle = True
|
2774
3413
|
|
2775
|
-
user_agent = {
|
2776
|
-
"file_type": "attn_procs_weights",
|
2777
|
-
"framework": "pytorch",
|
2778
|
-
}
|
3414
|
+
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
|
2779
3415
|
|
2780
|
-
state_dict = _fetch_state_dict(
|
3416
|
+
state_dict, metadata = _fetch_state_dict(
|
2781
3417
|
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
2782
3418
|
weight_name=weight_name,
|
2783
3419
|
use_safetensors=use_safetensors,
|
@@ -2798,10 +3434,16 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
|
2798
3434
|
logger.warning(warn_msg)
|
2799
3435
|
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
2800
3436
|
|
2801
|
-
|
3437
|
+
out = (state_dict, metadata) if return_lora_metadata else state_dict
|
3438
|
+
return out
|
2802
3439
|
|
3440
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
2803
3441
|
def load_lora_weights(
|
2804
|
-
self,
|
3442
|
+
self,
|
3443
|
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
3444
|
+
adapter_name: Optional[str] = None,
|
3445
|
+
hotswap: bool = False,
|
3446
|
+
**kwargs,
|
2805
3447
|
):
|
2806
3448
|
"""
|
2807
3449
|
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
@@ -2819,6 +3461,8 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
|
2819
3461
|
low_cpu_mem_usage (`bool`, *optional*):
|
2820
3462
|
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
2821
3463
|
weights.
|
3464
|
+
hotswap (`bool`, *optional*):
|
3465
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
2822
3466
|
kwargs (`dict`, *optional*):
|
2823
3467
|
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
2824
3468
|
"""
|
@@ -2836,7 +3480,8 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
|
2836
3480
|
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
2837
3481
|
|
2838
3482
|
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
2839
|
-
|
3483
|
+
kwargs["return_lora_metadata"] = True
|
3484
|
+
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
2840
3485
|
|
2841
3486
|
is_correct_format = all("lora" in key for key in state_dict.keys())
|
2842
3487
|
if not is_correct_format:
|
@@ -2846,54 +3491,45 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
|
2846
3491
|
state_dict,
|
2847
3492
|
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
2848
3493
|
adapter_name=adapter_name,
|
3494
|
+
metadata=metadata,
|
2849
3495
|
_pipeline=self,
|
2850
3496
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
3497
|
+
hotswap=hotswap,
|
2851
3498
|
)
|
2852
3499
|
|
2853
3500
|
@classmethod
|
2854
|
-
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->
|
3501
|
+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->MochiTransformer3DModel
|
2855
3502
|
def load_lora_into_transformer(
|
2856
|
-
cls,
|
3503
|
+
cls,
|
3504
|
+
state_dict,
|
3505
|
+
transformer,
|
3506
|
+
adapter_name=None,
|
3507
|
+
_pipeline=None,
|
3508
|
+
low_cpu_mem_usage=False,
|
3509
|
+
hotswap: bool = False,
|
3510
|
+
metadata=None,
|
2857
3511
|
):
|
2858
3512
|
"""
|
2859
3513
|
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
2860
3514
|
|
2861
|
-
Parameters:
|
2862
|
-
state_dict (`dict`):
|
2863
|
-
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
2864
|
-
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
2865
|
-
encoder lora layers.
|
2866
|
-
transformer (`
|
2867
|
-
The Transformer model to load the LoRA layers into.
|
2868
|
-
adapter_name (`str`, *optional*):
|
2869
|
-
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
2870
|
-
`default_{i}` where i is the total number of adapters being loaded.
|
2871
|
-
low_cpu_mem_usage (`bool`, *optional*):
|
2872
|
-
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
2873
|
-
weights.
|
2874
|
-
hotswap
|
2875
|
-
|
2876
|
-
|
2877
|
-
|
2878
|
-
|
2879
|
-
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
2880
|
-
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
2881
|
-
|
2882
|
-
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
2883
|
-
to call an additional method before loading the adapter:
|
2884
|
-
|
2885
|
-
```py
|
2886
|
-
pipeline = ... # load diffusers pipeline
|
2887
|
-
max_rank = ... # the highest rank among all LoRAs that you want to load
|
2888
|
-
# call *before* compiling and loading the LoRA adapter
|
2889
|
-
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
2890
|
-
pipeline.load_lora_weights(file_name)
|
2891
|
-
# optionally compile the model now
|
2892
|
-
```
|
2893
|
-
|
2894
|
-
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
2895
|
-
limitations to this technique, which are documented here:
|
2896
|
-
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
3515
|
+
Parameters:
|
3516
|
+
state_dict (`dict`):
|
3517
|
+
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
3518
|
+
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
3519
|
+
encoder lora layers.
|
3520
|
+
transformer (`MochiTransformer3DModel`):
|
3521
|
+
The Transformer model to load the LoRA layers into.
|
3522
|
+
adapter_name (`str`, *optional*):
|
3523
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
3524
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
3525
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
3526
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
3527
|
+
weights.
|
3528
|
+
hotswap (`bool`, *optional*):
|
3529
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
3530
|
+
metadata (`dict`):
|
3531
|
+
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
|
3532
|
+
from the state dict.
|
2897
3533
|
"""
|
2898
3534
|
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
2899
3535
|
raise ValueError(
|
@@ -2906,13 +3542,14 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
|
2906
3542
|
state_dict,
|
2907
3543
|
network_alphas=None,
|
2908
3544
|
adapter_name=adapter_name,
|
3545
|
+
metadata=metadata,
|
2909
3546
|
_pipeline=_pipeline,
|
2910
3547
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
2911
3548
|
hotswap=hotswap,
|
2912
3549
|
)
|
2913
3550
|
|
2914
3551
|
@classmethod
|
2915
|
-
#
|
3552
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
|
2916
3553
|
def save_lora_weights(
|
2917
3554
|
cls,
|
2918
3555
|
save_directory: Union[str, os.PathLike],
|
@@ -2921,9 +3558,10 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
|
2921
3558
|
weight_name: str = None,
|
2922
3559
|
save_function: Callable = None,
|
2923
3560
|
safe_serialization: bool = True,
|
3561
|
+
transformer_lora_adapter_metadata: Optional[dict] = None,
|
2924
3562
|
):
|
2925
3563
|
r"""
|
2926
|
-
Save the LoRA parameters corresponding to the
|
3564
|
+
Save the LoRA parameters corresponding to the transformer.
|
2927
3565
|
|
2928
3566
|
Arguments:
|
2929
3567
|
save_directory (`str` or `os.PathLike`):
|
@@ -2940,14 +3578,21 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
|
2940
3578
|
`DIFFUSERS_SAVE_MODE`.
|
2941
3579
|
safe_serialization (`bool`, *optional*, defaults to `True`):
|
2942
3580
|
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
3581
|
+
transformer_lora_adapter_metadata:
|
3582
|
+
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
|
2943
3583
|
"""
|
2944
3584
|
state_dict = {}
|
3585
|
+
lora_adapter_metadata = {}
|
2945
3586
|
|
2946
3587
|
if not transformer_lora_layers:
|
2947
3588
|
raise ValueError("You must pass `transformer_lora_layers`.")
|
2948
3589
|
|
2949
|
-
|
2950
|
-
|
3590
|
+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
3591
|
+
|
3592
|
+
if transformer_lora_adapter_metadata is not None:
|
3593
|
+
lora_adapter_metadata.update(
|
3594
|
+
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
|
3595
|
+
)
|
2951
3596
|
|
2952
3597
|
# Save the model
|
2953
3598
|
cls.write_lora_layers(
|
@@ -2957,8 +3602,10 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
|
2957
3602
|
weight_name=weight_name,
|
2958
3603
|
save_function=save_function,
|
2959
3604
|
safe_serialization=safe_serialization,
|
3605
|
+
lora_adapter_metadata=lora_adapter_metadata,
|
2960
3606
|
)
|
2961
3607
|
|
3608
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
2962
3609
|
def fuse_lora(
|
2963
3610
|
self,
|
2964
3611
|
components: List[str] = ["transformer"],
|
@@ -3006,6 +3653,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
|
3006
3653
|
**kwargs,
|
3007
3654
|
)
|
3008
3655
|
|
3656
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
3009
3657
|
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
3010
3658
|
r"""
|
3011
3659
|
Reverses the effect of
|
@@ -3024,9 +3672,9 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
|
3024
3672
|
super().unfuse_lora(components=components, **kwargs)
|
3025
3673
|
|
3026
3674
|
|
3027
|
-
class
|
3675
|
+
class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
3028
3676
|
r"""
|
3029
|
-
Load LoRA layers into [`
|
3677
|
+
Load LoRA layers into [`LTXVideoTransformer3DModel`]. Specific to [`LTXPipeline`].
|
3030
3678
|
"""
|
3031
3679
|
|
3032
3680
|
_lora_loadable_modules = ["transformer"]
|
@@ -3034,7 +3682,6 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
|
3034
3682
|
|
3035
3683
|
@classmethod
|
3036
3684
|
@validate_hf_hub_args
|
3037
|
-
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
|
3038
3685
|
def lora_state_dict(
|
3039
3686
|
cls,
|
3040
3687
|
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
@@ -3083,7 +3730,8 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
|
3083
3730
|
allowed by Git.
|
3084
3731
|
subfolder (`str`, *optional*, defaults to `""`):
|
3085
3732
|
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
3086
|
-
|
3733
|
+
return_lora_metadata (`bool`, *optional*, defaults to False):
|
3734
|
+
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
|
3087
3735
|
"""
|
3088
3736
|
# Load the main state dict first which has the LoRA layers for either of
|
3089
3737
|
# transformer and text encoder or both.
|
@@ -3096,18 +3744,16 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
|
3096
3744
|
subfolder = kwargs.pop("subfolder", None)
|
3097
3745
|
weight_name = kwargs.pop("weight_name", None)
|
3098
3746
|
use_safetensors = kwargs.pop("use_safetensors", None)
|
3747
|
+
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
|
3099
3748
|
|
3100
3749
|
allow_pickle = False
|
3101
3750
|
if use_safetensors is None:
|
3102
3751
|
use_safetensors = True
|
3103
3752
|
allow_pickle = True
|
3104
3753
|
|
3105
|
-
user_agent = {
|
3106
|
-
"file_type": "attn_procs_weights",
|
3107
|
-
"framework": "pytorch",
|
3108
|
-
}
|
3754
|
+
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
|
3109
3755
|
|
3110
|
-
state_dict = _fetch_state_dict(
|
3756
|
+
state_dict, metadata = _fetch_state_dict(
|
3111
3757
|
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
3112
3758
|
weight_name=weight_name,
|
3113
3759
|
use_safetensors=use_safetensors,
|
@@ -3128,11 +3774,20 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
|
3128
3774
|
logger.warning(warn_msg)
|
3129
3775
|
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
3130
3776
|
|
3131
|
-
|
3777
|
+
is_non_diffusers_format = any(k.startswith("diffusion_model.") for k in state_dict)
|
3778
|
+
if is_non_diffusers_format:
|
3779
|
+
state_dict = _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict)
|
3780
|
+
|
3781
|
+
out = (state_dict, metadata) if return_lora_metadata else state_dict
|
3782
|
+
return out
|
3132
3783
|
|
3133
3784
|
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
3134
3785
|
def load_lora_weights(
|
3135
|
-
self,
|
3786
|
+
self,
|
3787
|
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
3788
|
+
adapter_name: Optional[str] = None,
|
3789
|
+
hotswap: bool = False,
|
3790
|
+
**kwargs,
|
3136
3791
|
):
|
3137
3792
|
"""
|
3138
3793
|
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
@@ -3150,6 +3805,8 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
|
3150
3805
|
low_cpu_mem_usage (`bool`, *optional*):
|
3151
3806
|
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
3152
3807
|
weights.
|
3808
|
+
hotswap (`bool`, *optional*):
|
3809
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
3153
3810
|
kwargs (`dict`, *optional*):
|
3154
3811
|
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
3155
3812
|
"""
|
@@ -3167,7 +3824,8 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
|
3167
3824
|
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
3168
3825
|
|
3169
3826
|
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
3170
|
-
|
3827
|
+
kwargs["return_lora_metadata"] = True
|
3828
|
+
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
3171
3829
|
|
3172
3830
|
is_correct_format = all("lora" in key for key in state_dict.keys())
|
3173
3831
|
if not is_correct_format:
|
@@ -3177,14 +3835,23 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
|
3177
3835
|
state_dict,
|
3178
3836
|
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
3179
3837
|
adapter_name=adapter_name,
|
3838
|
+
metadata=metadata,
|
3180
3839
|
_pipeline=self,
|
3181
3840
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
3841
|
+
hotswap=hotswap,
|
3182
3842
|
)
|
3183
3843
|
|
3184
3844
|
@classmethod
|
3185
|
-
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->
|
3845
|
+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->LTXVideoTransformer3DModel
|
3186
3846
|
def load_lora_into_transformer(
|
3187
|
-
cls,
|
3847
|
+
cls,
|
3848
|
+
state_dict,
|
3849
|
+
transformer,
|
3850
|
+
adapter_name=None,
|
3851
|
+
_pipeline=None,
|
3852
|
+
low_cpu_mem_usage=False,
|
3853
|
+
hotswap: bool = False,
|
3854
|
+
metadata=None,
|
3188
3855
|
):
|
3189
3856
|
"""
|
3190
3857
|
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
@@ -3194,7 +3861,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
|
3194
3861
|
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
3195
3862
|
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
3196
3863
|
encoder lora layers.
|
3197
|
-
transformer (`
|
3864
|
+
transformer (`LTXVideoTransformer3DModel`):
|
3198
3865
|
The Transformer model to load the LoRA layers into.
|
3199
3866
|
adapter_name (`str`, *optional*):
|
3200
3867
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
@@ -3202,29 +3869,11 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
|
3202
3869
|
low_cpu_mem_usage (`bool`, *optional*):
|
3203
3870
|
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
3204
3871
|
weights.
|
3205
|
-
hotswap
|
3206
|
-
|
3207
|
-
|
3208
|
-
|
3209
|
-
|
3210
|
-
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
3211
|
-
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
3212
|
-
|
3213
|
-
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
3214
|
-
to call an additional method before loading the adapter:
|
3215
|
-
|
3216
|
-
```py
|
3217
|
-
pipeline = ... # load diffusers pipeline
|
3218
|
-
max_rank = ... # the highest rank among all LoRAs that you want to load
|
3219
|
-
# call *before* compiling and loading the LoRA adapter
|
3220
|
-
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
3221
|
-
pipeline.load_lora_weights(file_name)
|
3222
|
-
# optionally compile the model now
|
3223
|
-
```
|
3224
|
-
|
3225
|
-
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
3226
|
-
limitations to this technique, which are documented here:
|
3227
|
-
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
3872
|
+
hotswap (`bool`, *optional*):
|
3873
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
3874
|
+
metadata (`dict`):
|
3875
|
+
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
|
3876
|
+
from the state dict.
|
3228
3877
|
"""
|
3229
3878
|
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
3230
3879
|
raise ValueError(
|
@@ -3237,6 +3886,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
|
3237
3886
|
state_dict,
|
3238
3887
|
network_alphas=None,
|
3239
3888
|
adapter_name=adapter_name,
|
3889
|
+
metadata=metadata,
|
3240
3890
|
_pipeline=_pipeline,
|
3241
3891
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
3242
3892
|
hotswap=hotswap,
|
@@ -3252,9 +3902,10 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
|
3252
3902
|
weight_name: str = None,
|
3253
3903
|
save_function: Callable = None,
|
3254
3904
|
safe_serialization: bool = True,
|
3905
|
+
transformer_lora_adapter_metadata: Optional[dict] = None,
|
3255
3906
|
):
|
3256
3907
|
r"""
|
3257
|
-
Save the LoRA parameters corresponding to the
|
3908
|
+
Save the LoRA parameters corresponding to the transformer.
|
3258
3909
|
|
3259
3910
|
Arguments:
|
3260
3911
|
save_directory (`str` or `os.PathLike`):
|
@@ -3271,14 +3922,21 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
|
3271
3922
|
`DIFFUSERS_SAVE_MODE`.
|
3272
3923
|
safe_serialization (`bool`, *optional*, defaults to `True`):
|
3273
3924
|
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
3925
|
+
transformer_lora_adapter_metadata:
|
3926
|
+
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
|
3274
3927
|
"""
|
3275
3928
|
state_dict = {}
|
3929
|
+
lora_adapter_metadata = {}
|
3276
3930
|
|
3277
3931
|
if not transformer_lora_layers:
|
3278
3932
|
raise ValueError("You must pass `transformer_lora_layers`.")
|
3279
3933
|
|
3280
|
-
|
3281
|
-
|
3934
|
+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
3935
|
+
|
3936
|
+
if transformer_lora_adapter_metadata is not None:
|
3937
|
+
lora_adapter_metadata.update(
|
3938
|
+
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
|
3939
|
+
)
|
3282
3940
|
|
3283
3941
|
# Save the model
|
3284
3942
|
cls.write_lora_layers(
|
@@ -3288,6 +3946,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
|
3288
3946
|
weight_name=weight_name,
|
3289
3947
|
save_function=save_function,
|
3290
3948
|
safe_serialization=safe_serialization,
|
3949
|
+
lora_adapter_metadata=lora_adapter_metadata,
|
3291
3950
|
)
|
3292
3951
|
|
3293
3952
|
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
@@ -3357,9 +4016,9 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
|
3357
4016
|
super().unfuse_lora(components=components, **kwargs)
|
3358
4017
|
|
3359
4018
|
|
3360
|
-
class
|
4019
|
+
class SanaLoraLoaderMixin(LoraBaseMixin):
|
3361
4020
|
r"""
|
3362
|
-
Load LoRA layers into [`
|
4021
|
+
Load LoRA layers into [`SanaTransformer2DModel`]. Specific to [`SanaPipeline`].
|
3363
4022
|
"""
|
3364
4023
|
|
3365
4024
|
_lora_loadable_modules = ["transformer"]
|
@@ -3367,7 +4026,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
|
3367
4026
|
|
3368
4027
|
@classmethod
|
3369
4028
|
@validate_hf_hub_args
|
3370
|
-
# Copied from diffusers.loaders.lora_pipeline.
|
4029
|
+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
|
3371
4030
|
def lora_state_dict(
|
3372
4031
|
cls,
|
3373
4032
|
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
@@ -3416,6 +4075,8 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
|
3416
4075
|
allowed by Git.
|
3417
4076
|
subfolder (`str`, *optional*, defaults to `""`):
|
3418
4077
|
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
4078
|
+
return_lora_metadata (`bool`, *optional*, defaults to False):
|
4079
|
+
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
|
3419
4080
|
|
3420
4081
|
"""
|
3421
4082
|
# Load the main state dict first which has the LoRA layers for either of
|
@@ -3429,18 +4090,16 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
|
3429
4090
|
subfolder = kwargs.pop("subfolder", None)
|
3430
4091
|
weight_name = kwargs.pop("weight_name", None)
|
3431
4092
|
use_safetensors = kwargs.pop("use_safetensors", None)
|
4093
|
+
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
|
3432
4094
|
|
3433
4095
|
allow_pickle = False
|
3434
4096
|
if use_safetensors is None:
|
3435
4097
|
use_safetensors = True
|
3436
4098
|
allow_pickle = True
|
3437
4099
|
|
3438
|
-
user_agent = {
|
3439
|
-
"file_type": "attn_procs_weights",
|
3440
|
-
"framework": "pytorch",
|
3441
|
-
}
|
4100
|
+
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
|
3442
4101
|
|
3443
|
-
state_dict = _fetch_state_dict(
|
4102
|
+
state_dict, metadata = _fetch_state_dict(
|
3444
4103
|
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
3445
4104
|
weight_name=weight_name,
|
3446
4105
|
use_safetensors=use_safetensors,
|
@@ -3461,11 +4120,16 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
|
3461
4120
|
logger.warning(warn_msg)
|
3462
4121
|
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
3463
4122
|
|
3464
|
-
|
4123
|
+
out = (state_dict, metadata) if return_lora_metadata else state_dict
|
4124
|
+
return out
|
3465
4125
|
|
3466
4126
|
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
3467
4127
|
def load_lora_weights(
|
3468
|
-
self,
|
4128
|
+
self,
|
4129
|
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
4130
|
+
adapter_name: Optional[str] = None,
|
4131
|
+
hotswap: bool = False,
|
4132
|
+
**kwargs,
|
3469
4133
|
):
|
3470
4134
|
"""
|
3471
4135
|
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
@@ -3483,6 +4147,8 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
|
3483
4147
|
low_cpu_mem_usage (`bool`, *optional*):
|
3484
4148
|
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
3485
4149
|
weights.
|
4150
|
+
hotswap (`bool`, *optional*):
|
4151
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
3486
4152
|
kwargs (`dict`, *optional*):
|
3487
4153
|
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
3488
4154
|
"""
|
@@ -3500,7 +4166,8 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
|
3500
4166
|
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
3501
4167
|
|
3502
4168
|
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
3503
|
-
|
4169
|
+
kwargs["return_lora_metadata"] = True
|
4170
|
+
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
3504
4171
|
|
3505
4172
|
is_correct_format = all("lora" in key for key in state_dict.keys())
|
3506
4173
|
if not is_correct_format:
|
@@ -3510,14 +4177,23 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
|
3510
4177
|
state_dict,
|
3511
4178
|
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
3512
4179
|
adapter_name=adapter_name,
|
4180
|
+
metadata=metadata,
|
3513
4181
|
_pipeline=self,
|
3514
4182
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
4183
|
+
hotswap=hotswap,
|
3515
4184
|
)
|
3516
4185
|
|
3517
4186
|
@classmethod
|
3518
|
-
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->
|
4187
|
+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SanaTransformer2DModel
|
3519
4188
|
def load_lora_into_transformer(
|
3520
|
-
cls,
|
4189
|
+
cls,
|
4190
|
+
state_dict,
|
4191
|
+
transformer,
|
4192
|
+
adapter_name=None,
|
4193
|
+
_pipeline=None,
|
4194
|
+
low_cpu_mem_usage=False,
|
4195
|
+
hotswap: bool = False,
|
4196
|
+
metadata=None,
|
3521
4197
|
):
|
3522
4198
|
"""
|
3523
4199
|
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
@@ -3527,7 +4203,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
|
3527
4203
|
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
3528
4204
|
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
3529
4205
|
encoder lora layers.
|
3530
|
-
transformer (`
|
4206
|
+
transformer (`SanaTransformer2DModel`):
|
3531
4207
|
The Transformer model to load the LoRA layers into.
|
3532
4208
|
adapter_name (`str`, *optional*):
|
3533
4209
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
@@ -3535,29 +4211,11 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
|
3535
4211
|
low_cpu_mem_usage (`bool`, *optional*):
|
3536
4212
|
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
3537
4213
|
weights.
|
3538
|
-
hotswap
|
3539
|
-
|
3540
|
-
|
3541
|
-
|
3542
|
-
|
3543
|
-
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
3544
|
-
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
3545
|
-
|
3546
|
-
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
3547
|
-
to call an additional method before loading the adapter:
|
3548
|
-
|
3549
|
-
```py
|
3550
|
-
pipeline = ... # load diffusers pipeline
|
3551
|
-
max_rank = ... # the highest rank among all LoRAs that you want to load
|
3552
|
-
# call *before* compiling and loading the LoRA adapter
|
3553
|
-
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
3554
|
-
pipeline.load_lora_weights(file_name)
|
3555
|
-
# optionally compile the model now
|
3556
|
-
```
|
3557
|
-
|
3558
|
-
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
3559
|
-
limitations to this technique, which are documented here:
|
3560
|
-
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
4214
|
+
hotswap (`bool`, *optional*):
|
4215
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
4216
|
+
metadata (`dict`):
|
4217
|
+
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
|
4218
|
+
from the state dict.
|
3561
4219
|
"""
|
3562
4220
|
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
3563
4221
|
raise ValueError(
|
@@ -3570,6 +4228,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
|
3570
4228
|
state_dict,
|
3571
4229
|
network_alphas=None,
|
3572
4230
|
adapter_name=adapter_name,
|
4231
|
+
metadata=metadata,
|
3573
4232
|
_pipeline=_pipeline,
|
3574
4233
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
3575
4234
|
hotswap=hotswap,
|
@@ -3585,9 +4244,10 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
|
3585
4244
|
weight_name: str = None,
|
3586
4245
|
save_function: Callable = None,
|
3587
4246
|
safe_serialization: bool = True,
|
4247
|
+
transformer_lora_adapter_metadata: Optional[dict] = None,
|
3588
4248
|
):
|
3589
4249
|
r"""
|
3590
|
-
Save the LoRA parameters corresponding to the
|
4250
|
+
Save the LoRA parameters corresponding to the transformer.
|
3591
4251
|
|
3592
4252
|
Arguments:
|
3593
4253
|
save_directory (`str` or `os.PathLike`):
|
@@ -3604,14 +4264,21 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
|
3604
4264
|
`DIFFUSERS_SAVE_MODE`.
|
3605
4265
|
safe_serialization (`bool`, *optional*, defaults to `True`):
|
3606
4266
|
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
4267
|
+
transformer_lora_adapter_metadata:
|
4268
|
+
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
|
3607
4269
|
"""
|
3608
4270
|
state_dict = {}
|
4271
|
+
lora_adapter_metadata = {}
|
3609
4272
|
|
3610
4273
|
if not transformer_lora_layers:
|
3611
4274
|
raise ValueError("You must pass `transformer_lora_layers`.")
|
3612
4275
|
|
3613
|
-
|
3614
|
-
|
4276
|
+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
4277
|
+
|
4278
|
+
if transformer_lora_adapter_metadata is not None:
|
4279
|
+
lora_adapter_metadata.update(
|
4280
|
+
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
|
4281
|
+
)
|
3615
4282
|
|
3616
4283
|
# Save the model
|
3617
4284
|
cls.write_lora_layers(
|
@@ -3621,6 +4288,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
|
3621
4288
|
weight_name=weight_name,
|
3622
4289
|
save_function=save_function,
|
3623
4290
|
safe_serialization=safe_serialization,
|
4291
|
+
lora_adapter_metadata=lora_adapter_metadata,
|
3624
4292
|
)
|
3625
4293
|
|
3626
4294
|
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
@@ -3690,9 +4358,9 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
|
3690
4358
|
super().unfuse_lora(components=components, **kwargs)
|
3691
4359
|
|
3692
4360
|
|
3693
|
-
class
|
4361
|
+
class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
3694
4362
|
r"""
|
3695
|
-
Load LoRA layers into [`
|
4363
|
+
Load LoRA layers into [`HunyuanVideoTransformer3DModel`]. Specific to [`HunyuanVideoPipeline`].
|
3696
4364
|
"""
|
3697
4365
|
|
3698
4366
|
_lora_loadable_modules = ["transformer"]
|
@@ -3700,7 +4368,6 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
|
3700
4368
|
|
3701
4369
|
@classmethod
|
3702
4370
|
@validate_hf_hub_args
|
3703
|
-
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
|
3704
4371
|
def lora_state_dict(
|
3705
4372
|
cls,
|
3706
4373
|
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
@@ -3711,7 +4378,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
|
3711
4378
|
|
3712
4379
|
<Tip warning={true}>
|
3713
4380
|
|
3714
|
-
We support loading
|
4381
|
+
We support loading original format HunyuanVideo LoRA checkpoints.
|
3715
4382
|
|
3716
4383
|
This function is experimental and might change in the future.
|
3717
4384
|
|
@@ -3749,7 +4416,8 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
|
3749
4416
|
allowed by Git.
|
3750
4417
|
subfolder (`str`, *optional*, defaults to `""`):
|
3751
4418
|
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
3752
|
-
|
4419
|
+
return_lora_metadata (`bool`, *optional*, defaults to False):
|
4420
|
+
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
|
3753
4421
|
"""
|
3754
4422
|
# Load the main state dict first which has the LoRA layers for either of
|
3755
4423
|
# transformer and text encoder or both.
|
@@ -3762,18 +4430,16 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
|
3762
4430
|
subfolder = kwargs.pop("subfolder", None)
|
3763
4431
|
weight_name = kwargs.pop("weight_name", None)
|
3764
4432
|
use_safetensors = kwargs.pop("use_safetensors", None)
|
4433
|
+
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
|
3765
4434
|
|
3766
4435
|
allow_pickle = False
|
3767
4436
|
if use_safetensors is None:
|
3768
4437
|
use_safetensors = True
|
3769
4438
|
allow_pickle = True
|
3770
4439
|
|
3771
|
-
user_agent = {
|
3772
|
-
"file_type": "attn_procs_weights",
|
3773
|
-
"framework": "pytorch",
|
3774
|
-
}
|
4440
|
+
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
|
3775
4441
|
|
3776
|
-
state_dict = _fetch_state_dict(
|
4442
|
+
state_dict, metadata = _fetch_state_dict(
|
3777
4443
|
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
3778
4444
|
weight_name=weight_name,
|
3779
4445
|
use_safetensors=use_safetensors,
|
@@ -3794,11 +4460,20 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
|
3794
4460
|
logger.warning(warn_msg)
|
3795
4461
|
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
3796
4462
|
|
3797
|
-
|
4463
|
+
is_original_hunyuan_video = any("img_attn_qkv" in k for k in state_dict)
|
4464
|
+
if is_original_hunyuan_video:
|
4465
|
+
state_dict = _convert_hunyuan_video_lora_to_diffusers(state_dict)
|
4466
|
+
|
4467
|
+
out = (state_dict, metadata) if return_lora_metadata else state_dict
|
4468
|
+
return out
|
3798
4469
|
|
3799
4470
|
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
3800
4471
|
def load_lora_weights(
|
3801
|
-
self,
|
4472
|
+
self,
|
4473
|
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
4474
|
+
adapter_name: Optional[str] = None,
|
4475
|
+
hotswap: bool = False,
|
4476
|
+
**kwargs,
|
3802
4477
|
):
|
3803
4478
|
"""
|
3804
4479
|
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
@@ -3816,6 +4491,8 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
|
3816
4491
|
low_cpu_mem_usage (`bool`, *optional*):
|
3817
4492
|
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
3818
4493
|
weights.
|
4494
|
+
hotswap (`bool`, *optional*):
|
4495
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
3819
4496
|
kwargs (`dict`, *optional*):
|
3820
4497
|
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
3821
4498
|
"""
|
@@ -3833,7 +4510,8 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
|
3833
4510
|
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
3834
4511
|
|
3835
4512
|
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
3836
|
-
|
4513
|
+
kwargs["return_lora_metadata"] = True
|
4514
|
+
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
3837
4515
|
|
3838
4516
|
is_correct_format = all("lora" in key for key in state_dict.keys())
|
3839
4517
|
if not is_correct_format:
|
@@ -3843,14 +4521,23 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
|
3843
4521
|
state_dict,
|
3844
4522
|
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
3845
4523
|
adapter_name=adapter_name,
|
4524
|
+
metadata=metadata,
|
3846
4525
|
_pipeline=self,
|
3847
4526
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
4527
|
+
hotswap=hotswap,
|
3848
4528
|
)
|
3849
4529
|
|
3850
4530
|
@classmethod
|
3851
|
-
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->
|
4531
|
+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HunyuanVideoTransformer3DModel
|
3852
4532
|
def load_lora_into_transformer(
|
3853
|
-
cls,
|
4533
|
+
cls,
|
4534
|
+
state_dict,
|
4535
|
+
transformer,
|
4536
|
+
adapter_name=None,
|
4537
|
+
_pipeline=None,
|
4538
|
+
low_cpu_mem_usage=False,
|
4539
|
+
hotswap: bool = False,
|
4540
|
+
metadata=None,
|
3854
4541
|
):
|
3855
4542
|
"""
|
3856
4543
|
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
@@ -3860,7 +4547,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
|
3860
4547
|
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
3861
4548
|
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
3862
4549
|
encoder lora layers.
|
3863
|
-
transformer (`
|
4550
|
+
transformer (`HunyuanVideoTransformer3DModel`):
|
3864
4551
|
The Transformer model to load the LoRA layers into.
|
3865
4552
|
adapter_name (`str`, *optional*):
|
3866
4553
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
@@ -3868,29 +4555,11 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
|
3868
4555
|
low_cpu_mem_usage (`bool`, *optional*):
|
3869
4556
|
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
3870
4557
|
weights.
|
3871
|
-
hotswap
|
3872
|
-
|
3873
|
-
|
3874
|
-
|
3875
|
-
|
3876
|
-
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
3877
|
-
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
3878
|
-
|
3879
|
-
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
3880
|
-
to call an additional method before loading the adapter:
|
3881
|
-
|
3882
|
-
```py
|
3883
|
-
pipeline = ... # load diffusers pipeline
|
3884
|
-
max_rank = ... # the highest rank among all LoRAs that you want to load
|
3885
|
-
# call *before* compiling and loading the LoRA adapter
|
3886
|
-
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
3887
|
-
pipeline.load_lora_weights(file_name)
|
3888
|
-
# optionally compile the model now
|
3889
|
-
```
|
3890
|
-
|
3891
|
-
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
3892
|
-
limitations to this technique, which are documented here:
|
3893
|
-
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
4558
|
+
hotswap (`bool`, *optional*):
|
4559
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
4560
|
+
metadata (`dict`):
|
4561
|
+
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
|
4562
|
+
from the state dict.
|
3894
4563
|
"""
|
3895
4564
|
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
3896
4565
|
raise ValueError(
|
@@ -3903,6 +4572,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
|
3903
4572
|
state_dict,
|
3904
4573
|
network_alphas=None,
|
3905
4574
|
adapter_name=adapter_name,
|
4575
|
+
metadata=metadata,
|
3906
4576
|
_pipeline=_pipeline,
|
3907
4577
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
3908
4578
|
hotswap=hotswap,
|
@@ -3918,9 +4588,10 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
|
3918
4588
|
weight_name: str = None,
|
3919
4589
|
save_function: Callable = None,
|
3920
4590
|
safe_serialization: bool = True,
|
4591
|
+
transformer_lora_adapter_metadata: Optional[dict] = None,
|
3921
4592
|
):
|
3922
4593
|
r"""
|
3923
|
-
Save the LoRA parameters corresponding to the
|
4594
|
+
Save the LoRA parameters corresponding to the transformer.
|
3924
4595
|
|
3925
4596
|
Arguments:
|
3926
4597
|
save_directory (`str` or `os.PathLike`):
|
@@ -3937,14 +4608,21 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
|
3937
4608
|
`DIFFUSERS_SAVE_MODE`.
|
3938
4609
|
safe_serialization (`bool`, *optional*, defaults to `True`):
|
3939
4610
|
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
4611
|
+
transformer_lora_adapter_metadata:
|
4612
|
+
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
|
3940
4613
|
"""
|
3941
4614
|
state_dict = {}
|
4615
|
+
lora_adapter_metadata = {}
|
3942
4616
|
|
3943
4617
|
if not transformer_lora_layers:
|
3944
4618
|
raise ValueError("You must pass `transformer_lora_layers`.")
|
3945
4619
|
|
3946
|
-
|
3947
|
-
|
4620
|
+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
4621
|
+
|
4622
|
+
if transformer_lora_adapter_metadata is not None:
|
4623
|
+
lora_adapter_metadata.update(
|
4624
|
+
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
|
4625
|
+
)
|
3948
4626
|
|
3949
4627
|
# Save the model
|
3950
4628
|
cls.write_lora_layers(
|
@@ -3954,6 +4632,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
|
3954
4632
|
weight_name=weight_name,
|
3955
4633
|
save_function=save_function,
|
3956
4634
|
safe_serialization=safe_serialization,
|
4635
|
+
lora_adapter_metadata=lora_adapter_metadata,
|
3957
4636
|
)
|
3958
4637
|
|
3959
4638
|
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
@@ -4023,9 +4702,9 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
|
4023
4702
|
super().unfuse_lora(components=components, **kwargs)
|
4024
4703
|
|
4025
4704
|
|
4026
|
-
class
|
4705
|
+
class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
4027
4706
|
r"""
|
4028
|
-
Load LoRA layers into [`
|
4707
|
+
Load LoRA layers into [`Lumina2Transformer2DModel`]. Specific to [`Lumina2Text2ImgPipeline`].
|
4029
4708
|
"""
|
4030
4709
|
|
4031
4710
|
_lora_loadable_modules = ["transformer"]
|
@@ -4043,7 +4722,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
|
4043
4722
|
|
4044
4723
|
<Tip warning={true}>
|
4045
4724
|
|
4046
|
-
We support loading
|
4725
|
+
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
|
4047
4726
|
|
4048
4727
|
This function is experimental and might change in the future.
|
4049
4728
|
|
@@ -4081,7 +4760,8 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
|
4081
4760
|
allowed by Git.
|
4082
4761
|
subfolder (`str`, *optional*, defaults to `""`):
|
4083
4762
|
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
4084
|
-
|
4763
|
+
return_lora_metadata (`bool`, *optional*, defaults to False):
|
4764
|
+
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
|
4085
4765
|
"""
|
4086
4766
|
# Load the main state dict first which has the LoRA layers for either of
|
4087
4767
|
# transformer and text encoder or both.
|
@@ -4094,18 +4774,16 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
|
4094
4774
|
subfolder = kwargs.pop("subfolder", None)
|
4095
4775
|
weight_name = kwargs.pop("weight_name", None)
|
4096
4776
|
use_safetensors = kwargs.pop("use_safetensors", None)
|
4777
|
+
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
|
4097
4778
|
|
4098
4779
|
allow_pickle = False
|
4099
4780
|
if use_safetensors is None:
|
4100
4781
|
use_safetensors = True
|
4101
4782
|
allow_pickle = True
|
4102
4783
|
|
4103
|
-
user_agent = {
|
4104
|
-
"file_type": "attn_procs_weights",
|
4105
|
-
"framework": "pytorch",
|
4106
|
-
}
|
4784
|
+
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
|
4107
4785
|
|
4108
|
-
state_dict = _fetch_state_dict(
|
4786
|
+
state_dict, metadata = _fetch_state_dict(
|
4109
4787
|
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
4110
4788
|
weight_name=weight_name,
|
4111
4789
|
use_safetensors=use_safetensors,
|
@@ -4126,15 +4804,21 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
|
4126
4804
|
logger.warning(warn_msg)
|
4127
4805
|
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
4128
4806
|
|
4129
|
-
|
4130
|
-
|
4131
|
-
|
4807
|
+
# conversion.
|
4808
|
+
non_diffusers = any(k.startswith("diffusion_model.") for k in state_dict)
|
4809
|
+
if non_diffusers:
|
4810
|
+
state_dict = _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict)
|
4132
4811
|
|
4133
|
-
|
4812
|
+
out = (state_dict, metadata) if return_lora_metadata else state_dict
|
4813
|
+
return out
|
4134
4814
|
|
4135
4815
|
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
4136
4816
|
def load_lora_weights(
|
4137
|
-
self,
|
4817
|
+
self,
|
4818
|
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
4819
|
+
adapter_name: Optional[str] = None,
|
4820
|
+
hotswap: bool = False,
|
4821
|
+
**kwargs,
|
4138
4822
|
):
|
4139
4823
|
"""
|
4140
4824
|
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
@@ -4152,6 +4836,8 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
|
4152
4836
|
low_cpu_mem_usage (`bool`, *optional*):
|
4153
4837
|
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
4154
4838
|
weights.
|
4839
|
+
hotswap (`bool`, *optional*):
|
4840
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
4155
4841
|
kwargs (`dict`, *optional*):
|
4156
4842
|
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
4157
4843
|
"""
|
@@ -4169,7 +4855,8 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
|
4169
4855
|
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
4170
4856
|
|
4171
4857
|
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
4172
|
-
|
4858
|
+
kwargs["return_lora_metadata"] = True
|
4859
|
+
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
4173
4860
|
|
4174
4861
|
is_correct_format = all("lora" in key for key in state_dict.keys())
|
4175
4862
|
if not is_correct_format:
|
@@ -4179,54 +4866,45 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
|
4179
4866
|
state_dict,
|
4180
4867
|
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
4181
4868
|
adapter_name=adapter_name,
|
4869
|
+
metadata=metadata,
|
4182
4870
|
_pipeline=self,
|
4183
4871
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
4872
|
+
hotswap=hotswap,
|
4184
4873
|
)
|
4185
4874
|
|
4186
4875
|
@classmethod
|
4187
|
-
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->
|
4876
|
+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->Lumina2Transformer2DModel
|
4188
4877
|
def load_lora_into_transformer(
|
4189
|
-
cls,
|
4878
|
+
cls,
|
4879
|
+
state_dict,
|
4880
|
+
transformer,
|
4881
|
+
adapter_name=None,
|
4882
|
+
_pipeline=None,
|
4883
|
+
low_cpu_mem_usage=False,
|
4884
|
+
hotswap: bool = False,
|
4885
|
+
metadata=None,
|
4190
4886
|
):
|
4191
4887
|
"""
|
4192
|
-
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
4193
|
-
|
4194
|
-
Parameters:
|
4195
|
-
state_dict (`dict`):
|
4196
|
-
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
4197
|
-
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
4198
|
-
encoder lora layers.
|
4199
|
-
transformer (`
|
4200
|
-
The Transformer model to load the LoRA layers into.
|
4201
|
-
adapter_name (`str`, *optional*):
|
4202
|
-
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
4203
|
-
`default_{i}` where i is the total number of adapters being loaded.
|
4204
|
-
low_cpu_mem_usage (`bool`, *optional*):
|
4205
|
-
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
4206
|
-
weights.
|
4207
|
-
hotswap
|
4208
|
-
|
4209
|
-
|
4210
|
-
|
4211
|
-
|
4212
|
-
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
4213
|
-
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
4214
|
-
|
4215
|
-
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
4216
|
-
to call an additional method before loading the adapter:
|
4217
|
-
|
4218
|
-
```py
|
4219
|
-
pipeline = ... # load diffusers pipeline
|
4220
|
-
max_rank = ... # the highest rank among all LoRAs that you want to load
|
4221
|
-
# call *before* compiling and loading the LoRA adapter
|
4222
|
-
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
4223
|
-
pipeline.load_lora_weights(file_name)
|
4224
|
-
# optionally compile the model now
|
4225
|
-
```
|
4226
|
-
|
4227
|
-
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
4228
|
-
limitations to this technique, which are documented here:
|
4229
|
-
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
4888
|
+
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
4889
|
+
|
4890
|
+
Parameters:
|
4891
|
+
state_dict (`dict`):
|
4892
|
+
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
4893
|
+
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
4894
|
+
encoder lora layers.
|
4895
|
+
transformer (`Lumina2Transformer2DModel`):
|
4896
|
+
The Transformer model to load the LoRA layers into.
|
4897
|
+
adapter_name (`str`, *optional*):
|
4898
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
4899
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
4900
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
4901
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
4902
|
+
weights.
|
4903
|
+
hotswap (`bool`, *optional*):
|
4904
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
4905
|
+
metadata (`dict`):
|
4906
|
+
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
|
4907
|
+
from the state dict.
|
4230
4908
|
"""
|
4231
4909
|
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
4232
4910
|
raise ValueError(
|
@@ -4239,6 +4917,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
|
4239
4917
|
state_dict,
|
4240
4918
|
network_alphas=None,
|
4241
4919
|
adapter_name=adapter_name,
|
4920
|
+
metadata=metadata,
|
4242
4921
|
_pipeline=_pipeline,
|
4243
4922
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
4244
4923
|
hotswap=hotswap,
|
@@ -4254,9 +4933,10 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
|
4254
4933
|
weight_name: str = None,
|
4255
4934
|
save_function: Callable = None,
|
4256
4935
|
safe_serialization: bool = True,
|
4936
|
+
transformer_lora_adapter_metadata: Optional[dict] = None,
|
4257
4937
|
):
|
4258
4938
|
r"""
|
4259
|
-
Save the LoRA parameters corresponding to the
|
4939
|
+
Save the LoRA parameters corresponding to the transformer.
|
4260
4940
|
|
4261
4941
|
Arguments:
|
4262
4942
|
save_directory (`str` or `os.PathLike`):
|
@@ -4273,14 +4953,21 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
|
4273
4953
|
`DIFFUSERS_SAVE_MODE`.
|
4274
4954
|
safe_serialization (`bool`, *optional*, defaults to `True`):
|
4275
4955
|
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
4956
|
+
transformer_lora_adapter_metadata:
|
4957
|
+
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
|
4276
4958
|
"""
|
4277
4959
|
state_dict = {}
|
4960
|
+
lora_adapter_metadata = {}
|
4278
4961
|
|
4279
4962
|
if not transformer_lora_layers:
|
4280
4963
|
raise ValueError("You must pass `transformer_lora_layers`.")
|
4281
4964
|
|
4282
|
-
|
4283
|
-
|
4965
|
+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
4966
|
+
|
4967
|
+
if transformer_lora_adapter_metadata is not None:
|
4968
|
+
lora_adapter_metadata.update(
|
4969
|
+
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
|
4970
|
+
)
|
4284
4971
|
|
4285
4972
|
# Save the model
|
4286
4973
|
cls.write_lora_layers(
|
@@ -4290,9 +4977,10 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
|
4290
4977
|
weight_name=weight_name,
|
4291
4978
|
save_function=save_function,
|
4292
4979
|
safe_serialization=safe_serialization,
|
4980
|
+
lora_adapter_metadata=lora_adapter_metadata,
|
4293
4981
|
)
|
4294
4982
|
|
4295
|
-
# Copied from diffusers.loaders.lora_pipeline.
|
4983
|
+
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
|
4296
4984
|
def fuse_lora(
|
4297
4985
|
self,
|
4298
4986
|
components: List[str] = ["transformer"],
|
@@ -4340,7 +5028,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
|
4340
5028
|
**kwargs,
|
4341
5029
|
)
|
4342
5030
|
|
4343
|
-
# Copied from diffusers.loaders.lora_pipeline.
|
5031
|
+
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
|
4344
5032
|
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
4345
5033
|
r"""
|
4346
5034
|
Reverses the effect of
|
@@ -4359,9 +5047,9 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
|
4359
5047
|
super().unfuse_lora(components=components, **kwargs)
|
4360
5048
|
|
4361
5049
|
|
4362
|
-
class
|
5050
|
+
class WanLoraLoaderMixin(LoraBaseMixin):
|
4363
5051
|
r"""
|
4364
|
-
Load LoRA layers into [`
|
5052
|
+
Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`WanPipeline`] and `[WanImageToVideoPipeline`].
|
4365
5053
|
"""
|
4366
5054
|
|
4367
5055
|
_lora_loadable_modules = ["transformer"]
|
@@ -4417,7 +5105,8 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
|
4417
5105
|
allowed by Git.
|
4418
5106
|
subfolder (`str`, *optional*, defaults to `""`):
|
4419
5107
|
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
4420
|
-
|
5108
|
+
return_lora_metadata (`bool`, *optional*, defaults to False):
|
5109
|
+
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
|
4421
5110
|
"""
|
4422
5111
|
# Load the main state dict first which has the LoRA layers for either of
|
4423
5112
|
# transformer and text encoder or both.
|
@@ -4430,18 +5119,16 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
|
4430
5119
|
subfolder = kwargs.pop("subfolder", None)
|
4431
5120
|
weight_name = kwargs.pop("weight_name", None)
|
4432
5121
|
use_safetensors = kwargs.pop("use_safetensors", None)
|
5122
|
+
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
|
4433
5123
|
|
4434
5124
|
allow_pickle = False
|
4435
5125
|
if use_safetensors is None:
|
4436
5126
|
use_safetensors = True
|
4437
5127
|
allow_pickle = True
|
4438
5128
|
|
4439
|
-
user_agent = {
|
4440
|
-
"file_type": "attn_procs_weights",
|
4441
|
-
"framework": "pytorch",
|
4442
|
-
}
|
5129
|
+
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
|
4443
5130
|
|
4444
|
-
state_dict = _fetch_state_dict(
|
5131
|
+
state_dict, metadata = _fetch_state_dict(
|
4445
5132
|
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
4446
5133
|
weight_name=weight_name,
|
4447
5134
|
use_safetensors=use_safetensors,
|
@@ -4455,6 +5142,10 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
|
4455
5142
|
user_agent=user_agent,
|
4456
5143
|
allow_pickle=allow_pickle,
|
4457
5144
|
)
|
5145
|
+
if any(k.startswith("diffusion_model.") for k in state_dict):
|
5146
|
+
state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict)
|
5147
|
+
elif any(k.startswith("lora_unet_") for k in state_dict):
|
5148
|
+
state_dict = _convert_musubi_wan_lora_to_diffusers(state_dict)
|
4458
5149
|
|
4459
5150
|
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
4460
5151
|
if is_dora_scale_present:
|
@@ -4462,16 +5153,63 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
|
4462
5153
|
logger.warning(warn_msg)
|
4463
5154
|
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
4464
5155
|
|
4465
|
-
|
4466
|
-
|
4467
|
-
|
4468
|
-
|
5156
|
+
out = (state_dict, metadata) if return_lora_metadata else state_dict
|
5157
|
+
return out
|
5158
|
+
|
5159
|
+
@classmethod
|
5160
|
+
def _maybe_expand_t2v_lora_for_i2v(
|
5161
|
+
cls,
|
5162
|
+
transformer: torch.nn.Module,
|
5163
|
+
state_dict,
|
5164
|
+
):
|
5165
|
+
if transformer.config.image_dim is None:
|
5166
|
+
return state_dict
|
5167
|
+
|
5168
|
+
target_device = transformer.device
|
5169
|
+
|
5170
|
+
if any(k.startswith("transformer.blocks.") for k in state_dict):
|
5171
|
+
num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict if "blocks." in k})
|
5172
|
+
is_i2v_lora = any("add_k_proj" in k for k in state_dict) and any("add_v_proj" in k for k in state_dict)
|
5173
|
+
has_bias = any(".lora_B.bias" in k for k in state_dict)
|
5174
|
+
|
5175
|
+
if is_i2v_lora:
|
5176
|
+
return state_dict
|
5177
|
+
|
5178
|
+
for i in range(num_blocks):
|
5179
|
+
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
|
5180
|
+
# These keys should exist if the block `i` was part of the T2V LoRA.
|
5181
|
+
ref_key_lora_A = f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"
|
5182
|
+
ref_key_lora_B = f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"
|
5183
|
+
|
5184
|
+
if ref_key_lora_A not in state_dict or ref_key_lora_B not in state_dict:
|
5185
|
+
continue
|
5186
|
+
|
5187
|
+
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like(
|
5188
|
+
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"], device=target_device
|
5189
|
+
)
|
5190
|
+
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like(
|
5191
|
+
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"], device=target_device
|
5192
|
+
)
|
5193
|
+
|
5194
|
+
# If the original LoRA had biases (indicated by has_bias)
|
5195
|
+
# AND the specific reference bias key exists for this block.
|
5196
|
+
|
5197
|
+
ref_key_lora_B_bias = f"transformer.blocks.{i}.attn2.to_k.lora_B.bias"
|
5198
|
+
if has_bias and ref_key_lora_B_bias in state_dict:
|
5199
|
+
ref_lora_B_bias_tensor = state_dict[ref_key_lora_B_bias]
|
5200
|
+
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.bias"] = torch.zeros_like(
|
5201
|
+
ref_lora_B_bias_tensor,
|
5202
|
+
device=target_device,
|
5203
|
+
)
|
4469
5204
|
|
4470
5205
|
return state_dict
|
4471
5206
|
|
4472
|
-
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
4473
5207
|
def load_lora_weights(
|
4474
|
-
self,
|
5208
|
+
self,
|
5209
|
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
5210
|
+
adapter_name: Optional[str] = None,
|
5211
|
+
hotswap: bool = False,
|
5212
|
+
**kwargs,
|
4475
5213
|
):
|
4476
5214
|
"""
|
4477
5215
|
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
@@ -4489,6 +5227,8 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
|
4489
5227
|
low_cpu_mem_usage (`bool`, *optional*):
|
4490
5228
|
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
4491
5229
|
weights.
|
5230
|
+
hotswap (`bool`, *optional*):
|
5231
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
4492
5232
|
kwargs (`dict`, *optional*):
|
4493
5233
|
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
4494
5234
|
"""
|
@@ -4506,8 +5246,13 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
|
4506
5246
|
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
4507
5247
|
|
4508
5248
|
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
4509
|
-
|
4510
|
-
|
5249
|
+
kwargs["return_lora_metadata"] = True
|
5250
|
+
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
5251
|
+
# convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers
|
5252
|
+
state_dict = self._maybe_expand_t2v_lora_for_i2v(
|
5253
|
+
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
5254
|
+
state_dict=state_dict,
|
5255
|
+
)
|
4511
5256
|
is_correct_format = all("lora" in key for key in state_dict.keys())
|
4512
5257
|
if not is_correct_format:
|
4513
5258
|
raise ValueError("Invalid LoRA checkpoint.")
|
@@ -4516,14 +5261,23 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
|
4516
5261
|
state_dict,
|
4517
5262
|
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
4518
5263
|
adapter_name=adapter_name,
|
5264
|
+
metadata=metadata,
|
4519
5265
|
_pipeline=self,
|
4520
5266
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
5267
|
+
hotswap=hotswap,
|
4521
5268
|
)
|
4522
5269
|
|
4523
5270
|
@classmethod
|
4524
|
-
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->
|
5271
|
+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel
|
4525
5272
|
def load_lora_into_transformer(
|
4526
|
-
cls,
|
5273
|
+
cls,
|
5274
|
+
state_dict,
|
5275
|
+
transformer,
|
5276
|
+
adapter_name=None,
|
5277
|
+
_pipeline=None,
|
5278
|
+
low_cpu_mem_usage=False,
|
5279
|
+
hotswap: bool = False,
|
5280
|
+
metadata=None,
|
4527
5281
|
):
|
4528
5282
|
"""
|
4529
5283
|
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
@@ -4533,7 +5287,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
|
4533
5287
|
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
4534
5288
|
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
4535
5289
|
encoder lora layers.
|
4536
|
-
transformer (`
|
5290
|
+
transformer (`WanTransformer3DModel`):
|
4537
5291
|
The Transformer model to load the LoRA layers into.
|
4538
5292
|
adapter_name (`str`, *optional*):
|
4539
5293
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
@@ -4541,29 +5295,11 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
|
4541
5295
|
low_cpu_mem_usage (`bool`, *optional*):
|
4542
5296
|
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
4543
5297
|
weights.
|
4544
|
-
hotswap
|
4545
|
-
|
4546
|
-
|
4547
|
-
|
4548
|
-
|
4549
|
-
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
4550
|
-
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
4551
|
-
|
4552
|
-
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
4553
|
-
to call an additional method before loading the adapter:
|
4554
|
-
|
4555
|
-
```py
|
4556
|
-
pipeline = ... # load diffusers pipeline
|
4557
|
-
max_rank = ... # the highest rank among all LoRAs that you want to load
|
4558
|
-
# call *before* compiling and loading the LoRA adapter
|
4559
|
-
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
4560
|
-
pipeline.load_lora_weights(file_name)
|
4561
|
-
# optionally compile the model now
|
4562
|
-
```
|
4563
|
-
|
4564
|
-
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
4565
|
-
limitations to this technique, which are documented here:
|
4566
|
-
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
5298
|
+
hotswap (`bool`, *optional*):
|
5299
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
5300
|
+
metadata (`dict`):
|
5301
|
+
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
|
5302
|
+
from the state dict.
|
4567
5303
|
"""
|
4568
5304
|
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
4569
5305
|
raise ValueError(
|
@@ -4576,6 +5312,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
|
4576
5312
|
state_dict,
|
4577
5313
|
network_alphas=None,
|
4578
5314
|
adapter_name=adapter_name,
|
5315
|
+
metadata=metadata,
|
4579
5316
|
_pipeline=_pipeline,
|
4580
5317
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
4581
5318
|
hotswap=hotswap,
|
@@ -4591,9 +5328,10 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
|
4591
5328
|
weight_name: str = None,
|
4592
5329
|
save_function: Callable = None,
|
4593
5330
|
safe_serialization: bool = True,
|
5331
|
+
transformer_lora_adapter_metadata: Optional[dict] = None,
|
4594
5332
|
):
|
4595
5333
|
r"""
|
4596
|
-
Save the LoRA parameters corresponding to the
|
5334
|
+
Save the LoRA parameters corresponding to the transformer.
|
4597
5335
|
|
4598
5336
|
Arguments:
|
4599
5337
|
save_directory (`str` or `os.PathLike`):
|
@@ -4610,14 +5348,21 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
|
4610
5348
|
`DIFFUSERS_SAVE_MODE`.
|
4611
5349
|
safe_serialization (`bool`, *optional*, defaults to `True`):
|
4612
5350
|
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
5351
|
+
transformer_lora_adapter_metadata:
|
5352
|
+
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
|
4613
5353
|
"""
|
4614
5354
|
state_dict = {}
|
5355
|
+
lora_adapter_metadata = {}
|
4615
5356
|
|
4616
5357
|
if not transformer_lora_layers:
|
4617
5358
|
raise ValueError("You must pass `transformer_lora_layers`.")
|
4618
5359
|
|
4619
|
-
|
4620
|
-
|
5360
|
+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
5361
|
+
|
5362
|
+
if transformer_lora_adapter_metadata is not None:
|
5363
|
+
lora_adapter_metadata.update(
|
5364
|
+
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
|
5365
|
+
)
|
4621
5366
|
|
4622
5367
|
# Save the model
|
4623
5368
|
cls.write_lora_layers(
|
@@ -4627,9 +5372,10 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
|
4627
5372
|
weight_name=weight_name,
|
4628
5373
|
save_function=save_function,
|
4629
5374
|
safe_serialization=safe_serialization,
|
5375
|
+
lora_adapter_metadata=lora_adapter_metadata,
|
4630
5376
|
)
|
4631
5377
|
|
4632
|
-
# Copied from diffusers.loaders.lora_pipeline.
|
5378
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
4633
5379
|
def fuse_lora(
|
4634
5380
|
self,
|
4635
5381
|
components: List[str] = ["transformer"],
|
@@ -4677,7 +5423,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
|
4677
5423
|
**kwargs,
|
4678
5424
|
)
|
4679
5425
|
|
4680
|
-
# Copied from diffusers.loaders.lora_pipeline.
|
5426
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
4681
5427
|
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
4682
5428
|
r"""
|
4683
5429
|
Reverses the effect of
|
@@ -4696,9 +5442,9 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
|
4696
5442
|
super().unfuse_lora(components=components, **kwargs)
|
4697
5443
|
|
4698
5444
|
|
4699
|
-
class
|
5445
|
+
class CogView4LoraLoaderMixin(LoraBaseMixin):
|
4700
5446
|
r"""
|
4701
|
-
Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`
|
5447
|
+
Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`CogView4Pipeline`].
|
4702
5448
|
"""
|
4703
5449
|
|
4704
5450
|
_lora_loadable_modules = ["transformer"]
|
@@ -4706,6 +5452,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
|
4706
5452
|
|
4707
5453
|
@classmethod
|
4708
5454
|
@validate_hf_hub_args
|
5455
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
|
4709
5456
|
def lora_state_dict(
|
4710
5457
|
cls,
|
4711
5458
|
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
@@ -4754,6 +5501,8 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
|
4754
5501
|
allowed by Git.
|
4755
5502
|
subfolder (`str`, *optional*, defaults to `""`):
|
4756
5503
|
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
5504
|
+
return_lora_metadata (`bool`, *optional*, defaults to False):
|
5505
|
+
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
|
4757
5506
|
|
4758
5507
|
"""
|
4759
5508
|
# Load the main state dict first which has the LoRA layers for either of
|
@@ -4767,18 +5516,16 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
|
4767
5516
|
subfolder = kwargs.pop("subfolder", None)
|
4768
5517
|
weight_name = kwargs.pop("weight_name", None)
|
4769
5518
|
use_safetensors = kwargs.pop("use_safetensors", None)
|
5519
|
+
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
|
4770
5520
|
|
4771
5521
|
allow_pickle = False
|
4772
5522
|
if use_safetensors is None:
|
4773
5523
|
use_safetensors = True
|
4774
5524
|
allow_pickle = True
|
4775
5525
|
|
4776
|
-
user_agent = {
|
4777
|
-
"file_type": "attn_procs_weights",
|
4778
|
-
"framework": "pytorch",
|
4779
|
-
}
|
5526
|
+
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
|
4780
5527
|
|
4781
|
-
state_dict = _fetch_state_dict(
|
5528
|
+
state_dict, metadata = _fetch_state_dict(
|
4782
5529
|
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
4783
5530
|
weight_name=weight_name,
|
4784
5531
|
use_safetensors=use_safetensors,
|
@@ -4792,8 +5539,6 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
|
4792
5539
|
user_agent=user_agent,
|
4793
5540
|
allow_pickle=allow_pickle,
|
4794
5541
|
)
|
4795
|
-
if any(k.startswith("diffusion_model.") for k in state_dict):
|
4796
|
-
state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict)
|
4797
5542
|
|
4798
5543
|
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
4799
5544
|
if is_dora_scale_present:
|
@@ -4801,37 +5546,16 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
|
4801
5546
|
logger.warning(warn_msg)
|
4802
5547
|
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
4803
5548
|
|
4804
|
-
|
4805
|
-
|
4806
|
-
@classmethod
|
4807
|
-
def _maybe_expand_t2v_lora_for_i2v(
|
4808
|
-
cls,
|
4809
|
-
transformer: torch.nn.Module,
|
4810
|
-
state_dict,
|
4811
|
-
):
|
4812
|
-
if transformer.config.image_dim is None:
|
4813
|
-
return state_dict
|
4814
|
-
|
4815
|
-
if any(k.startswith("transformer.blocks.") for k in state_dict):
|
4816
|
-
num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict})
|
4817
|
-
is_i2v_lora = any("add_k_proj" in k for k in state_dict) and any("add_v_proj" in k for k in state_dict)
|
4818
|
-
|
4819
|
-
if is_i2v_lora:
|
4820
|
-
return state_dict
|
4821
|
-
|
4822
|
-
for i in range(num_blocks):
|
4823
|
-
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
|
4824
|
-
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like(
|
4825
|
-
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"]
|
4826
|
-
)
|
4827
|
-
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like(
|
4828
|
-
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"]
|
4829
|
-
)
|
4830
|
-
|
4831
|
-
return state_dict
|
5549
|
+
out = (state_dict, metadata) if return_lora_metadata else state_dict
|
5550
|
+
return out
|
4832
5551
|
|
5552
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
4833
5553
|
def load_lora_weights(
|
4834
|
-
self,
|
5554
|
+
self,
|
5555
|
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
5556
|
+
adapter_name: Optional[str] = None,
|
5557
|
+
hotswap: bool = False,
|
5558
|
+
**kwargs,
|
4835
5559
|
):
|
4836
5560
|
"""
|
4837
5561
|
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
@@ -4849,6 +5573,8 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
|
4849
5573
|
low_cpu_mem_usage (`bool`, *optional*):
|
4850
5574
|
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
4851
5575
|
weights.
|
5576
|
+
hotswap (`bool`, *optional*):
|
5577
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
4852
5578
|
kwargs (`dict`, *optional*):
|
4853
5579
|
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
4854
5580
|
"""
|
@@ -4866,12 +5592,9 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
|
4866
5592
|
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
4867
5593
|
|
4868
5594
|
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
4869
|
-
|
4870
|
-
|
4871
|
-
|
4872
|
-
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
4873
|
-
state_dict=state_dict,
|
4874
|
-
)
|
5595
|
+
kwargs["return_lora_metadata"] = True
|
5596
|
+
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
5597
|
+
|
4875
5598
|
is_correct_format = all("lora" in key for key in state_dict.keys())
|
4876
5599
|
if not is_correct_format:
|
4877
5600
|
raise ValueError("Invalid LoRA checkpoint.")
|
@@ -4880,14 +5603,23 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
|
4880
5603
|
state_dict,
|
4881
5604
|
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
4882
5605
|
adapter_name=adapter_name,
|
5606
|
+
metadata=metadata,
|
4883
5607
|
_pipeline=self,
|
4884
5608
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
5609
|
+
hotswap=hotswap,
|
4885
5610
|
)
|
4886
5611
|
|
4887
5612
|
@classmethod
|
4888
|
-
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->
|
5613
|
+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogView4Transformer2DModel
|
4889
5614
|
def load_lora_into_transformer(
|
4890
|
-
cls,
|
5615
|
+
cls,
|
5616
|
+
state_dict,
|
5617
|
+
transformer,
|
5618
|
+
adapter_name=None,
|
5619
|
+
_pipeline=None,
|
5620
|
+
low_cpu_mem_usage=False,
|
5621
|
+
hotswap: bool = False,
|
5622
|
+
metadata=None,
|
4891
5623
|
):
|
4892
5624
|
"""
|
4893
5625
|
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
@@ -4897,7 +5629,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
|
4897
5629
|
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
4898
5630
|
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
4899
5631
|
encoder lora layers.
|
4900
|
-
transformer (`
|
5632
|
+
transformer (`CogView4Transformer2DModel`):
|
4901
5633
|
The Transformer model to load the LoRA layers into.
|
4902
5634
|
adapter_name (`str`, *optional*):
|
4903
5635
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
@@ -4905,29 +5637,11 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
|
4905
5637
|
low_cpu_mem_usage (`bool`, *optional*):
|
4906
5638
|
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
4907
5639
|
weights.
|
4908
|
-
hotswap
|
4909
|
-
|
4910
|
-
|
4911
|
-
|
4912
|
-
|
4913
|
-
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
4914
|
-
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
4915
|
-
|
4916
|
-
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
4917
|
-
to call an additional method before loading the adapter:
|
4918
|
-
|
4919
|
-
```py
|
4920
|
-
pipeline = ... # load diffusers pipeline
|
4921
|
-
max_rank = ... # the highest rank among all LoRAs that you want to load
|
4922
|
-
# call *before* compiling and loading the LoRA adapter
|
4923
|
-
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
4924
|
-
pipeline.load_lora_weights(file_name)
|
4925
|
-
# optionally compile the model now
|
4926
|
-
```
|
4927
|
-
|
4928
|
-
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
4929
|
-
limitations to this technique, which are documented here:
|
4930
|
-
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
5640
|
+
hotswap (`bool`, *optional*):
|
5641
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
5642
|
+
metadata (`dict`):
|
5643
|
+
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
|
5644
|
+
from the state dict.
|
4931
5645
|
"""
|
4932
5646
|
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
4933
5647
|
raise ValueError(
|
@@ -4940,6 +5654,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
|
4940
5654
|
state_dict,
|
4941
5655
|
network_alphas=None,
|
4942
5656
|
adapter_name=adapter_name,
|
5657
|
+
metadata=metadata,
|
4943
5658
|
_pipeline=_pipeline,
|
4944
5659
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
4945
5660
|
hotswap=hotswap,
|
@@ -4955,9 +5670,10 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
|
4955
5670
|
weight_name: str = None,
|
4956
5671
|
save_function: Callable = None,
|
4957
5672
|
safe_serialization: bool = True,
|
5673
|
+
transformer_lora_adapter_metadata: Optional[dict] = None,
|
4958
5674
|
):
|
4959
5675
|
r"""
|
4960
|
-
Save the LoRA parameters corresponding to the
|
5676
|
+
Save the LoRA parameters corresponding to the transformer.
|
4961
5677
|
|
4962
5678
|
Arguments:
|
4963
5679
|
save_directory (`str` or `os.PathLike`):
|
@@ -4974,14 +5690,21 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
|
4974
5690
|
`DIFFUSERS_SAVE_MODE`.
|
4975
5691
|
safe_serialization (`bool`, *optional*, defaults to `True`):
|
4976
5692
|
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
5693
|
+
transformer_lora_adapter_metadata:
|
5694
|
+
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
|
4977
5695
|
"""
|
4978
5696
|
state_dict = {}
|
5697
|
+
lora_adapter_metadata = {}
|
4979
5698
|
|
4980
5699
|
if not transformer_lora_layers:
|
4981
5700
|
raise ValueError("You must pass `transformer_lora_layers`.")
|
4982
5701
|
|
4983
|
-
|
4984
|
-
|
5702
|
+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
5703
|
+
|
5704
|
+
if transformer_lora_adapter_metadata is not None:
|
5705
|
+
lora_adapter_metadata.update(
|
5706
|
+
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
|
5707
|
+
)
|
4985
5708
|
|
4986
5709
|
# Save the model
|
4987
5710
|
cls.write_lora_layers(
|
@@ -4991,6 +5714,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
|
4991
5714
|
weight_name=weight_name,
|
4992
5715
|
save_function=save_function,
|
4993
5716
|
safe_serialization=safe_serialization,
|
5717
|
+
lora_adapter_metadata=lora_adapter_metadata,
|
4994
5718
|
)
|
4995
5719
|
|
4996
5720
|
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
@@ -5060,9 +5784,9 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
|
5060
5784
|
super().unfuse_lora(components=components, **kwargs)
|
5061
5785
|
|
5062
5786
|
|
5063
|
-
class
|
5787
|
+
class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
|
5064
5788
|
r"""
|
5065
|
-
Load LoRA layers into [`
|
5789
|
+
Load LoRA layers into [`HiDreamImageTransformer2DModel`]. Specific to [`HiDreamImagePipeline`].
|
5066
5790
|
"""
|
5067
5791
|
|
5068
5792
|
_lora_loadable_modules = ["transformer"]
|
@@ -5070,7 +5794,6 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
|
|
5070
5794
|
|
5071
5795
|
@classmethod
|
5072
5796
|
@validate_hf_hub_args
|
5073
|
-
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
|
5074
5797
|
def lora_state_dict(
|
5075
5798
|
cls,
|
5076
5799
|
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
@@ -5119,7 +5842,8 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
|
|
5119
5842
|
allowed by Git.
|
5120
5843
|
subfolder (`str`, *optional*, defaults to `""`):
|
5121
5844
|
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
5122
|
-
|
5845
|
+
return_lora_metadata (`bool`, *optional*, defaults to False):
|
5846
|
+
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
|
5123
5847
|
"""
|
5124
5848
|
# Load the main state dict first which has the LoRA layers for either of
|
5125
5849
|
# transformer and text encoder or both.
|
@@ -5132,18 +5856,16 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
|
|
5132
5856
|
subfolder = kwargs.pop("subfolder", None)
|
5133
5857
|
weight_name = kwargs.pop("weight_name", None)
|
5134
5858
|
use_safetensors = kwargs.pop("use_safetensors", None)
|
5859
|
+
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
|
5135
5860
|
|
5136
5861
|
allow_pickle = False
|
5137
5862
|
if use_safetensors is None:
|
5138
5863
|
use_safetensors = True
|
5139
5864
|
allow_pickle = True
|
5140
5865
|
|
5141
|
-
user_agent = {
|
5142
|
-
"file_type": "attn_procs_weights",
|
5143
|
-
"framework": "pytorch",
|
5144
|
-
}
|
5866
|
+
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
|
5145
5867
|
|
5146
|
-
state_dict = _fetch_state_dict(
|
5868
|
+
state_dict, metadata = _fetch_state_dict(
|
5147
5869
|
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
5148
5870
|
weight_name=weight_name,
|
5149
5871
|
use_safetensors=use_safetensors,
|
@@ -5164,11 +5886,20 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
|
|
5164
5886
|
logger.warning(warn_msg)
|
5165
5887
|
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
5166
5888
|
|
5167
|
-
|
5889
|
+
is_non_diffusers_format = any("diffusion_model" in k for k in state_dict)
|
5890
|
+
if is_non_diffusers_format:
|
5891
|
+
state_dict = _convert_non_diffusers_hidream_lora_to_diffusers(state_dict)
|
5892
|
+
|
5893
|
+
out = (state_dict, metadata) if return_lora_metadata else state_dict
|
5894
|
+
return out
|
5168
5895
|
|
5169
5896
|
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
5170
5897
|
def load_lora_weights(
|
5171
|
-
self,
|
5898
|
+
self,
|
5899
|
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
5900
|
+
adapter_name: Optional[str] = None,
|
5901
|
+
hotswap: bool = False,
|
5902
|
+
**kwargs,
|
5172
5903
|
):
|
5173
5904
|
"""
|
5174
5905
|
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
@@ -5186,6 +5917,8 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
|
|
5186
5917
|
low_cpu_mem_usage (`bool`, *optional*):
|
5187
5918
|
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
5188
5919
|
weights.
|
5920
|
+
hotswap (`bool`, *optional*):
|
5921
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
5189
5922
|
kwargs (`dict`, *optional*):
|
5190
5923
|
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
5191
5924
|
"""
|
@@ -5203,7 +5936,8 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
|
|
5203
5936
|
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
5204
5937
|
|
5205
5938
|
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
5206
|
-
|
5939
|
+
kwargs["return_lora_metadata"] = True
|
5940
|
+
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
5207
5941
|
|
5208
5942
|
is_correct_format = all("lora" in key for key in state_dict.keys())
|
5209
5943
|
if not is_correct_format:
|
@@ -5213,14 +5947,23 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
|
|
5213
5947
|
state_dict,
|
5214
5948
|
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
5215
5949
|
adapter_name=adapter_name,
|
5950
|
+
metadata=metadata,
|
5216
5951
|
_pipeline=self,
|
5217
5952
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
5953
|
+
hotswap=hotswap,
|
5218
5954
|
)
|
5219
5955
|
|
5220
5956
|
@classmethod
|
5221
|
-
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->
|
5957
|
+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HiDreamImageTransformer2DModel
|
5222
5958
|
def load_lora_into_transformer(
|
5223
|
-
cls,
|
5959
|
+
cls,
|
5960
|
+
state_dict,
|
5961
|
+
transformer,
|
5962
|
+
adapter_name=None,
|
5963
|
+
_pipeline=None,
|
5964
|
+
low_cpu_mem_usage=False,
|
5965
|
+
hotswap: bool = False,
|
5966
|
+
metadata=None,
|
5224
5967
|
):
|
5225
5968
|
"""
|
5226
5969
|
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
@@ -5230,7 +5973,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
|
|
5230
5973
|
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
5231
5974
|
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
5232
5975
|
encoder lora layers.
|
5233
|
-
transformer (`
|
5976
|
+
transformer (`HiDreamImageTransformer2DModel`):
|
5234
5977
|
The Transformer model to load the LoRA layers into.
|
5235
5978
|
adapter_name (`str`, *optional*):
|
5236
5979
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
@@ -5238,29 +5981,11 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
|
|
5238
5981
|
low_cpu_mem_usage (`bool`, *optional*):
|
5239
5982
|
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
5240
5983
|
weights.
|
5241
|
-
hotswap
|
5242
|
-
|
5243
|
-
|
5244
|
-
|
5245
|
-
|
5246
|
-
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
5247
|
-
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
5248
|
-
|
5249
|
-
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
5250
|
-
to call an additional method before loading the adapter:
|
5251
|
-
|
5252
|
-
```py
|
5253
|
-
pipeline = ... # load diffusers pipeline
|
5254
|
-
max_rank = ... # the highest rank among all LoRAs that you want to load
|
5255
|
-
# call *before* compiling and loading the LoRA adapter
|
5256
|
-
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
5257
|
-
pipeline.load_lora_weights(file_name)
|
5258
|
-
# optionally compile the model now
|
5259
|
-
```
|
5260
|
-
|
5261
|
-
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
5262
|
-
limitations to this technique, which are documented here:
|
5263
|
-
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
5984
|
+
hotswap (`bool`, *optional*):
|
5985
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
5986
|
+
metadata (`dict`):
|
5987
|
+
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
|
5988
|
+
from the state dict.
|
5264
5989
|
"""
|
5265
5990
|
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
5266
5991
|
raise ValueError(
|
@@ -5273,6 +5998,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
|
|
5273
5998
|
state_dict,
|
5274
5999
|
network_alphas=None,
|
5275
6000
|
adapter_name=adapter_name,
|
6001
|
+
metadata=metadata,
|
5276
6002
|
_pipeline=_pipeline,
|
5277
6003
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
5278
6004
|
hotswap=hotswap,
|
@@ -5288,9 +6014,10 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
|
|
5288
6014
|
weight_name: str = None,
|
5289
6015
|
save_function: Callable = None,
|
5290
6016
|
safe_serialization: bool = True,
|
6017
|
+
transformer_lora_adapter_metadata: Optional[dict] = None,
|
5291
6018
|
):
|
5292
6019
|
r"""
|
5293
|
-
Save the LoRA parameters corresponding to the
|
6020
|
+
Save the LoRA parameters corresponding to the transformer.
|
5294
6021
|
|
5295
6022
|
Arguments:
|
5296
6023
|
save_directory (`str` or `os.PathLike`):
|
@@ -5307,14 +6034,21 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
|
|
5307
6034
|
`DIFFUSERS_SAVE_MODE`.
|
5308
6035
|
safe_serialization (`bool`, *optional*, defaults to `True`):
|
5309
6036
|
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
6037
|
+
transformer_lora_adapter_metadata:
|
6038
|
+
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
|
5310
6039
|
"""
|
5311
6040
|
state_dict = {}
|
6041
|
+
lora_adapter_metadata = {}
|
5312
6042
|
|
5313
6043
|
if not transformer_lora_layers:
|
5314
6044
|
raise ValueError("You must pass `transformer_lora_layers`.")
|
5315
6045
|
|
5316
|
-
|
5317
|
-
|
6046
|
+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
6047
|
+
|
6048
|
+
if transformer_lora_adapter_metadata is not None:
|
6049
|
+
lora_adapter_metadata.update(
|
6050
|
+
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
|
6051
|
+
)
|
5318
6052
|
|
5319
6053
|
# Save the model
|
5320
6054
|
cls.write_lora_layers(
|
@@ -5324,9 +6058,10 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
|
|
5324
6058
|
weight_name=weight_name,
|
5325
6059
|
save_function=save_function,
|
5326
6060
|
safe_serialization=safe_serialization,
|
6061
|
+
lora_adapter_metadata=lora_adapter_metadata,
|
5327
6062
|
)
|
5328
6063
|
|
5329
|
-
# Copied from diffusers.loaders.lora_pipeline.
|
6064
|
+
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
|
5330
6065
|
def fuse_lora(
|
5331
6066
|
self,
|
5332
6067
|
components: List[str] = ["transformer"],
|
@@ -5374,7 +6109,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
|
|
5374
6109
|
**kwargs,
|
5375
6110
|
)
|
5376
6111
|
|
5377
|
-
# Copied from diffusers.loaders.lora_pipeline.
|
6112
|
+
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
|
5378
6113
|
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
5379
6114
|
r"""
|
5380
6115
|
Reverses the effect of
|