diffusers 0.33.1__py3-none-any.whl → 0.35.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +145 -1
- diffusers/callbacks.py +35 -0
- diffusers/commands/__init__.py +1 -1
- diffusers/commands/custom_blocks.py +134 -0
- diffusers/commands/diffusers_cli.py +3 -1
- diffusers/commands/env.py +1 -1
- diffusers/commands/fp16_safetensors.py +2 -2
- diffusers/configuration_utils.py +11 -2
- diffusers/dependency_versions_check.py +1 -1
- diffusers/dependency_versions_table.py +3 -3
- diffusers/experimental/rl/value_guided_sampling.py +1 -1
- diffusers/guiders/__init__.py +41 -0
- diffusers/guiders/adaptive_projected_guidance.py +188 -0
- diffusers/guiders/auto_guidance.py +190 -0
- diffusers/guiders/classifier_free_guidance.py +141 -0
- diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
- diffusers/guiders/frequency_decoupled_guidance.py +327 -0
- diffusers/guiders/guider_utils.py +309 -0
- diffusers/guiders/perturbed_attention_guidance.py +271 -0
- diffusers/guiders/skip_layer_guidance.py +262 -0
- diffusers/guiders/smoothed_energy_guidance.py +251 -0
- diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
- diffusers/hooks/__init__.py +17 -0
- diffusers/hooks/_common.py +56 -0
- diffusers/hooks/_helpers.py +293 -0
- diffusers/hooks/faster_cache.py +9 -8
- diffusers/hooks/first_block_cache.py +259 -0
- diffusers/hooks/group_offloading.py +332 -227
- diffusers/hooks/hooks.py +58 -3
- diffusers/hooks/layer_skip.py +263 -0
- diffusers/hooks/layerwise_casting.py +5 -10
- diffusers/hooks/pyramid_attention_broadcast.py +15 -12
- diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
- diffusers/hooks/utils.py +43 -0
- diffusers/image_processor.py +7 -2
- diffusers/loaders/__init__.py +10 -0
- diffusers/loaders/ip_adapter.py +260 -18
- diffusers/loaders/lora_base.py +261 -127
- diffusers/loaders/lora_conversion_utils.py +657 -35
- diffusers/loaders/lora_pipeline.py +2778 -1246
- diffusers/loaders/peft.py +78 -112
- diffusers/loaders/single_file.py +2 -2
- diffusers/loaders/single_file_model.py +64 -15
- diffusers/loaders/single_file_utils.py +395 -7
- diffusers/loaders/textual_inversion.py +3 -2
- diffusers/loaders/transformer_flux.py +10 -11
- diffusers/loaders/transformer_sd3.py +8 -3
- diffusers/loaders/unet.py +24 -21
- diffusers/loaders/unet_loader_utils.py +6 -3
- diffusers/loaders/utils.py +1 -1
- diffusers/models/__init__.py +23 -1
- diffusers/models/activations.py +5 -5
- diffusers/models/adapter.py +2 -3
- diffusers/models/attention.py +488 -7
- diffusers/models/attention_dispatch.py +1218 -0
- diffusers/models/attention_flax.py +10 -10
- diffusers/models/attention_processor.py +113 -667
- diffusers/models/auto_model.py +49 -12
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
- diffusers/models/autoencoders/autoencoder_dc.py +17 -4
- diffusers/models/autoencoders/autoencoder_kl.py +5 -5
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
- diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1110 -0
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_wan.py +626 -62
- diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
- diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
- diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
- diffusers/models/autoencoders/vae.py +13 -2
- diffusers/models/autoencoders/vq_model.py +2 -2
- diffusers/models/cache_utils.py +32 -10
- diffusers/models/controlnet.py +1 -1
- diffusers/models/controlnet_flux.py +1 -1
- diffusers/models/controlnet_sd3.py +1 -1
- diffusers/models/controlnet_sparsectrl.py +1 -1
- diffusers/models/controlnets/__init__.py +1 -0
- diffusers/models/controlnets/controlnet.py +3 -3
- diffusers/models/controlnets/controlnet_flax.py +1 -1
- diffusers/models/controlnets/controlnet_flux.py +21 -20
- diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
- diffusers/models/controlnets/controlnet_sana.py +290 -0
- diffusers/models/controlnets/controlnet_sd3.py +1 -1
- diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
- diffusers/models/controlnets/controlnet_union.py +5 -5
- diffusers/models/controlnets/controlnet_xs.py +7 -7
- diffusers/models/controlnets/multicontrolnet.py +4 -5
- diffusers/models/controlnets/multicontrolnet_union.py +5 -6
- diffusers/models/downsampling.py +2 -2
- diffusers/models/embeddings.py +36 -46
- diffusers/models/embeddings_flax.py +2 -2
- diffusers/models/lora.py +3 -3
- diffusers/models/model_loading_utils.py +233 -1
- diffusers/models/modeling_flax_utils.py +1 -2
- diffusers/models/modeling_utils.py +203 -108
- diffusers/models/normalization.py +4 -4
- diffusers/models/resnet.py +2 -2
- diffusers/models/resnet_flax.py +1 -1
- diffusers/models/transformers/__init__.py +7 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
- diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
- diffusers/models/transformers/consisid_transformer_3d.py +1 -1
- diffusers/models/transformers/dit_transformer_2d.py +2 -2
- diffusers/models/transformers/dual_transformer_2d.py +1 -1
- diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
- diffusers/models/transformers/latte_transformer_3d.py +4 -5
- diffusers/models/transformers/lumina_nextdit2d.py +2 -2
- diffusers/models/transformers/pixart_transformer_2d.py +3 -3
- diffusers/models/transformers/prior_transformer.py +1 -1
- diffusers/models/transformers/sana_transformer.py +8 -3
- diffusers/models/transformers/stable_audio_transformer.py +5 -9
- diffusers/models/transformers/t5_film_transformer.py +3 -3
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +1 -1
- diffusers/models/transformers/transformer_chroma.py +641 -0
- diffusers/models/transformers/transformer_cogview3plus.py +5 -10
- diffusers/models/transformers/transformer_cogview4.py +353 -27
- diffusers/models/transformers/transformer_cosmos.py +586 -0
- diffusers/models/transformers/transformer_flux.py +376 -138
- diffusers/models/transformers/transformer_hidream_image.py +942 -0
- diffusers/models/transformers/transformer_hunyuan_video.py +12 -8
- diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
- diffusers/models/transformers/transformer_ltx.py +105 -24
- diffusers/models/transformers/transformer_lumina2.py +1 -1
- diffusers/models/transformers/transformer_mochi.py +1 -1
- diffusers/models/transformers/transformer_omnigen.py +2 -2
- diffusers/models/transformers/transformer_qwenimage.py +645 -0
- diffusers/models/transformers/transformer_sd3.py +7 -7
- diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/transformers/transformer_wan.py +316 -87
- diffusers/models/transformers/transformer_wan_vace.py +387 -0
- diffusers/models/unets/unet_1d.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +1 -1
- diffusers/models/unets/unet_2d_blocks.py +1 -1
- diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
- diffusers/models/unets/unet_2d_condition.py +4 -3
- diffusers/models/unets/unet_2d_condition_flax.py +2 -2
- diffusers/models/unets/unet_3d_blocks.py +1 -1
- diffusers/models/unets/unet_3d_condition.py +3 -3
- diffusers/models/unets/unet_i2vgen_xl.py +3 -3
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +2 -2
- diffusers/models/unets/unet_stable_cascade.py +1 -1
- diffusers/models/upsampling.py +2 -2
- diffusers/models/vae_flax.py +2 -2
- diffusers/models/vq_model.py +1 -1
- diffusers/modular_pipelines/__init__.py +83 -0
- diffusers/modular_pipelines/components_manager.py +1068 -0
- diffusers/modular_pipelines/flux/__init__.py +66 -0
- diffusers/modular_pipelines/flux/before_denoise.py +689 -0
- diffusers/modular_pipelines/flux/decoders.py +109 -0
- diffusers/modular_pipelines/flux/denoise.py +227 -0
- diffusers/modular_pipelines/flux/encoders.py +412 -0
- diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
- diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
- diffusers/modular_pipelines/modular_pipeline.py +2446 -0
- diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
- diffusers/modular_pipelines/node_utils.py +665 -0
- diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
- diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
- diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
- diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
- diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
- diffusers/modular_pipelines/wan/__init__.py +66 -0
- diffusers/modular_pipelines/wan/before_denoise.py +365 -0
- diffusers/modular_pipelines/wan/decoders.py +105 -0
- diffusers/modular_pipelines/wan/denoise.py +261 -0
- diffusers/modular_pipelines/wan/encoders.py +242 -0
- diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
- diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
- diffusers/pipelines/__init__.py +68 -6
- diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
- diffusers/pipelines/amused/pipeline_amused.py +7 -6
- diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
- diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
- diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +22 -13
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
- diffusers/pipelines/auto_pipeline.py +23 -20
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
- diffusers/pipelines/chroma/__init__.py +49 -0
- diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
- diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
- diffusers/pipelines/chroma/pipeline_output.py +21 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +17 -16
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +17 -16
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +18 -17
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +17 -16
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
- diffusers/pipelines/cogview4/pipeline_cogview4.py +23 -22
- diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
- diffusers/pipelines/consisid/consisid_utils.py +2 -2
- diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +11 -10
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +226 -107
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +12 -8
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +207 -105
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
- diffusers/pipelines/cosmos/__init__.py +54 -0
- diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
- diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
- diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
- diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
- diffusers/pipelines/cosmos/pipeline_output.py +40 -0
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
- diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
- diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
- diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
- diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
- diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
- diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
- diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +8 -8
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
- diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +4 -2
- diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
- diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
- diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
- diffusers/pipelines/flux/__init__.py +4 -0
- diffusers/pipelines/flux/modeling_flux.py +1 -1
- diffusers/pipelines/flux/pipeline_flux.py +37 -36
- diffusers/pipelines/flux/pipeline_flux_control.py +9 -9
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +31 -23
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +3 -2
- diffusers/pipelines/flux/pipeline_flux_fill.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_img2img.py +40 -7
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +12 -7
- diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
- diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +2 -2
- diffusers/pipelines/flux/pipeline_output.py +6 -4
- diffusers/pipelines/free_init_utils.py +2 -2
- diffusers/pipelines/free_noise_utils.py +3 -3
- diffusers/pipelines/hidream_image/__init__.py +47 -0
- diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
- diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
- diffusers/pipelines/hunyuan_video/__init__.py +2 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +26 -25
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
- diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
- diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
- diffusers/pipelines/kolors/text_encoder.py +3 -3
- diffusers/pipelines/kolors/tokenizer.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
- diffusers/pipelines/latte/pipeline_latte.py +12 -12
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
- diffusers/pipelines/ltx/__init__.py +4 -0
- diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +64 -18
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +117 -38
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +63 -18
- diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
- diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
- diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
- diffusers/pipelines/mochi/pipeline_mochi.py +15 -14
- diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
- diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
- diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
- diffusers/pipelines/onnx_utils.py +15 -2
- diffusers/pipelines/pag/pag_utils.py +2 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
- diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
- diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
- diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
- diffusers/pipelines/pia/pipeline_pia.py +8 -6
- diffusers/pipelines/pipeline_flax_utils.py +5 -6
- diffusers/pipelines/pipeline_loading_utils.py +113 -15
- diffusers/pipelines/pipeline_utils.py +127 -48
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +14 -12
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +31 -11
- diffusers/pipelines/qwenimage/__init__.py +55 -0
- diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
- diffusers/pipelines/sana/__init__.py +4 -0
- diffusers/pipelines/sana/pipeline_sana.py +23 -21
- diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
- diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
- diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
- diffusers/pipelines/shap_e/camera.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
- diffusers/pipelines/shap_e/renderer.py +3 -3
- diffusers/pipelines/skyreels_v2/__init__.py +59 -0
- diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
- diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
- diffusers/pipelines/stable_diffusion/__init__.py +0 -7
- diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +12 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +11 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
- diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
- diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +13 -12
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
- diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
- diffusers/pipelines/unclip/text_proj.py +2 -2
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
- diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
- diffusers/pipelines/visualcloze/__init__.py +52 -0
- diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
- diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
- diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
- diffusers/pipelines/wan/__init__.py +2 -0
- diffusers/pipelines/wan/pipeline_wan.py +91 -30
- diffusers/pipelines/wan/pipeline_wan_i2v.py +145 -45
- diffusers/pipelines/wan/pipeline_wan_vace.py +975 -0
- diffusers/pipelines/wan/pipeline_wan_video2video.py +14 -16
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
- diffusers/quantizers/__init__.py +3 -1
- diffusers/quantizers/base.py +17 -1
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
- diffusers/quantizers/bitsandbytes/utils.py +10 -7
- diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
- diffusers/quantizers/gguf/utils.py +108 -16
- diffusers/quantizers/pipe_quant_config.py +202 -0
- diffusers/quantizers/quantization_config.py +18 -16
- diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +31 -1
- diffusers/schedulers/__init__.py +3 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
- diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
- diffusers/schedulers/scheduling_consistency_models.py +1 -1
- diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
- diffusers/schedulers/scheduling_ddim.py +8 -8
- diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
- diffusers/schedulers/scheduling_ddim_flax.py +6 -6
- diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
- diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
- diffusers/schedulers/scheduling_ddpm.py +9 -9
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
- diffusers/schedulers/scheduling_deis_multistep.py +16 -9
- diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +18 -12
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
- diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +19 -13
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
- diffusers/schedulers/scheduling_edm_euler.py +20 -11
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
- diffusers/schedulers/scheduling_euler_discrete.py +3 -3
- diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
- diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
- diffusers/schedulers/scheduling_heun_discrete.py +2 -2
- diffusers/schedulers/scheduling_ipndm.py +2 -2
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
- diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
- diffusers/schedulers/scheduling_lcm.py +3 -3
- diffusers/schedulers/scheduling_lms_discrete.py +2 -2
- diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
- diffusers/schedulers/scheduling_pndm.py +4 -4
- diffusers/schedulers/scheduling_pndm_flax.py +4 -4
- diffusers/schedulers/scheduling_repaint.py +9 -9
- diffusers/schedulers/scheduling_sasolver.py +15 -15
- diffusers/schedulers/scheduling_scm.py +1 -2
- diffusers/schedulers/scheduling_sde_ve.py +1 -1
- diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
- diffusers/schedulers/scheduling_tcd.py +3 -3
- diffusers/schedulers/scheduling_unclip.py +5 -5
- diffusers/schedulers/scheduling_unipc_multistep.py +21 -12
- diffusers/schedulers/scheduling_utils.py +3 -3
- diffusers/schedulers/scheduling_utils_flax.py +2 -2
- diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
- diffusers/training_utils.py +91 -5
- diffusers/utils/__init__.py +15 -0
- diffusers/utils/accelerate_utils.py +1 -1
- diffusers/utils/constants.py +4 -0
- diffusers/utils/doc_utils.py +1 -1
- diffusers/utils/dummy_pt_objects.py +432 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +480 -0
- diffusers/utils/dynamic_modules_utils.py +85 -8
- diffusers/utils/export_utils.py +1 -1
- diffusers/utils/hub_utils.py +33 -17
- diffusers/utils/import_utils.py +151 -18
- diffusers/utils/logging.py +1 -1
- diffusers/utils/outputs.py +2 -1
- diffusers/utils/peft_utils.py +96 -10
- diffusers/utils/state_dict_utils.py +20 -3
- diffusers/utils/testing_utils.py +195 -17
- diffusers/utils/torch_utils.py +43 -5
- diffusers/video_processor.py +2 -2
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/METADATA +72 -57
- diffusers-0.35.0.dist-info/RECORD +703 -0
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/WHEEL +1 -1
- diffusers-0.33.1.dist-info/RECORD +0 -608
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -33,6 +33,24 @@ def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", b
|
|
33
33
|
# 1. get all state_dict_keys
|
34
34
|
all_keys = list(state_dict.keys())
|
35
35
|
sgm_patterns = ["input_blocks", "middle_block", "output_blocks"]
|
36
|
+
not_sgm_patterns = ["down_blocks", "mid_block", "up_blocks"]
|
37
|
+
|
38
|
+
# check if state_dict contains both patterns
|
39
|
+
contains_sgm_patterns = False
|
40
|
+
contains_not_sgm_patterns = False
|
41
|
+
for key in all_keys:
|
42
|
+
if any(p in key for p in sgm_patterns):
|
43
|
+
contains_sgm_patterns = True
|
44
|
+
elif any(p in key for p in not_sgm_patterns):
|
45
|
+
contains_not_sgm_patterns = True
|
46
|
+
|
47
|
+
# if state_dict contains both patterns, remove sgm
|
48
|
+
# we can then return state_dict immediately
|
49
|
+
if contains_sgm_patterns and contains_not_sgm_patterns:
|
50
|
+
for key in all_keys:
|
51
|
+
if any(p in key for p in sgm_patterns):
|
52
|
+
state_dict.pop(key)
|
53
|
+
return state_dict
|
36
54
|
|
37
55
|
# 2. check if needs remapping, if not return original dict
|
38
56
|
is_in_sgm_format = False
|
@@ -126,7 +144,7 @@ def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", b
|
|
126
144
|
)
|
127
145
|
new_state_dict[new_key] = state_dict.pop(key)
|
128
146
|
|
129
|
-
if
|
147
|
+
if state_dict:
|
130
148
|
raise ValueError("At this point all state dict entries have to be converted.")
|
131
149
|
|
132
150
|
return new_state_dict
|
@@ -415,7 +433,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
|
415
433
|
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
|
416
434
|
if not is_sparse:
|
417
435
|
# down_weight is copied to each split
|
418
|
-
ait_sd.update(
|
436
|
+
ait_sd.update(dict.fromkeys(ait_down_keys, down_weight))
|
419
437
|
|
420
438
|
# up_weight is split to each split
|
421
439
|
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
|
@@ -709,8 +727,25 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
|
709
727
|
elif k.startswith("lora_te1_"):
|
710
728
|
has_te_keys = True
|
711
729
|
continue
|
730
|
+
elif k.startswith("lora_transformer_context_embedder"):
|
731
|
+
diffusers_key = "context_embedder"
|
732
|
+
elif k.startswith("lora_transformer_norm_out_linear"):
|
733
|
+
diffusers_key = "norm_out.linear"
|
734
|
+
elif k.startswith("lora_transformer_proj_out"):
|
735
|
+
diffusers_key = "proj_out"
|
736
|
+
elif k.startswith("lora_transformer_x_embedder"):
|
737
|
+
diffusers_key = "x_embedder"
|
738
|
+
elif k.startswith("lora_transformer_time_text_embed_guidance_embedder_linear_"):
|
739
|
+
i = int(k.split("lora_transformer_time_text_embed_guidance_embedder_linear_")[-1])
|
740
|
+
diffusers_key = f"time_text_embed.guidance_embedder.linear_{i}"
|
741
|
+
elif k.startswith("lora_transformer_time_text_embed_text_embedder_linear_"):
|
742
|
+
i = int(k.split("lora_transformer_time_text_embed_text_embedder_linear_")[-1])
|
743
|
+
diffusers_key = f"time_text_embed.text_embedder.linear_{i}"
|
744
|
+
elif k.startswith("lora_transformer_time_text_embed_timestep_embedder_linear_"):
|
745
|
+
i = int(k.split("lora_transformer_time_text_embed_timestep_embedder_linear_")[-1])
|
746
|
+
diffusers_key = f"time_text_embed.timestep_embedder.linear_{i}"
|
712
747
|
else:
|
713
|
-
raise NotImplementedError
|
748
|
+
raise NotImplementedError(f"Handling for key ({k}) is not implemented.")
|
714
749
|
|
715
750
|
if "attn_" in k:
|
716
751
|
if "_to_out_0" in k:
|
@@ -782,7 +817,11 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
|
782
817
|
# has both `peft` and non-peft state dict.
|
783
818
|
has_peft_state_dict = any(k.startswith("transformer.") for k in state_dict)
|
784
819
|
if has_peft_state_dict:
|
785
|
-
state_dict = {
|
820
|
+
state_dict = {
|
821
|
+
k.replace("lora_down.weight", "lora_A.weight").replace("lora_up.weight", "lora_B.weight"): v
|
822
|
+
for k, v in state_dict.items()
|
823
|
+
if k.startswith("transformer.")
|
824
|
+
}
|
786
825
|
return state_dict
|
787
826
|
|
788
827
|
# Another weird one.
|
@@ -801,7 +840,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
|
801
840
|
if zero_status_pe:
|
802
841
|
logger.info(
|
803
842
|
"The `position_embedding` LoRA params are all zeros which make them ineffective. "
|
804
|
-
"So, we will purge them out of the
|
843
|
+
"So, we will purge them out of the current state dict to make loading possible."
|
805
844
|
)
|
806
845
|
|
807
846
|
else:
|
@@ -817,7 +856,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
|
817
856
|
if zero_status_t5:
|
818
857
|
logger.info(
|
819
858
|
"The `t5xxl` LoRA params are all zeros which make them ineffective. "
|
820
|
-
"So, we will purge them out of the
|
859
|
+
"So, we will purge them out of the current state dict to make loading possible."
|
821
860
|
)
|
822
861
|
else:
|
823
862
|
logger.info(
|
@@ -832,7 +871,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
|
832
871
|
if zero_status_diff_b:
|
833
872
|
logger.info(
|
834
873
|
"The `diff_b` LoRA params are all zeros which make them ineffective. "
|
835
|
-
"So, we will purge them out of the
|
874
|
+
"So, we will purge them out of the current state dict to make loading possible."
|
836
875
|
)
|
837
876
|
else:
|
838
877
|
logger.info(
|
@@ -848,7 +887,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
|
848
887
|
if zero_status_diff:
|
849
888
|
logger.info(
|
850
889
|
"The `diff` LoRA params are all zeros which make them ineffective. "
|
851
|
-
"So, we will purge them out of the
|
890
|
+
"So, we will purge them out of the current state dict to make loading possible."
|
852
891
|
)
|
853
892
|
else:
|
854
893
|
logger.info(
|
@@ -905,7 +944,7 @@ def _convert_xlabs_flux_lora_to_diffusers(old_state_dict):
|
|
905
944
|
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
|
906
945
|
|
907
946
|
# down_weight is copied to each split
|
908
|
-
ait_sd.update(
|
947
|
+
ait_sd.update(dict.fromkeys(ait_down_keys, down_weight))
|
909
948
|
|
910
949
|
# up_weight is split to each split
|
911
950
|
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
|
@@ -1219,7 +1258,7 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
|
|
1219
1258
|
f"double_blocks.{i}.txt_attn.norm.key_norm.scale"
|
1220
1259
|
)
|
1221
1260
|
|
1222
|
-
# single
|
1261
|
+
# single transformer blocks
|
1223
1262
|
for i in range(num_single_layers):
|
1224
1263
|
block_prefix = f"single_transformer_blocks.{i}."
|
1225
1264
|
|
@@ -1311,6 +1350,228 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
|
|
1311
1350
|
return converted_state_dict
|
1312
1351
|
|
1313
1352
|
|
1353
|
+
def _convert_fal_kontext_lora_to_diffusers(original_state_dict):
|
1354
|
+
converted_state_dict = {}
|
1355
|
+
original_state_dict_keys = list(original_state_dict.keys())
|
1356
|
+
num_layers = 19
|
1357
|
+
num_single_layers = 38
|
1358
|
+
inner_dim = 3072
|
1359
|
+
mlp_ratio = 4.0
|
1360
|
+
|
1361
|
+
# double transformer blocks
|
1362
|
+
for i in range(num_layers):
|
1363
|
+
block_prefix = f"transformer_blocks.{i}."
|
1364
|
+
original_block_prefix = "base_model.model."
|
1365
|
+
|
1366
|
+
for lora_key in ["lora_A", "lora_B"]:
|
1367
|
+
# norms
|
1368
|
+
converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.weight"] = original_state_dict.pop(
|
1369
|
+
f"{original_block_prefix}double_blocks.{i}.img_mod.lin.{lora_key}.weight"
|
1370
|
+
)
|
1371
|
+
if f"double_blocks.{i}.img_mod.lin.{lora_key}.bias" in original_state_dict_keys:
|
1372
|
+
converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.bias"] = original_state_dict.pop(
|
1373
|
+
f"{original_block_prefix}double_blocks.{i}.img_mod.lin.{lora_key}.bias"
|
1374
|
+
)
|
1375
|
+
|
1376
|
+
converted_state_dict[f"{block_prefix}norm1_context.linear.{lora_key}.weight"] = original_state_dict.pop(
|
1377
|
+
f"{original_block_prefix}double_blocks.{i}.txt_mod.lin.{lora_key}.weight"
|
1378
|
+
)
|
1379
|
+
|
1380
|
+
# Q, K, V
|
1381
|
+
if lora_key == "lora_A":
|
1382
|
+
sample_lora_weight = original_state_dict.pop(
|
1383
|
+
f"{original_block_prefix}double_blocks.{i}.img_attn.qkv.{lora_key}.weight"
|
1384
|
+
)
|
1385
|
+
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_lora_weight])
|
1386
|
+
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_lora_weight])
|
1387
|
+
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_lora_weight])
|
1388
|
+
|
1389
|
+
context_lora_weight = original_state_dict.pop(
|
1390
|
+
f"{original_block_prefix}double_blocks.{i}.txt_attn.qkv.{lora_key}.weight"
|
1391
|
+
)
|
1392
|
+
converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat(
|
1393
|
+
[context_lora_weight]
|
1394
|
+
)
|
1395
|
+
converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat(
|
1396
|
+
[context_lora_weight]
|
1397
|
+
)
|
1398
|
+
converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat(
|
1399
|
+
[context_lora_weight]
|
1400
|
+
)
|
1401
|
+
else:
|
1402
|
+
sample_q, sample_k, sample_v = torch.chunk(
|
1403
|
+
original_state_dict.pop(
|
1404
|
+
f"{original_block_prefix}double_blocks.{i}.img_attn.qkv.{lora_key}.weight"
|
1405
|
+
),
|
1406
|
+
3,
|
1407
|
+
dim=0,
|
1408
|
+
)
|
1409
|
+
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_q])
|
1410
|
+
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_k])
|
1411
|
+
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_v])
|
1412
|
+
|
1413
|
+
context_q, context_k, context_v = torch.chunk(
|
1414
|
+
original_state_dict.pop(
|
1415
|
+
f"{original_block_prefix}double_blocks.{i}.txt_attn.qkv.{lora_key}.weight"
|
1416
|
+
),
|
1417
|
+
3,
|
1418
|
+
dim=0,
|
1419
|
+
)
|
1420
|
+
converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat([context_q])
|
1421
|
+
converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat([context_k])
|
1422
|
+
converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat([context_v])
|
1423
|
+
|
1424
|
+
if f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias" in original_state_dict_keys:
|
1425
|
+
sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
|
1426
|
+
original_state_dict.pop(f"{original_block_prefix}double_blocks.{i}.img_attn.qkv.{lora_key}.bias"),
|
1427
|
+
3,
|
1428
|
+
dim=0,
|
1429
|
+
)
|
1430
|
+
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([sample_q_bias])
|
1431
|
+
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([sample_k_bias])
|
1432
|
+
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([sample_v_bias])
|
1433
|
+
|
1434
|
+
if f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias" in original_state_dict_keys:
|
1435
|
+
context_q_bias, context_k_bias, context_v_bias = torch.chunk(
|
1436
|
+
original_state_dict.pop(f"{original_block_prefix}double_blocks.{i}.txt_attn.qkv.{lora_key}.bias"),
|
1437
|
+
3,
|
1438
|
+
dim=0,
|
1439
|
+
)
|
1440
|
+
converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.bias"] = torch.cat([context_q_bias])
|
1441
|
+
converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.bias"] = torch.cat([context_k_bias])
|
1442
|
+
converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.bias"] = torch.cat([context_v_bias])
|
1443
|
+
|
1444
|
+
# ff img_mlp
|
1445
|
+
converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.weight"] = original_state_dict.pop(
|
1446
|
+
f"{original_block_prefix}double_blocks.{i}.img_mlp.0.{lora_key}.weight"
|
1447
|
+
)
|
1448
|
+
if f"{original_block_prefix}double_blocks.{i}.img_mlp.0.{lora_key}.bias" in original_state_dict_keys:
|
1449
|
+
converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.bias"] = original_state_dict.pop(
|
1450
|
+
f"{original_block_prefix}double_blocks.{i}.img_mlp.0.{lora_key}.bias"
|
1451
|
+
)
|
1452
|
+
|
1453
|
+
converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.weight"] = original_state_dict.pop(
|
1454
|
+
f"{original_block_prefix}double_blocks.{i}.img_mlp.2.{lora_key}.weight"
|
1455
|
+
)
|
1456
|
+
if f"{original_block_prefix}double_blocks.{i}.img_mlp.2.{lora_key}.bias" in original_state_dict_keys:
|
1457
|
+
converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.bias"] = original_state_dict.pop(
|
1458
|
+
f"{original_block_prefix}double_blocks.{i}.img_mlp.2.{lora_key}.bias"
|
1459
|
+
)
|
1460
|
+
|
1461
|
+
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.weight"] = original_state_dict.pop(
|
1462
|
+
f"{original_block_prefix}double_blocks.{i}.txt_mlp.0.{lora_key}.weight"
|
1463
|
+
)
|
1464
|
+
if f"{original_block_prefix}double_blocks.{i}.txt_mlp.0.{lora_key}.bias" in original_state_dict_keys:
|
1465
|
+
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.bias"] = original_state_dict.pop(
|
1466
|
+
f"{original_block_prefix}double_blocks.{i}.txt_mlp.0.{lora_key}.bias"
|
1467
|
+
)
|
1468
|
+
|
1469
|
+
converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.weight"] = original_state_dict.pop(
|
1470
|
+
f"{original_block_prefix}double_blocks.{i}.txt_mlp.2.{lora_key}.weight"
|
1471
|
+
)
|
1472
|
+
if f"{original_block_prefix}double_blocks.{i}.txt_mlp.2.{lora_key}.bias" in original_state_dict_keys:
|
1473
|
+
converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.bias"] = original_state_dict.pop(
|
1474
|
+
f"{original_block_prefix}double_blocks.{i}.txt_mlp.2.{lora_key}.bias"
|
1475
|
+
)
|
1476
|
+
|
1477
|
+
# output projections.
|
1478
|
+
converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.weight"] = original_state_dict.pop(
|
1479
|
+
f"{original_block_prefix}double_blocks.{i}.img_attn.proj.{lora_key}.weight"
|
1480
|
+
)
|
1481
|
+
if f"{original_block_prefix}double_blocks.{i}.img_attn.proj.{lora_key}.bias" in original_state_dict_keys:
|
1482
|
+
converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.bias"] = original_state_dict.pop(
|
1483
|
+
f"{original_block_prefix}double_blocks.{i}.img_attn.proj.{lora_key}.bias"
|
1484
|
+
)
|
1485
|
+
converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.weight"] = original_state_dict.pop(
|
1486
|
+
f"{original_block_prefix}double_blocks.{i}.txt_attn.proj.{lora_key}.weight"
|
1487
|
+
)
|
1488
|
+
if f"{original_block_prefix}double_blocks.{i}.txt_attn.proj.{lora_key}.bias" in original_state_dict_keys:
|
1489
|
+
converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.bias"] = original_state_dict.pop(
|
1490
|
+
f"{original_block_prefix}double_blocks.{i}.txt_attn.proj.{lora_key}.bias"
|
1491
|
+
)
|
1492
|
+
|
1493
|
+
# single transformer blocks
|
1494
|
+
for i in range(num_single_layers):
|
1495
|
+
block_prefix = f"single_transformer_blocks.{i}."
|
1496
|
+
|
1497
|
+
for lora_key in ["lora_A", "lora_B"]:
|
1498
|
+
# norm.linear <- single_blocks.0.modulation.lin
|
1499
|
+
converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.weight"] = original_state_dict.pop(
|
1500
|
+
f"{original_block_prefix}single_blocks.{i}.modulation.lin.{lora_key}.weight"
|
1501
|
+
)
|
1502
|
+
if f"{original_block_prefix}single_blocks.{i}.modulation.lin.{lora_key}.bias" in original_state_dict_keys:
|
1503
|
+
converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.bias"] = original_state_dict.pop(
|
1504
|
+
f"{original_block_prefix}single_blocks.{i}.modulation.lin.{lora_key}.bias"
|
1505
|
+
)
|
1506
|
+
|
1507
|
+
# Q, K, V, mlp
|
1508
|
+
mlp_hidden_dim = int(inner_dim * mlp_ratio)
|
1509
|
+
split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
|
1510
|
+
|
1511
|
+
if lora_key == "lora_A":
|
1512
|
+
lora_weight = original_state_dict.pop(
|
1513
|
+
f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.weight"
|
1514
|
+
)
|
1515
|
+
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([lora_weight])
|
1516
|
+
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([lora_weight])
|
1517
|
+
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([lora_weight])
|
1518
|
+
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([lora_weight])
|
1519
|
+
|
1520
|
+
if f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys:
|
1521
|
+
lora_bias = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias")
|
1522
|
+
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([lora_bias])
|
1523
|
+
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([lora_bias])
|
1524
|
+
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([lora_bias])
|
1525
|
+
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([lora_bias])
|
1526
|
+
else:
|
1527
|
+
q, k, v, mlp = torch.split(
|
1528
|
+
original_state_dict.pop(f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.weight"),
|
1529
|
+
split_size,
|
1530
|
+
dim=0,
|
1531
|
+
)
|
1532
|
+
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([q])
|
1533
|
+
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([k])
|
1534
|
+
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([v])
|
1535
|
+
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([mlp])
|
1536
|
+
|
1537
|
+
if f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys:
|
1538
|
+
q_bias, k_bias, v_bias, mlp_bias = torch.split(
|
1539
|
+
original_state_dict.pop(f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.bias"),
|
1540
|
+
split_size,
|
1541
|
+
dim=0,
|
1542
|
+
)
|
1543
|
+
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([q_bias])
|
1544
|
+
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([k_bias])
|
1545
|
+
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([v_bias])
|
1546
|
+
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([mlp_bias])
|
1547
|
+
|
1548
|
+
# output projections.
|
1549
|
+
converted_state_dict[f"{block_prefix}proj_out.{lora_key}.weight"] = original_state_dict.pop(
|
1550
|
+
f"{original_block_prefix}single_blocks.{i}.linear2.{lora_key}.weight"
|
1551
|
+
)
|
1552
|
+
if f"{original_block_prefix}single_blocks.{i}.linear2.{lora_key}.bias" in original_state_dict_keys:
|
1553
|
+
converted_state_dict[f"{block_prefix}proj_out.{lora_key}.bias"] = original_state_dict.pop(
|
1554
|
+
f"{original_block_prefix}single_blocks.{i}.linear2.{lora_key}.bias"
|
1555
|
+
)
|
1556
|
+
|
1557
|
+
for lora_key in ["lora_A", "lora_B"]:
|
1558
|
+
converted_state_dict[f"proj_out.{lora_key}.weight"] = original_state_dict.pop(
|
1559
|
+
f"{original_block_prefix}final_layer.linear.{lora_key}.weight"
|
1560
|
+
)
|
1561
|
+
if f"{original_block_prefix}final_layer.linear.{lora_key}.bias" in original_state_dict_keys:
|
1562
|
+
converted_state_dict[f"proj_out.{lora_key}.bias"] = original_state_dict.pop(
|
1563
|
+
f"{original_block_prefix}final_layer.linear.{lora_key}.bias"
|
1564
|
+
)
|
1565
|
+
|
1566
|
+
if len(original_state_dict) > 0:
|
1567
|
+
raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.")
|
1568
|
+
|
1569
|
+
for key in list(converted_state_dict.keys()):
|
1570
|
+
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
|
1571
|
+
|
1572
|
+
return converted_state_dict
|
1573
|
+
|
1574
|
+
|
1314
1575
|
def _convert_hunyuan_video_lora_to_diffusers(original_state_dict):
|
1315
1576
|
converted_state_dict = {k: original_state_dict.pop(k) for k in list(original_state_dict.keys())}
|
1316
1577
|
|
@@ -1561,45 +1822,286 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
|
|
1561
1822
|
converted_state_dict = {}
|
1562
1823
|
original_state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()}
|
1563
1824
|
|
1564
|
-
|
1825
|
+
block_numbers = {int(k.split(".")[1]) for k in original_state_dict if k.startswith("blocks.")}
|
1826
|
+
min_block = min(block_numbers)
|
1827
|
+
max_block = max(block_numbers)
|
1828
|
+
|
1565
1829
|
is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict)
|
1830
|
+
lora_down_key = "lora_A" if any("lora_A" in k for k in original_state_dict) else "lora_down"
|
1831
|
+
lora_up_key = "lora_B" if any("lora_B" in k for k in original_state_dict) else "lora_up"
|
1832
|
+
has_time_projection_weight = any(
|
1833
|
+
k.startswith("time_projection") and k.endswith(".weight") for k in original_state_dict
|
1834
|
+
)
|
1566
1835
|
|
1567
|
-
|
1836
|
+
def get_alpha_scales(down_weight, alpha_key):
|
1837
|
+
rank = down_weight.shape[0]
|
1838
|
+
alpha = original_state_dict.pop(alpha_key).item()
|
1839
|
+
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
|
1840
|
+
scale_down = scale
|
1841
|
+
scale_up = 1.0
|
1842
|
+
while scale_down * 2 < scale_up:
|
1843
|
+
scale_down *= 2
|
1844
|
+
scale_up /= 2
|
1845
|
+
return scale_down, scale_up
|
1846
|
+
|
1847
|
+
for key in list(original_state_dict.keys()):
|
1848
|
+
if key.endswith((".diff", ".diff_b")) and "norm" in key:
|
1849
|
+
# NOTE: we don't support this because norm layer diff keys are just zeroed values. We can support it
|
1850
|
+
# in future if needed and they are not zeroed.
|
1851
|
+
original_state_dict.pop(key)
|
1852
|
+
logger.debug(f"Removing {key} key from the state dict as it is a norm diff key. This is unsupported.")
|
1853
|
+
|
1854
|
+
if "time_projection" in key and not has_time_projection_weight:
|
1855
|
+
# AccVideo lora has diff bias keys but not the weight keys. This causes a weird problem where
|
1856
|
+
# our lora config adds the time proj lora layers, but we don't have the weights for them.
|
1857
|
+
# CausVid lora has the weight keys and the bias keys.
|
1858
|
+
original_state_dict.pop(key)
|
1859
|
+
|
1860
|
+
# For the `diff_b` keys, we treat them as lora_bias.
|
1861
|
+
# https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.lora_bias
|
1862
|
+
|
1863
|
+
for i in range(min_block, max_block + 1):
|
1568
1864
|
# Self-attention
|
1569
1865
|
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
|
1570
|
-
|
1571
|
-
|
1572
|
-
|
1573
|
-
|
1574
|
-
|
1575
|
-
|
1866
|
+
alpha_key = f"blocks.{i}.self_attn.{o}.alpha"
|
1867
|
+
has_alpha = alpha_key in original_state_dict
|
1868
|
+
original_key_A = f"blocks.{i}.self_attn.{o}.{lora_down_key}.weight"
|
1869
|
+
converted_key_A = f"blocks.{i}.attn1.{c}.lora_A.weight"
|
1870
|
+
|
1871
|
+
original_key_B = f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight"
|
1872
|
+
converted_key_B = f"blocks.{i}.attn1.{c}.lora_B.weight"
|
1873
|
+
|
1874
|
+
if has_alpha:
|
1875
|
+
down_weight = original_state_dict.pop(original_key_A)
|
1876
|
+
up_weight = original_state_dict.pop(original_key_B)
|
1877
|
+
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
|
1878
|
+
converted_state_dict[converted_key_A] = down_weight * scale_down
|
1879
|
+
converted_state_dict[converted_key_B] = up_weight * scale_up
|
1880
|
+
|
1881
|
+
else:
|
1882
|
+
if original_key_A in original_state_dict:
|
1883
|
+
converted_state_dict[converted_key_A] = original_state_dict.pop(original_key_A)
|
1884
|
+
if original_key_B in original_state_dict:
|
1885
|
+
converted_state_dict[converted_key_B] = original_state_dict.pop(original_key_B)
|
1886
|
+
|
1887
|
+
original_key = f"blocks.{i}.self_attn.{o}.diff_b"
|
1888
|
+
converted_key = f"blocks.{i}.attn1.{c}.lora_B.bias"
|
1889
|
+
if original_key in original_state_dict:
|
1890
|
+
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
1576
1891
|
|
1577
1892
|
# Cross-attention
|
1578
1893
|
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
|
1579
|
-
|
1580
|
-
|
1581
|
-
|
1582
|
-
|
1583
|
-
|
1584
|
-
|
1894
|
+
alpha_key = f"blocks.{i}.cross_attn.{o}.alpha"
|
1895
|
+
has_alpha = alpha_key in original_state_dict
|
1896
|
+
original_key_A = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
|
1897
|
+
converted_key_A = f"blocks.{i}.attn2.{c}.lora_A.weight"
|
1898
|
+
|
1899
|
+
original_key_B = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
|
1900
|
+
converted_key_B = f"blocks.{i}.attn2.{c}.lora_B.weight"
|
1901
|
+
|
1902
|
+
if original_key_A in original_state_dict:
|
1903
|
+
down_weight = original_state_dict.pop(original_key_A)
|
1904
|
+
converted_state_dict[converted_key_A] = down_weight
|
1905
|
+
if original_key_B in original_state_dict:
|
1906
|
+
up_weight = original_state_dict.pop(original_key_B)
|
1907
|
+
converted_state_dict[converted_key_B] = up_weight
|
1908
|
+
if has_alpha:
|
1909
|
+
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
|
1910
|
+
converted_state_dict[converted_key_A] *= scale_down
|
1911
|
+
converted_state_dict[converted_key_B] *= scale_up
|
1912
|
+
|
1913
|
+
original_key = f"blocks.{i}.cross_attn.{o}.diff_b"
|
1914
|
+
converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias"
|
1915
|
+
if original_key in original_state_dict:
|
1916
|
+
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
1585
1917
|
|
1586
1918
|
if is_i2v_lora:
|
1587
1919
|
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
|
1588
|
-
|
1589
|
-
|
1590
|
-
|
1591
|
-
|
1592
|
-
|
1593
|
-
|
1920
|
+
alpha_key = f"blocks.{i}.cross_attn.{o}.alpha"
|
1921
|
+
has_alpha = alpha_key in original_state_dict
|
1922
|
+
original_key_A = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
|
1923
|
+
converted_key_A = f"blocks.{i}.attn2.{c}.lora_A.weight"
|
1924
|
+
|
1925
|
+
original_key_B = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
|
1926
|
+
converted_key_B = f"blocks.{i}.attn2.{c}.lora_B.weight"
|
1927
|
+
|
1928
|
+
if original_key_A in original_state_dict:
|
1929
|
+
down_weight = original_state_dict.pop(original_key_A)
|
1930
|
+
converted_state_dict[converted_key_A] = down_weight
|
1931
|
+
if original_key_B in original_state_dict:
|
1932
|
+
up_weight = original_state_dict.pop(original_key_B)
|
1933
|
+
converted_state_dict[converted_key_B] = up_weight
|
1934
|
+
if has_alpha:
|
1935
|
+
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
|
1936
|
+
converted_state_dict[converted_key_A] *= scale_down
|
1937
|
+
converted_state_dict[converted_key_B] *= scale_up
|
1938
|
+
|
1939
|
+
original_key = f"blocks.{i}.cross_attn.{o}.diff_b"
|
1940
|
+
converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias"
|
1941
|
+
if original_key in original_state_dict:
|
1942
|
+
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
1594
1943
|
|
1595
1944
|
# FFN
|
1596
1945
|
for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]):
|
1597
|
-
|
1598
|
-
|
1946
|
+
alpha_key = f"blocks.{i}.{o}.alpha"
|
1947
|
+
has_alpha = alpha_key in original_state_dict
|
1948
|
+
original_key_A = f"blocks.{i}.{o}.{lora_down_key}.weight"
|
1949
|
+
converted_key_A = f"blocks.{i}.ffn.{c}.lora_A.weight"
|
1950
|
+
|
1951
|
+
original_key_B = f"blocks.{i}.{o}.{lora_up_key}.weight"
|
1952
|
+
converted_key_B = f"blocks.{i}.ffn.{c}.lora_B.weight"
|
1953
|
+
|
1954
|
+
if original_key_A in original_state_dict:
|
1955
|
+
down_weight = original_state_dict.pop(original_key_A)
|
1956
|
+
converted_state_dict[converted_key_A] = down_weight
|
1957
|
+
if original_key_B in original_state_dict:
|
1958
|
+
up_weight = original_state_dict.pop(original_key_B)
|
1959
|
+
converted_state_dict[converted_key_B] = up_weight
|
1960
|
+
if has_alpha:
|
1961
|
+
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
|
1962
|
+
converted_state_dict[converted_key_A] *= scale_down
|
1963
|
+
converted_state_dict[converted_key_B] *= scale_up
|
1964
|
+
|
1965
|
+
original_key = f"blocks.{i}.{o}.diff_b"
|
1966
|
+
converted_key = f"blocks.{i}.ffn.{c}.lora_B.bias"
|
1967
|
+
if original_key in original_state_dict:
|
1968
|
+
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
1969
|
+
|
1970
|
+
# Remaining.
|
1971
|
+
if original_state_dict:
|
1972
|
+
if any("time_projection" in k for k in original_state_dict):
|
1973
|
+
original_key = f"time_projection.1.{lora_down_key}.weight"
|
1974
|
+
converted_key = "condition_embedder.time_proj.lora_A.weight"
|
1975
|
+
if original_key in original_state_dict:
|
1976
|
+
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
1977
|
+
|
1978
|
+
original_key = f"time_projection.1.{lora_up_key}.weight"
|
1979
|
+
converted_key = "condition_embedder.time_proj.lora_B.weight"
|
1980
|
+
if original_key in original_state_dict:
|
1981
|
+
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
1982
|
+
|
1983
|
+
if "time_projection.1.diff_b" in original_state_dict:
|
1984
|
+
converted_state_dict["condition_embedder.time_proj.lora_B.bias"] = original_state_dict.pop(
|
1985
|
+
"time_projection.1.diff_b"
|
1986
|
+
)
|
1987
|
+
|
1988
|
+
if any("head.head" in k for k in state_dict):
|
1989
|
+
converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop(
|
1990
|
+
f"head.head.{lora_down_key}.weight"
|
1599
1991
|
)
|
1600
|
-
converted_state_dict[
|
1601
|
-
|
1992
|
+
converted_state_dict["proj_out.lora_B.weight"] = original_state_dict.pop(f"head.head.{lora_up_key}.weight")
|
1993
|
+
if "head.head.diff_b" in original_state_dict:
|
1994
|
+
converted_state_dict["proj_out.lora_B.bias"] = original_state_dict.pop("head.head.diff_b")
|
1995
|
+
|
1996
|
+
for text_time in ["text_embedding", "time_embedding"]:
|
1997
|
+
if any(text_time in k for k in original_state_dict):
|
1998
|
+
for b_n in [0, 2]:
|
1999
|
+
diffusers_b_n = 1 if b_n == 0 else 2
|
2000
|
+
diffusers_name = (
|
2001
|
+
"condition_embedder.text_embedder"
|
2002
|
+
if text_time == "text_embedding"
|
2003
|
+
else "condition_embedder.time_embedder"
|
2004
|
+
)
|
2005
|
+
if any(f"{text_time}.{b_n}" in k for k in original_state_dict):
|
2006
|
+
converted_state_dict[f"{diffusers_name}.linear_{diffusers_b_n}.lora_A.weight"] = (
|
2007
|
+
original_state_dict.pop(f"{text_time}.{b_n}.{lora_down_key}.weight")
|
2008
|
+
)
|
2009
|
+
converted_state_dict[f"{diffusers_name}.linear_{diffusers_b_n}.lora_B.weight"] = (
|
2010
|
+
original_state_dict.pop(f"{text_time}.{b_n}.{lora_up_key}.weight")
|
2011
|
+
)
|
2012
|
+
if f"{text_time}.{b_n}.diff_b" in original_state_dict:
|
2013
|
+
converted_state_dict[f"{diffusers_name}.linear_{diffusers_b_n}.lora_B.bias"] = (
|
2014
|
+
original_state_dict.pop(f"{text_time}.{b_n}.diff_b")
|
2015
|
+
)
|
2016
|
+
|
2017
|
+
for img_ours, img_theirs in [
|
2018
|
+
("ff.net.0.proj", "img_emb.proj.1"),
|
2019
|
+
("ff.net.2", "img_emb.proj.3"),
|
2020
|
+
]:
|
2021
|
+
original_key = f"{img_theirs}.{lora_down_key}.weight"
|
2022
|
+
converted_key = f"condition_embedder.image_embedder.{img_ours}.lora_A.weight"
|
2023
|
+
if original_key in original_state_dict:
|
2024
|
+
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
2025
|
+
|
2026
|
+
original_key = f"{img_theirs}.{lora_up_key}.weight"
|
2027
|
+
converted_key = f"condition_embedder.image_embedder.{img_ours}.lora_B.weight"
|
2028
|
+
if original_key in original_state_dict:
|
2029
|
+
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
2030
|
+
bias_key_theirs = original_key.removesuffix(f".{lora_up_key}.weight") + ".diff_b"
|
2031
|
+
if bias_key_theirs in original_state_dict:
|
2032
|
+
bias_key = converted_key.removesuffix(".weight") + ".bias"
|
2033
|
+
converted_state_dict[bias_key] = original_state_dict.pop(bias_key_theirs)
|
2034
|
+
|
2035
|
+
if len(original_state_dict) > 0:
|
2036
|
+
diff = all(".diff" in k for k in original_state_dict)
|
2037
|
+
if diff:
|
2038
|
+
diff_keys = {k for k in original_state_dict if k.endswith(".diff")}
|
2039
|
+
if not all("lora" not in k for k in diff_keys):
|
2040
|
+
raise ValueError
|
2041
|
+
logger.info(
|
2042
|
+
"The remaining `state_dict` contains `diff` keys which we do not handle yet. If you see performance issues, please file an issue: "
|
2043
|
+
"https://github.com/huggingface/diffusers//issues/new"
|
1602
2044
|
)
|
2045
|
+
else:
|
2046
|
+
raise ValueError(f"`state_dict` should be empty at this point but has {original_state_dict.keys()=}")
|
2047
|
+
|
2048
|
+
for key in list(converted_state_dict.keys()):
|
2049
|
+
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
|
2050
|
+
|
2051
|
+
return converted_state_dict
|
2052
|
+
|
2053
|
+
|
2054
|
+
def _convert_musubi_wan_lora_to_diffusers(state_dict):
|
2055
|
+
# https://github.com/kohya-ss/musubi-tuner
|
2056
|
+
converted_state_dict = {}
|
2057
|
+
original_state_dict = {k[len("lora_unet_") :]: v for k, v in state_dict.items()}
|
2058
|
+
|
2059
|
+
num_blocks = len({k.split("blocks_")[1].split("_")[0] for k in original_state_dict})
|
2060
|
+
is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict)
|
2061
|
+
|
2062
|
+
def get_alpha_scales(down_weight, key):
|
2063
|
+
rank = down_weight.shape[0]
|
2064
|
+
alpha = original_state_dict.pop(key + ".alpha").item()
|
2065
|
+
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
|
2066
|
+
scale_down = scale
|
2067
|
+
scale_up = 1.0
|
2068
|
+
while scale_down * 2 < scale_up:
|
2069
|
+
scale_down *= 2
|
2070
|
+
scale_up /= 2
|
2071
|
+
return scale_down, scale_up
|
2072
|
+
|
2073
|
+
for i in range(num_blocks):
|
2074
|
+
# Self-attention
|
2075
|
+
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
|
2076
|
+
down_weight = original_state_dict.pop(f"blocks_{i}_self_attn_{o}.lora_down.weight")
|
2077
|
+
up_weight = original_state_dict.pop(f"blocks_{i}_self_attn_{o}.lora_up.weight")
|
2078
|
+
scale_down, scale_up = get_alpha_scales(down_weight, f"blocks_{i}_self_attn_{o}")
|
2079
|
+
converted_state_dict[f"blocks.{i}.attn1.{c}.lora_A.weight"] = down_weight * scale_down
|
2080
|
+
converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.weight"] = up_weight * scale_up
|
2081
|
+
|
2082
|
+
# Cross-attention
|
2083
|
+
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
|
2084
|
+
down_weight = original_state_dict.pop(f"blocks_{i}_cross_attn_{o}.lora_down.weight")
|
2085
|
+
up_weight = original_state_dict.pop(f"blocks_{i}_cross_attn_{o}.lora_up.weight")
|
2086
|
+
scale_down, scale_up = get_alpha_scales(down_weight, f"blocks_{i}_cross_attn_{o}")
|
2087
|
+
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = down_weight * scale_down
|
2088
|
+
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = up_weight * scale_up
|
2089
|
+
|
2090
|
+
if is_i2v_lora:
|
2091
|
+
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
|
2092
|
+
down_weight = original_state_dict.pop(f"blocks_{i}_cross_attn_{o}.lora_down.weight")
|
2093
|
+
up_weight = original_state_dict.pop(f"blocks_{i}_cross_attn_{o}.lora_up.weight")
|
2094
|
+
scale_down, scale_up = get_alpha_scales(down_weight, f"blocks_{i}_cross_attn_{o}")
|
2095
|
+
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = down_weight * scale_down
|
2096
|
+
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = up_weight * scale_up
|
2097
|
+
|
2098
|
+
# FFN
|
2099
|
+
for o, c in zip(["ffn_0", "ffn_2"], ["net.0.proj", "net.2"]):
|
2100
|
+
down_weight = original_state_dict.pop(f"blocks_{i}_{o}.lora_down.weight")
|
2101
|
+
up_weight = original_state_dict.pop(f"blocks_{i}_{o}.lora_up.weight")
|
2102
|
+
scale_down, scale_up = get_alpha_scales(down_weight, f"blocks_{i}_{o}")
|
2103
|
+
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_A.weight"] = down_weight * scale_down
|
2104
|
+
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.weight"] = up_weight * scale_up
|
1603
2105
|
|
1604
2106
|
if len(original_state_dict) > 0:
|
1605
2107
|
raise ValueError(f"`state_dict` should be empty at this point but has {original_state_dict.keys()=}")
|
@@ -1608,3 +2110,123 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
|
|
1608
2110
|
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
|
1609
2111
|
|
1610
2112
|
return converted_state_dict
|
2113
|
+
|
2114
|
+
|
2115
|
+
def _convert_non_diffusers_hidream_lora_to_diffusers(state_dict, non_diffusers_prefix="diffusion_model"):
|
2116
|
+
if not all(k.startswith(non_diffusers_prefix) for k in state_dict):
|
2117
|
+
raise ValueError("Invalid LoRA state dict for HiDream.")
|
2118
|
+
converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()}
|
2119
|
+
converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
|
2120
|
+
return converted_state_dict
|
2121
|
+
|
2122
|
+
|
2123
|
+
def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_prefix="diffusion_model"):
|
2124
|
+
if not all(k.startswith(f"{non_diffusers_prefix}.") for k in state_dict):
|
2125
|
+
raise ValueError("Invalid LoRA state dict for LTX-Video.")
|
2126
|
+
converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()}
|
2127
|
+
converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
|
2128
|
+
return converted_state_dict
|
2129
|
+
|
2130
|
+
|
2131
|
+
def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict):
|
2132
|
+
has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
|
2133
|
+
if has_lora_unet:
|
2134
|
+
state_dict = {k.removeprefix("lora_unet_"): v for k, v in state_dict.items()}
|
2135
|
+
|
2136
|
+
def convert_key(key: str) -> str:
|
2137
|
+
prefix = "transformer_blocks"
|
2138
|
+
if "." in key:
|
2139
|
+
base, suffix = key.rsplit(".", 1)
|
2140
|
+
else:
|
2141
|
+
base, suffix = key, ""
|
2142
|
+
|
2143
|
+
start = f"{prefix}_"
|
2144
|
+
rest = base[len(start) :]
|
2145
|
+
|
2146
|
+
if "." in rest:
|
2147
|
+
head, tail = rest.split(".", 1)
|
2148
|
+
tail = "." + tail
|
2149
|
+
else:
|
2150
|
+
head, tail = rest, ""
|
2151
|
+
|
2152
|
+
# Protected n-grams that must keep their internal underscores
|
2153
|
+
protected = {
|
2154
|
+
# pairs
|
2155
|
+
("to", "q"),
|
2156
|
+
("to", "k"),
|
2157
|
+
("to", "v"),
|
2158
|
+
("to", "out"),
|
2159
|
+
("add", "q"),
|
2160
|
+
("add", "k"),
|
2161
|
+
("add", "v"),
|
2162
|
+
("txt", "mlp"),
|
2163
|
+
("img", "mlp"),
|
2164
|
+
("txt", "mod"),
|
2165
|
+
("img", "mod"),
|
2166
|
+
# triplets
|
2167
|
+
("add", "q", "proj"),
|
2168
|
+
("add", "k", "proj"),
|
2169
|
+
("add", "v", "proj"),
|
2170
|
+
("to", "add", "out"),
|
2171
|
+
}
|
2172
|
+
|
2173
|
+
prot_by_len = {}
|
2174
|
+
for ng in protected:
|
2175
|
+
prot_by_len.setdefault(len(ng), set()).add(ng)
|
2176
|
+
|
2177
|
+
parts = head.split("_")
|
2178
|
+
merged = []
|
2179
|
+
i = 0
|
2180
|
+
lengths_desc = sorted(prot_by_len.keys(), reverse=True)
|
2181
|
+
|
2182
|
+
while i < len(parts):
|
2183
|
+
matched = False
|
2184
|
+
for L in lengths_desc:
|
2185
|
+
if i + L <= len(parts) and tuple(parts[i : i + L]) in prot_by_len[L]:
|
2186
|
+
merged.append("_".join(parts[i : i + L]))
|
2187
|
+
i += L
|
2188
|
+
matched = True
|
2189
|
+
break
|
2190
|
+
if not matched:
|
2191
|
+
merged.append(parts[i])
|
2192
|
+
i += 1
|
2193
|
+
|
2194
|
+
head_converted = ".".join(merged)
|
2195
|
+
converted_base = f"{prefix}.{head_converted}{tail}"
|
2196
|
+
return converted_base + (("." + suffix) if suffix else "")
|
2197
|
+
|
2198
|
+
state_dict = {convert_key(k): v for k, v in state_dict.items()}
|
2199
|
+
|
2200
|
+
converted_state_dict = {}
|
2201
|
+
all_keys = list(state_dict.keys())
|
2202
|
+
down_key = ".lora_down.weight"
|
2203
|
+
up_key = ".lora_up.weight"
|
2204
|
+
|
2205
|
+
def get_alpha_scales(down_weight, alpha_key):
|
2206
|
+
rank = down_weight.shape[0]
|
2207
|
+
alpha = state_dict.pop(alpha_key).item()
|
2208
|
+
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
|
2209
|
+
scale_down = scale
|
2210
|
+
scale_up = 1.0
|
2211
|
+
while scale_down * 2 < scale_up:
|
2212
|
+
scale_down *= 2
|
2213
|
+
scale_up /= 2
|
2214
|
+
return scale_down, scale_up
|
2215
|
+
|
2216
|
+
for k in all_keys:
|
2217
|
+
if k.endswith(down_key):
|
2218
|
+
diffusers_down_key = k.replace(down_key, ".lora_A.weight")
|
2219
|
+
diffusers_up_key = k.replace(down_key, up_key).replace(up_key, ".lora_B.weight")
|
2220
|
+
alpha_key = k.replace(down_key, ".alpha")
|
2221
|
+
|
2222
|
+
down_weight = state_dict.pop(k)
|
2223
|
+
up_weight = state_dict.pop(k.replace(down_key, up_key))
|
2224
|
+
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
|
2225
|
+
converted_state_dict[diffusers_down_key] = down_weight * scale_down
|
2226
|
+
converted_state_dict[diffusers_up_key] = up_weight * scale_up
|
2227
|
+
|
2228
|
+
if len(state_dict) > 0:
|
2229
|
+
raise ValueError(f"`state_dict` should be empty at this point but has {state_dict.keys()=}")
|
2230
|
+
|
2231
|
+
converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
|
2232
|
+
return converted_state_dict
|