diffusers 0.32.2__py3-none-any.whl → 0.33.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +186 -3
- diffusers/configuration_utils.py +40 -12
- diffusers/dependency_versions_table.py +9 -2
- diffusers/hooks/__init__.py +9 -0
- diffusers/hooks/faster_cache.py +653 -0
- diffusers/hooks/group_offloading.py +793 -0
- diffusers/hooks/hooks.py +236 -0
- diffusers/hooks/layerwise_casting.py +245 -0
- diffusers/hooks/pyramid_attention_broadcast.py +311 -0
- diffusers/loaders/__init__.py +6 -0
- diffusers/loaders/ip_adapter.py +38 -30
- diffusers/loaders/lora_base.py +121 -86
- diffusers/loaders/lora_conversion_utils.py +504 -44
- diffusers/loaders/lora_pipeline.py +1769 -181
- diffusers/loaders/peft.py +167 -57
- diffusers/loaders/single_file.py +17 -2
- diffusers/loaders/single_file_model.py +53 -5
- diffusers/loaders/single_file_utils.py +646 -72
- diffusers/loaders/textual_inversion.py +9 -9
- diffusers/loaders/transformer_flux.py +8 -9
- diffusers/loaders/transformer_sd3.py +120 -39
- diffusers/loaders/unet.py +20 -7
- diffusers/models/__init__.py +22 -0
- diffusers/models/activations.py +9 -9
- diffusers/models/attention.py +0 -1
- diffusers/models/attention_processor.py +163 -25
- diffusers/models/auto_model.py +169 -0
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
- diffusers/models/autoencoders/autoencoder_dc.py +106 -4
- diffusers/models/autoencoders/autoencoder_kl.py +0 -4
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
- diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
- diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
- diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
- diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
- diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
- diffusers/models/autoencoders/vae.py +31 -141
- diffusers/models/autoencoders/vq_model.py +3 -0
- diffusers/models/cache_utils.py +108 -0
- diffusers/models/controlnets/__init__.py +1 -0
- diffusers/models/controlnets/controlnet.py +3 -8
- diffusers/models/controlnets/controlnet_flux.py +14 -42
- diffusers/models/controlnets/controlnet_sd3.py +58 -34
- diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
- diffusers/models/controlnets/controlnet_union.py +27 -18
- diffusers/models/controlnets/controlnet_xs.py +7 -46
- diffusers/models/controlnets/multicontrolnet_union.py +196 -0
- diffusers/models/embeddings.py +18 -7
- diffusers/models/model_loading_utils.py +122 -80
- diffusers/models/modeling_flax_pytorch_utils.py +1 -1
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +617 -272
- diffusers/models/normalization.py +67 -14
- diffusers/models/resnet.py +1 -1
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
- diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
- diffusers/models/transformers/consisid_transformer_3d.py +789 -0
- diffusers/models/transformers/dit_transformer_2d.py +5 -19
- diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
- diffusers/models/transformers/latte_transformer_3d.py +20 -15
- diffusers/models/transformers/lumina_nextdit2d.py +3 -1
- diffusers/models/transformers/pixart_transformer_2d.py +4 -19
- diffusers/models/transformers/prior_transformer.py +5 -1
- diffusers/models/transformers/sana_transformer.py +144 -40
- diffusers/models/transformers/stable_audio_transformer.py +5 -20
- diffusers/models/transformers/transformer_2d.py +7 -22
- diffusers/models/transformers/transformer_allegro.py +9 -17
- diffusers/models/transformers/transformer_cogview3plus.py +6 -17
- diffusers/models/transformers/transformer_cogview4.py +462 -0
- diffusers/models/transformers/transformer_easyanimate.py +527 -0
- diffusers/models/transformers/transformer_flux.py +68 -110
- diffusers/models/transformers/transformer_hunyuan_video.py +404 -46
- diffusers/models/transformers/transformer_ltx.py +53 -35
- diffusers/models/transformers/transformer_lumina2.py +548 -0
- diffusers/models/transformers/transformer_mochi.py +6 -17
- diffusers/models/transformers/transformer_omnigen.py +469 -0
- diffusers/models/transformers/transformer_sd3.py +56 -86
- diffusers/models/transformers/transformer_temporal.py +5 -11
- diffusers/models/transformers/transformer_wan.py +469 -0
- diffusers/models/unets/unet_1d.py +3 -1
- diffusers/models/unets/unet_2d.py +21 -20
- diffusers/models/unets/unet_2d_blocks.py +19 -243
- diffusers/models/unets/unet_2d_condition.py +4 -6
- diffusers/models/unets/unet_3d_blocks.py +14 -127
- diffusers/models/unets/unet_3d_condition.py +8 -12
- diffusers/models/unets/unet_i2vgen_xl.py +5 -13
- diffusers/models/unets/unet_kandinsky3.py +0 -4
- diffusers/models/unets/unet_motion_model.py +20 -114
- diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
- diffusers/models/unets/unet_stable_cascade.py +8 -35
- diffusers/models/unets/uvit_2d.py +1 -4
- diffusers/optimization.py +2 -2
- diffusers/pipelines/__init__.py +57 -8
- diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
- diffusers/pipelines/amused/pipeline_amused.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
- diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
- diffusers/pipelines/auto_pipeline.py +35 -14
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
- diffusers/pipelines/cogview4/__init__.py +49 -0
- diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
- diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
- diffusers/pipelines/cogview4/pipeline_output.py +21 -0
- diffusers/pipelines/consisid/__init__.py +49 -0
- diffusers/pipelines/consisid/consisid_utils.py +357 -0
- diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
- diffusers/pipelines/consisid/pipeline_output.py +20 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
- diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
- diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +15 -2
- diffusers/pipelines/easyanimate/__init__.py +52 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
- diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
- diffusers/pipelines/flux/pipeline_flux.py +53 -21
- diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
- diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
- diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
- diffusers/pipelines/free_noise_utils.py +3 -3
- diffusers/pipelines/hunyuan_video/__init__.py +4 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
- diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
- diffusers/pipelines/kolors/text_encoder.py +7 -34
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
- diffusers/pipelines/latte/pipeline_latte.py +36 -7
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
- diffusers/pipelines/ltx/__init__.py +2 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
- diffusers/pipelines/lumina/__init__.py +2 -2
- diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
- diffusers/pipelines/lumina2/__init__.py +48 -0
- diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
- diffusers/pipelines/marigold/__init__.py +2 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
- diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
- diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
- diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
- diffusers/pipelines/omnigen/__init__.py +50 -0
- diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
- diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
- diffusers/pipelines/onnx_utils.py +5 -3
- diffusers/pipelines/pag/pag_utils.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
- diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
- diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
- diffusers/pipelines/pia/pipeline_pia.py +13 -1
- diffusers/pipelines/pipeline_flax_utils.py +7 -7
- diffusers/pipelines/pipeline_loading_utils.py +193 -83
- diffusers/pipelines/pipeline_utils.py +221 -106
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
- diffusers/pipelines/sana/__init__.py +2 -0
- diffusers/pipelines/sana/pipeline_sana.py +183 -58
- diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
- diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
- diffusers/pipelines/shap_e/renderer.py +6 -6
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
- diffusers/pipelines/transformers_loading_utils.py +121 -0
- diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
- diffusers/pipelines/wan/__init__.py +51 -0
- diffusers/pipelines/wan/pipeline_output.py +20 -0
- diffusers/pipelines/wan/pipeline_wan.py +593 -0
- diffusers/pipelines/wan/pipeline_wan_i2v.py +722 -0
- diffusers/pipelines/wan/pipeline_wan_video2video.py +725 -0
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
- diffusers/quantizers/auto.py +5 -1
- diffusers/quantizers/base.py +5 -9
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
- diffusers/quantizers/bitsandbytes/utils.py +30 -20
- diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
- diffusers/quantizers/gguf/utils.py +4 -2
- diffusers/quantizers/quantization_config.py +59 -4
- diffusers/quantizers/quanto/__init__.py +1 -0
- diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
- diffusers/quantizers/quanto/utils.py +60 -0
- diffusers/quantizers/torchao/__init__.py +1 -1
- diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
- diffusers/schedulers/__init__.py +2 -1
- diffusers/schedulers/scheduling_consistency_models.py +1 -2
- diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
- diffusers/schedulers/scheduling_ddpm.py +2 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
- diffusers/schedulers/scheduling_edm_euler.py +45 -10
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
- diffusers/schedulers/scheduling_heun_discrete.py +1 -1
- diffusers/schedulers/scheduling_lcm.py +1 -2
- diffusers/schedulers/scheduling_lms_discrete.py +1 -1
- diffusers/schedulers/scheduling_repaint.py +5 -1
- diffusers/schedulers/scheduling_scm.py +265 -0
- diffusers/schedulers/scheduling_tcd.py +1 -2
- diffusers/schedulers/scheduling_utils.py +2 -1
- diffusers/training_utils.py +14 -7
- diffusers/utils/__init__.py +9 -1
- diffusers/utils/constants.py +13 -1
- diffusers/utils/deprecation_utils.py +1 -1
- diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
- diffusers/utils/dummy_gguf_objects.py +17 -0
- diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
- diffusers/utils/dummy_pt_objects.py +233 -0
- diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
- diffusers/utils/dummy_torchao_objects.py +17 -0
- diffusers/utils/dynamic_modules_utils.py +1 -1
- diffusers/utils/export_utils.py +28 -3
- diffusers/utils/hub_utils.py +52 -102
- diffusers/utils/import_utils.py +121 -221
- diffusers/utils/loading_utils.py +2 -1
- diffusers/utils/logging.py +1 -2
- diffusers/utils/peft_utils.py +6 -14
- diffusers/utils/remote_utils.py +425 -0
- diffusers/utils/source_code_parsing_utils.py +52 -0
- diffusers/utils/state_dict_utils.py +15 -1
- diffusers/utils/testing_utils.py +243 -13
- diffusers/utils/torch_utils.py +10 -0
- diffusers/utils/typing_utils.py +91 -0
- diffusers/video_processor.py +1 -1
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/METADATA +76 -44
- diffusers-0.33.0.dist-info/RECORD +608 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/WHEEL +1 -1
- diffusers-0.32.2.dist-info/RECORD +0 -550
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/LICENSE +0 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/top_level.txt +0 -0
@@ -13,15 +13,22 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import re
|
16
|
+
from typing import List
|
16
17
|
|
17
18
|
import torch
|
18
19
|
|
19
|
-
from ..utils import is_peft_version, logging
|
20
|
+
from ..utils import is_peft_version, logging, state_dict_all_zero
|
20
21
|
|
21
22
|
|
22
23
|
logger = logging.get_logger(__name__)
|
23
24
|
|
24
25
|
|
26
|
+
def swap_scale_shift(weight):
|
27
|
+
shift, scale = weight.chunk(2, dim=0)
|
28
|
+
new_weight = torch.cat([scale, shift], dim=0)
|
29
|
+
return new_weight
|
30
|
+
|
31
|
+
|
25
32
|
def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", block_slice_pos=5):
|
26
33
|
# 1. get all state_dict_keys
|
27
34
|
all_keys = list(state_dict.keys())
|
@@ -177,9 +184,9 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
|
|
177
184
|
# Store DoRA scale if present.
|
178
185
|
if dora_present_in_unet:
|
179
186
|
dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down."
|
180
|
-
unet_state_dict[
|
181
|
-
|
182
|
-
|
187
|
+
unet_state_dict[diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")] = (
|
188
|
+
state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
189
|
+
)
|
183
190
|
|
184
191
|
# Handle text encoder LoRAs.
|
185
192
|
elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
|
@@ -199,13 +206,13 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
|
|
199
206
|
"_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer."
|
200
207
|
)
|
201
208
|
if lora_name.startswith(("lora_te_", "lora_te1_")):
|
202
|
-
te_state_dict[
|
203
|
-
|
204
|
-
|
209
|
+
te_state_dict[diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")] = (
|
210
|
+
state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
211
|
+
)
|
205
212
|
elif lora_name.startswith("lora_te2_"):
|
206
|
-
te2_state_dict[
|
207
|
-
|
208
|
-
|
213
|
+
te2_state_dict[diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")] = (
|
214
|
+
state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
215
|
+
)
|
209
216
|
|
210
217
|
# Store alpha if present.
|
211
218
|
if lora_name_alpha in state_dict:
|
@@ -313,6 +320,7 @@ def _convert_text_encoder_lora_key(key, lora_name):
|
|
313
320
|
# Be aware that this is the new diffusers convention and the rest of the code might
|
314
321
|
# not utilize it yet.
|
315
322
|
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
|
323
|
+
|
316
324
|
return diffusers_name
|
317
325
|
|
318
326
|
|
@@ -331,8 +339,7 @@ def _get_alpha_name(lora_name_alpha, diffusers_name, alpha):
|
|
331
339
|
|
332
340
|
|
333
341
|
# The utilities under `_convert_kohya_flux_lora_to_diffusers()`
|
334
|
-
# are
|
335
|
-
# All credits go to `kohya-ss`.
|
342
|
+
# are adapted from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
|
336
343
|
def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
337
344
|
def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
|
338
345
|
if sds_key + ".lora_down.weight" not in sds_sd:
|
@@ -341,7 +348,8 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
|
341
348
|
|
342
349
|
# scale weight by alpha and dim
|
343
350
|
rank = down_weight.shape[0]
|
344
|
-
|
351
|
+
default_alpha = torch.tensor(rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False)
|
352
|
+
alpha = sds_sd.pop(sds_key + ".alpha", default_alpha).item() # alpha is scalar
|
345
353
|
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
|
346
354
|
|
347
355
|
# calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2
|
@@ -362,7 +370,10 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
|
362
370
|
sd_lora_rank = down_weight.shape[0]
|
363
371
|
|
364
372
|
# scale weight by alpha and dim
|
365
|
-
|
373
|
+
default_alpha = torch.tensor(
|
374
|
+
sd_lora_rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False
|
375
|
+
)
|
376
|
+
alpha = sds_sd.pop(sds_key + ".alpha", default_alpha)
|
366
377
|
scale = alpha / sd_lora_rank
|
367
378
|
|
368
379
|
# calculate scale_down and scale_up
|
@@ -516,10 +527,103 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
|
516
527
|
f"transformer.single_transformer_blocks.{i}.norm.linear",
|
517
528
|
)
|
518
529
|
|
530
|
+
# TODO: alphas.
|
531
|
+
def assign_remaining_weights(assignments, source):
|
532
|
+
for lora_key in ["lora_A", "lora_B"]:
|
533
|
+
orig_lora_key = "lora_down" if lora_key == "lora_A" else "lora_up"
|
534
|
+
for target_fmt, source_fmt, transform in assignments:
|
535
|
+
target_key = target_fmt.format(lora_key=lora_key)
|
536
|
+
source_key = source_fmt.format(orig_lora_key=orig_lora_key)
|
537
|
+
value = source.pop(source_key)
|
538
|
+
if transform:
|
539
|
+
value = transform(value)
|
540
|
+
ait_sd[target_key] = value
|
541
|
+
|
542
|
+
if any("guidance_in" in k for k in sds_sd):
|
543
|
+
assign_remaining_weights(
|
544
|
+
[
|
545
|
+
(
|
546
|
+
"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight",
|
547
|
+
"lora_unet_guidance_in_in_layer.{orig_lora_key}.weight",
|
548
|
+
None,
|
549
|
+
),
|
550
|
+
(
|
551
|
+
"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight",
|
552
|
+
"lora_unet_guidance_in_out_layer.{orig_lora_key}.weight",
|
553
|
+
None,
|
554
|
+
),
|
555
|
+
],
|
556
|
+
sds_sd,
|
557
|
+
)
|
558
|
+
|
559
|
+
if any("img_in" in k for k in sds_sd):
|
560
|
+
assign_remaining_weights(
|
561
|
+
[
|
562
|
+
("x_embedder.{lora_key}.weight", "lora_unet_img_in.{orig_lora_key}.weight", None),
|
563
|
+
],
|
564
|
+
sds_sd,
|
565
|
+
)
|
566
|
+
|
567
|
+
if any("txt_in" in k for k in sds_sd):
|
568
|
+
assign_remaining_weights(
|
569
|
+
[
|
570
|
+
("context_embedder.{lora_key}.weight", "lora_unet_txt_in.{orig_lora_key}.weight", None),
|
571
|
+
],
|
572
|
+
sds_sd,
|
573
|
+
)
|
574
|
+
|
575
|
+
if any("time_in" in k for k in sds_sd):
|
576
|
+
assign_remaining_weights(
|
577
|
+
[
|
578
|
+
(
|
579
|
+
"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight",
|
580
|
+
"lora_unet_time_in_in_layer.{orig_lora_key}.weight",
|
581
|
+
None,
|
582
|
+
),
|
583
|
+
(
|
584
|
+
"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight",
|
585
|
+
"lora_unet_time_in_out_layer.{orig_lora_key}.weight",
|
586
|
+
None,
|
587
|
+
),
|
588
|
+
],
|
589
|
+
sds_sd,
|
590
|
+
)
|
591
|
+
|
592
|
+
if any("vector_in" in k for k in sds_sd):
|
593
|
+
assign_remaining_weights(
|
594
|
+
[
|
595
|
+
(
|
596
|
+
"time_text_embed.text_embedder.linear_1.{lora_key}.weight",
|
597
|
+
"lora_unet_vector_in_in_layer.{orig_lora_key}.weight",
|
598
|
+
None,
|
599
|
+
),
|
600
|
+
(
|
601
|
+
"time_text_embed.text_embedder.linear_2.{lora_key}.weight",
|
602
|
+
"lora_unet_vector_in_out_layer.{orig_lora_key}.weight",
|
603
|
+
None,
|
604
|
+
),
|
605
|
+
],
|
606
|
+
sds_sd,
|
607
|
+
)
|
608
|
+
|
609
|
+
if any("final_layer" in k for k in sds_sd):
|
610
|
+
# Notice the swap in processing for "final_layer".
|
611
|
+
assign_remaining_weights(
|
612
|
+
[
|
613
|
+
(
|
614
|
+
"norm_out.linear.{lora_key}.weight",
|
615
|
+
"lora_unet_final_layer_adaLN_modulation_1.{orig_lora_key}.weight",
|
616
|
+
swap_scale_shift,
|
617
|
+
),
|
618
|
+
("proj_out.{lora_key}.weight", "lora_unet_final_layer_linear.{orig_lora_key}.weight", None),
|
619
|
+
],
|
620
|
+
sds_sd,
|
621
|
+
)
|
622
|
+
|
519
623
|
remaining_keys = list(sds_sd.keys())
|
520
624
|
te_state_dict = {}
|
521
625
|
if remaining_keys:
|
522
|
-
if not all(k.startswith("lora_te1") for k in remaining_keys):
|
626
|
+
if not all(k.startswith(("lora_te", "lora_te1")) for k in remaining_keys):
|
523
627
|
raise ValueError(f"Incompatible keys detected: \n\n {', '.join(remaining_keys)}")
|
524
628
|
for key in remaining_keys:
|
525
629
|
if not key.endswith("lora_down.weight"):
|
@@ -558,6 +662,223 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
|
558
662
|
new_state_dict = {**ait_sd, **te_state_dict}
|
559
663
|
return new_state_dict
|
560
664
|
|
665
|
+
def _convert_mixture_state_dict_to_diffusers(state_dict):
|
666
|
+
new_state_dict = {}
|
667
|
+
|
668
|
+
def _convert(original_key, diffusers_key, state_dict, new_state_dict):
|
669
|
+
down_key = f"{original_key}.lora_down.weight"
|
670
|
+
down_weight = state_dict.pop(down_key)
|
671
|
+
lora_rank = down_weight.shape[0]
|
672
|
+
|
673
|
+
up_weight_key = f"{original_key}.lora_up.weight"
|
674
|
+
up_weight = state_dict.pop(up_weight_key)
|
675
|
+
|
676
|
+
alpha_key = f"{original_key}.alpha"
|
677
|
+
alpha = state_dict.pop(alpha_key)
|
678
|
+
|
679
|
+
# scale weight by alpha and dim
|
680
|
+
scale = alpha / lora_rank
|
681
|
+
# calculate scale_down and scale_up
|
682
|
+
scale_down = scale
|
683
|
+
scale_up = 1.0
|
684
|
+
while scale_down * 2 < scale_up:
|
685
|
+
scale_down *= 2
|
686
|
+
scale_up /= 2
|
687
|
+
down_weight = down_weight * scale_down
|
688
|
+
up_weight = up_weight * scale_up
|
689
|
+
|
690
|
+
diffusers_down_key = f"{diffusers_key}.lora_A.weight"
|
691
|
+
new_state_dict[diffusers_down_key] = down_weight
|
692
|
+
new_state_dict[diffusers_down_key.replace(".lora_A.", ".lora_B.")] = up_weight
|
693
|
+
|
694
|
+
all_unique_keys = {
|
695
|
+
k.replace(".lora_down.weight", "").replace(".lora_up.weight", "").replace(".alpha", "")
|
696
|
+
for k in state_dict
|
697
|
+
if not k.startswith(("lora_unet_"))
|
698
|
+
}
|
699
|
+
assert all(k.startswith(("lora_transformer_", "lora_te1_")) for k in all_unique_keys), f"{all_unique_keys=}"
|
700
|
+
|
701
|
+
has_te_keys = False
|
702
|
+
for k in all_unique_keys:
|
703
|
+
if k.startswith("lora_transformer_single_transformer_blocks_"):
|
704
|
+
i = int(k.split("lora_transformer_single_transformer_blocks_")[-1].split("_")[0])
|
705
|
+
diffusers_key = f"single_transformer_blocks.{i}"
|
706
|
+
elif k.startswith("lora_transformer_transformer_blocks_"):
|
707
|
+
i = int(k.split("lora_transformer_transformer_blocks_")[-1].split("_")[0])
|
708
|
+
diffusers_key = f"transformer_blocks.{i}"
|
709
|
+
elif k.startswith("lora_te1_"):
|
710
|
+
has_te_keys = True
|
711
|
+
continue
|
712
|
+
else:
|
713
|
+
raise NotImplementedError
|
714
|
+
|
715
|
+
if "attn_" in k:
|
716
|
+
if "_to_out_0" in k:
|
717
|
+
diffusers_key += ".attn.to_out.0"
|
718
|
+
elif "_to_add_out" in k:
|
719
|
+
diffusers_key += ".attn.to_add_out"
|
720
|
+
elif any(qkv in k for qkv in ["to_q", "to_k", "to_v"]):
|
721
|
+
remaining = k.split("attn_")[-1]
|
722
|
+
diffusers_key += f".attn.{remaining}"
|
723
|
+
elif any(add_qkv in k for add_qkv in ["add_q_proj", "add_k_proj", "add_v_proj"]):
|
724
|
+
remaining = k.split("attn_")[-1]
|
725
|
+
diffusers_key += f".attn.{remaining}"
|
726
|
+
|
727
|
+
_convert(k, diffusers_key, state_dict, new_state_dict)
|
728
|
+
|
729
|
+
if has_te_keys:
|
730
|
+
layer_pattern = re.compile(r"lora_te1_text_model_encoder_layers_(\d+)")
|
731
|
+
attn_mapping = {
|
732
|
+
"q_proj": ".self_attn.q_proj",
|
733
|
+
"k_proj": ".self_attn.k_proj",
|
734
|
+
"v_proj": ".self_attn.v_proj",
|
735
|
+
"out_proj": ".self_attn.out_proj",
|
736
|
+
}
|
737
|
+
mlp_mapping = {"fc1": ".mlp.fc1", "fc2": ".mlp.fc2"}
|
738
|
+
for k in all_unique_keys:
|
739
|
+
if not k.startswith("lora_te1_"):
|
740
|
+
continue
|
741
|
+
|
742
|
+
match = layer_pattern.search(k)
|
743
|
+
if not match:
|
744
|
+
continue
|
745
|
+
i = int(match.group(1))
|
746
|
+
diffusers_key = f"text_model.encoder.layers.{i}"
|
747
|
+
|
748
|
+
if "attn" in k:
|
749
|
+
for key_fragment, suffix in attn_mapping.items():
|
750
|
+
if key_fragment in k:
|
751
|
+
diffusers_key += suffix
|
752
|
+
break
|
753
|
+
elif "mlp" in k:
|
754
|
+
for key_fragment, suffix in mlp_mapping.items():
|
755
|
+
if key_fragment in k:
|
756
|
+
diffusers_key += suffix
|
757
|
+
break
|
758
|
+
|
759
|
+
_convert(k, diffusers_key, state_dict, new_state_dict)
|
760
|
+
|
761
|
+
remaining_all_unet = False
|
762
|
+
if state_dict:
|
763
|
+
remaining_all_unet = all(k.startswith("lora_unet_") for k in state_dict)
|
764
|
+
if remaining_all_unet:
|
765
|
+
keys = list(state_dict.keys())
|
766
|
+
for k in keys:
|
767
|
+
state_dict.pop(k)
|
768
|
+
|
769
|
+
if len(state_dict) > 0:
|
770
|
+
raise ValueError(
|
771
|
+
f"Expected an empty state dict at this point but its has these keys which couldn't be parsed: {list(state_dict.keys())}."
|
772
|
+
)
|
773
|
+
|
774
|
+
transformer_state_dict = {
|
775
|
+
f"transformer.{k}": v for k, v in new_state_dict.items() if not k.startswith("text_model.")
|
776
|
+
}
|
777
|
+
te_state_dict = {f"text_encoder.{k}": v for k, v in new_state_dict.items() if k.startswith("text_model.")}
|
778
|
+
return {**transformer_state_dict, **te_state_dict}
|
779
|
+
|
780
|
+
# This is weird.
|
781
|
+
# https://huggingface.co/sayakpaul/different-lora-from-civitai/tree/main?show_file_info=sharp_detailed_foot.safetensors
|
782
|
+
# has both `peft` and non-peft state dict.
|
783
|
+
has_peft_state_dict = any(k.startswith("transformer.") for k in state_dict)
|
784
|
+
if has_peft_state_dict:
|
785
|
+
state_dict = {k: v for k, v in state_dict.items() if k.startswith("transformer.")}
|
786
|
+
return state_dict
|
787
|
+
|
788
|
+
# Another weird one.
|
789
|
+
has_mixture = any(
|
790
|
+
k.startswith("lora_transformer_") and ("lora_down" in k or "lora_up" in k or "alpha" in k) for k in state_dict
|
791
|
+
)
|
792
|
+
|
793
|
+
# ComfyUI.
|
794
|
+
if not has_mixture:
|
795
|
+
state_dict = {k.replace("diffusion_model.", "lora_unet_"): v for k, v in state_dict.items()}
|
796
|
+
state_dict = {k.replace("text_encoders.clip_l.transformer.", "lora_te_"): v for k, v in state_dict.items()}
|
797
|
+
|
798
|
+
has_position_embedding = any("position_embedding" in k for k in state_dict)
|
799
|
+
if has_position_embedding:
|
800
|
+
zero_status_pe = state_dict_all_zero(state_dict, "position_embedding")
|
801
|
+
if zero_status_pe:
|
802
|
+
logger.info(
|
803
|
+
"The `position_embedding` LoRA params are all zeros which make them ineffective. "
|
804
|
+
"So, we will purge them out of the curret state dict to make loading possible."
|
805
|
+
)
|
806
|
+
|
807
|
+
else:
|
808
|
+
logger.info(
|
809
|
+
"The state_dict has position_embedding LoRA params and we currently do not support them. "
|
810
|
+
"Open an issue if you need this supported - https://github.com/huggingface/diffusers/issues/new."
|
811
|
+
)
|
812
|
+
state_dict = {k: v for k, v in state_dict.items() if "position_embedding" not in k}
|
813
|
+
|
814
|
+
has_t5xxl = any(k.startswith("text_encoders.t5xxl.transformer.") for k in state_dict)
|
815
|
+
if has_t5xxl:
|
816
|
+
zero_status_t5 = state_dict_all_zero(state_dict, "text_encoders.t5xxl")
|
817
|
+
if zero_status_t5:
|
818
|
+
logger.info(
|
819
|
+
"The `t5xxl` LoRA params are all zeros which make them ineffective. "
|
820
|
+
"So, we will purge them out of the curret state dict to make loading possible."
|
821
|
+
)
|
822
|
+
else:
|
823
|
+
logger.info(
|
824
|
+
"T5-xxl keys found in the state dict, which are currently unsupported. We will filter them out."
|
825
|
+
"Open an issue if this is a problem - https://github.com/huggingface/diffusers/issues/new."
|
826
|
+
)
|
827
|
+
state_dict = {k: v for k, v in state_dict.items() if not k.startswith("text_encoders.t5xxl.transformer.")}
|
828
|
+
|
829
|
+
has_diffb = any("diff_b" in k and k.startswith(("lora_unet_", "lora_te_")) for k in state_dict)
|
830
|
+
if has_diffb:
|
831
|
+
zero_status_diff_b = state_dict_all_zero(state_dict, ".diff_b")
|
832
|
+
if zero_status_diff_b:
|
833
|
+
logger.info(
|
834
|
+
"The `diff_b` LoRA params are all zeros which make them ineffective. "
|
835
|
+
"So, we will purge them out of the curret state dict to make loading possible."
|
836
|
+
)
|
837
|
+
else:
|
838
|
+
logger.info(
|
839
|
+
"`diff_b` keys found in the state dict which are currently unsupported. "
|
840
|
+
"So, we will filter out those keys. Open an issue if this is a problem - "
|
841
|
+
"https://github.com/huggingface/diffusers/issues/new."
|
842
|
+
)
|
843
|
+
state_dict = {k: v for k, v in state_dict.items() if ".diff_b" not in k}
|
844
|
+
|
845
|
+
has_norm_diff = any(".norm" in k and ".diff" in k for k in state_dict)
|
846
|
+
if has_norm_diff:
|
847
|
+
zero_status_diff = state_dict_all_zero(state_dict, ".diff")
|
848
|
+
if zero_status_diff:
|
849
|
+
logger.info(
|
850
|
+
"The `diff` LoRA params are all zeros which make them ineffective. "
|
851
|
+
"So, we will purge them out of the curret state dict to make loading possible."
|
852
|
+
)
|
853
|
+
else:
|
854
|
+
logger.info(
|
855
|
+
"Normalization diff keys found in the state dict which are currently unsupported. "
|
856
|
+
"So, we will filter out those keys. Open an issue if this is a problem - "
|
857
|
+
"https://github.com/huggingface/diffusers/issues/new."
|
858
|
+
)
|
859
|
+
state_dict = {k: v for k, v in state_dict.items() if ".norm" not in k and ".diff" not in k}
|
860
|
+
|
861
|
+
limit_substrings = ["lora_down", "lora_up"]
|
862
|
+
if any("alpha" in k for k in state_dict):
|
863
|
+
limit_substrings.append("alpha")
|
864
|
+
|
865
|
+
state_dict = {
|
866
|
+
_custom_replace(k, limit_substrings): v
|
867
|
+
for k, v in state_dict.items()
|
868
|
+
if k.startswith(("lora_unet_", "lora_te_"))
|
869
|
+
}
|
870
|
+
|
871
|
+
if any("text_projection" in k for k in state_dict):
|
872
|
+
logger.info(
|
873
|
+
"`text_projection` keys found in the `state_dict` which are unexpected. "
|
874
|
+
"So, we will filter out those keys. Open an issue if this is a problem - "
|
875
|
+
"https://github.com/huggingface/diffusers/issues/new."
|
876
|
+
)
|
877
|
+
state_dict = {k: v for k, v in state_dict.items() if "text_projection" not in k}
|
878
|
+
|
879
|
+
if has_mixture:
|
880
|
+
return _convert_mixture_state_dict_to_diffusers(state_dict)
|
881
|
+
|
561
882
|
return _convert_sd_scripts_to_ai_toolkit(state_dict)
|
562
883
|
|
563
884
|
|
@@ -669,6 +990,26 @@ def _convert_xlabs_flux_lora_to_diffusers(old_state_dict):
|
|
669
990
|
return new_state_dict
|
670
991
|
|
671
992
|
|
993
|
+
def _custom_replace(key: str, substrings: List[str]) -> str:
|
994
|
+
# Replaces the "."s with "_"s upto the `substrings`.
|
995
|
+
# Example:
|
996
|
+
# lora_unet.foo.bar.lora_A.weight -> lora_unet_foo_bar.lora_A.weight
|
997
|
+
pattern = "(" + "|".join(re.escape(sub) for sub in substrings) + ")"
|
998
|
+
|
999
|
+
match = re.search(pattern, key)
|
1000
|
+
if match:
|
1001
|
+
start_sub = match.start()
|
1002
|
+
if start_sub > 0 and key[start_sub - 1] == ".":
|
1003
|
+
boundary = start_sub - 1
|
1004
|
+
else:
|
1005
|
+
boundary = start_sub
|
1006
|
+
left = key[:boundary].replace(".", "_")
|
1007
|
+
right = key[boundary:]
|
1008
|
+
return left + right
|
1009
|
+
else:
|
1010
|
+
return key.replace(".", "_")
|
1011
|
+
|
1012
|
+
|
672
1013
|
def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
|
673
1014
|
converted_state_dict = {}
|
674
1015
|
original_state_dict_keys = list(original_state_dict.keys())
|
@@ -677,28 +1018,23 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
|
|
677
1018
|
inner_dim = 3072
|
678
1019
|
mlp_ratio = 4.0
|
679
1020
|
|
680
|
-
def swap_scale_shift(weight):
|
681
|
-
shift, scale = weight.chunk(2, dim=0)
|
682
|
-
new_weight = torch.cat([scale, shift], dim=0)
|
683
|
-
return new_weight
|
684
|
-
|
685
1021
|
for lora_key in ["lora_A", "lora_B"]:
|
686
1022
|
## time_text_embed.timestep_embedder <- time_in
|
687
|
-
converted_state_dict[
|
688
|
-
f"
|
689
|
-
|
1023
|
+
converted_state_dict[f"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight"] = (
|
1024
|
+
original_state_dict.pop(f"time_in.in_layer.{lora_key}.weight")
|
1025
|
+
)
|
690
1026
|
if f"time_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
|
691
|
-
converted_state_dict[
|
692
|
-
f"
|
693
|
-
|
1027
|
+
converted_state_dict[f"time_text_embed.timestep_embedder.linear_1.{lora_key}.bias"] = (
|
1028
|
+
original_state_dict.pop(f"time_in.in_layer.{lora_key}.bias")
|
1029
|
+
)
|
694
1030
|
|
695
|
-
converted_state_dict[
|
696
|
-
f"
|
697
|
-
|
1031
|
+
converted_state_dict[f"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight"] = (
|
1032
|
+
original_state_dict.pop(f"time_in.out_layer.{lora_key}.weight")
|
1033
|
+
)
|
698
1034
|
if f"time_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
|
699
|
-
converted_state_dict[
|
700
|
-
f"
|
701
|
-
|
1035
|
+
converted_state_dict[f"time_text_embed.timestep_embedder.linear_2.{lora_key}.bias"] = (
|
1036
|
+
original_state_dict.pop(f"time_in.out_layer.{lora_key}.bias")
|
1037
|
+
)
|
702
1038
|
|
703
1039
|
## time_text_embed.text_embedder <- vector_in
|
704
1040
|
converted_state_dict[f"time_text_embed.text_embedder.linear_1.{lora_key}.weight"] = original_state_dict.pop(
|
@@ -720,21 +1056,21 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
|
|
720
1056
|
# guidance
|
721
1057
|
has_guidance = any("guidance" in k for k in original_state_dict)
|
722
1058
|
if has_guidance:
|
723
|
-
converted_state_dict[
|
724
|
-
f"
|
725
|
-
|
1059
|
+
converted_state_dict[f"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight"] = (
|
1060
|
+
original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.weight")
|
1061
|
+
)
|
726
1062
|
if f"guidance_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
|
727
|
-
converted_state_dict[
|
728
|
-
f"
|
729
|
-
|
1063
|
+
converted_state_dict[f"time_text_embed.guidance_embedder.linear_1.{lora_key}.bias"] = (
|
1064
|
+
original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.bias")
|
1065
|
+
)
|
730
1066
|
|
731
|
-
converted_state_dict[
|
732
|
-
f"
|
733
|
-
|
1067
|
+
converted_state_dict[f"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight"] = (
|
1068
|
+
original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.weight")
|
1069
|
+
)
|
734
1070
|
if f"guidance_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
|
735
|
-
converted_state_dict[
|
736
|
-
f"
|
737
|
-
|
1071
|
+
converted_state_dict[f"time_text_embed.guidance_embedder.linear_2.{lora_key}.bias"] = (
|
1072
|
+
original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.bias")
|
1073
|
+
)
|
738
1074
|
|
739
1075
|
# context_embedder
|
740
1076
|
converted_state_dict[f"context_embedder.{lora_key}.weight"] = original_state_dict.pop(
|
@@ -1148,3 +1484,127 @@ def _convert_hunyuan_video_lora_to_diffusers(original_state_dict):
|
|
1148
1484
|
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
|
1149
1485
|
|
1150
1486
|
return converted_state_dict
|
1487
|
+
|
1488
|
+
|
1489
|
+
def _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict):
|
1490
|
+
# Remove "diffusion_model." prefix from keys.
|
1491
|
+
state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()}
|
1492
|
+
converted_state_dict = {}
|
1493
|
+
|
1494
|
+
def get_num_layers(keys, pattern):
|
1495
|
+
layers = set()
|
1496
|
+
for key in keys:
|
1497
|
+
match = re.search(pattern, key)
|
1498
|
+
if match:
|
1499
|
+
layers.add(int(match.group(1)))
|
1500
|
+
return len(layers)
|
1501
|
+
|
1502
|
+
def process_block(prefix, index, convert_norm):
|
1503
|
+
# Process attention qkv: pop lora_A and lora_B weights.
|
1504
|
+
lora_down = state_dict.pop(f"{prefix}.{index}.attention.qkv.lora_A.weight")
|
1505
|
+
lora_up = state_dict.pop(f"{prefix}.{index}.attention.qkv.lora_B.weight")
|
1506
|
+
for attn_key in ["to_q", "to_k", "to_v"]:
|
1507
|
+
converted_state_dict[f"{prefix}.{index}.attn.{attn_key}.lora_A.weight"] = lora_down
|
1508
|
+
for attn_key, weight in zip(["to_q", "to_k", "to_v"], torch.split(lora_up, [2304, 768, 768], dim=0)):
|
1509
|
+
converted_state_dict[f"{prefix}.{index}.attn.{attn_key}.lora_B.weight"] = weight
|
1510
|
+
|
1511
|
+
# Process attention out weights.
|
1512
|
+
converted_state_dict[f"{prefix}.{index}.attn.to_out.0.lora_A.weight"] = state_dict.pop(
|
1513
|
+
f"{prefix}.{index}.attention.out.lora_A.weight"
|
1514
|
+
)
|
1515
|
+
converted_state_dict[f"{prefix}.{index}.attn.to_out.0.lora_B.weight"] = state_dict.pop(
|
1516
|
+
f"{prefix}.{index}.attention.out.lora_B.weight"
|
1517
|
+
)
|
1518
|
+
|
1519
|
+
# Process feed-forward weights for layers 1, 2, and 3.
|
1520
|
+
for layer in range(1, 4):
|
1521
|
+
converted_state_dict[f"{prefix}.{index}.feed_forward.linear_{layer}.lora_A.weight"] = state_dict.pop(
|
1522
|
+
f"{prefix}.{index}.feed_forward.w{layer}.lora_A.weight"
|
1523
|
+
)
|
1524
|
+
converted_state_dict[f"{prefix}.{index}.feed_forward.linear_{layer}.lora_B.weight"] = state_dict.pop(
|
1525
|
+
f"{prefix}.{index}.feed_forward.w{layer}.lora_B.weight"
|
1526
|
+
)
|
1527
|
+
|
1528
|
+
if convert_norm:
|
1529
|
+
converted_state_dict[f"{prefix}.{index}.norm1.linear.lora_A.weight"] = state_dict.pop(
|
1530
|
+
f"{prefix}.{index}.adaLN_modulation.1.lora_A.weight"
|
1531
|
+
)
|
1532
|
+
converted_state_dict[f"{prefix}.{index}.norm1.linear.lora_B.weight"] = state_dict.pop(
|
1533
|
+
f"{prefix}.{index}.adaLN_modulation.1.lora_B.weight"
|
1534
|
+
)
|
1535
|
+
|
1536
|
+
noise_refiner_pattern = r"noise_refiner\.(\d+)\."
|
1537
|
+
num_noise_refiner_layers = get_num_layers(state_dict.keys(), noise_refiner_pattern)
|
1538
|
+
for i in range(num_noise_refiner_layers):
|
1539
|
+
process_block("noise_refiner", i, convert_norm=True)
|
1540
|
+
|
1541
|
+
context_refiner_pattern = r"context_refiner\.(\d+)\."
|
1542
|
+
num_context_refiner_layers = get_num_layers(state_dict.keys(), context_refiner_pattern)
|
1543
|
+
for i in range(num_context_refiner_layers):
|
1544
|
+
process_block("context_refiner", i, convert_norm=False)
|
1545
|
+
|
1546
|
+
core_transformer_pattern = r"layers\.(\d+)\."
|
1547
|
+
num_core_transformer_layers = get_num_layers(state_dict.keys(), core_transformer_pattern)
|
1548
|
+
for i in range(num_core_transformer_layers):
|
1549
|
+
process_block("layers", i, convert_norm=True)
|
1550
|
+
|
1551
|
+
if len(state_dict) > 0:
|
1552
|
+
raise ValueError(f"`state_dict` should be empty at this point but has {state_dict.keys()=}")
|
1553
|
+
|
1554
|
+
for key in list(converted_state_dict.keys()):
|
1555
|
+
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
|
1556
|
+
|
1557
|
+
return converted_state_dict
|
1558
|
+
|
1559
|
+
|
1560
|
+
def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
|
1561
|
+
converted_state_dict = {}
|
1562
|
+
original_state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()}
|
1563
|
+
|
1564
|
+
num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in original_state_dict})
|
1565
|
+
is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict)
|
1566
|
+
|
1567
|
+
for i in range(num_blocks):
|
1568
|
+
# Self-attention
|
1569
|
+
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
|
1570
|
+
converted_state_dict[f"blocks.{i}.attn1.{c}.lora_A.weight"] = original_state_dict.pop(
|
1571
|
+
f"blocks.{i}.self_attn.{o}.lora_A.weight"
|
1572
|
+
)
|
1573
|
+
converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.weight"] = original_state_dict.pop(
|
1574
|
+
f"blocks.{i}.self_attn.{o}.lora_B.weight"
|
1575
|
+
)
|
1576
|
+
|
1577
|
+
# Cross-attention
|
1578
|
+
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
|
1579
|
+
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop(
|
1580
|
+
f"blocks.{i}.cross_attn.{o}.lora_A.weight"
|
1581
|
+
)
|
1582
|
+
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
|
1583
|
+
f"blocks.{i}.cross_attn.{o}.lora_B.weight"
|
1584
|
+
)
|
1585
|
+
|
1586
|
+
if is_i2v_lora:
|
1587
|
+
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
|
1588
|
+
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop(
|
1589
|
+
f"blocks.{i}.cross_attn.{o}.lora_A.weight"
|
1590
|
+
)
|
1591
|
+
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
|
1592
|
+
f"blocks.{i}.cross_attn.{o}.lora_B.weight"
|
1593
|
+
)
|
1594
|
+
|
1595
|
+
# FFN
|
1596
|
+
for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]):
|
1597
|
+
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_A.weight"] = original_state_dict.pop(
|
1598
|
+
f"blocks.{i}.{o}.lora_A.weight"
|
1599
|
+
)
|
1600
|
+
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.weight"] = original_state_dict.pop(
|
1601
|
+
f"blocks.{i}.{o}.lora_B.weight"
|
1602
|
+
)
|
1603
|
+
|
1604
|
+
if len(original_state_dict) > 0:
|
1605
|
+
raise ValueError(f"`state_dict` should be empty at this point but has {original_state_dict.keys()=}")
|
1606
|
+
|
1607
|
+
for key in list(converted_state_dict.keys()):
|
1608
|
+
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
|
1609
|
+
|
1610
|
+
return converted_state_dict
|