diffusers 0.33.1__py3-none-any.whl → 0.35.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +145 -1
- diffusers/callbacks.py +35 -0
- diffusers/commands/__init__.py +1 -1
- diffusers/commands/custom_blocks.py +134 -0
- diffusers/commands/diffusers_cli.py +3 -1
- diffusers/commands/env.py +1 -1
- diffusers/commands/fp16_safetensors.py +2 -2
- diffusers/configuration_utils.py +11 -2
- diffusers/dependency_versions_check.py +1 -1
- diffusers/dependency_versions_table.py +3 -3
- diffusers/experimental/rl/value_guided_sampling.py +1 -1
- diffusers/guiders/__init__.py +41 -0
- diffusers/guiders/adaptive_projected_guidance.py +188 -0
- diffusers/guiders/auto_guidance.py +190 -0
- diffusers/guiders/classifier_free_guidance.py +141 -0
- diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
- diffusers/guiders/frequency_decoupled_guidance.py +327 -0
- diffusers/guiders/guider_utils.py +309 -0
- diffusers/guiders/perturbed_attention_guidance.py +271 -0
- diffusers/guiders/skip_layer_guidance.py +262 -0
- diffusers/guiders/smoothed_energy_guidance.py +251 -0
- diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
- diffusers/hooks/__init__.py +17 -0
- diffusers/hooks/_common.py +56 -0
- diffusers/hooks/_helpers.py +293 -0
- diffusers/hooks/faster_cache.py +9 -8
- diffusers/hooks/first_block_cache.py +259 -0
- diffusers/hooks/group_offloading.py +332 -227
- diffusers/hooks/hooks.py +58 -3
- diffusers/hooks/layer_skip.py +263 -0
- diffusers/hooks/layerwise_casting.py +5 -10
- diffusers/hooks/pyramid_attention_broadcast.py +15 -12
- diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
- diffusers/hooks/utils.py +43 -0
- diffusers/image_processor.py +7 -2
- diffusers/loaders/__init__.py +10 -0
- diffusers/loaders/ip_adapter.py +260 -18
- diffusers/loaders/lora_base.py +261 -127
- diffusers/loaders/lora_conversion_utils.py +657 -35
- diffusers/loaders/lora_pipeline.py +2778 -1246
- diffusers/loaders/peft.py +78 -112
- diffusers/loaders/single_file.py +2 -2
- diffusers/loaders/single_file_model.py +64 -15
- diffusers/loaders/single_file_utils.py +395 -7
- diffusers/loaders/textual_inversion.py +3 -2
- diffusers/loaders/transformer_flux.py +10 -11
- diffusers/loaders/transformer_sd3.py +8 -3
- diffusers/loaders/unet.py +24 -21
- diffusers/loaders/unet_loader_utils.py +6 -3
- diffusers/loaders/utils.py +1 -1
- diffusers/models/__init__.py +23 -1
- diffusers/models/activations.py +5 -5
- diffusers/models/adapter.py +2 -3
- diffusers/models/attention.py +488 -7
- diffusers/models/attention_dispatch.py +1218 -0
- diffusers/models/attention_flax.py +10 -10
- diffusers/models/attention_processor.py +113 -667
- diffusers/models/auto_model.py +49 -12
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
- diffusers/models/autoencoders/autoencoder_dc.py +17 -4
- diffusers/models/autoencoders/autoencoder_kl.py +5 -5
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
- diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1110 -0
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_wan.py +626 -62
- diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
- diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
- diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
- diffusers/models/autoencoders/vae.py +13 -2
- diffusers/models/autoencoders/vq_model.py +2 -2
- diffusers/models/cache_utils.py +32 -10
- diffusers/models/controlnet.py +1 -1
- diffusers/models/controlnet_flux.py +1 -1
- diffusers/models/controlnet_sd3.py +1 -1
- diffusers/models/controlnet_sparsectrl.py +1 -1
- diffusers/models/controlnets/__init__.py +1 -0
- diffusers/models/controlnets/controlnet.py +3 -3
- diffusers/models/controlnets/controlnet_flax.py +1 -1
- diffusers/models/controlnets/controlnet_flux.py +21 -20
- diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
- diffusers/models/controlnets/controlnet_sana.py +290 -0
- diffusers/models/controlnets/controlnet_sd3.py +1 -1
- diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
- diffusers/models/controlnets/controlnet_union.py +5 -5
- diffusers/models/controlnets/controlnet_xs.py +7 -7
- diffusers/models/controlnets/multicontrolnet.py +4 -5
- diffusers/models/controlnets/multicontrolnet_union.py +5 -6
- diffusers/models/downsampling.py +2 -2
- diffusers/models/embeddings.py +36 -46
- diffusers/models/embeddings_flax.py +2 -2
- diffusers/models/lora.py +3 -3
- diffusers/models/model_loading_utils.py +233 -1
- diffusers/models/modeling_flax_utils.py +1 -2
- diffusers/models/modeling_utils.py +203 -108
- diffusers/models/normalization.py +4 -4
- diffusers/models/resnet.py +2 -2
- diffusers/models/resnet_flax.py +1 -1
- diffusers/models/transformers/__init__.py +7 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
- diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
- diffusers/models/transformers/consisid_transformer_3d.py +1 -1
- diffusers/models/transformers/dit_transformer_2d.py +2 -2
- diffusers/models/transformers/dual_transformer_2d.py +1 -1
- diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
- diffusers/models/transformers/latte_transformer_3d.py +4 -5
- diffusers/models/transformers/lumina_nextdit2d.py +2 -2
- diffusers/models/transformers/pixart_transformer_2d.py +3 -3
- diffusers/models/transformers/prior_transformer.py +1 -1
- diffusers/models/transformers/sana_transformer.py +8 -3
- diffusers/models/transformers/stable_audio_transformer.py +5 -9
- diffusers/models/transformers/t5_film_transformer.py +3 -3
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +1 -1
- diffusers/models/transformers/transformer_chroma.py +641 -0
- diffusers/models/transformers/transformer_cogview3plus.py +5 -10
- diffusers/models/transformers/transformer_cogview4.py +353 -27
- diffusers/models/transformers/transformer_cosmos.py +586 -0
- diffusers/models/transformers/transformer_flux.py +376 -138
- diffusers/models/transformers/transformer_hidream_image.py +942 -0
- diffusers/models/transformers/transformer_hunyuan_video.py +12 -8
- diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
- diffusers/models/transformers/transformer_ltx.py +105 -24
- diffusers/models/transformers/transformer_lumina2.py +1 -1
- diffusers/models/transformers/transformer_mochi.py +1 -1
- diffusers/models/transformers/transformer_omnigen.py +2 -2
- diffusers/models/transformers/transformer_qwenimage.py +645 -0
- diffusers/models/transformers/transformer_sd3.py +7 -7
- diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/transformers/transformer_wan.py +316 -87
- diffusers/models/transformers/transformer_wan_vace.py +387 -0
- diffusers/models/unets/unet_1d.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +1 -1
- diffusers/models/unets/unet_2d_blocks.py +1 -1
- diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
- diffusers/models/unets/unet_2d_condition.py +4 -3
- diffusers/models/unets/unet_2d_condition_flax.py +2 -2
- diffusers/models/unets/unet_3d_blocks.py +1 -1
- diffusers/models/unets/unet_3d_condition.py +3 -3
- diffusers/models/unets/unet_i2vgen_xl.py +3 -3
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +2 -2
- diffusers/models/unets/unet_stable_cascade.py +1 -1
- diffusers/models/upsampling.py +2 -2
- diffusers/models/vae_flax.py +2 -2
- diffusers/models/vq_model.py +1 -1
- diffusers/modular_pipelines/__init__.py +83 -0
- diffusers/modular_pipelines/components_manager.py +1068 -0
- diffusers/modular_pipelines/flux/__init__.py +66 -0
- diffusers/modular_pipelines/flux/before_denoise.py +689 -0
- diffusers/modular_pipelines/flux/decoders.py +109 -0
- diffusers/modular_pipelines/flux/denoise.py +227 -0
- diffusers/modular_pipelines/flux/encoders.py +412 -0
- diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
- diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
- diffusers/modular_pipelines/modular_pipeline.py +2446 -0
- diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
- diffusers/modular_pipelines/node_utils.py +665 -0
- diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
- diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
- diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
- diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
- diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
- diffusers/modular_pipelines/wan/__init__.py +66 -0
- diffusers/modular_pipelines/wan/before_denoise.py +365 -0
- diffusers/modular_pipelines/wan/decoders.py +105 -0
- diffusers/modular_pipelines/wan/denoise.py +261 -0
- diffusers/modular_pipelines/wan/encoders.py +242 -0
- diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
- diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
- diffusers/pipelines/__init__.py +68 -6
- diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
- diffusers/pipelines/amused/pipeline_amused.py +7 -6
- diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
- diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
- diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +22 -13
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
- diffusers/pipelines/auto_pipeline.py +23 -20
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
- diffusers/pipelines/chroma/__init__.py +49 -0
- diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
- diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
- diffusers/pipelines/chroma/pipeline_output.py +21 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +17 -16
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +17 -16
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +18 -17
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +17 -16
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
- diffusers/pipelines/cogview4/pipeline_cogview4.py +23 -22
- diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
- diffusers/pipelines/consisid/consisid_utils.py +2 -2
- diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +11 -10
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +226 -107
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +12 -8
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +207 -105
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
- diffusers/pipelines/cosmos/__init__.py +54 -0
- diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
- diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
- diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
- diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
- diffusers/pipelines/cosmos/pipeline_output.py +40 -0
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
- diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
- diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
- diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
- diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
- diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
- diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
- diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +8 -8
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
- diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +4 -2
- diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
- diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
- diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
- diffusers/pipelines/flux/__init__.py +4 -0
- diffusers/pipelines/flux/modeling_flux.py +1 -1
- diffusers/pipelines/flux/pipeline_flux.py +37 -36
- diffusers/pipelines/flux/pipeline_flux_control.py +9 -9
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +31 -23
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +3 -2
- diffusers/pipelines/flux/pipeline_flux_fill.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_img2img.py +40 -7
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +12 -7
- diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
- diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +2 -2
- diffusers/pipelines/flux/pipeline_output.py +6 -4
- diffusers/pipelines/free_init_utils.py +2 -2
- diffusers/pipelines/free_noise_utils.py +3 -3
- diffusers/pipelines/hidream_image/__init__.py +47 -0
- diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
- diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
- diffusers/pipelines/hunyuan_video/__init__.py +2 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +26 -25
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
- diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
- diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
- diffusers/pipelines/kolors/text_encoder.py +3 -3
- diffusers/pipelines/kolors/tokenizer.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
- diffusers/pipelines/latte/pipeline_latte.py +12 -12
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
- diffusers/pipelines/ltx/__init__.py +4 -0
- diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +64 -18
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +117 -38
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +63 -18
- diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
- diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
- diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
- diffusers/pipelines/mochi/pipeline_mochi.py +15 -14
- diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
- diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
- diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
- diffusers/pipelines/onnx_utils.py +15 -2
- diffusers/pipelines/pag/pag_utils.py +2 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
- diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
- diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
- diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
- diffusers/pipelines/pia/pipeline_pia.py +8 -6
- diffusers/pipelines/pipeline_flax_utils.py +5 -6
- diffusers/pipelines/pipeline_loading_utils.py +113 -15
- diffusers/pipelines/pipeline_utils.py +127 -48
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +14 -12
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +31 -11
- diffusers/pipelines/qwenimage/__init__.py +55 -0
- diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
- diffusers/pipelines/sana/__init__.py +4 -0
- diffusers/pipelines/sana/pipeline_sana.py +23 -21
- diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
- diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
- diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
- diffusers/pipelines/shap_e/camera.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
- diffusers/pipelines/shap_e/renderer.py +3 -3
- diffusers/pipelines/skyreels_v2/__init__.py +59 -0
- diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
- diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
- diffusers/pipelines/stable_diffusion/__init__.py +0 -7
- diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +12 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +11 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
- diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
- diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +13 -12
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
- diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
- diffusers/pipelines/unclip/text_proj.py +2 -2
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
- diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
- diffusers/pipelines/visualcloze/__init__.py +52 -0
- diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
- diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
- diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
- diffusers/pipelines/wan/__init__.py +2 -0
- diffusers/pipelines/wan/pipeline_wan.py +91 -30
- diffusers/pipelines/wan/pipeline_wan_i2v.py +145 -45
- diffusers/pipelines/wan/pipeline_wan_vace.py +975 -0
- diffusers/pipelines/wan/pipeline_wan_video2video.py +14 -16
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
- diffusers/quantizers/__init__.py +3 -1
- diffusers/quantizers/base.py +17 -1
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
- diffusers/quantizers/bitsandbytes/utils.py +10 -7
- diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
- diffusers/quantizers/gguf/utils.py +108 -16
- diffusers/quantizers/pipe_quant_config.py +202 -0
- diffusers/quantizers/quantization_config.py +18 -16
- diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +31 -1
- diffusers/schedulers/__init__.py +3 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
- diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
- diffusers/schedulers/scheduling_consistency_models.py +1 -1
- diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
- diffusers/schedulers/scheduling_ddim.py +8 -8
- diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
- diffusers/schedulers/scheduling_ddim_flax.py +6 -6
- diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
- diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
- diffusers/schedulers/scheduling_ddpm.py +9 -9
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
- diffusers/schedulers/scheduling_deis_multistep.py +16 -9
- diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +18 -12
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
- diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +19 -13
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
- diffusers/schedulers/scheduling_edm_euler.py +20 -11
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
- diffusers/schedulers/scheduling_euler_discrete.py +3 -3
- diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
- diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
- diffusers/schedulers/scheduling_heun_discrete.py +2 -2
- diffusers/schedulers/scheduling_ipndm.py +2 -2
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
- diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
- diffusers/schedulers/scheduling_lcm.py +3 -3
- diffusers/schedulers/scheduling_lms_discrete.py +2 -2
- diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
- diffusers/schedulers/scheduling_pndm.py +4 -4
- diffusers/schedulers/scheduling_pndm_flax.py +4 -4
- diffusers/schedulers/scheduling_repaint.py +9 -9
- diffusers/schedulers/scheduling_sasolver.py +15 -15
- diffusers/schedulers/scheduling_scm.py +1 -2
- diffusers/schedulers/scheduling_sde_ve.py +1 -1
- diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
- diffusers/schedulers/scheduling_tcd.py +3 -3
- diffusers/schedulers/scheduling_unclip.py +5 -5
- diffusers/schedulers/scheduling_unipc_multistep.py +21 -12
- diffusers/schedulers/scheduling_utils.py +3 -3
- diffusers/schedulers/scheduling_utils_flax.py +2 -2
- diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
- diffusers/training_utils.py +91 -5
- diffusers/utils/__init__.py +15 -0
- diffusers/utils/accelerate_utils.py +1 -1
- diffusers/utils/constants.py +4 -0
- diffusers/utils/doc_utils.py +1 -1
- diffusers/utils/dummy_pt_objects.py +432 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +480 -0
- diffusers/utils/dynamic_modules_utils.py +85 -8
- diffusers/utils/export_utils.py +1 -1
- diffusers/utils/hub_utils.py +33 -17
- diffusers/utils/import_utils.py +151 -18
- diffusers/utils/logging.py +1 -1
- diffusers/utils/outputs.py +2 -1
- diffusers/utils/peft_utils.py +96 -10
- diffusers/utils/state_dict_utils.py +20 -3
- diffusers/utils/testing_utils.py +195 -17
- diffusers/utils/torch_utils.py +43 -5
- diffusers/video_processor.py +2 -2
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/METADATA +72 -57
- diffusers-0.35.0.dist-info/RECORD +703 -0
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/WHEEL +1 -1
- diffusers-0.33.1.dist-info/RECORD +0 -608
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
diffusers/loaders/lora_base.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -14,6 +14,7 @@
|
|
14
14
|
|
15
15
|
import copy
|
16
16
|
import inspect
|
17
|
+
import json
|
17
18
|
import os
|
18
19
|
from pathlib import Path
|
19
20
|
from typing import Callable, Dict, List, Optional, Union
|
@@ -33,7 +34,6 @@ from ..utils import (
|
|
33
34
|
delete_adapter_layers,
|
34
35
|
deprecate,
|
35
36
|
get_adapter_name,
|
36
|
-
get_peft_kwargs,
|
37
37
|
is_accelerate_available,
|
38
38
|
is_peft_available,
|
39
39
|
is_peft_version,
|
@@ -45,13 +45,13 @@ from ..utils import (
|
|
45
45
|
set_adapter_layers,
|
46
46
|
set_weights_and_activate_adapters,
|
47
47
|
)
|
48
|
+
from ..utils.peft_utils import _create_lora_config
|
49
|
+
from ..utils.state_dict_utils import _load_sft_state_dict_metadata
|
48
50
|
|
49
51
|
|
50
52
|
if is_transformers_available():
|
51
53
|
from transformers import PreTrainedModel
|
52
54
|
|
53
|
-
from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
|
54
|
-
|
55
55
|
if is_peft_available():
|
56
56
|
from peft.tuners.tuners_utils import BaseTunerLayer
|
57
57
|
|
@@ -62,6 +62,7 @@ logger = logging.get_logger(__name__)
|
|
62
62
|
|
63
63
|
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
|
64
64
|
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
|
65
|
+
LORA_ADAPTER_METADATA_KEY = "lora_adapter_metadata"
|
65
66
|
|
66
67
|
|
67
68
|
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
|
@@ -206,6 +207,7 @@ def _fetch_state_dict(
|
|
206
207
|
subfolder,
|
207
208
|
user_agent,
|
208
209
|
allow_pickle,
|
210
|
+
metadata=None,
|
209
211
|
):
|
210
212
|
model_file = None
|
211
213
|
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
@@ -236,11 +238,14 @@ def _fetch_state_dict(
|
|
236
238
|
user_agent=user_agent,
|
237
239
|
)
|
238
240
|
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
241
|
+
metadata = _load_sft_state_dict_metadata(model_file)
|
242
|
+
|
239
243
|
except (IOError, safetensors.SafetensorError) as e:
|
240
244
|
if not allow_pickle:
|
241
245
|
raise e
|
242
246
|
# try loading non-safetensors weights
|
243
247
|
model_file = None
|
248
|
+
metadata = None
|
244
249
|
pass
|
245
250
|
|
246
251
|
if model_file is None:
|
@@ -261,10 +266,11 @@ def _fetch_state_dict(
|
|
261
266
|
user_agent=user_agent,
|
262
267
|
)
|
263
268
|
state_dict = load_state_dict(model_file)
|
269
|
+
metadata = None
|
264
270
|
else:
|
265
271
|
state_dict = pretrained_model_name_or_path_or_dict
|
266
272
|
|
267
|
-
return state_dict
|
273
|
+
return state_dict, metadata
|
268
274
|
|
269
275
|
|
270
276
|
def _best_guess_weight_name(
|
@@ -299,13 +305,18 @@ def _best_guess_weight_name(
|
|
299
305
|
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files))
|
300
306
|
|
301
307
|
if len(targeted_files) > 1:
|
302
|
-
|
303
|
-
f"Provided path contains more than one weights file in the {file_extension} format.
|
308
|
+
logger.warning(
|
309
|
+
f"Provided path contains more than one weights file in the {file_extension} format. `{targeted_files[0]}` is going to be loaded, for precise control, specify a `weight_name` in `load_lora_weights`."
|
304
310
|
)
|
305
311
|
weight_name = targeted_files[0]
|
306
312
|
return weight_name
|
307
313
|
|
308
314
|
|
315
|
+
def _pack_dict_with_prefix(state_dict, prefix):
|
316
|
+
sd_with_prefix = {f"{prefix}.{key}": value for key, value in state_dict.items()}
|
317
|
+
return sd_with_prefix
|
318
|
+
|
319
|
+
|
309
320
|
def _load_lora_into_text_encoder(
|
310
321
|
state_dict,
|
311
322
|
network_alphas,
|
@@ -317,10 +328,16 @@ def _load_lora_into_text_encoder(
|
|
317
328
|
_pipeline=None,
|
318
329
|
low_cpu_mem_usage=False,
|
319
330
|
hotswap: bool = False,
|
331
|
+
metadata=None,
|
320
332
|
):
|
333
|
+
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
|
334
|
+
|
321
335
|
if not USE_PEFT_BACKEND:
|
322
336
|
raise ValueError("PEFT backend is required for this method.")
|
323
337
|
|
338
|
+
if network_alphas and metadata:
|
339
|
+
raise ValueError("`network_alphas` and `metadata` cannot be specified both at the same time.")
|
340
|
+
|
324
341
|
peft_kwargs = {}
|
325
342
|
if low_cpu_mem_usage:
|
326
343
|
if not is_peft_version(">=", "0.13.1"):
|
@@ -335,8 +352,6 @@ def _load_lora_into_text_encoder(
|
|
335
352
|
)
|
336
353
|
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
337
354
|
|
338
|
-
from peft import LoraConfig
|
339
|
-
|
340
355
|
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
341
356
|
# then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as
|
342
357
|
# their prefixes.
|
@@ -348,7 +363,9 @@ def _load_lora_into_text_encoder(
|
|
348
363
|
|
349
364
|
# Load the layers corresponding to text encoder and make necessary adjustments.
|
350
365
|
if prefix is not None:
|
351
|
-
state_dict = {k
|
366
|
+
state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
|
367
|
+
if metadata is not None:
|
368
|
+
metadata = {k.removeprefix(f"{prefix}."): v for k, v in metadata.items() if k.startswith(f"{prefix}.")}
|
352
369
|
|
353
370
|
if len(state_dict) > 0:
|
354
371
|
logger.info(f"Loading {prefix}.")
|
@@ -358,54 +375,27 @@ def _load_lora_into_text_encoder(
|
|
358
375
|
# convert state dict
|
359
376
|
state_dict = convert_state_dict_to_peft(state_dict)
|
360
377
|
|
361
|
-
for name, _ in
|
362
|
-
|
363
|
-
rank_key = f"{name}.
|
364
|
-
if rank_key
|
365
|
-
|
366
|
-
rank[rank_key] = state_dict[rank_key].shape[1]
|
367
|
-
|
368
|
-
for name, _ in text_encoder_mlp_modules(text_encoder):
|
369
|
-
for module in ("fc1", "fc2"):
|
370
|
-
rank_key = f"{name}.{module}.lora_B.weight"
|
371
|
-
if rank_key not in state_dict:
|
372
|
-
continue
|
373
|
-
rank[rank_key] = state_dict[rank_key].shape[1]
|
378
|
+
for name, _ in text_encoder.named_modules():
|
379
|
+
if name.endswith((".q_proj", ".k_proj", ".v_proj", ".out_proj", ".fc1", ".fc2")):
|
380
|
+
rank_key = f"{name}.lora_B.weight"
|
381
|
+
if rank_key in state_dict:
|
382
|
+
rank[rank_key] = state_dict[rank_key].shape[1]
|
374
383
|
|
375
384
|
if network_alphas is not None:
|
376
385
|
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
|
377
|
-
network_alphas = {k.
|
378
|
-
|
379
|
-
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)
|
380
|
-
|
381
|
-
if "use_dora" in lora_config_kwargs:
|
382
|
-
if lora_config_kwargs["use_dora"]:
|
383
|
-
if is_peft_version("<", "0.9.0"):
|
384
|
-
raise ValueError(
|
385
|
-
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
386
|
-
)
|
387
|
-
else:
|
388
|
-
if is_peft_version("<", "0.9.0"):
|
389
|
-
lora_config_kwargs.pop("use_dora")
|
390
|
-
|
391
|
-
if "lora_bias" in lora_config_kwargs:
|
392
|
-
if lora_config_kwargs["lora_bias"]:
|
393
|
-
if is_peft_version("<=", "0.13.2"):
|
394
|
-
raise ValueError(
|
395
|
-
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
|
396
|
-
)
|
397
|
-
else:
|
398
|
-
if is_peft_version("<=", "0.13.2"):
|
399
|
-
lora_config_kwargs.pop("lora_bias")
|
386
|
+
network_alphas = {k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys}
|
400
387
|
|
401
|
-
|
388
|
+
# create `LoraConfig`
|
389
|
+
lora_config = _create_lora_config(state_dict, network_alphas, metadata, rank, is_unet=False)
|
402
390
|
|
403
391
|
# adapter_name
|
404
392
|
if adapter_name is None:
|
405
393
|
adapter_name = get_adapter_name(text_encoder)
|
406
394
|
|
407
|
-
|
408
|
-
|
395
|
+
# <Unsafe code
|
396
|
+
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = _func_optionally_disable_offloading(
|
397
|
+
_pipeline
|
398
|
+
)
|
409
399
|
# inject LoRA layers and load the state dict
|
410
400
|
# in transformers we automatically check whether the adapter name is already in use or not
|
411
401
|
text_encoder.load_adapter(
|
@@ -417,7 +407,6 @@ def _load_lora_into_text_encoder(
|
|
417
407
|
|
418
408
|
# scale LoRA layers with `lora_scale`
|
419
409
|
scale_lora_layers(text_encoder, weight=lora_scale)
|
420
|
-
|
421
410
|
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
422
411
|
|
423
412
|
# Offload back.
|
@@ -425,47 +414,90 @@ def _load_lora_into_text_encoder(
|
|
425
414
|
_pipeline.enable_model_cpu_offload()
|
426
415
|
elif is_sequential_cpu_offload:
|
427
416
|
_pipeline.enable_sequential_cpu_offload()
|
417
|
+
elif is_group_offload:
|
418
|
+
for component in _pipeline.components.values():
|
419
|
+
if isinstance(component, torch.nn.Module):
|
420
|
+
_maybe_remove_and_reapply_group_offloading(component)
|
428
421
|
# Unsafe code />
|
429
422
|
|
430
423
|
if prefix is not None and not state_dict:
|
424
|
+
model_class_name = text_encoder.__class__.__name__
|
431
425
|
logger.warning(
|
432
|
-
f"No LoRA keys associated to {
|
426
|
+
f"No LoRA keys associated to {model_class_name} found with the {prefix=}. "
|
433
427
|
"This is safe to ignore if LoRA state dict didn't originally have any "
|
434
|
-
f"{
|
428
|
+
f"{model_class_name} related params. You can also try specifying `prefix=None` "
|
435
429
|
"to resolve the warning. Otherwise, open an issue if you think it's unexpected: "
|
436
430
|
"https://github.com/huggingface/diffusers/issues/new"
|
437
431
|
)
|
438
432
|
|
439
433
|
|
440
434
|
def _func_optionally_disable_offloading(_pipeline):
|
435
|
+
"""
|
436
|
+
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
|
437
|
+
|
438
|
+
Args:
|
439
|
+
_pipeline (`DiffusionPipeline`):
|
440
|
+
The pipeline to disable offloading for.
|
441
|
+
|
442
|
+
Returns:
|
443
|
+
tuple:
|
444
|
+
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` or `is_group_offload` is True.
|
445
|
+
"""
|
446
|
+
from ..hooks.group_offloading import _is_group_offload_enabled
|
447
|
+
|
441
448
|
is_model_cpu_offload = False
|
442
449
|
is_sequential_cpu_offload = False
|
450
|
+
is_group_offload = False
|
443
451
|
|
444
452
|
if _pipeline is not None and _pipeline.hf_device_map is None:
|
445
453
|
for _, component in _pipeline.components.items():
|
446
|
-
if isinstance(component, nn.Module)
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
454
|
+
if not isinstance(component, nn.Module):
|
455
|
+
continue
|
456
|
+
is_group_offload = is_group_offload or _is_group_offload_enabled(component)
|
457
|
+
if not hasattr(component, "_hf_hook"):
|
458
|
+
continue
|
459
|
+
is_model_cpu_offload = is_model_cpu_offload or isinstance(component._hf_hook, CpuOffload)
|
460
|
+
is_sequential_cpu_offload = is_sequential_cpu_offload or (
|
461
|
+
isinstance(component._hf_hook, AlignDevicesHook)
|
462
|
+
or hasattr(component._hf_hook, "hooks")
|
463
|
+
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
|
464
|
+
)
|
455
465
|
|
456
|
-
|
457
|
-
|
458
|
-
)
|
466
|
+
if is_sequential_cpu_offload or is_model_cpu_offload:
|
467
|
+
logger.info(
|
468
|
+
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
469
|
+
)
|
470
|
+
for _, component in _pipeline.components.items():
|
471
|
+
if not isinstance(component, nn.Module) or not hasattr(component, "_hf_hook"):
|
472
|
+
continue
|
459
473
|
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
460
474
|
|
461
|
-
return (is_model_cpu_offload, is_sequential_cpu_offload)
|
475
|
+
return (is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload)
|
462
476
|
|
463
477
|
|
464
478
|
class LoraBaseMixin:
|
465
479
|
"""Utility class for handling LoRAs."""
|
466
480
|
|
467
481
|
_lora_loadable_modules = []
|
468
|
-
|
482
|
+
_merged_adapters = set()
|
483
|
+
|
484
|
+
@property
|
485
|
+
def lora_scale(self) -> float:
|
486
|
+
"""
|
487
|
+
Returns the lora scale which can be set at run time by the pipeline. # if `_lora_scale` has not been set,
|
488
|
+
return 1.
|
489
|
+
"""
|
490
|
+
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
|
491
|
+
|
492
|
+
@property
|
493
|
+
def num_fused_loras(self):
|
494
|
+
"""Returns the number of LoRAs that have been fused."""
|
495
|
+
return len(self._merged_adapters)
|
496
|
+
|
497
|
+
@property
|
498
|
+
def fused_loras(self):
|
499
|
+
"""Returns names of the LoRAs that have been fused."""
|
500
|
+
return self._merged_adapters
|
469
501
|
|
470
502
|
def load_lora_weights(self, **kwargs):
|
471
503
|
raise NotImplementedError("`load_lora_weights()` is not implemented.")
|
@@ -478,33 +510,6 @@ class LoraBaseMixin:
|
|
478
510
|
def lora_state_dict(cls, **kwargs):
|
479
511
|
raise NotImplementedError("`lora_state_dict()` is not implemented.")
|
480
512
|
|
481
|
-
@classmethod
|
482
|
-
def _optionally_disable_offloading(cls, _pipeline):
|
483
|
-
"""
|
484
|
-
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
|
485
|
-
|
486
|
-
Args:
|
487
|
-
_pipeline (`DiffusionPipeline`):
|
488
|
-
The pipeline to disable offloading for.
|
489
|
-
|
490
|
-
Returns:
|
491
|
-
tuple:
|
492
|
-
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
|
493
|
-
"""
|
494
|
-
return _func_optionally_disable_offloading(_pipeline=_pipeline)
|
495
|
-
|
496
|
-
@classmethod
|
497
|
-
def _fetch_state_dict(cls, *args, **kwargs):
|
498
|
-
deprecation_message = f"Using the `_fetch_state_dict()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`."
|
499
|
-
deprecate("_fetch_state_dict", "0.35.0", deprecation_message)
|
500
|
-
return _fetch_state_dict(*args, **kwargs)
|
501
|
-
|
502
|
-
@classmethod
|
503
|
-
def _best_guess_weight_name(cls, *args, **kwargs):
|
504
|
-
deprecation_message = f"Using the `_best_guess_weight_name()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`."
|
505
|
-
deprecate("_best_guess_weight_name", "0.35.0", deprecation_message)
|
506
|
-
return _best_guess_weight_name(*args, **kwargs)
|
507
|
-
|
508
513
|
def unload_lora_weights(self):
|
509
514
|
"""
|
510
515
|
Unloads the LoRA parameters.
|
@@ -592,6 +597,9 @@ class LoraBaseMixin:
|
|
592
597
|
if len(components) == 0:
|
593
598
|
raise ValueError("`components` cannot be an empty list.")
|
594
599
|
|
600
|
+
# Need to retrieve the names as `adapter_names` can be None. So we cannot directly use it
|
601
|
+
# in `self._merged_adapters = self._merged_adapters | merged_adapter_names`.
|
602
|
+
merged_adapter_names = set()
|
595
603
|
for fuse_component in components:
|
596
604
|
if fuse_component not in self._lora_loadable_modules:
|
597
605
|
raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.")
|
@@ -601,13 +609,19 @@ class LoraBaseMixin:
|
|
601
609
|
# check if diffusers model
|
602
610
|
if issubclass(model.__class__, ModelMixin):
|
603
611
|
model.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
|
612
|
+
for module in model.modules():
|
613
|
+
if isinstance(module, BaseTunerLayer):
|
614
|
+
merged_adapter_names.update(set(module.merged_adapters))
|
604
615
|
# handle transformers models.
|
605
616
|
if issubclass(model.__class__, PreTrainedModel):
|
606
617
|
fuse_text_encoder_lora(
|
607
618
|
model, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
608
619
|
)
|
620
|
+
for module in model.modules():
|
621
|
+
if isinstance(module, BaseTunerLayer):
|
622
|
+
merged_adapter_names.update(set(module.merged_adapters))
|
609
623
|
|
610
|
-
self.
|
624
|
+
self._merged_adapters = self._merged_adapters | merged_adapter_names
|
611
625
|
|
612
626
|
def unfuse_lora(self, components: List[str] = [], **kwargs):
|
613
627
|
r"""
|
@@ -661,15 +675,42 @@ class LoraBaseMixin:
|
|
661
675
|
if issubclass(model.__class__, (ModelMixin, PreTrainedModel)):
|
662
676
|
for module in model.modules():
|
663
677
|
if isinstance(module, BaseTunerLayer):
|
678
|
+
for adapter in set(module.merged_adapters):
|
679
|
+
if adapter and adapter in self._merged_adapters:
|
680
|
+
self._merged_adapters = self._merged_adapters - {adapter}
|
664
681
|
module.unmerge()
|
665
682
|
|
666
|
-
self.num_fused_loras -= 1
|
667
|
-
|
668
683
|
def set_adapters(
|
669
684
|
self,
|
670
685
|
adapter_names: Union[List[str], str],
|
671
686
|
adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None,
|
672
687
|
):
|
688
|
+
"""
|
689
|
+
Set the currently active adapters for use in the pipeline.
|
690
|
+
|
691
|
+
Args:
|
692
|
+
adapter_names (`List[str]` or `str`):
|
693
|
+
The names of the adapters to use.
|
694
|
+
adapter_weights (`Union[List[float], float]`, *optional*):
|
695
|
+
The adapter(s) weights to use with the UNet. If `None`, the weights are set to `1.0` for all the
|
696
|
+
adapters.
|
697
|
+
|
698
|
+
Example:
|
699
|
+
|
700
|
+
```py
|
701
|
+
from diffusers import AutoPipelineForText2Image
|
702
|
+
import torch
|
703
|
+
|
704
|
+
pipeline = AutoPipelineForText2Image.from_pretrained(
|
705
|
+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
706
|
+
).to("cuda")
|
707
|
+
pipeline.load_lora_weights(
|
708
|
+
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
|
709
|
+
)
|
710
|
+
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
711
|
+
pipeline.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5])
|
712
|
+
```
|
713
|
+
"""
|
673
714
|
if isinstance(adapter_weights, dict):
|
674
715
|
components_passed = set(adapter_weights.keys())
|
675
716
|
lora_components = set(self._lora_loadable_modules)
|
@@ -713,7 +754,11 @@ class LoraBaseMixin:
|
|
713
754
|
# Decompose weights into weights for denoiser and text encoders.
|
714
755
|
_component_adapter_weights = {}
|
715
756
|
for component in self._lora_loadable_modules:
|
716
|
-
model = getattr(self, component)
|
757
|
+
model = getattr(self, component, None)
|
758
|
+
# To guard for cases like Wan. In Wan2.1 and WanVace, we have a single denoiser.
|
759
|
+
# Whereas in Wan 2.2, we have two denoisers.
|
760
|
+
if model is None:
|
761
|
+
continue
|
717
762
|
|
718
763
|
for adapter_name, weights in zip(adapter_names, adapter_weights):
|
719
764
|
if isinstance(weights, dict):
|
@@ -739,6 +784,24 @@ class LoraBaseMixin:
|
|
739
784
|
set_adapters_for_text_encoder(adapter_names, model, _component_adapter_weights[component])
|
740
785
|
|
741
786
|
def disable_lora(self):
|
787
|
+
"""
|
788
|
+
Disables the active LoRA layers of the pipeline.
|
789
|
+
|
790
|
+
Example:
|
791
|
+
|
792
|
+
```py
|
793
|
+
from diffusers import AutoPipelineForText2Image
|
794
|
+
import torch
|
795
|
+
|
796
|
+
pipeline = AutoPipelineForText2Image.from_pretrained(
|
797
|
+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
798
|
+
).to("cuda")
|
799
|
+
pipeline.load_lora_weights(
|
800
|
+
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
|
801
|
+
)
|
802
|
+
pipeline.disable_lora()
|
803
|
+
```
|
804
|
+
"""
|
742
805
|
if not USE_PEFT_BACKEND:
|
743
806
|
raise ValueError("PEFT backend is required for this method.")
|
744
807
|
|
@@ -751,6 +814,24 @@ class LoraBaseMixin:
|
|
751
814
|
disable_lora_for_text_encoder(model)
|
752
815
|
|
753
816
|
def enable_lora(self):
|
817
|
+
"""
|
818
|
+
Enables the active LoRA layers of the pipeline.
|
819
|
+
|
820
|
+
Example:
|
821
|
+
|
822
|
+
```py
|
823
|
+
from diffusers import AutoPipelineForText2Image
|
824
|
+
import torch
|
825
|
+
|
826
|
+
pipeline = AutoPipelineForText2Image.from_pretrained(
|
827
|
+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
828
|
+
).to("cuda")
|
829
|
+
pipeline.load_lora_weights(
|
830
|
+
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
|
831
|
+
)
|
832
|
+
pipeline.enable_lora()
|
833
|
+
```
|
834
|
+
"""
|
754
835
|
if not USE_PEFT_BACKEND:
|
755
836
|
raise ValueError("PEFT backend is required for this method.")
|
756
837
|
|
@@ -764,10 +845,26 @@ class LoraBaseMixin:
|
|
764
845
|
|
765
846
|
def delete_adapters(self, adapter_names: Union[List[str], str]):
|
766
847
|
"""
|
848
|
+
Delete an adapter's LoRA layers from the pipeline.
|
849
|
+
|
767
850
|
Args:
|
768
|
-
Deletes the LoRA layers of `adapter_name` for the unet and text-encoder(s).
|
769
851
|
adapter_names (`Union[List[str], str]`):
|
770
|
-
The names of the
|
852
|
+
The names of the adapters to delete.
|
853
|
+
|
854
|
+
Example:
|
855
|
+
|
856
|
+
```py
|
857
|
+
from diffusers import AutoPipelineForText2Image
|
858
|
+
import torch
|
859
|
+
|
860
|
+
pipeline = AutoPipelineForText2Image.from_pretrained(
|
861
|
+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
862
|
+
).to("cuda")
|
863
|
+
pipeline.load_lora_weights(
|
864
|
+
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_names="cinematic"
|
865
|
+
)
|
866
|
+
pipeline.delete_adapters("cinematic")
|
867
|
+
```
|
771
868
|
"""
|
772
869
|
if not USE_PEFT_BACKEND:
|
773
870
|
raise ValueError("PEFT backend is required for this method.")
|
@@ -844,6 +941,27 @@ class LoraBaseMixin:
|
|
844
941
|
Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case
|
845
942
|
you want to load multiple adapters and free some GPU memory.
|
846
943
|
|
944
|
+
After offloading the LoRA adapters to CPU, as long as the rest of the model is still on GPU, the LoRA adapters
|
945
|
+
can no longer be used for inference, as that would cause a device mismatch. Remember to set the device back to
|
946
|
+
GPU before using those LoRA adapters for inference.
|
947
|
+
|
948
|
+
```python
|
949
|
+
>>> pipe.load_lora_weights(path_1, adapter_name="adapter-1")
|
950
|
+
>>> pipe.load_lora_weights(path_2, adapter_name="adapter-2")
|
951
|
+
>>> pipe.set_adapters("adapter-1")
|
952
|
+
>>> image_1 = pipe(**kwargs)
|
953
|
+
>>> # switch to adapter-2, offload adapter-1
|
954
|
+
>>> pipeline.set_lora_device(adapter_names=["adapter-1"], device="cpu")
|
955
|
+
>>> pipeline.set_lora_device(adapter_names=["adapter-2"], device="cuda:0")
|
956
|
+
>>> pipe.set_adapters("adapter-2")
|
957
|
+
>>> image_2 = pipe(**kwargs)
|
958
|
+
>>> # switch back to adapter-1, offload adapter-2
|
959
|
+
>>> pipeline.set_lora_device(adapter_names=["adapter-2"], device="cpu")
|
960
|
+
>>> pipeline.set_lora_device(adapter_names=["adapter-1"], device="cuda:0")
|
961
|
+
>>> pipe.set_adapters("adapter-1")
|
962
|
+
>>> ...
|
963
|
+
```
|
964
|
+
|
847
965
|
Args:
|
848
966
|
adapter_names (`List[str]`):
|
849
967
|
List of adapters to send device to.
|
@@ -859,6 +977,10 @@ class LoraBaseMixin:
|
|
859
977
|
for module in model.modules():
|
860
978
|
if isinstance(module, BaseTunerLayer):
|
861
979
|
for adapter_name in adapter_names:
|
980
|
+
if adapter_name not in module.lora_A:
|
981
|
+
# it is sufficient to check lora_A
|
982
|
+
continue
|
983
|
+
|
862
984
|
module.lora_A[adapter_name].to(device)
|
863
985
|
module.lora_B[adapter_name].to(device)
|
864
986
|
# this is a param, not a module, so device placement is not in-place -> re-assign
|
@@ -868,11 +990,28 @@ class LoraBaseMixin:
|
|
868
990
|
adapter_name
|
869
991
|
].to(device)
|
870
992
|
|
993
|
+
def enable_lora_hotswap(self, **kwargs) -> None:
|
994
|
+
"""
|
995
|
+
Hotswap adapters without triggering recompilation of a model or if the ranks of the loaded adapters are
|
996
|
+
different.
|
997
|
+
|
998
|
+
Args:
|
999
|
+
target_rank (`int`):
|
1000
|
+
The highest rank among all the adapters that will be loaded.
|
1001
|
+
check_compiled (`str`, *optional*, defaults to `"error"`):
|
1002
|
+
How to handle a model that is already compiled. The check can return the following messages:
|
1003
|
+
- "error" (default): raise an error
|
1004
|
+
- "warn": issue a warning
|
1005
|
+
- "ignore": do nothing
|
1006
|
+
"""
|
1007
|
+
for key, component in self.components.items():
|
1008
|
+
if hasattr(component, "enable_lora_hotswap") and (key in self._lora_loadable_modules):
|
1009
|
+
component.enable_lora_hotswap(**kwargs)
|
1010
|
+
|
871
1011
|
@staticmethod
|
872
1012
|
def pack_weights(layers, prefix):
|
873
1013
|
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
|
874
|
-
|
875
|
-
return layers_state_dict
|
1014
|
+
return _pack_dict_with_prefix(layers_weights, prefix)
|
876
1015
|
|
877
1016
|
@staticmethod
|
878
1017
|
def write_lora_layers(
|
@@ -882,16 +1021,33 @@ class LoraBaseMixin:
|
|
882
1021
|
weight_name: str,
|
883
1022
|
save_function: Callable,
|
884
1023
|
safe_serialization: bool,
|
1024
|
+
lora_adapter_metadata: Optional[dict] = None,
|
885
1025
|
):
|
1026
|
+
"""Writes the state dict of the LoRA layers (optionally with metadata) to disk."""
|
886
1027
|
if os.path.isfile(save_directory):
|
887
1028
|
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
888
1029
|
return
|
889
1030
|
|
1031
|
+
if lora_adapter_metadata and not safe_serialization:
|
1032
|
+
raise ValueError("`lora_adapter_metadata` cannot be specified when not using `safe_serialization`.")
|
1033
|
+
if lora_adapter_metadata and not isinstance(lora_adapter_metadata, dict):
|
1034
|
+
raise TypeError("`lora_adapter_metadata` must be of type `dict`.")
|
1035
|
+
|
890
1036
|
if save_function is None:
|
891
1037
|
if safe_serialization:
|
892
1038
|
|
893
1039
|
def save_function(weights, filename):
|
894
|
-
|
1040
|
+
# Inject framework format.
|
1041
|
+
metadata = {"format": "pt"}
|
1042
|
+
if lora_adapter_metadata:
|
1043
|
+
for key, value in lora_adapter_metadata.items():
|
1044
|
+
if isinstance(value, set):
|
1045
|
+
lora_adapter_metadata[key] = list(value)
|
1046
|
+
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(
|
1047
|
+
lora_adapter_metadata, indent=2, sort_keys=True
|
1048
|
+
)
|
1049
|
+
|
1050
|
+
return safetensors.torch.save_file(weights, filename, metadata=metadata)
|
895
1051
|
|
896
1052
|
else:
|
897
1053
|
save_function = torch.save
|
@@ -908,28 +1064,6 @@ class LoraBaseMixin:
|
|
908
1064
|
save_function(state_dict, save_path)
|
909
1065
|
logger.info(f"Model weights saved in {save_path}")
|
910
1066
|
|
911
|
-
@
|
912
|
-
def
|
913
|
-
|
914
|
-
# if _lora_scale has not been set, return 1
|
915
|
-
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
|
916
|
-
|
917
|
-
def enable_lora_hotswap(self, **kwargs) -> None:
|
918
|
-
"""Enables the possibility to hotswap LoRA adapters.
|
919
|
-
|
920
|
-
Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of
|
921
|
-
the loaded adapters differ.
|
922
|
-
|
923
|
-
Args:
|
924
|
-
target_rank (`int`):
|
925
|
-
The highest rank among all the adapters that will be loaded.
|
926
|
-
check_compiled (`str`, *optional*, defaults to `"error"`):
|
927
|
-
How to handle the case when the model is already compiled, which should generally be avoided. The
|
928
|
-
options are:
|
929
|
-
- "error" (default): raise an error
|
930
|
-
- "warn": issue a warning
|
931
|
-
- "ignore": do nothing
|
932
|
-
"""
|
933
|
-
for key, component in self.components.items():
|
934
|
-
if hasattr(component, "enable_lora_hotswap") and (key in self._lora_loadable_modules):
|
935
|
-
component.enable_lora_hotswap(**kwargs)
|
1067
|
+
@classmethod
|
1068
|
+
def _optionally_disable_offloading(cls, _pipeline):
|
1069
|
+
return _func_optionally_disable_offloading(_pipeline=_pipeline)
|