diffusers 0.33.1__py3-none-any.whl → 0.35.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +145 -1
- diffusers/callbacks.py +35 -0
- diffusers/commands/__init__.py +1 -1
- diffusers/commands/custom_blocks.py +134 -0
- diffusers/commands/diffusers_cli.py +3 -1
- diffusers/commands/env.py +1 -1
- diffusers/commands/fp16_safetensors.py +2 -2
- diffusers/configuration_utils.py +11 -2
- diffusers/dependency_versions_check.py +1 -1
- diffusers/dependency_versions_table.py +3 -3
- diffusers/experimental/rl/value_guided_sampling.py +1 -1
- diffusers/guiders/__init__.py +41 -0
- diffusers/guiders/adaptive_projected_guidance.py +188 -0
- diffusers/guiders/auto_guidance.py +190 -0
- diffusers/guiders/classifier_free_guidance.py +141 -0
- diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
- diffusers/guiders/frequency_decoupled_guidance.py +327 -0
- diffusers/guiders/guider_utils.py +309 -0
- diffusers/guiders/perturbed_attention_guidance.py +271 -0
- diffusers/guiders/skip_layer_guidance.py +262 -0
- diffusers/guiders/smoothed_energy_guidance.py +251 -0
- diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
- diffusers/hooks/__init__.py +17 -0
- diffusers/hooks/_common.py +56 -0
- diffusers/hooks/_helpers.py +293 -0
- diffusers/hooks/faster_cache.py +9 -8
- diffusers/hooks/first_block_cache.py +259 -0
- diffusers/hooks/group_offloading.py +332 -227
- diffusers/hooks/hooks.py +58 -3
- diffusers/hooks/layer_skip.py +263 -0
- diffusers/hooks/layerwise_casting.py +5 -10
- diffusers/hooks/pyramid_attention_broadcast.py +15 -12
- diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
- diffusers/hooks/utils.py +43 -0
- diffusers/image_processor.py +7 -2
- diffusers/loaders/__init__.py +10 -0
- diffusers/loaders/ip_adapter.py +260 -18
- diffusers/loaders/lora_base.py +261 -127
- diffusers/loaders/lora_conversion_utils.py +657 -35
- diffusers/loaders/lora_pipeline.py +2778 -1246
- diffusers/loaders/peft.py +78 -112
- diffusers/loaders/single_file.py +2 -2
- diffusers/loaders/single_file_model.py +64 -15
- diffusers/loaders/single_file_utils.py +395 -7
- diffusers/loaders/textual_inversion.py +3 -2
- diffusers/loaders/transformer_flux.py +10 -11
- diffusers/loaders/transformer_sd3.py +8 -3
- diffusers/loaders/unet.py +24 -21
- diffusers/loaders/unet_loader_utils.py +6 -3
- diffusers/loaders/utils.py +1 -1
- diffusers/models/__init__.py +23 -1
- diffusers/models/activations.py +5 -5
- diffusers/models/adapter.py +2 -3
- diffusers/models/attention.py +488 -7
- diffusers/models/attention_dispatch.py +1218 -0
- diffusers/models/attention_flax.py +10 -10
- diffusers/models/attention_processor.py +113 -667
- diffusers/models/auto_model.py +49 -12
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
- diffusers/models/autoencoders/autoencoder_dc.py +17 -4
- diffusers/models/autoencoders/autoencoder_kl.py +5 -5
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
- diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1110 -0
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_wan.py +626 -62
- diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
- diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
- diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
- diffusers/models/autoencoders/vae.py +13 -2
- diffusers/models/autoencoders/vq_model.py +2 -2
- diffusers/models/cache_utils.py +32 -10
- diffusers/models/controlnet.py +1 -1
- diffusers/models/controlnet_flux.py +1 -1
- diffusers/models/controlnet_sd3.py +1 -1
- diffusers/models/controlnet_sparsectrl.py +1 -1
- diffusers/models/controlnets/__init__.py +1 -0
- diffusers/models/controlnets/controlnet.py +3 -3
- diffusers/models/controlnets/controlnet_flax.py +1 -1
- diffusers/models/controlnets/controlnet_flux.py +21 -20
- diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
- diffusers/models/controlnets/controlnet_sana.py +290 -0
- diffusers/models/controlnets/controlnet_sd3.py +1 -1
- diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
- diffusers/models/controlnets/controlnet_union.py +5 -5
- diffusers/models/controlnets/controlnet_xs.py +7 -7
- diffusers/models/controlnets/multicontrolnet.py +4 -5
- diffusers/models/controlnets/multicontrolnet_union.py +5 -6
- diffusers/models/downsampling.py +2 -2
- diffusers/models/embeddings.py +36 -46
- diffusers/models/embeddings_flax.py +2 -2
- diffusers/models/lora.py +3 -3
- diffusers/models/model_loading_utils.py +233 -1
- diffusers/models/modeling_flax_utils.py +1 -2
- diffusers/models/modeling_utils.py +203 -108
- diffusers/models/normalization.py +4 -4
- diffusers/models/resnet.py +2 -2
- diffusers/models/resnet_flax.py +1 -1
- diffusers/models/transformers/__init__.py +7 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
- diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
- diffusers/models/transformers/consisid_transformer_3d.py +1 -1
- diffusers/models/transformers/dit_transformer_2d.py +2 -2
- diffusers/models/transformers/dual_transformer_2d.py +1 -1
- diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
- diffusers/models/transformers/latte_transformer_3d.py +4 -5
- diffusers/models/transformers/lumina_nextdit2d.py +2 -2
- diffusers/models/transformers/pixart_transformer_2d.py +3 -3
- diffusers/models/transformers/prior_transformer.py +1 -1
- diffusers/models/transformers/sana_transformer.py +8 -3
- diffusers/models/transformers/stable_audio_transformer.py +5 -9
- diffusers/models/transformers/t5_film_transformer.py +3 -3
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +1 -1
- diffusers/models/transformers/transformer_chroma.py +641 -0
- diffusers/models/transformers/transformer_cogview3plus.py +5 -10
- diffusers/models/transformers/transformer_cogview4.py +353 -27
- diffusers/models/transformers/transformer_cosmos.py +586 -0
- diffusers/models/transformers/transformer_flux.py +376 -138
- diffusers/models/transformers/transformer_hidream_image.py +942 -0
- diffusers/models/transformers/transformer_hunyuan_video.py +12 -8
- diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
- diffusers/models/transformers/transformer_ltx.py +105 -24
- diffusers/models/transformers/transformer_lumina2.py +1 -1
- diffusers/models/transformers/transformer_mochi.py +1 -1
- diffusers/models/transformers/transformer_omnigen.py +2 -2
- diffusers/models/transformers/transformer_qwenimage.py +645 -0
- diffusers/models/transformers/transformer_sd3.py +7 -7
- diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/transformers/transformer_wan.py +316 -87
- diffusers/models/transformers/transformer_wan_vace.py +387 -0
- diffusers/models/unets/unet_1d.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +1 -1
- diffusers/models/unets/unet_2d_blocks.py +1 -1
- diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
- diffusers/models/unets/unet_2d_condition.py +4 -3
- diffusers/models/unets/unet_2d_condition_flax.py +2 -2
- diffusers/models/unets/unet_3d_blocks.py +1 -1
- diffusers/models/unets/unet_3d_condition.py +3 -3
- diffusers/models/unets/unet_i2vgen_xl.py +3 -3
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +2 -2
- diffusers/models/unets/unet_stable_cascade.py +1 -1
- diffusers/models/upsampling.py +2 -2
- diffusers/models/vae_flax.py +2 -2
- diffusers/models/vq_model.py +1 -1
- diffusers/modular_pipelines/__init__.py +83 -0
- diffusers/modular_pipelines/components_manager.py +1068 -0
- diffusers/modular_pipelines/flux/__init__.py +66 -0
- diffusers/modular_pipelines/flux/before_denoise.py +689 -0
- diffusers/modular_pipelines/flux/decoders.py +109 -0
- diffusers/modular_pipelines/flux/denoise.py +227 -0
- diffusers/modular_pipelines/flux/encoders.py +412 -0
- diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
- diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
- diffusers/modular_pipelines/modular_pipeline.py +2446 -0
- diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
- diffusers/modular_pipelines/node_utils.py +665 -0
- diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
- diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
- diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
- diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
- diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
- diffusers/modular_pipelines/wan/__init__.py +66 -0
- diffusers/modular_pipelines/wan/before_denoise.py +365 -0
- diffusers/modular_pipelines/wan/decoders.py +105 -0
- diffusers/modular_pipelines/wan/denoise.py +261 -0
- diffusers/modular_pipelines/wan/encoders.py +242 -0
- diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
- diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
- diffusers/pipelines/__init__.py +68 -6
- diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
- diffusers/pipelines/amused/pipeline_amused.py +7 -6
- diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
- diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
- diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +22 -13
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
- diffusers/pipelines/auto_pipeline.py +23 -20
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
- diffusers/pipelines/chroma/__init__.py +49 -0
- diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
- diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
- diffusers/pipelines/chroma/pipeline_output.py +21 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +17 -16
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +17 -16
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +18 -17
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +17 -16
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
- diffusers/pipelines/cogview4/pipeline_cogview4.py +23 -22
- diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
- diffusers/pipelines/consisid/consisid_utils.py +2 -2
- diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +11 -10
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +226 -107
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +12 -8
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +207 -105
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
- diffusers/pipelines/cosmos/__init__.py +54 -0
- diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
- diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
- diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
- diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
- diffusers/pipelines/cosmos/pipeline_output.py +40 -0
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
- diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
- diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
- diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
- diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
- diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
- diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
- diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +8 -8
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
- diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +4 -2
- diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
- diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
- diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
- diffusers/pipelines/flux/__init__.py +4 -0
- diffusers/pipelines/flux/modeling_flux.py +1 -1
- diffusers/pipelines/flux/pipeline_flux.py +37 -36
- diffusers/pipelines/flux/pipeline_flux_control.py +9 -9
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +31 -23
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +3 -2
- diffusers/pipelines/flux/pipeline_flux_fill.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_img2img.py +40 -7
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +12 -7
- diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
- diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +2 -2
- diffusers/pipelines/flux/pipeline_output.py +6 -4
- diffusers/pipelines/free_init_utils.py +2 -2
- diffusers/pipelines/free_noise_utils.py +3 -3
- diffusers/pipelines/hidream_image/__init__.py +47 -0
- diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
- diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
- diffusers/pipelines/hunyuan_video/__init__.py +2 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +26 -25
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
- diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
- diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
- diffusers/pipelines/kolors/text_encoder.py +3 -3
- diffusers/pipelines/kolors/tokenizer.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
- diffusers/pipelines/latte/pipeline_latte.py +12 -12
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
- diffusers/pipelines/ltx/__init__.py +4 -0
- diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +64 -18
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +117 -38
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +63 -18
- diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
- diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
- diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
- diffusers/pipelines/mochi/pipeline_mochi.py +15 -14
- diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
- diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
- diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
- diffusers/pipelines/onnx_utils.py +15 -2
- diffusers/pipelines/pag/pag_utils.py +2 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
- diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
- diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
- diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
- diffusers/pipelines/pia/pipeline_pia.py +8 -6
- diffusers/pipelines/pipeline_flax_utils.py +5 -6
- diffusers/pipelines/pipeline_loading_utils.py +113 -15
- diffusers/pipelines/pipeline_utils.py +127 -48
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +14 -12
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +31 -11
- diffusers/pipelines/qwenimage/__init__.py +55 -0
- diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
- diffusers/pipelines/sana/__init__.py +4 -0
- diffusers/pipelines/sana/pipeline_sana.py +23 -21
- diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
- diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
- diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
- diffusers/pipelines/shap_e/camera.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
- diffusers/pipelines/shap_e/renderer.py +3 -3
- diffusers/pipelines/skyreels_v2/__init__.py +59 -0
- diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
- diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
- diffusers/pipelines/stable_diffusion/__init__.py +0 -7
- diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +12 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +11 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
- diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
- diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +13 -12
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
- diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
- diffusers/pipelines/unclip/text_proj.py +2 -2
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
- diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
- diffusers/pipelines/visualcloze/__init__.py +52 -0
- diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
- diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
- diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
- diffusers/pipelines/wan/__init__.py +2 -0
- diffusers/pipelines/wan/pipeline_wan.py +91 -30
- diffusers/pipelines/wan/pipeline_wan_i2v.py +145 -45
- diffusers/pipelines/wan/pipeline_wan_vace.py +975 -0
- diffusers/pipelines/wan/pipeline_wan_video2video.py +14 -16
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
- diffusers/quantizers/__init__.py +3 -1
- diffusers/quantizers/base.py +17 -1
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
- diffusers/quantizers/bitsandbytes/utils.py +10 -7
- diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
- diffusers/quantizers/gguf/utils.py +108 -16
- diffusers/quantizers/pipe_quant_config.py +202 -0
- diffusers/quantizers/quantization_config.py +18 -16
- diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +31 -1
- diffusers/schedulers/__init__.py +3 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
- diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
- diffusers/schedulers/scheduling_consistency_models.py +1 -1
- diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
- diffusers/schedulers/scheduling_ddim.py +8 -8
- diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
- diffusers/schedulers/scheduling_ddim_flax.py +6 -6
- diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
- diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
- diffusers/schedulers/scheduling_ddpm.py +9 -9
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
- diffusers/schedulers/scheduling_deis_multistep.py +16 -9
- diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +18 -12
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
- diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +19 -13
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
- diffusers/schedulers/scheduling_edm_euler.py +20 -11
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
- diffusers/schedulers/scheduling_euler_discrete.py +3 -3
- diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
- diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
- diffusers/schedulers/scheduling_heun_discrete.py +2 -2
- diffusers/schedulers/scheduling_ipndm.py +2 -2
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
- diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
- diffusers/schedulers/scheduling_lcm.py +3 -3
- diffusers/schedulers/scheduling_lms_discrete.py +2 -2
- diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
- diffusers/schedulers/scheduling_pndm.py +4 -4
- diffusers/schedulers/scheduling_pndm_flax.py +4 -4
- diffusers/schedulers/scheduling_repaint.py +9 -9
- diffusers/schedulers/scheduling_sasolver.py +15 -15
- diffusers/schedulers/scheduling_scm.py +1 -2
- diffusers/schedulers/scheduling_sde_ve.py +1 -1
- diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
- diffusers/schedulers/scheduling_tcd.py +3 -3
- diffusers/schedulers/scheduling_unclip.py +5 -5
- diffusers/schedulers/scheduling_unipc_multistep.py +21 -12
- diffusers/schedulers/scheduling_utils.py +3 -3
- diffusers/schedulers/scheduling_utils_flax.py +2 -2
- diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
- diffusers/training_utils.py +91 -5
- diffusers/utils/__init__.py +15 -0
- diffusers/utils/accelerate_utils.py +1 -1
- diffusers/utils/constants.py +4 -0
- diffusers/utils/doc_utils.py +1 -1
- diffusers/utils/dummy_pt_objects.py +432 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +480 -0
- diffusers/utils/dynamic_modules_utils.py +85 -8
- diffusers/utils/export_utils.py +1 -1
- diffusers/utils/hub_utils.py +33 -17
- diffusers/utils/import_utils.py +151 -18
- diffusers/utils/logging.py +1 -1
- diffusers/utils/outputs.py +2 -1
- diffusers/utils/peft_utils.py +96 -10
- diffusers/utils/state_dict_utils.py +20 -3
- diffusers/utils/testing_utils.py +195 -17
- diffusers/utils/torch_utils.py +43 -5
- diffusers/video_processor.py +2 -2
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/METADATA +72 -57
- diffusers-0.35.0.dist-info/RECORD +703 -0
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/WHEEL +1 -1
- diffusers-0.33.1.dist-info/RECORD +0 -608
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
diffusers/loaders/peft.py
CHANGED
@@ -13,6 +13,7 @@
|
|
13
13
|
# See the License for the specific language governing permissions and
|
14
14
|
# limitations under the License.
|
15
15
|
import inspect
|
16
|
+
import json
|
16
17
|
import os
|
17
18
|
from functools import partial
|
18
19
|
from pathlib import Path
|
@@ -28,13 +29,13 @@ from ..utils import (
|
|
28
29
|
convert_unet_state_dict_to_peft,
|
29
30
|
delete_adapter_layers,
|
30
31
|
get_adapter_name,
|
31
|
-
get_peft_kwargs,
|
32
32
|
is_peft_available,
|
33
33
|
is_peft_version,
|
34
34
|
logging,
|
35
35
|
set_adapter_layers,
|
36
36
|
set_weights_and_activate_adapters,
|
37
37
|
)
|
38
|
+
from ..utils.peft_utils import _create_lora_config, _maybe_warn_for_unhandled_keys
|
38
39
|
from .lora_base import _fetch_state_dict, _func_optionally_disable_offloading
|
39
40
|
from .unet_loader_utils import _maybe_expand_lora_scales
|
40
41
|
|
@@ -52,32 +53,18 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
|
|
52
53
|
"HunyuanVideoTransformer3DModel": lambda model_cls, weights: weights,
|
53
54
|
"LTXVideoTransformer3DModel": lambda model_cls, weights: weights,
|
54
55
|
"SanaTransformer2DModel": lambda model_cls, weights: weights,
|
56
|
+
"AuraFlowTransformer2DModel": lambda model_cls, weights: weights,
|
55
57
|
"Lumina2Transformer2DModel": lambda model_cls, weights: weights,
|
56
58
|
"WanTransformer3DModel": lambda model_cls, weights: weights,
|
57
59
|
"CogView4Transformer2DModel": lambda model_cls, weights: weights,
|
60
|
+
"HiDreamImageTransformer2DModel": lambda model_cls, weights: weights,
|
61
|
+
"HunyuanVideoFramepackTransformer3DModel": lambda model_cls, weights: weights,
|
62
|
+
"WanVACETransformer3DModel": lambda model_cls, weights: weights,
|
63
|
+
"ChromaTransformer2DModel": lambda model_cls, weights: weights,
|
64
|
+
"QwenImageTransformer2DModel": lambda model_cls, weights: weights,
|
58
65
|
}
|
59
66
|
|
60
67
|
|
61
|
-
def _maybe_raise_error_for_ambiguity(config):
|
62
|
-
rank_pattern = config["rank_pattern"].copy()
|
63
|
-
target_modules = config["target_modules"]
|
64
|
-
|
65
|
-
for key in list(rank_pattern.keys()):
|
66
|
-
# try to detect ambiguity
|
67
|
-
# `target_modules` can also be a str, in which case this loop would loop
|
68
|
-
# over the chars of the str. The technically correct way to match LoRA keys
|
69
|
-
# in PEFT is to use LoraModel._check_target_module_exists (lora_config, key).
|
70
|
-
# But this cuts it for now.
|
71
|
-
exact_matches = [mod for mod in target_modules if mod == key]
|
72
|
-
substring_matches = [mod for mod in target_modules if key in mod and mod != key]
|
73
|
-
|
74
|
-
if exact_matches and substring_matches:
|
75
|
-
if is_peft_version("<", "0.14.1"):
|
76
|
-
raise ValueError(
|
77
|
-
"There are ambiguous keys present in this LoRA. To load it, please update your `peft` installation - `pip install -U peft`."
|
78
|
-
)
|
79
|
-
|
80
|
-
|
81
68
|
class PeftAdapterMixin:
|
82
69
|
"""
|
83
70
|
A class containing all functions for loading and using adapters weights that are supported in PEFT library. For
|
@@ -99,17 +86,6 @@ class PeftAdapterMixin:
|
|
99
86
|
@classmethod
|
100
87
|
# Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
|
101
88
|
def _optionally_disable_offloading(cls, _pipeline):
|
102
|
-
"""
|
103
|
-
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
|
104
|
-
|
105
|
-
Args:
|
106
|
-
_pipeline (`DiffusionPipeline`):
|
107
|
-
The pipeline to disable offloading for.
|
108
|
-
|
109
|
-
Returns:
|
110
|
-
tuple:
|
111
|
-
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
|
112
|
-
"""
|
113
89
|
return _func_optionally_disable_offloading(_pipeline=_pipeline)
|
114
90
|
|
115
91
|
def load_lora_adapter(
|
@@ -181,10 +157,15 @@ class PeftAdapterMixin:
|
|
181
157
|
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
182
158
|
limitations to this technique, which are documented here:
|
183
159
|
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
160
|
+
metadata:
|
161
|
+
LoRA adapter metadata. When supplied, the metadata inferred through the state dict isn't used to
|
162
|
+
initialize `LoraConfig`.
|
184
163
|
"""
|
185
|
-
from peft import
|
164
|
+
from peft import inject_adapter_in_model, set_peft_model_state_dict
|
186
165
|
from peft.tuners.tuners_utils import BaseTunerLayer
|
187
166
|
|
167
|
+
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
|
168
|
+
|
188
169
|
cache_dir = kwargs.pop("cache_dir", None)
|
189
170
|
force_download = kwargs.pop("force_download", False)
|
190
171
|
proxies = kwargs.pop("proxies", None)
|
@@ -198,6 +179,7 @@ class PeftAdapterMixin:
|
|
198
179
|
network_alphas = kwargs.pop("network_alphas", None)
|
199
180
|
_pipeline = kwargs.pop("_pipeline", None)
|
200
181
|
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
|
182
|
+
metadata = kwargs.pop("metadata", None)
|
201
183
|
allow_pickle = False
|
202
184
|
|
203
185
|
if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"):
|
@@ -205,12 +187,8 @@ class PeftAdapterMixin:
|
|
205
187
|
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
206
188
|
)
|
207
189
|
|
208
|
-
user_agent = {
|
209
|
-
|
210
|
-
"framework": "pytorch",
|
211
|
-
}
|
212
|
-
|
213
|
-
state_dict = _fetch_state_dict(
|
190
|
+
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
|
191
|
+
state_dict, metadata = _fetch_state_dict(
|
214
192
|
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
215
193
|
weight_name=weight_name,
|
216
194
|
use_safetensors=use_safetensors,
|
@@ -223,12 +201,17 @@ class PeftAdapterMixin:
|
|
223
201
|
subfolder=subfolder,
|
224
202
|
user_agent=user_agent,
|
225
203
|
allow_pickle=allow_pickle,
|
204
|
+
metadata=metadata,
|
226
205
|
)
|
227
206
|
if network_alphas is not None and prefix is None:
|
228
207
|
raise ValueError("`network_alphas` cannot be None when `prefix` is None.")
|
208
|
+
if network_alphas and metadata:
|
209
|
+
raise ValueError("Both `network_alphas` and `metadata` cannot be specified.")
|
229
210
|
|
230
211
|
if prefix is not None:
|
231
|
-
state_dict = {k
|
212
|
+
state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
|
213
|
+
if metadata is not None:
|
214
|
+
metadata = {k.removeprefix(f"{prefix}."): v for k, v in metadata.items() if k.startswith(f"{prefix}.")}
|
232
215
|
|
233
216
|
if len(state_dict) > 0:
|
234
217
|
if adapter_name in getattr(self, "peft_config", {}) and not hotswap:
|
@@ -248,7 +231,7 @@ class PeftAdapterMixin:
|
|
248
231
|
|
249
232
|
rank = {}
|
250
233
|
for key, val in state_dict.items():
|
251
|
-
# Cannot figure out rank from lora layers that don't have
|
234
|
+
# Cannot figure out rank from lora layers that don't have at least 2 dimensions.
|
252
235
|
# Bias layers in LoRA only have a single dimension
|
253
236
|
if "lora_B" in key and val.ndim > 1:
|
254
237
|
# Check out https://github.com/huggingface/peft/pull/2419 for the `^` symbol.
|
@@ -259,44 +242,33 @@ class PeftAdapterMixin:
|
|
259
242
|
|
260
243
|
if network_alphas is not None and len(network_alphas) >= 1:
|
261
244
|
alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")]
|
262
|
-
network_alphas = {
|
263
|
-
|
264
|
-
|
265
|
-
_maybe_raise_error_for_ambiguity(lora_config_kwargs)
|
245
|
+
network_alphas = {
|
246
|
+
k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys
|
247
|
+
}
|
266
248
|
|
267
|
-
if "use_dora" in lora_config_kwargs:
|
268
|
-
if lora_config_kwargs["use_dora"]:
|
269
|
-
if is_peft_version("<", "0.9.0"):
|
270
|
-
raise ValueError(
|
271
|
-
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
272
|
-
)
|
273
|
-
else:
|
274
|
-
if is_peft_version("<", "0.9.0"):
|
275
|
-
lora_config_kwargs.pop("use_dora")
|
276
|
-
|
277
|
-
if "lora_bias" in lora_config_kwargs:
|
278
|
-
if lora_config_kwargs["lora_bias"]:
|
279
|
-
if is_peft_version("<=", "0.13.2"):
|
280
|
-
raise ValueError(
|
281
|
-
"You need `peft` 0.14.0 at least to use `lora_bias` in LoRAs. Please upgrade your installation of `peft`."
|
282
|
-
)
|
283
|
-
else:
|
284
|
-
if is_peft_version("<=", "0.13.2"):
|
285
|
-
lora_config_kwargs.pop("lora_bias")
|
286
|
-
|
287
|
-
lora_config = LoraConfig(**lora_config_kwargs)
|
288
249
|
# adapter_name
|
289
250
|
if adapter_name is None:
|
290
251
|
adapter_name = get_adapter_name(self)
|
291
252
|
|
253
|
+
# create LoraConfig
|
254
|
+
lora_config = _create_lora_config(
|
255
|
+
state_dict,
|
256
|
+
network_alphas,
|
257
|
+
metadata,
|
258
|
+
rank,
|
259
|
+
model_state_dict=self.state_dict(),
|
260
|
+
adapter_name=adapter_name,
|
261
|
+
)
|
262
|
+
|
292
263
|
# <Unsafe code
|
293
264
|
# We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
|
294
265
|
# Now we remove any existing hooks to `_pipeline`.
|
295
266
|
|
296
267
|
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
|
297
|
-
# otherwise loading LoRA weights will lead to an error
|
298
|
-
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(
|
299
|
-
|
268
|
+
# otherwise loading LoRA weights will lead to an error.
|
269
|
+
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
|
270
|
+
_pipeline
|
271
|
+
)
|
300
272
|
peft_kwargs = {}
|
301
273
|
if is_peft_version(">=", "0.13.1"):
|
302
274
|
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
@@ -328,7 +300,7 @@ class PeftAdapterMixin:
|
|
328
300
|
new_sd[k] = v
|
329
301
|
return new_sd
|
330
302
|
|
331
|
-
# To handle scenarios where we cannot successfully set state dict. If it's
|
303
|
+
# To handle scenarios where we cannot successfully set state dict. If it's unsuccessful,
|
332
304
|
# we should also delete the `peft_config` associated to the `adapter_name`.
|
333
305
|
try:
|
334
306
|
if hotswap:
|
@@ -342,13 +314,15 @@ class PeftAdapterMixin:
|
|
342
314
|
config=lora_config,
|
343
315
|
)
|
344
316
|
except Exception as e:
|
345
|
-
logger.error(f"Hotswapping {adapter_name} was
|
317
|
+
logger.error(f"Hotswapping {adapter_name} was unsuccessful with the following error: \n{e}")
|
346
318
|
raise
|
347
319
|
# the hotswap function raises if there are incompatible keys, so if we reach this point we can set
|
348
320
|
# it to None
|
349
321
|
incompatible_keys = None
|
350
322
|
else:
|
351
|
-
inject_adapter_in_model(
|
323
|
+
inject_adapter_in_model(
|
324
|
+
lora_config, self, adapter_name=adapter_name, state_dict=state_dict, **peft_kwargs
|
325
|
+
)
|
352
326
|
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
|
353
327
|
|
354
328
|
if self._prepare_lora_hotswap_kwargs is not None:
|
@@ -377,46 +351,28 @@ class PeftAdapterMixin:
|
|
377
351
|
module.delete_adapter(adapter_name)
|
378
352
|
|
379
353
|
self.peft_config.pop(adapter_name)
|
380
|
-
logger.error(f"Loading {adapter_name} was
|
354
|
+
logger.error(f"Loading {adapter_name} was unsuccessful with the following error: \n{e}")
|
381
355
|
raise
|
382
356
|
|
383
|
-
|
384
|
-
if incompatible_keys is not None:
|
385
|
-
# Check only for unexpected keys.
|
386
|
-
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
387
|
-
if unexpected_keys:
|
388
|
-
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
|
389
|
-
if lora_unexpected_keys:
|
390
|
-
warn_msg = (
|
391
|
-
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
|
392
|
-
f" {', '.join(lora_unexpected_keys)}. "
|
393
|
-
)
|
394
|
-
|
395
|
-
# Filter missing keys specific to the current adapter.
|
396
|
-
missing_keys = getattr(incompatible_keys, "missing_keys", None)
|
397
|
-
if missing_keys:
|
398
|
-
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
|
399
|
-
if lora_missing_keys:
|
400
|
-
warn_msg += (
|
401
|
-
f"Loading adapter weights from state_dict led to missing keys in the model:"
|
402
|
-
f" {', '.join(lora_missing_keys)}."
|
403
|
-
)
|
404
|
-
|
405
|
-
if warn_msg:
|
406
|
-
logger.warning(warn_msg)
|
357
|
+
_maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name)
|
407
358
|
|
408
359
|
# Offload back.
|
409
360
|
if is_model_cpu_offload:
|
410
361
|
_pipeline.enable_model_cpu_offload()
|
411
362
|
elif is_sequential_cpu_offload:
|
412
363
|
_pipeline.enable_sequential_cpu_offload()
|
364
|
+
elif is_group_offload:
|
365
|
+
for component in _pipeline.components.values():
|
366
|
+
if isinstance(component, torch.nn.Module):
|
367
|
+
_maybe_remove_and_reapply_group_offloading(component)
|
413
368
|
# Unsafe code />
|
414
369
|
|
415
370
|
if prefix is not None and not state_dict:
|
371
|
+
model_class_name = self.__class__.__name__
|
416
372
|
logger.warning(
|
417
|
-
f"No LoRA keys associated to {
|
373
|
+
f"No LoRA keys associated to {model_class_name} found with the {prefix=}. "
|
418
374
|
"This is safe to ignore if LoRA state dict didn't originally have any "
|
419
|
-
f"{
|
375
|
+
f"{model_class_name} related params. You can also try specifying `prefix=None` "
|
420
376
|
"to resolve the warning. Otherwise, open an issue if you think it's unexpected: "
|
421
377
|
"https://github.com/huggingface/diffusers/issues/new"
|
422
378
|
)
|
@@ -439,17 +395,13 @@ class PeftAdapterMixin:
|
|
439
395
|
underlying model has multiple adapters loaded.
|
440
396
|
upcast_before_saving (`bool`, defaults to `False`):
|
441
397
|
Whether to cast the underlying model to `torch.float32` before serialization.
|
442
|
-
save_function (`Callable`):
|
443
|
-
The function to use to save the state dictionary. Useful during distributed training when you need to
|
444
|
-
replace `torch.save` with another method. Can be configured with the environment variable
|
445
|
-
`DIFFUSERS_SAVE_MODE`.
|
446
398
|
safe_serialization (`bool`, *optional*, defaults to `True`):
|
447
399
|
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
448
400
|
weight_name: (`str`, *optional*, defaults to `None`): Name of the file to serialize the state dict with.
|
449
401
|
"""
|
450
402
|
from peft.utils import get_peft_model_state_dict
|
451
403
|
|
452
|
-
from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
|
404
|
+
from .lora_base import LORA_ADAPTER_METADATA_KEY, LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
|
453
405
|
|
454
406
|
if adapter_name is None:
|
455
407
|
adapter_name = get_adapter_name(self)
|
@@ -457,6 +409,8 @@ class PeftAdapterMixin:
|
|
457
409
|
if adapter_name not in getattr(self, "peft_config", {}):
|
458
410
|
raise ValueError(f"Adapter name {adapter_name} not found in the model.")
|
459
411
|
|
412
|
+
lora_adapter_metadata = self.peft_config[adapter_name].to_dict()
|
413
|
+
|
460
414
|
lora_layers_to_save = get_peft_model_state_dict(
|
461
415
|
self.to(dtype=torch.float32 if upcast_before_saving else None), adapter_name=adapter_name
|
462
416
|
)
|
@@ -466,7 +420,15 @@ class PeftAdapterMixin:
|
|
466
420
|
if safe_serialization:
|
467
421
|
|
468
422
|
def save_function(weights, filename):
|
469
|
-
|
423
|
+
# Inject framework format.
|
424
|
+
metadata = {"format": "pt"}
|
425
|
+
if lora_adapter_metadata is not None:
|
426
|
+
for key, value in lora_adapter_metadata.items():
|
427
|
+
if isinstance(value, set):
|
428
|
+
lora_adapter_metadata[key] = list(value)
|
429
|
+
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True)
|
430
|
+
|
431
|
+
return safetensors.torch.save_file(weights, filename, metadata=metadata)
|
470
432
|
|
471
433
|
else:
|
472
434
|
save_function = torch.save
|
@@ -479,7 +441,6 @@ class PeftAdapterMixin:
|
|
479
441
|
else:
|
480
442
|
weight_name = LORA_WEIGHT_NAME
|
481
443
|
|
482
|
-
# TODO: we could consider saving the `peft_config` as well.
|
483
444
|
save_path = Path(save_directory, weight_name).as_posix()
|
484
445
|
save_function(lora_layers_to_save, save_path)
|
485
446
|
logger.info(f"Model weights saved in {save_path}")
|
@@ -490,7 +451,7 @@ class PeftAdapterMixin:
|
|
490
451
|
weights: Optional[Union[float, Dict, List[float], List[Dict], List[None]]] = None,
|
491
452
|
):
|
492
453
|
"""
|
493
|
-
Set the currently active adapters for use in the
|
454
|
+
Set the currently active adapters for use in the diffusion network (e.g. unet, transformer, etc.).
|
494
455
|
|
495
456
|
Args:
|
496
457
|
adapter_names (`List[str]` or `str`):
|
@@ -512,7 +473,7 @@ class PeftAdapterMixin:
|
|
512
473
|
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
|
513
474
|
)
|
514
475
|
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
515
|
-
pipeline.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5])
|
476
|
+
pipeline.unet.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5])
|
516
477
|
```
|
517
478
|
"""
|
518
479
|
if not USE_PEFT_BACKEND:
|
@@ -710,7 +671,7 @@ class PeftAdapterMixin:
|
|
710
671
|
if self.lora_scale != 1.0:
|
711
672
|
module.scale_layer(self.lora_scale)
|
712
673
|
|
713
|
-
# For BC with
|
674
|
+
# For BC with previous PEFT versions, we need to check the signature
|
714
675
|
# of the `merge` method to see if it supports the `adapter_names` argument.
|
715
676
|
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
|
716
677
|
if "adapter_names" in supported_merge_kwargs:
|
@@ -738,11 +699,16 @@ class PeftAdapterMixin:
|
|
738
699
|
if not USE_PEFT_BACKEND:
|
739
700
|
raise ValueError("PEFT backend is required for `unload_lora()`.")
|
740
701
|
|
702
|
+
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
|
741
703
|
from ..utils import recurse_remove_peft_layers
|
742
704
|
|
743
705
|
recurse_remove_peft_layers(self)
|
744
706
|
if hasattr(self, "peft_config"):
|
745
707
|
del self.peft_config
|
708
|
+
if hasattr(self, "_hf_peft_config_loaded"):
|
709
|
+
self._hf_peft_config_loaded = None
|
710
|
+
|
711
|
+
_maybe_remove_and_reapply_group_offloading(self)
|
746
712
|
|
747
713
|
def disable_lora(self):
|
748
714
|
"""
|
@@ -760,7 +726,7 @@ class PeftAdapterMixin:
|
|
760
726
|
pipeline.load_lora_weights(
|
761
727
|
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
|
762
728
|
)
|
763
|
-
pipeline.disable_lora()
|
729
|
+
pipeline.unet.disable_lora()
|
764
730
|
```
|
765
731
|
"""
|
766
732
|
if not USE_PEFT_BACKEND:
|
@@ -783,7 +749,7 @@ class PeftAdapterMixin:
|
|
783
749
|
pipeline.load_lora_weights(
|
784
750
|
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
|
785
751
|
)
|
786
|
-
pipeline.enable_lora()
|
752
|
+
pipeline.unet.enable_lora()
|
787
753
|
```
|
788
754
|
"""
|
789
755
|
if not USE_PEFT_BACKEND:
|
@@ -810,7 +776,7 @@ class PeftAdapterMixin:
|
|
810
776
|
pipeline.load_lora_weights(
|
811
777
|
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_names="cinematic"
|
812
778
|
)
|
813
|
-
pipeline.delete_adapters("cinematic")
|
779
|
+
pipeline.unet.delete_adapters("cinematic")
|
814
780
|
```
|
815
781
|
"""
|
816
782
|
if not USE_PEFT_BACKEND:
|
diffusers/loaders/single_file.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -453,7 +453,7 @@ class FromSingleFileMixin:
|
|
453
453
|
logger.warning(
|
454
454
|
"Detected legacy `from_single_file` loading behavior. Attempting to create the pipeline based on inferred components.\n"
|
455
455
|
"This may lead to errors if the model components are not correctly inferred. \n"
|
456
|
-
"To avoid this warning, please
|
456
|
+
"To avoid this warning, please explicitly pass the `config` argument to `from_single_file` with a path to a local diffusers model repo \n"
|
457
457
|
"e.g. `from_single_file(<my model checkpoint path>, config=<path to local diffusers model repo>) \n"
|
458
458
|
"or run `from_single_file` with `local_files_only=False` first to update the local cache directory with "
|
459
459
|
"the necessary config files.\n"
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -21,15 +21,20 @@ import torch
|
|
21
21
|
from huggingface_hub.utils import validate_hf_hub_args
|
22
22
|
from typing_extensions import Self
|
23
23
|
|
24
|
+
from .. import __version__
|
24
25
|
from ..quantizers import DiffusersAutoQuantizer
|
25
|
-
from ..utils import deprecate, is_accelerate_available, logging
|
26
|
+
from ..utils import deprecate, is_accelerate_available, is_torch_version, logging
|
27
|
+
from ..utils.torch_utils import empty_device_cache
|
26
28
|
from .single_file_utils import (
|
27
29
|
SingleFileComponentError,
|
28
30
|
convert_animatediff_checkpoint_to_diffusers,
|
29
31
|
convert_auraflow_transformer_checkpoint_to_diffusers,
|
30
32
|
convert_autoencoder_dc_checkpoint_to_diffusers,
|
33
|
+
convert_chroma_transformer_checkpoint_to_diffusers,
|
31
34
|
convert_controlnet_checkpoint,
|
35
|
+
convert_cosmos_transformer_checkpoint_to_diffusers,
|
32
36
|
convert_flux_transformer_checkpoint_to_diffusers,
|
37
|
+
convert_hidream_transformer_to_diffusers,
|
33
38
|
convert_hunyuan_video_transformer_to_diffusers,
|
34
39
|
convert_ldm_unet_checkpoint,
|
35
40
|
convert_ldm_vae_checkpoint,
|
@@ -57,8 +62,12 @@ logger = logging.get_logger(__name__)
|
|
57
62
|
if is_accelerate_available():
|
58
63
|
from accelerate import dispatch_model, init_empty_weights
|
59
64
|
|
60
|
-
from ..models.
|
65
|
+
from ..models.model_loading_utils import load_model_dict_into_meta
|
61
66
|
|
67
|
+
if is_torch_version(">=", "1.9.0") and is_accelerate_available():
|
68
|
+
_LOW_CPU_MEM_USAGE_DEFAULT = True
|
69
|
+
else:
|
70
|
+
_LOW_CPU_MEM_USAGE_DEFAULT = False
|
62
71
|
|
63
72
|
SINGLE_FILE_LOADABLE_CLASSES = {
|
64
73
|
"StableCascadeUNet": {
|
@@ -95,6 +104,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
|
95
104
|
"checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers,
|
96
105
|
"default_subfolder": "transformer",
|
97
106
|
},
|
107
|
+
"ChromaTransformer2DModel": {
|
108
|
+
"checkpoint_mapping_fn": convert_chroma_transformer_checkpoint_to_diffusers,
|
109
|
+
"default_subfolder": "transformer",
|
110
|
+
},
|
98
111
|
"LTXVideoTransformer3DModel": {
|
99
112
|
"checkpoint_mapping_fn": convert_ltx_transformer_checkpoint_to_diffusers,
|
100
113
|
"default_subfolder": "transformer",
|
@@ -128,13 +141,33 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
|
128
141
|
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
|
129
142
|
"default_subfolder": "transformer",
|
130
143
|
},
|
144
|
+
"WanVACETransformer3DModel": {
|
145
|
+
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
|
146
|
+
"default_subfolder": "transformer",
|
147
|
+
},
|
131
148
|
"AutoencoderKLWan": {
|
132
149
|
"checkpoint_mapping_fn": convert_wan_vae_to_diffusers,
|
133
150
|
"default_subfolder": "vae",
|
134
151
|
},
|
152
|
+
"HiDreamImageTransformer2DModel": {
|
153
|
+
"checkpoint_mapping_fn": convert_hidream_transformer_to_diffusers,
|
154
|
+
"default_subfolder": "transformer",
|
155
|
+
},
|
156
|
+
"CosmosTransformer3DModel": {
|
157
|
+
"checkpoint_mapping_fn": convert_cosmos_transformer_checkpoint_to_diffusers,
|
158
|
+
"default_subfolder": "transformer",
|
159
|
+
},
|
160
|
+
"QwenImageTransformer2DModel": {
|
161
|
+
"checkpoint_mapping_fn": lambda x: x,
|
162
|
+
"default_subfolder": "transformer",
|
163
|
+
},
|
135
164
|
}
|
136
165
|
|
137
166
|
|
167
|
+
def _should_convert_state_dict_to_diffusers(model_state_dict, checkpoint_state_dict):
|
168
|
+
return not set(model_state_dict.keys()).issubset(set(checkpoint_state_dict.keys()))
|
169
|
+
|
170
|
+
|
138
171
|
def _get_single_file_loadable_mapping_class(cls):
|
139
172
|
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
140
173
|
for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES:
|
@@ -186,9 +219,8 @@ class FromOriginalModelMixin:
|
|
186
219
|
original_config (`str`, *optional*):
|
187
220
|
Dict or path to a yaml file containing the configuration for the model in its original format.
|
188
221
|
If a dict is provided, it will be used to initialize the model configuration.
|
189
|
-
torch_dtype (`
|
190
|
-
Override the default `torch.dtype` and load the model with another dtype.
|
191
|
-
dtype is automatically derived from the model's weights.
|
222
|
+
torch_dtype (`torch.dtype`, *optional*):
|
223
|
+
Override the default `torch.dtype` and load the model with another dtype.
|
192
224
|
force_download (`bool`, *optional*, defaults to `False`):
|
193
225
|
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
194
226
|
cached versions if they exist.
|
@@ -208,6 +240,11 @@ class FromOriginalModelMixin:
|
|
208
240
|
revision (`str`, *optional*, defaults to `"main"`):
|
209
241
|
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
210
242
|
allowed by Git.
|
243
|
+
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 and
|
244
|
+
is_accelerate_available() else `False`): Speed up model loading only loading the pretrained weights and
|
245
|
+
not initializing the weights. This also tries to not use more than 1x model size in CPU memory
|
246
|
+
(including peak memory) while loading the model. Only supported for PyTorch >= 1.9.0. If you are using
|
247
|
+
an older version of PyTorch, setting this argument to `True` will raise an error.
|
211
248
|
disable_mmap ('bool', *optional*, defaults to 'False'):
|
212
249
|
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
|
213
250
|
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
|
@@ -257,9 +294,15 @@ class FromOriginalModelMixin:
|
|
257
294
|
config_revision = kwargs.pop("config_revision", None)
|
258
295
|
torch_dtype = kwargs.pop("torch_dtype", None)
|
259
296
|
quantization_config = kwargs.pop("quantization_config", None)
|
297
|
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
260
298
|
device = kwargs.pop("device", None)
|
261
299
|
disable_mmap = kwargs.pop("disable_mmap", False)
|
262
300
|
|
301
|
+
user_agent = {"diffusers": __version__, "file_type": "single_file", "framework": "pytorch"}
|
302
|
+
# In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
|
303
|
+
if quantization_config is not None:
|
304
|
+
user_agent["quant"] = quantization_config.quant_method.value
|
305
|
+
|
263
306
|
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
264
307
|
torch_dtype = torch.float32
|
265
308
|
logger.warning(
|
@@ -278,6 +321,7 @@ class FromOriginalModelMixin:
|
|
278
321
|
local_files_only=local_files_only,
|
279
322
|
revision=revision,
|
280
323
|
disable_mmap=disable_mmap,
|
324
|
+
user_agent=user_agent,
|
281
325
|
)
|
282
326
|
if quantization_config is not None:
|
283
327
|
hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)
|
@@ -355,19 +399,23 @@ class FromOriginalModelMixin:
|
|
355
399
|
model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs}
|
356
400
|
diffusers_model_config.update(model_kwargs)
|
357
401
|
|
402
|
+
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
|
403
|
+
with ctx():
|
404
|
+
model = cls.from_config(diffusers_model_config)
|
405
|
+
|
358
406
|
checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
|
359
|
-
|
360
|
-
|
361
|
-
|
407
|
+
|
408
|
+
if _should_convert_state_dict_to_diffusers(model.state_dict(), checkpoint):
|
409
|
+
diffusers_format_checkpoint = checkpoint_mapping_fn(
|
410
|
+
config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
|
411
|
+
)
|
412
|
+
else:
|
413
|
+
diffusers_format_checkpoint = checkpoint
|
414
|
+
|
362
415
|
if not diffusers_format_checkpoint:
|
363
416
|
raise SingleFileComponentError(
|
364
417
|
f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
|
365
418
|
)
|
366
|
-
|
367
|
-
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
368
|
-
with ctx():
|
369
|
-
model = cls.from_config(diffusers_model_config)
|
370
|
-
|
371
419
|
# Check if `_keep_in_fp32_modules` is not None
|
372
420
|
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
|
373
421
|
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
|
@@ -389,7 +437,7 @@ class FromOriginalModelMixin:
|
|
389
437
|
)
|
390
438
|
|
391
439
|
device_map = None
|
392
|
-
if
|
440
|
+
if low_cpu_mem_usage:
|
393
441
|
param_device = torch.device(device) if device else torch.device("cpu")
|
394
442
|
empty_state_dict = model.state_dict()
|
395
443
|
unexpected_keys = [
|
@@ -405,6 +453,7 @@ class FromOriginalModelMixin:
|
|
405
453
|
keep_in_fp32_modules=keep_in_fp32_modules,
|
406
454
|
unexpected_keys=unexpected_keys,
|
407
455
|
)
|
456
|
+
empty_device_cache()
|
408
457
|
else:
|
409
458
|
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
|
410
459
|
|