diffusers 0.27.0__py3-none-any.whl → 0.32.2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diffusers/__init__.py +233 -6
- diffusers/callbacks.py +209 -0
- diffusers/commands/env.py +102 -6
- diffusers/configuration_utils.py +45 -16
- diffusers/dependency_versions_table.py +4 -3
- diffusers/image_processor.py +434 -110
- diffusers/loaders/__init__.py +42 -9
- diffusers/loaders/ip_adapter.py +626 -36
- diffusers/loaders/lora_base.py +900 -0
- diffusers/loaders/lora_conversion_utils.py +991 -125
- diffusers/loaders/lora_pipeline.py +3812 -0
- diffusers/loaders/peft.py +571 -7
- diffusers/loaders/single_file.py +405 -173
- diffusers/loaders/single_file_model.py +385 -0
- diffusers/loaders/single_file_utils.py +1783 -713
- diffusers/loaders/textual_inversion.py +41 -23
- diffusers/loaders/transformer_flux.py +181 -0
- diffusers/loaders/transformer_sd3.py +89 -0
- diffusers/loaders/unet.py +464 -540
- diffusers/loaders/unet_loader_utils.py +163 -0
- diffusers/models/__init__.py +76 -7
- diffusers/models/activations.py +65 -10
- diffusers/models/adapter.py +53 -53
- diffusers/models/attention.py +605 -18
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +4304 -687
- diffusers/models/autoencoders/__init__.py +8 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +15 -17
- diffusers/models/autoencoders/autoencoder_dc.py +620 -0
- diffusers/models/autoencoders/autoencoder_kl.py +110 -28
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1482 -0
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +19 -24
- diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
- diffusers/models/autoencoders/autoencoder_tiny.py +21 -18
- diffusers/models/autoencoders/consistency_decoder_vae.py +45 -20
- diffusers/models/autoencoders/vae.py +41 -29
- diffusers/models/autoencoders/vq_model.py +182 -0
- diffusers/models/controlnet.py +47 -800
- diffusers/models/controlnet_flux.py +70 -0
- diffusers/models/controlnet_sd3.py +68 -0
- diffusers/models/controlnet_sparsectrl.py +116 -0
- diffusers/models/controlnets/__init__.py +23 -0
- diffusers/models/controlnets/controlnet.py +872 -0
- diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +9 -9
- diffusers/models/controlnets/controlnet_flux.py +536 -0
- diffusers/models/controlnets/controlnet_hunyuan.py +401 -0
- diffusers/models/controlnets/controlnet_sd3.py +489 -0
- diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
- diffusers/models/controlnets/controlnet_union.py +832 -0
- diffusers/models/controlnets/controlnet_xs.py +1946 -0
- diffusers/models/controlnets/multicontrolnet.py +183 -0
- diffusers/models/downsampling.py +85 -18
- diffusers/models/embeddings.py +1856 -158
- diffusers/models/embeddings_flax.py +23 -9
- diffusers/models/model_loading_utils.py +480 -0
- diffusers/models/modeling_flax_pytorch_utils.py +2 -1
- diffusers/models/modeling_flax_utils.py +2 -7
- diffusers/models/modeling_outputs.py +14 -0
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +611 -146
- diffusers/models/normalization.py +361 -20
- diffusers/models/resnet.py +18 -23
- diffusers/models/transformers/__init__.py +16 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +544 -0
- diffusers/models/transformers/cogvideox_transformer_3d.py +542 -0
- diffusers/models/transformers/dit_transformer_2d.py +240 -0
- diffusers/models/transformers/dual_transformer_2d.py +9 -8
- diffusers/models/transformers/hunyuan_transformer_2d.py +578 -0
- diffusers/models/transformers/latte_transformer_3d.py +327 -0
- diffusers/models/transformers/lumina_nextdit2d.py +340 -0
- diffusers/models/transformers/pixart_transformer_2d.py +445 -0
- diffusers/models/transformers/prior_transformer.py +13 -13
- diffusers/models/transformers/sana_transformer.py +488 -0
- diffusers/models/transformers/stable_audio_transformer.py +458 -0
- diffusers/models/transformers/t5_film_transformer.py +17 -19
- diffusers/models/transformers/transformer_2d.py +297 -187
- diffusers/models/transformers/transformer_allegro.py +422 -0
- diffusers/models/transformers/transformer_cogview3plus.py +386 -0
- diffusers/models/transformers/transformer_flux.py +593 -0
- diffusers/models/transformers/transformer_hunyuan_video.py +791 -0
- diffusers/models/transformers/transformer_ltx.py +469 -0
- diffusers/models/transformers/transformer_mochi.py +499 -0
- diffusers/models/transformers/transformer_sd3.py +461 -0
- diffusers/models/transformers/transformer_temporal.py +21 -19
- diffusers/models/unets/unet_1d.py +8 -8
- diffusers/models/unets/unet_1d_blocks.py +31 -31
- diffusers/models/unets/unet_2d.py +17 -10
- diffusers/models/unets/unet_2d_blocks.py +225 -149
- diffusers/models/unets/unet_2d_condition.py +50 -53
- diffusers/models/unets/unet_2d_condition_flax.py +6 -5
- diffusers/models/unets/unet_3d_blocks.py +192 -1057
- diffusers/models/unets/unet_3d_condition.py +22 -27
- diffusers/models/unets/unet_i2vgen_xl.py +22 -18
- diffusers/models/unets/unet_kandinsky3.py +2 -2
- diffusers/models/unets/unet_motion_model.py +1413 -89
- diffusers/models/unets/unet_spatio_temporal_condition.py +40 -16
- diffusers/models/unets/unet_stable_cascade.py +19 -18
- diffusers/models/unets/uvit_2d.py +2 -2
- diffusers/models/upsampling.py +95 -26
- diffusers/models/vq_model.py +12 -164
- diffusers/optimization.py +1 -1
- diffusers/pipelines/__init__.py +202 -3
- diffusers/pipelines/allegro/__init__.py +48 -0
- diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
- diffusers/pipelines/allegro/pipeline_output.py +23 -0
- diffusers/pipelines/amused/pipeline_amused.py +12 -12
- diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
- diffusers/pipelines/animatediff/__init__.py +8 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +122 -109
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1106 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1288 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1010 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +236 -180
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
- diffusers/pipelines/animatediff/pipeline_output.py +3 -2
- diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +58 -39
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +121 -36
- diffusers/pipelines/aura_flow/__init__.py +48 -0
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +584 -0
- diffusers/pipelines/auto_pipeline.py +196 -28
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +6 -6
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/cogvideo/__init__.py +54 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +772 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +885 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +851 -0
- diffusers/pipelines/cogvideo/pipeline_output.py +20 -0
- diffusers/pipelines/cogview3/__init__.py +47 -0
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
- diffusers/pipelines/cogview3/pipeline_output.py +21 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +6 -6
- diffusers/pipelines/controlnet/__init__.py +86 -80
- diffusers/pipelines/controlnet/multicontrolnet.py +7 -182
- diffusers/pipelines/controlnet/pipeline_controlnet.py +134 -87
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +93 -77
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +88 -197
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +136 -90
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +176 -80
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +125 -89
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
- diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1060 -0
- diffusers/pipelines/controlnet_sd3/__init__.py +57 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +1133 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
- diffusers/pipelines/controlnet_xs/__init__.py +68 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +916 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1111 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +16 -30
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +20 -35
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +23 -41
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +22 -38
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +25 -41
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +19 -34
- diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
- diffusers/pipelines/deepfloyd_if/watermark.py +1 -1
- diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +70 -30
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +48 -25
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +21 -20
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +27 -29
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +33 -27
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +33 -23
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +36 -30
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +102 -69
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/dit/pipeline_dit.py +7 -4
- diffusers/pipelines/flux/__init__.py +69 -0
- diffusers/pipelines/flux/modeling_flux.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +957 -0
- diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
- diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
- diffusers/pipelines/flux/pipeline_output.py +37 -0
- diffusers/pipelines/free_init_utils.py +41 -38
- diffusers/pipelines/free_noise_utils.py +596 -0
- diffusers/pipelines/hunyuan_video/__init__.py +48 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
- diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
- diffusers/pipelines/hunyuandit/__init__.py +48 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +916 -0
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +32 -29
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +34 -31
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +22 -35
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +26 -37
- diffusers/pipelines/kolors/__init__.py +54 -0
- diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1250 -0
- diffusers/pipelines/kolors/pipeline_output.py +21 -0
- diffusers/pipelines/kolors/text_encoder.py +889 -0
- diffusers/pipelines/kolors/tokenizer.py +338 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +82 -62
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +77 -60
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +12 -12
- diffusers/pipelines/latte/__init__.py +48 -0
- diffusers/pipelines/latte/pipeline_latte.py +881 -0
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +80 -74
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +85 -76
- diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
- diffusers/pipelines/ltx/__init__.py +50 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
- diffusers/pipelines/ltx/pipeline_output.py +20 -0
- diffusers/pipelines/lumina/__init__.py +48 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +890 -0
- diffusers/pipelines/marigold/__init__.py +50 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +576 -0
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
- diffusers/pipelines/mochi/__init__.py +48 -0
- diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
- diffusers/pipelines/mochi/pipeline_output.py +20 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
- diffusers/pipelines/pag/__init__.py +80 -0
- diffusers/pipelines/pag/pag_utils.py +243 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1328 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1610 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +969 -0
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +865 -0
- diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
- diffusers/pipelines/pag/pipeline_pag_sd.py +1062 -0
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +994 -0
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +866 -0
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1345 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1544 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1776 -0
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
- diffusers/pipelines/pia/pipeline_pia.py +74 -164
- diffusers/pipelines/pipeline_flax_utils.py +5 -10
- diffusers/pipelines/pipeline_loading_utils.py +515 -53
- diffusers/pipelines/pipeline_utils.py +411 -222
- diffusers/pipelines/pixart_alpha/__init__.py +8 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +76 -93
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +873 -0
- diffusers/pipelines/sana/__init__.py +47 -0
- diffusers/pipelines/sana/pipeline_output.py +21 -0
- diffusers/pipelines/sana/pipeline_sana.py +884 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +27 -23
- diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_audio/__init__.py +50 -0
- diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +756 -0
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +71 -25
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +35 -34
- diffusers/pipelines/stable_diffusion/__init__.py +0 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +20 -11
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +145 -79
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +43 -28
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +100 -68
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +109 -201
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +131 -32
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +247 -87
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +30 -29
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +35 -27
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +49 -42
- diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
- diffusers/pipelines/stable_diffusion_3/__init__.py +54 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_output.py +21 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +1140 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +1036 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1250 -0
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +29 -20
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +59 -58
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +31 -25
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +38 -22
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -24
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -23
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +107 -67
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +316 -69
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +98 -30
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +121 -83
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +161 -105
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +142 -218
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -29
- diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +69 -39
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +105 -74
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +29 -49
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +32 -93
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +37 -25
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +54 -40
- diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
- diffusers/pipelines/unidiffuser/modeling_uvit.py +12 -12
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +29 -28
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +6 -8
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +15 -14
- diffusers/{models/dual_transformer_2d.py → quantizers/__init__.py} +2 -6
- diffusers/quantizers/auto.py +139 -0
- diffusers/quantizers/base.py +233 -0
- diffusers/quantizers/bitsandbytes/__init__.py +2 -0
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
- diffusers/quantizers/bitsandbytes/utils.py +306 -0
- diffusers/quantizers/gguf/__init__.py +1 -0
- diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
- diffusers/quantizers/gguf/utils.py +456 -0
- diffusers/quantizers/quantization_config.py +669 -0
- diffusers/quantizers/torchao/__init__.py +15 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +292 -0
- diffusers/schedulers/__init__.py +12 -2
- diffusers/schedulers/deprecated/__init__.py +1 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
- diffusers/schedulers/scheduling_amused.py +5 -5
- diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
- diffusers/schedulers/scheduling_consistency_models.py +23 -25
- diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
- diffusers/schedulers/scheduling_ddim.py +27 -26
- diffusers/schedulers/scheduling_ddim_cogvideox.py +452 -0
- diffusers/schedulers/scheduling_ddim_flax.py +2 -1
- diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
- diffusers/schedulers/scheduling_ddim_parallel.py +32 -31
- diffusers/schedulers/scheduling_ddpm.py +27 -30
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +33 -36
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
- diffusers/schedulers/scheduling_deis_multistep.py +150 -50
- diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +221 -84
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +158 -52
- diffusers/schedulers/scheduling_dpmsolver_sde.py +153 -34
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +275 -86
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +81 -57
- diffusers/schedulers/scheduling_edm_euler.py +62 -39
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +30 -29
- diffusers/schedulers/scheduling_euler_discrete.py +255 -74
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +458 -0
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +320 -0
- diffusers/schedulers/scheduling_heun_discrete.py +174 -46
- diffusers/schedulers/scheduling_ipndm.py +9 -9
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +138 -29
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +132 -26
- diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
- diffusers/schedulers/scheduling_lcm.py +23 -29
- diffusers/schedulers/scheduling_lms_discrete.py +105 -28
- diffusers/schedulers/scheduling_pndm.py +20 -20
- diffusers/schedulers/scheduling_repaint.py +21 -21
- diffusers/schedulers/scheduling_sasolver.py +157 -60
- diffusers/schedulers/scheduling_sde_ve.py +19 -19
- diffusers/schedulers/scheduling_tcd.py +41 -36
- diffusers/schedulers/scheduling_unclip.py +19 -16
- diffusers/schedulers/scheduling_unipc_multistep.py +243 -47
- diffusers/schedulers/scheduling_utils.py +12 -5
- diffusers/schedulers/scheduling_utils_flax.py +1 -3
- diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
- diffusers/training_utils.py +214 -30
- diffusers/utils/__init__.py +17 -1
- diffusers/utils/constants.py +3 -0
- diffusers/utils/doc_utils.py +1 -0
- diffusers/utils/dummy_pt_objects.py +592 -7
- diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
- diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +1001 -71
- diffusers/utils/dynamic_modules_utils.py +34 -29
- diffusers/utils/export_utils.py +50 -6
- diffusers/utils/hub_utils.py +131 -17
- diffusers/utils/import_utils.py +210 -8
- diffusers/utils/loading_utils.py +118 -5
- diffusers/utils/logging.py +4 -2
- diffusers/utils/peft_utils.py +37 -7
- diffusers/utils/state_dict_utils.py +13 -2
- diffusers/utils/testing_utils.py +193 -11
- diffusers/utils/torch_utils.py +4 -0
- diffusers/video_processor.py +113 -0
- {diffusers-0.27.0.dist-info → diffusers-0.32.2.dist-info}/METADATA +82 -91
- diffusers-0.32.2.dist-info/RECORD +550 -0
- {diffusers-0.27.0.dist-info → diffusers-0.32.2.dist-info}/WHEEL +1 -1
- diffusers/loaders/autoencoder.py +0 -146
- diffusers/loaders/controlnet.py +0 -136
- diffusers/loaders/lora.py +0 -1349
- diffusers/models/prior_transformer.py +0 -12
- diffusers/models/t5_film_transformer.py +0 -70
- diffusers/models/transformer_2d.py +0 -25
- diffusers/models/transformer_temporal.py +0 -34
- diffusers/models/unet_1d.py +0 -26
- diffusers/models/unet_1d_blocks.py +0 -203
- diffusers/models/unet_2d.py +0 -27
- diffusers/models/unet_2d_blocks.py +0 -375
- diffusers/models/unet_2d_condition.py +0 -25
- diffusers-0.27.0.dist-info/RECORD +0 -399
- {diffusers-0.27.0.dist-info → diffusers-0.32.2.dist-info}/LICENSE +0 -0
- {diffusers-0.27.0.dist-info → diffusers-0.32.2.dist-info}/entry_points.txt +0 -0
- {diffusers-0.27.0.dist-info → diffusers-0.32.2.dist-info}/top_level.txt +0 -0
@@ -14,7 +14,9 @@
|
|
14
14
|
|
15
15
|
import re
|
16
16
|
|
17
|
-
|
17
|
+
import torch
|
18
|
+
|
19
|
+
from ..utils import is_peft_version, logging
|
18
20
|
|
19
21
|
|
20
22
|
logger = logging.get_logger(__name__)
|
@@ -123,153 +125,100 @@ def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", b
|
|
123
125
|
return new_state_dict
|
124
126
|
|
125
127
|
|
126
|
-
def
|
128
|
+
def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_name="text_encoder"):
|
129
|
+
"""
|
130
|
+
Converts a non-Diffusers LoRA state dict to a Diffusers compatible state dict.
|
131
|
+
|
132
|
+
Args:
|
133
|
+
state_dict (`dict`): The state dict to convert.
|
134
|
+
unet_name (`str`, optional): The name of the U-Net module in the Diffusers model. Defaults to "unet".
|
135
|
+
text_encoder_name (`str`, optional): The name of the text encoder module in the Diffusers model. Defaults to
|
136
|
+
"text_encoder".
|
137
|
+
|
138
|
+
Returns:
|
139
|
+
`tuple`: A tuple containing the converted state dict and a dictionary of alphas.
|
140
|
+
"""
|
127
141
|
unet_state_dict = {}
|
128
142
|
te_state_dict = {}
|
129
143
|
te2_state_dict = {}
|
130
144
|
network_alphas = {}
|
131
145
|
|
132
|
-
#
|
133
|
-
|
134
|
-
for
|
146
|
+
# Check for DoRA-enabled LoRAs.
|
147
|
+
dora_present_in_unet = any("dora_scale" in k and "lora_unet_" in k for k in state_dict)
|
148
|
+
dora_present_in_te = any("dora_scale" in k and ("lora_te_" in k or "lora_te1_" in k) for k in state_dict)
|
149
|
+
dora_present_in_te2 = any("dora_scale" in k and "lora_te2_" in k for k in state_dict)
|
150
|
+
if dora_present_in_unet or dora_present_in_te or dora_present_in_te2:
|
151
|
+
if is_peft_version("<", "0.9.0"):
|
152
|
+
raise ValueError(
|
153
|
+
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
154
|
+
)
|
155
|
+
|
156
|
+
# Iterate over all LoRA weights.
|
157
|
+
all_lora_keys = list(state_dict.keys())
|
158
|
+
for key in all_lora_keys:
|
159
|
+
if not key.endswith("lora_down.weight"):
|
160
|
+
continue
|
161
|
+
|
162
|
+
# Extract LoRA name.
|
135
163
|
lora_name = key.split(".")[0]
|
164
|
+
|
165
|
+
# Find corresponding up weight and alpha.
|
136
166
|
lora_name_up = lora_name + ".lora_up.weight"
|
137
167
|
lora_name_alpha = lora_name + ".alpha"
|
138
168
|
|
169
|
+
# Handle U-Net LoRAs.
|
139
170
|
if lora_name.startswith("lora_unet_"):
|
140
|
-
diffusers_name = key
|
171
|
+
diffusers_name = _convert_unet_lora_key(key)
|
141
172
|
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
|
173
|
+
# Store down and up weights.
|
174
|
+
unet_state_dict[diffusers_name] = state_dict.pop(key)
|
175
|
+
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
146
176
|
|
147
|
-
if
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
else:
|
154
|
-
diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
|
155
|
-
|
156
|
-
diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
|
157
|
-
diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
|
158
|
-
diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
|
159
|
-
diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
|
160
|
-
diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
|
161
|
-
diffusers_name = diffusers_name.replace("proj.in", "proj_in")
|
162
|
-
diffusers_name = diffusers_name.replace("proj.out", "proj_out")
|
163
|
-
diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj")
|
164
|
-
|
165
|
-
# SDXL specificity.
|
166
|
-
if "emb" in diffusers_name and "time.emb.proj" not in diffusers_name:
|
167
|
-
pattern = r"\.\d+(?=\D*$)"
|
168
|
-
diffusers_name = re.sub(pattern, "", diffusers_name, count=1)
|
169
|
-
if ".in." in diffusers_name:
|
170
|
-
diffusers_name = diffusers_name.replace("in.layers.2", "conv1")
|
171
|
-
if ".out." in diffusers_name:
|
172
|
-
diffusers_name = diffusers_name.replace("out.layers.3", "conv2")
|
173
|
-
if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name:
|
174
|
-
diffusers_name = diffusers_name.replace("op", "conv")
|
175
|
-
if "skip" in diffusers_name:
|
176
|
-
diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut")
|
177
|
-
|
178
|
-
# LyCORIS specificity.
|
179
|
-
if "time.emb.proj" in diffusers_name:
|
180
|
-
diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj")
|
181
|
-
if "conv.shortcut" in diffusers_name:
|
182
|
-
diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut")
|
183
|
-
|
184
|
-
# General coverage.
|
185
|
-
if "transformer_blocks" in diffusers_name:
|
186
|
-
if "attn1" in diffusers_name or "attn2" in diffusers_name:
|
187
|
-
diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
|
188
|
-
diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
|
189
|
-
unet_state_dict[diffusers_name] = state_dict.pop(key)
|
190
|
-
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
191
|
-
elif "ff" in diffusers_name:
|
192
|
-
unet_state_dict[diffusers_name] = state_dict.pop(key)
|
193
|
-
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
194
|
-
elif any(key in diffusers_name for key in ("proj_in", "proj_out")):
|
195
|
-
unet_state_dict[diffusers_name] = state_dict.pop(key)
|
196
|
-
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
197
|
-
else:
|
198
|
-
unet_state_dict[diffusers_name] = state_dict.pop(key)
|
199
|
-
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
200
|
-
|
201
|
-
elif lora_name.startswith("lora_te_"):
|
202
|
-
diffusers_name = key.replace("lora_te_", "").replace("_", ".")
|
203
|
-
diffusers_name = diffusers_name.replace("text.model", "text_model")
|
204
|
-
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
|
205
|
-
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
|
206
|
-
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
|
207
|
-
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
|
208
|
-
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
|
209
|
-
if "self_attn" in diffusers_name:
|
210
|
-
te_state_dict[diffusers_name] = state_dict.pop(key)
|
211
|
-
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
212
|
-
elif "mlp" in diffusers_name:
|
213
|
-
# Be aware that this is the new diffusers convention and the rest of the code might
|
214
|
-
# not utilize it yet.
|
215
|
-
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
|
216
|
-
te_state_dict[diffusers_name] = state_dict.pop(key)
|
217
|
-
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
177
|
+
# Store DoRA scale if present.
|
178
|
+
if dora_present_in_unet:
|
179
|
+
dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down."
|
180
|
+
unet_state_dict[
|
181
|
+
diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")
|
182
|
+
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
218
183
|
|
219
|
-
#
|
220
|
-
elif lora_name.startswith("lora_te1_"):
|
221
|
-
diffusers_name = key
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
|
226
|
-
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
|
227
|
-
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
|
228
|
-
if "self_attn" in diffusers_name:
|
229
|
-
te_state_dict[diffusers_name] = state_dict.pop(key)
|
230
|
-
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
231
|
-
elif "mlp" in diffusers_name:
|
232
|
-
# Be aware that this is the new diffusers convention and the rest of the code might
|
233
|
-
# not utilize it yet.
|
234
|
-
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
|
184
|
+
# Handle text encoder LoRAs.
|
185
|
+
elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
|
186
|
+
diffusers_name = _convert_text_encoder_lora_key(key, lora_name)
|
187
|
+
|
188
|
+
# Store down and up weights for te or te2.
|
189
|
+
if lora_name.startswith(("lora_te_", "lora_te1_")):
|
235
190
|
te_state_dict[diffusers_name] = state_dict.pop(key)
|
236
191
|
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
237
|
-
|
238
|
-
# (sayakpaul): Duplicate code. Needs to be cleaned.
|
239
|
-
elif lora_name.startswith("lora_te2_"):
|
240
|
-
diffusers_name = key.replace("lora_te2_", "").replace("_", ".")
|
241
|
-
diffusers_name = diffusers_name.replace("text.model", "text_model")
|
242
|
-
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
|
243
|
-
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
|
244
|
-
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
|
245
|
-
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
|
246
|
-
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
|
247
|
-
if "self_attn" in diffusers_name:
|
248
|
-
te2_state_dict[diffusers_name] = state_dict.pop(key)
|
249
|
-
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
250
|
-
elif "mlp" in diffusers_name:
|
251
|
-
# Be aware that this is the new diffusers convention and the rest of the code might
|
252
|
-
# not utilize it yet.
|
253
|
-
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
|
192
|
+
else:
|
254
193
|
te2_state_dict[diffusers_name] = state_dict.pop(key)
|
255
194
|
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
256
195
|
|
257
|
-
|
196
|
+
# Store DoRA scale if present.
|
197
|
+
if dora_present_in_te or dora_present_in_te2:
|
198
|
+
dora_scale_key_to_replace_te = (
|
199
|
+
"_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer."
|
200
|
+
)
|
201
|
+
if lora_name.startswith(("lora_te_", "lora_te1_")):
|
202
|
+
te_state_dict[
|
203
|
+
diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
|
204
|
+
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
205
|
+
elif lora_name.startswith("lora_te2_"):
|
206
|
+
te2_state_dict[
|
207
|
+
diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
|
208
|
+
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
209
|
+
|
210
|
+
# Store alpha if present.
|
258
211
|
if lora_name_alpha in state_dict:
|
259
212
|
alpha = state_dict.pop(lora_name_alpha).item()
|
260
|
-
|
261
|
-
prefix = "unet."
|
262
|
-
elif lora_name_alpha.startswith(("lora_te_", "lora_te1_")):
|
263
|
-
prefix = "text_encoder."
|
264
|
-
else:
|
265
|
-
prefix = "text_encoder_2."
|
266
|
-
new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha"
|
267
|
-
network_alphas.update({new_name: alpha})
|
213
|
+
network_alphas.update(_get_alpha_name(lora_name_alpha, diffusers_name, alpha))
|
268
214
|
|
215
|
+
# Check if any keys remain.
|
269
216
|
if len(state_dict) > 0:
|
270
|
-
raise ValueError(f"The following keys have not been correctly
|
217
|
+
raise ValueError(f"The following keys have not been correctly renamed: \n\n {', '.join(state_dict.keys())}")
|
218
|
+
|
219
|
+
logger.info("Non-diffusers checkpoint detected.")
|
271
220
|
|
272
|
-
|
221
|
+
# Construct final state dict.
|
273
222
|
unet_state_dict = {f"{unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()}
|
274
223
|
te_state_dict = {f"{text_encoder_name}.{module_name}": params for module_name, params in te_state_dict.items()}
|
275
224
|
te2_state_dict = (
|
@@ -282,3 +231,920 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
|
|
282
231
|
|
283
232
|
new_state_dict = {**unet_state_dict, **te_state_dict}
|
284
233
|
return new_state_dict, network_alphas
|
234
|
+
|
235
|
+
|
236
|
+
def _convert_unet_lora_key(key):
|
237
|
+
"""
|
238
|
+
Converts a U-Net LoRA key to a Diffusers compatible key.
|
239
|
+
"""
|
240
|
+
diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
|
241
|
+
|
242
|
+
# Replace common U-Net naming patterns.
|
243
|
+
diffusers_name = diffusers_name.replace("input.blocks", "down_blocks")
|
244
|
+
diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
|
245
|
+
diffusers_name = diffusers_name.replace("middle.block", "mid_block")
|
246
|
+
diffusers_name = diffusers_name.replace("mid.block", "mid_block")
|
247
|
+
diffusers_name = diffusers_name.replace("output.blocks", "up_blocks")
|
248
|
+
diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
|
249
|
+
diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
|
250
|
+
diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
|
251
|
+
diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
|
252
|
+
diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
|
253
|
+
diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
|
254
|
+
diffusers_name = diffusers_name.replace("proj.in", "proj_in")
|
255
|
+
diffusers_name = diffusers_name.replace("proj.out", "proj_out")
|
256
|
+
diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj")
|
257
|
+
|
258
|
+
# SDXL specific conversions.
|
259
|
+
if "emb" in diffusers_name and "time.emb.proj" not in diffusers_name:
|
260
|
+
pattern = r"\.\d+(?=\D*$)"
|
261
|
+
diffusers_name = re.sub(pattern, "", diffusers_name, count=1)
|
262
|
+
if ".in." in diffusers_name:
|
263
|
+
diffusers_name = diffusers_name.replace("in.layers.2", "conv1")
|
264
|
+
if ".out." in diffusers_name:
|
265
|
+
diffusers_name = diffusers_name.replace("out.layers.3", "conv2")
|
266
|
+
if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name:
|
267
|
+
diffusers_name = diffusers_name.replace("op", "conv")
|
268
|
+
if "skip" in diffusers_name:
|
269
|
+
diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut")
|
270
|
+
|
271
|
+
# LyCORIS specific conversions.
|
272
|
+
if "time.emb.proj" in diffusers_name:
|
273
|
+
diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj")
|
274
|
+
if "conv.shortcut" in diffusers_name:
|
275
|
+
diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut")
|
276
|
+
|
277
|
+
# General conversions.
|
278
|
+
if "transformer_blocks" in diffusers_name:
|
279
|
+
if "attn1" in diffusers_name or "attn2" in diffusers_name:
|
280
|
+
diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
|
281
|
+
diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
|
282
|
+
elif "ff" in diffusers_name:
|
283
|
+
pass
|
284
|
+
elif any(key in diffusers_name for key in ("proj_in", "proj_out")):
|
285
|
+
pass
|
286
|
+
else:
|
287
|
+
pass
|
288
|
+
|
289
|
+
return diffusers_name
|
290
|
+
|
291
|
+
|
292
|
+
def _convert_text_encoder_lora_key(key, lora_name):
|
293
|
+
"""
|
294
|
+
Converts a text encoder LoRA key to a Diffusers compatible key.
|
295
|
+
"""
|
296
|
+
if lora_name.startswith(("lora_te_", "lora_te1_")):
|
297
|
+
key_to_replace = "lora_te_" if lora_name.startswith("lora_te_") else "lora_te1_"
|
298
|
+
else:
|
299
|
+
key_to_replace = "lora_te2_"
|
300
|
+
|
301
|
+
diffusers_name = key.replace(key_to_replace, "").replace("_", ".")
|
302
|
+
diffusers_name = diffusers_name.replace("text.model", "text_model")
|
303
|
+
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
|
304
|
+
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
|
305
|
+
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
|
306
|
+
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
|
307
|
+
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
|
308
|
+
diffusers_name = diffusers_name.replace("text.projection", "text_projection")
|
309
|
+
|
310
|
+
if "self_attn" in diffusers_name or "text_projection" in diffusers_name:
|
311
|
+
pass
|
312
|
+
elif "mlp" in diffusers_name:
|
313
|
+
# Be aware that this is the new diffusers convention and the rest of the code might
|
314
|
+
# not utilize it yet.
|
315
|
+
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
|
316
|
+
return diffusers_name
|
317
|
+
|
318
|
+
|
319
|
+
def _get_alpha_name(lora_name_alpha, diffusers_name, alpha):
|
320
|
+
"""
|
321
|
+
Gets the correct alpha name for the Diffusers model.
|
322
|
+
"""
|
323
|
+
if lora_name_alpha.startswith("lora_unet_"):
|
324
|
+
prefix = "unet."
|
325
|
+
elif lora_name_alpha.startswith(("lora_te_", "lora_te1_")):
|
326
|
+
prefix = "text_encoder."
|
327
|
+
else:
|
328
|
+
prefix = "text_encoder_2."
|
329
|
+
new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha"
|
330
|
+
return {new_name: alpha}
|
331
|
+
|
332
|
+
|
333
|
+
# The utilities under `_convert_kohya_flux_lora_to_diffusers()`
|
334
|
+
# are taken from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
|
335
|
+
# All credits go to `kohya-ss`.
|
336
|
+
def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
337
|
+
def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
|
338
|
+
if sds_key + ".lora_down.weight" not in sds_sd:
|
339
|
+
return
|
340
|
+
down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
|
341
|
+
|
342
|
+
# scale weight by alpha and dim
|
343
|
+
rank = down_weight.shape[0]
|
344
|
+
alpha = sds_sd.pop(sds_key + ".alpha").item() # alpha is scalar
|
345
|
+
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
|
346
|
+
|
347
|
+
# calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2
|
348
|
+
scale_down = scale
|
349
|
+
scale_up = 1.0
|
350
|
+
while scale_down * 2 < scale_up:
|
351
|
+
scale_down *= 2
|
352
|
+
scale_up /= 2
|
353
|
+
|
354
|
+
ait_sd[ait_key + ".lora_A.weight"] = down_weight * scale_down
|
355
|
+
ait_sd[ait_key + ".lora_B.weight"] = sds_sd.pop(sds_key + ".lora_up.weight") * scale_up
|
356
|
+
|
357
|
+
def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
|
358
|
+
if sds_key + ".lora_down.weight" not in sds_sd:
|
359
|
+
return
|
360
|
+
down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
|
361
|
+
up_weight = sds_sd.pop(sds_key + ".lora_up.weight")
|
362
|
+
sd_lora_rank = down_weight.shape[0]
|
363
|
+
|
364
|
+
# scale weight by alpha and dim
|
365
|
+
alpha = sds_sd.pop(sds_key + ".alpha")
|
366
|
+
scale = alpha / sd_lora_rank
|
367
|
+
|
368
|
+
# calculate scale_down and scale_up
|
369
|
+
scale_down = scale
|
370
|
+
scale_up = 1.0
|
371
|
+
while scale_down * 2 < scale_up:
|
372
|
+
scale_down *= 2
|
373
|
+
scale_up /= 2
|
374
|
+
|
375
|
+
down_weight = down_weight * scale_down
|
376
|
+
up_weight = up_weight * scale_up
|
377
|
+
|
378
|
+
# calculate dims if not provided
|
379
|
+
num_splits = len(ait_keys)
|
380
|
+
if dims is None:
|
381
|
+
dims = [up_weight.shape[0] // num_splits] * num_splits
|
382
|
+
else:
|
383
|
+
assert sum(dims) == up_weight.shape[0]
|
384
|
+
|
385
|
+
# check upweight is sparse or not
|
386
|
+
is_sparse = False
|
387
|
+
if sd_lora_rank % num_splits == 0:
|
388
|
+
ait_rank = sd_lora_rank // num_splits
|
389
|
+
is_sparse = True
|
390
|
+
i = 0
|
391
|
+
for j in range(len(dims)):
|
392
|
+
for k in range(len(dims)):
|
393
|
+
if j == k:
|
394
|
+
continue
|
395
|
+
is_sparse = is_sparse and torch.all(
|
396
|
+
up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0
|
397
|
+
)
|
398
|
+
i += dims[j]
|
399
|
+
if is_sparse:
|
400
|
+
logger.info(f"weight is sparse: {sds_key}")
|
401
|
+
|
402
|
+
# make ai-toolkit weight
|
403
|
+
ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
|
404
|
+
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
|
405
|
+
if not is_sparse:
|
406
|
+
# down_weight is copied to each split
|
407
|
+
ait_sd.update({k: down_weight for k in ait_down_keys})
|
408
|
+
|
409
|
+
# up_weight is split to each split
|
410
|
+
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
|
411
|
+
else:
|
412
|
+
# down_weight is chunked to each split
|
413
|
+
ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))}) # noqa: C416
|
414
|
+
|
415
|
+
# up_weight is sparse: only non-zero values are copied to each split
|
416
|
+
i = 0
|
417
|
+
for j in range(len(dims)):
|
418
|
+
ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous()
|
419
|
+
i += dims[j]
|
420
|
+
|
421
|
+
def _convert_sd_scripts_to_ai_toolkit(sds_sd):
|
422
|
+
ait_sd = {}
|
423
|
+
for i in range(19):
|
424
|
+
_convert_to_ai_toolkit(
|
425
|
+
sds_sd,
|
426
|
+
ait_sd,
|
427
|
+
f"lora_unet_double_blocks_{i}_img_attn_proj",
|
428
|
+
f"transformer.transformer_blocks.{i}.attn.to_out.0",
|
429
|
+
)
|
430
|
+
_convert_to_ai_toolkit_cat(
|
431
|
+
sds_sd,
|
432
|
+
ait_sd,
|
433
|
+
f"lora_unet_double_blocks_{i}_img_attn_qkv",
|
434
|
+
[
|
435
|
+
f"transformer.transformer_blocks.{i}.attn.to_q",
|
436
|
+
f"transformer.transformer_blocks.{i}.attn.to_k",
|
437
|
+
f"transformer.transformer_blocks.{i}.attn.to_v",
|
438
|
+
],
|
439
|
+
)
|
440
|
+
_convert_to_ai_toolkit(
|
441
|
+
sds_sd,
|
442
|
+
ait_sd,
|
443
|
+
f"lora_unet_double_blocks_{i}_img_mlp_0",
|
444
|
+
f"transformer.transformer_blocks.{i}.ff.net.0.proj",
|
445
|
+
)
|
446
|
+
_convert_to_ai_toolkit(
|
447
|
+
sds_sd,
|
448
|
+
ait_sd,
|
449
|
+
f"lora_unet_double_blocks_{i}_img_mlp_2",
|
450
|
+
f"transformer.transformer_blocks.{i}.ff.net.2",
|
451
|
+
)
|
452
|
+
_convert_to_ai_toolkit(
|
453
|
+
sds_sd,
|
454
|
+
ait_sd,
|
455
|
+
f"lora_unet_double_blocks_{i}_img_mod_lin",
|
456
|
+
f"transformer.transformer_blocks.{i}.norm1.linear",
|
457
|
+
)
|
458
|
+
_convert_to_ai_toolkit(
|
459
|
+
sds_sd,
|
460
|
+
ait_sd,
|
461
|
+
f"lora_unet_double_blocks_{i}_txt_attn_proj",
|
462
|
+
f"transformer.transformer_blocks.{i}.attn.to_add_out",
|
463
|
+
)
|
464
|
+
_convert_to_ai_toolkit_cat(
|
465
|
+
sds_sd,
|
466
|
+
ait_sd,
|
467
|
+
f"lora_unet_double_blocks_{i}_txt_attn_qkv",
|
468
|
+
[
|
469
|
+
f"transformer.transformer_blocks.{i}.attn.add_q_proj",
|
470
|
+
f"transformer.transformer_blocks.{i}.attn.add_k_proj",
|
471
|
+
f"transformer.transformer_blocks.{i}.attn.add_v_proj",
|
472
|
+
],
|
473
|
+
)
|
474
|
+
_convert_to_ai_toolkit(
|
475
|
+
sds_sd,
|
476
|
+
ait_sd,
|
477
|
+
f"lora_unet_double_blocks_{i}_txt_mlp_0",
|
478
|
+
f"transformer.transformer_blocks.{i}.ff_context.net.0.proj",
|
479
|
+
)
|
480
|
+
_convert_to_ai_toolkit(
|
481
|
+
sds_sd,
|
482
|
+
ait_sd,
|
483
|
+
f"lora_unet_double_blocks_{i}_txt_mlp_2",
|
484
|
+
f"transformer.transformer_blocks.{i}.ff_context.net.2",
|
485
|
+
)
|
486
|
+
_convert_to_ai_toolkit(
|
487
|
+
sds_sd,
|
488
|
+
ait_sd,
|
489
|
+
f"lora_unet_double_blocks_{i}_txt_mod_lin",
|
490
|
+
f"transformer.transformer_blocks.{i}.norm1_context.linear",
|
491
|
+
)
|
492
|
+
|
493
|
+
for i in range(38):
|
494
|
+
_convert_to_ai_toolkit_cat(
|
495
|
+
sds_sd,
|
496
|
+
ait_sd,
|
497
|
+
f"lora_unet_single_blocks_{i}_linear1",
|
498
|
+
[
|
499
|
+
f"transformer.single_transformer_blocks.{i}.attn.to_q",
|
500
|
+
f"transformer.single_transformer_blocks.{i}.attn.to_k",
|
501
|
+
f"transformer.single_transformer_blocks.{i}.attn.to_v",
|
502
|
+
f"transformer.single_transformer_blocks.{i}.proj_mlp",
|
503
|
+
],
|
504
|
+
dims=[3072, 3072, 3072, 12288],
|
505
|
+
)
|
506
|
+
_convert_to_ai_toolkit(
|
507
|
+
sds_sd,
|
508
|
+
ait_sd,
|
509
|
+
f"lora_unet_single_blocks_{i}_linear2",
|
510
|
+
f"transformer.single_transformer_blocks.{i}.proj_out",
|
511
|
+
)
|
512
|
+
_convert_to_ai_toolkit(
|
513
|
+
sds_sd,
|
514
|
+
ait_sd,
|
515
|
+
f"lora_unet_single_blocks_{i}_modulation_lin",
|
516
|
+
f"transformer.single_transformer_blocks.{i}.norm.linear",
|
517
|
+
)
|
518
|
+
|
519
|
+
remaining_keys = list(sds_sd.keys())
|
520
|
+
te_state_dict = {}
|
521
|
+
if remaining_keys:
|
522
|
+
if not all(k.startswith("lora_te1") for k in remaining_keys):
|
523
|
+
raise ValueError(f"Incompatible keys detected: \n\n {', '.join(remaining_keys)}")
|
524
|
+
for key in remaining_keys:
|
525
|
+
if not key.endswith("lora_down.weight"):
|
526
|
+
continue
|
527
|
+
|
528
|
+
lora_name = key.split(".")[0]
|
529
|
+
lora_name_up = f"{lora_name}.lora_up.weight"
|
530
|
+
lora_name_alpha = f"{lora_name}.alpha"
|
531
|
+
diffusers_name = _convert_text_encoder_lora_key(key, lora_name)
|
532
|
+
|
533
|
+
if lora_name.startswith(("lora_te_", "lora_te1_")):
|
534
|
+
down_weight = sds_sd.pop(key)
|
535
|
+
sd_lora_rank = down_weight.shape[0]
|
536
|
+
te_state_dict[diffusers_name] = down_weight
|
537
|
+
te_state_dict[diffusers_name.replace(".down.", ".up.")] = sds_sd.pop(lora_name_up)
|
538
|
+
|
539
|
+
if lora_name_alpha in sds_sd:
|
540
|
+
alpha = sds_sd.pop(lora_name_alpha).item()
|
541
|
+
scale = alpha / sd_lora_rank
|
542
|
+
|
543
|
+
scale_down = scale
|
544
|
+
scale_up = 1.0
|
545
|
+
while scale_down * 2 < scale_up:
|
546
|
+
scale_down *= 2
|
547
|
+
scale_up /= 2
|
548
|
+
|
549
|
+
te_state_dict[diffusers_name] *= scale_down
|
550
|
+
te_state_dict[diffusers_name.replace(".down.", ".up.")] *= scale_up
|
551
|
+
|
552
|
+
if len(sds_sd) > 0:
|
553
|
+
logger.warning(f"Unsupported keys for ai-toolkit: {sds_sd.keys()}")
|
554
|
+
|
555
|
+
if te_state_dict:
|
556
|
+
te_state_dict = {f"text_encoder.{module_name}": params for module_name, params in te_state_dict.items()}
|
557
|
+
|
558
|
+
new_state_dict = {**ait_sd, **te_state_dict}
|
559
|
+
return new_state_dict
|
560
|
+
|
561
|
+
return _convert_sd_scripts_to_ai_toolkit(state_dict)
|
562
|
+
|
563
|
+
|
564
|
+
# Adapted from https://gist.github.com/Leommm-byte/6b331a1e9bd53271210b26543a7065d6
|
565
|
+
# Some utilities were reused from
|
566
|
+
# https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
|
567
|
+
def _convert_xlabs_flux_lora_to_diffusers(old_state_dict):
|
568
|
+
new_state_dict = {}
|
569
|
+
orig_keys = list(old_state_dict.keys())
|
570
|
+
|
571
|
+
def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
|
572
|
+
down_weight = sds_sd.pop(sds_key)
|
573
|
+
up_weight = sds_sd.pop(sds_key.replace(".down.weight", ".up.weight"))
|
574
|
+
|
575
|
+
# calculate dims if not provided
|
576
|
+
num_splits = len(ait_keys)
|
577
|
+
if dims is None:
|
578
|
+
dims = [up_weight.shape[0] // num_splits] * num_splits
|
579
|
+
else:
|
580
|
+
assert sum(dims) == up_weight.shape[0]
|
581
|
+
|
582
|
+
# make ai-toolkit weight
|
583
|
+
ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
|
584
|
+
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
|
585
|
+
|
586
|
+
# down_weight is copied to each split
|
587
|
+
ait_sd.update({k: down_weight for k in ait_down_keys})
|
588
|
+
|
589
|
+
# up_weight is split to each split
|
590
|
+
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
|
591
|
+
|
592
|
+
for old_key in orig_keys:
|
593
|
+
# Handle double_blocks
|
594
|
+
if old_key.startswith(("diffusion_model.double_blocks", "double_blocks")):
|
595
|
+
block_num = re.search(r"double_blocks\.(\d+)", old_key).group(1)
|
596
|
+
new_key = f"transformer.transformer_blocks.{block_num}"
|
597
|
+
|
598
|
+
if "processor.proj_lora1" in old_key:
|
599
|
+
new_key += ".attn.to_out.0"
|
600
|
+
elif "processor.proj_lora2" in old_key:
|
601
|
+
new_key += ".attn.to_add_out"
|
602
|
+
# Handle text latents.
|
603
|
+
elif "processor.qkv_lora2" in old_key and "up" not in old_key:
|
604
|
+
handle_qkv(
|
605
|
+
old_state_dict,
|
606
|
+
new_state_dict,
|
607
|
+
old_key,
|
608
|
+
[
|
609
|
+
f"transformer.transformer_blocks.{block_num}.attn.add_q_proj",
|
610
|
+
f"transformer.transformer_blocks.{block_num}.attn.add_k_proj",
|
611
|
+
f"transformer.transformer_blocks.{block_num}.attn.add_v_proj",
|
612
|
+
],
|
613
|
+
)
|
614
|
+
# continue
|
615
|
+
# Handle image latents.
|
616
|
+
elif "processor.qkv_lora1" in old_key and "up" not in old_key:
|
617
|
+
handle_qkv(
|
618
|
+
old_state_dict,
|
619
|
+
new_state_dict,
|
620
|
+
old_key,
|
621
|
+
[
|
622
|
+
f"transformer.transformer_blocks.{block_num}.attn.to_q",
|
623
|
+
f"transformer.transformer_blocks.{block_num}.attn.to_k",
|
624
|
+
f"transformer.transformer_blocks.{block_num}.attn.to_v",
|
625
|
+
],
|
626
|
+
)
|
627
|
+
# continue
|
628
|
+
|
629
|
+
if "down" in old_key:
|
630
|
+
new_key += ".lora_A.weight"
|
631
|
+
elif "up" in old_key:
|
632
|
+
new_key += ".lora_B.weight"
|
633
|
+
|
634
|
+
# Handle single_blocks
|
635
|
+
elif old_key.startswith(("diffusion_model.single_blocks", "single_blocks")):
|
636
|
+
block_num = re.search(r"single_blocks\.(\d+)", old_key).group(1)
|
637
|
+
new_key = f"transformer.single_transformer_blocks.{block_num}"
|
638
|
+
|
639
|
+
if "proj_lora" in old_key:
|
640
|
+
new_key += ".proj_out"
|
641
|
+
elif "qkv_lora" in old_key and "up" not in old_key:
|
642
|
+
handle_qkv(
|
643
|
+
old_state_dict,
|
644
|
+
new_state_dict,
|
645
|
+
old_key,
|
646
|
+
[
|
647
|
+
f"transformer.single_transformer_blocks.{block_num}.attn.to_q",
|
648
|
+
f"transformer.single_transformer_blocks.{block_num}.attn.to_k",
|
649
|
+
f"transformer.single_transformer_blocks.{block_num}.attn.to_v",
|
650
|
+
],
|
651
|
+
)
|
652
|
+
|
653
|
+
if "down" in old_key:
|
654
|
+
new_key += ".lora_A.weight"
|
655
|
+
elif "up" in old_key:
|
656
|
+
new_key += ".lora_B.weight"
|
657
|
+
|
658
|
+
else:
|
659
|
+
# Handle other potential key patterns here
|
660
|
+
new_key = old_key
|
661
|
+
|
662
|
+
# Since we already handle qkv above.
|
663
|
+
if "qkv" not in old_key:
|
664
|
+
new_state_dict[new_key] = old_state_dict.pop(old_key)
|
665
|
+
|
666
|
+
if len(old_state_dict) > 0:
|
667
|
+
raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.")
|
668
|
+
|
669
|
+
return new_state_dict
|
670
|
+
|
671
|
+
|
672
|
+
def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
|
673
|
+
converted_state_dict = {}
|
674
|
+
original_state_dict_keys = list(original_state_dict.keys())
|
675
|
+
num_layers = 19
|
676
|
+
num_single_layers = 38
|
677
|
+
inner_dim = 3072
|
678
|
+
mlp_ratio = 4.0
|
679
|
+
|
680
|
+
def swap_scale_shift(weight):
|
681
|
+
shift, scale = weight.chunk(2, dim=0)
|
682
|
+
new_weight = torch.cat([scale, shift], dim=0)
|
683
|
+
return new_weight
|
684
|
+
|
685
|
+
for lora_key in ["lora_A", "lora_B"]:
|
686
|
+
## time_text_embed.timestep_embedder <- time_in
|
687
|
+
converted_state_dict[
|
688
|
+
f"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight"
|
689
|
+
] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.weight")
|
690
|
+
if f"time_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
|
691
|
+
converted_state_dict[
|
692
|
+
f"time_text_embed.timestep_embedder.linear_1.{lora_key}.bias"
|
693
|
+
] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.bias")
|
694
|
+
|
695
|
+
converted_state_dict[
|
696
|
+
f"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight"
|
697
|
+
] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.weight")
|
698
|
+
if f"time_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
|
699
|
+
converted_state_dict[
|
700
|
+
f"time_text_embed.timestep_embedder.linear_2.{lora_key}.bias"
|
701
|
+
] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.bias")
|
702
|
+
|
703
|
+
## time_text_embed.text_embedder <- vector_in
|
704
|
+
converted_state_dict[f"time_text_embed.text_embedder.linear_1.{lora_key}.weight"] = original_state_dict.pop(
|
705
|
+
f"vector_in.in_layer.{lora_key}.weight"
|
706
|
+
)
|
707
|
+
if f"vector_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
|
708
|
+
converted_state_dict[f"time_text_embed.text_embedder.linear_1.{lora_key}.bias"] = original_state_dict.pop(
|
709
|
+
f"vector_in.in_layer.{lora_key}.bias"
|
710
|
+
)
|
711
|
+
|
712
|
+
converted_state_dict[f"time_text_embed.text_embedder.linear_2.{lora_key}.weight"] = original_state_dict.pop(
|
713
|
+
f"vector_in.out_layer.{lora_key}.weight"
|
714
|
+
)
|
715
|
+
if f"vector_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
|
716
|
+
converted_state_dict[f"time_text_embed.text_embedder.linear_2.{lora_key}.bias"] = original_state_dict.pop(
|
717
|
+
f"vector_in.out_layer.{lora_key}.bias"
|
718
|
+
)
|
719
|
+
|
720
|
+
# guidance
|
721
|
+
has_guidance = any("guidance" in k for k in original_state_dict)
|
722
|
+
if has_guidance:
|
723
|
+
converted_state_dict[
|
724
|
+
f"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight"
|
725
|
+
] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.weight")
|
726
|
+
if f"guidance_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
|
727
|
+
converted_state_dict[
|
728
|
+
f"time_text_embed.guidance_embedder.linear_1.{lora_key}.bias"
|
729
|
+
] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.bias")
|
730
|
+
|
731
|
+
converted_state_dict[
|
732
|
+
f"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight"
|
733
|
+
] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.weight")
|
734
|
+
if f"guidance_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
|
735
|
+
converted_state_dict[
|
736
|
+
f"time_text_embed.guidance_embedder.linear_2.{lora_key}.bias"
|
737
|
+
] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.bias")
|
738
|
+
|
739
|
+
# context_embedder
|
740
|
+
converted_state_dict[f"context_embedder.{lora_key}.weight"] = original_state_dict.pop(
|
741
|
+
f"txt_in.{lora_key}.weight"
|
742
|
+
)
|
743
|
+
if f"txt_in.{lora_key}.bias" in original_state_dict_keys:
|
744
|
+
converted_state_dict[f"context_embedder.{lora_key}.bias"] = original_state_dict.pop(
|
745
|
+
f"txt_in.{lora_key}.bias"
|
746
|
+
)
|
747
|
+
|
748
|
+
# x_embedder
|
749
|
+
converted_state_dict[f"x_embedder.{lora_key}.weight"] = original_state_dict.pop(f"img_in.{lora_key}.weight")
|
750
|
+
if f"img_in.{lora_key}.bias" in original_state_dict_keys:
|
751
|
+
converted_state_dict[f"x_embedder.{lora_key}.bias"] = original_state_dict.pop(f"img_in.{lora_key}.bias")
|
752
|
+
|
753
|
+
# double transformer blocks
|
754
|
+
for i in range(num_layers):
|
755
|
+
block_prefix = f"transformer_blocks.{i}."
|
756
|
+
|
757
|
+
for lora_key in ["lora_A", "lora_B"]:
|
758
|
+
# norms
|
759
|
+
converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.weight"] = original_state_dict.pop(
|
760
|
+
f"double_blocks.{i}.img_mod.lin.{lora_key}.weight"
|
761
|
+
)
|
762
|
+
if f"double_blocks.{i}.img_mod.lin.{lora_key}.bias" in original_state_dict_keys:
|
763
|
+
converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.bias"] = original_state_dict.pop(
|
764
|
+
f"double_blocks.{i}.img_mod.lin.{lora_key}.bias"
|
765
|
+
)
|
766
|
+
|
767
|
+
converted_state_dict[f"{block_prefix}norm1_context.linear.{lora_key}.weight"] = original_state_dict.pop(
|
768
|
+
f"double_blocks.{i}.txt_mod.lin.{lora_key}.weight"
|
769
|
+
)
|
770
|
+
if f"double_blocks.{i}.txt_mod.lin.{lora_key}.bias" in original_state_dict_keys:
|
771
|
+
converted_state_dict[f"{block_prefix}norm1_context.linear.{lora_key}.bias"] = original_state_dict.pop(
|
772
|
+
f"double_blocks.{i}.txt_mod.lin.{lora_key}.bias"
|
773
|
+
)
|
774
|
+
|
775
|
+
# Q, K, V
|
776
|
+
if lora_key == "lora_A":
|
777
|
+
sample_lora_weight = original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.weight")
|
778
|
+
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_lora_weight])
|
779
|
+
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_lora_weight])
|
780
|
+
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_lora_weight])
|
781
|
+
|
782
|
+
context_lora_weight = original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.weight")
|
783
|
+
converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat(
|
784
|
+
[context_lora_weight]
|
785
|
+
)
|
786
|
+
converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat(
|
787
|
+
[context_lora_weight]
|
788
|
+
)
|
789
|
+
converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat(
|
790
|
+
[context_lora_weight]
|
791
|
+
)
|
792
|
+
else:
|
793
|
+
sample_q, sample_k, sample_v = torch.chunk(
|
794
|
+
original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.weight"), 3, dim=0
|
795
|
+
)
|
796
|
+
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_q])
|
797
|
+
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_k])
|
798
|
+
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_v])
|
799
|
+
|
800
|
+
context_q, context_k, context_v = torch.chunk(
|
801
|
+
original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.weight"), 3, dim=0
|
802
|
+
)
|
803
|
+
converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat([context_q])
|
804
|
+
converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat([context_k])
|
805
|
+
converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat([context_v])
|
806
|
+
|
807
|
+
if f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias" in original_state_dict_keys:
|
808
|
+
sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
|
809
|
+
original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias"), 3, dim=0
|
810
|
+
)
|
811
|
+
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([sample_q_bias])
|
812
|
+
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([sample_k_bias])
|
813
|
+
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([sample_v_bias])
|
814
|
+
|
815
|
+
if f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias" in original_state_dict_keys:
|
816
|
+
context_q_bias, context_k_bias, context_v_bias = torch.chunk(
|
817
|
+
original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias"), 3, dim=0
|
818
|
+
)
|
819
|
+
converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.bias"] = torch.cat([context_q_bias])
|
820
|
+
converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.bias"] = torch.cat([context_k_bias])
|
821
|
+
converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.bias"] = torch.cat([context_v_bias])
|
822
|
+
|
823
|
+
# ff img_mlp
|
824
|
+
converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.weight"] = original_state_dict.pop(
|
825
|
+
f"double_blocks.{i}.img_mlp.0.{lora_key}.weight"
|
826
|
+
)
|
827
|
+
if f"double_blocks.{i}.img_mlp.0.{lora_key}.bias" in original_state_dict_keys:
|
828
|
+
converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.bias"] = original_state_dict.pop(
|
829
|
+
f"double_blocks.{i}.img_mlp.0.{lora_key}.bias"
|
830
|
+
)
|
831
|
+
|
832
|
+
converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.weight"] = original_state_dict.pop(
|
833
|
+
f"double_blocks.{i}.img_mlp.2.{lora_key}.weight"
|
834
|
+
)
|
835
|
+
if f"double_blocks.{i}.img_mlp.2.{lora_key}.bias" in original_state_dict_keys:
|
836
|
+
converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.bias"] = original_state_dict.pop(
|
837
|
+
f"double_blocks.{i}.img_mlp.2.{lora_key}.bias"
|
838
|
+
)
|
839
|
+
|
840
|
+
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.weight"] = original_state_dict.pop(
|
841
|
+
f"double_blocks.{i}.txt_mlp.0.{lora_key}.weight"
|
842
|
+
)
|
843
|
+
if f"double_blocks.{i}.txt_mlp.0.{lora_key}.bias" in original_state_dict_keys:
|
844
|
+
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.bias"] = original_state_dict.pop(
|
845
|
+
f"double_blocks.{i}.txt_mlp.0.{lora_key}.bias"
|
846
|
+
)
|
847
|
+
|
848
|
+
converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.weight"] = original_state_dict.pop(
|
849
|
+
f"double_blocks.{i}.txt_mlp.2.{lora_key}.weight"
|
850
|
+
)
|
851
|
+
if f"double_blocks.{i}.txt_mlp.2.{lora_key}.bias" in original_state_dict_keys:
|
852
|
+
converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.bias"] = original_state_dict.pop(
|
853
|
+
f"double_blocks.{i}.txt_mlp.2.{lora_key}.bias"
|
854
|
+
)
|
855
|
+
|
856
|
+
# output projections.
|
857
|
+
converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.weight"] = original_state_dict.pop(
|
858
|
+
f"double_blocks.{i}.img_attn.proj.{lora_key}.weight"
|
859
|
+
)
|
860
|
+
if f"double_blocks.{i}.img_attn.proj.{lora_key}.bias" in original_state_dict_keys:
|
861
|
+
converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.bias"] = original_state_dict.pop(
|
862
|
+
f"double_blocks.{i}.img_attn.proj.{lora_key}.bias"
|
863
|
+
)
|
864
|
+
converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.weight"] = original_state_dict.pop(
|
865
|
+
f"double_blocks.{i}.txt_attn.proj.{lora_key}.weight"
|
866
|
+
)
|
867
|
+
if f"double_blocks.{i}.txt_attn.proj.{lora_key}.bias" in original_state_dict_keys:
|
868
|
+
converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.bias"] = original_state_dict.pop(
|
869
|
+
f"double_blocks.{i}.txt_attn.proj.{lora_key}.bias"
|
870
|
+
)
|
871
|
+
|
872
|
+
# qk_norm
|
873
|
+
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
|
874
|
+
f"double_blocks.{i}.img_attn.norm.query_norm.scale"
|
875
|
+
)
|
876
|
+
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
|
877
|
+
f"double_blocks.{i}.img_attn.norm.key_norm.scale"
|
878
|
+
)
|
879
|
+
converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = original_state_dict.pop(
|
880
|
+
f"double_blocks.{i}.txt_attn.norm.query_norm.scale"
|
881
|
+
)
|
882
|
+
converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = original_state_dict.pop(
|
883
|
+
f"double_blocks.{i}.txt_attn.norm.key_norm.scale"
|
884
|
+
)
|
885
|
+
|
886
|
+
# single transfomer blocks
|
887
|
+
for i in range(num_single_layers):
|
888
|
+
block_prefix = f"single_transformer_blocks.{i}."
|
889
|
+
|
890
|
+
for lora_key in ["lora_A", "lora_B"]:
|
891
|
+
# norm.linear <- single_blocks.0.modulation.lin
|
892
|
+
converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.weight"] = original_state_dict.pop(
|
893
|
+
f"single_blocks.{i}.modulation.lin.{lora_key}.weight"
|
894
|
+
)
|
895
|
+
if f"single_blocks.{i}.modulation.lin.{lora_key}.bias" in original_state_dict_keys:
|
896
|
+
converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.bias"] = original_state_dict.pop(
|
897
|
+
f"single_blocks.{i}.modulation.lin.{lora_key}.bias"
|
898
|
+
)
|
899
|
+
|
900
|
+
# Q, K, V, mlp
|
901
|
+
mlp_hidden_dim = int(inner_dim * mlp_ratio)
|
902
|
+
split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
|
903
|
+
|
904
|
+
if lora_key == "lora_A":
|
905
|
+
lora_weight = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.weight")
|
906
|
+
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([lora_weight])
|
907
|
+
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([lora_weight])
|
908
|
+
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([lora_weight])
|
909
|
+
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([lora_weight])
|
910
|
+
|
911
|
+
if f"single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys:
|
912
|
+
lora_bias = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias")
|
913
|
+
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([lora_bias])
|
914
|
+
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([lora_bias])
|
915
|
+
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([lora_bias])
|
916
|
+
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([lora_bias])
|
917
|
+
else:
|
918
|
+
q, k, v, mlp = torch.split(
|
919
|
+
original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.weight"), split_size, dim=0
|
920
|
+
)
|
921
|
+
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([q])
|
922
|
+
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([k])
|
923
|
+
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([v])
|
924
|
+
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([mlp])
|
925
|
+
|
926
|
+
if f"single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys:
|
927
|
+
q_bias, k_bias, v_bias, mlp_bias = torch.split(
|
928
|
+
original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias"), split_size, dim=0
|
929
|
+
)
|
930
|
+
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([q_bias])
|
931
|
+
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([k_bias])
|
932
|
+
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([v_bias])
|
933
|
+
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([mlp_bias])
|
934
|
+
|
935
|
+
# output projections.
|
936
|
+
converted_state_dict[f"{block_prefix}proj_out.{lora_key}.weight"] = original_state_dict.pop(
|
937
|
+
f"single_blocks.{i}.linear2.{lora_key}.weight"
|
938
|
+
)
|
939
|
+
if f"single_blocks.{i}.linear2.{lora_key}.bias" in original_state_dict_keys:
|
940
|
+
converted_state_dict[f"{block_prefix}proj_out.{lora_key}.bias"] = original_state_dict.pop(
|
941
|
+
f"single_blocks.{i}.linear2.{lora_key}.bias"
|
942
|
+
)
|
943
|
+
|
944
|
+
# qk norm
|
945
|
+
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
|
946
|
+
f"single_blocks.{i}.norm.query_norm.scale"
|
947
|
+
)
|
948
|
+
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
|
949
|
+
f"single_blocks.{i}.norm.key_norm.scale"
|
950
|
+
)
|
951
|
+
|
952
|
+
for lora_key in ["lora_A", "lora_B"]:
|
953
|
+
converted_state_dict[f"proj_out.{lora_key}.weight"] = original_state_dict.pop(
|
954
|
+
f"final_layer.linear.{lora_key}.weight"
|
955
|
+
)
|
956
|
+
if f"final_layer.linear.{lora_key}.bias" in original_state_dict_keys:
|
957
|
+
converted_state_dict[f"proj_out.{lora_key}.bias"] = original_state_dict.pop(
|
958
|
+
f"final_layer.linear.{lora_key}.bias"
|
959
|
+
)
|
960
|
+
|
961
|
+
converted_state_dict[f"norm_out.linear.{lora_key}.weight"] = swap_scale_shift(
|
962
|
+
original_state_dict.pop(f"final_layer.adaLN_modulation.1.{lora_key}.weight")
|
963
|
+
)
|
964
|
+
if f"final_layer.adaLN_modulation.1.{lora_key}.bias" in original_state_dict_keys:
|
965
|
+
converted_state_dict[f"norm_out.linear.{lora_key}.bias"] = swap_scale_shift(
|
966
|
+
original_state_dict.pop(f"final_layer.adaLN_modulation.1.{lora_key}.bias")
|
967
|
+
)
|
968
|
+
|
969
|
+
if len(original_state_dict) > 0:
|
970
|
+
raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.")
|
971
|
+
|
972
|
+
for key in list(converted_state_dict.keys()):
|
973
|
+
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
|
974
|
+
|
975
|
+
return converted_state_dict
|
976
|
+
|
977
|
+
|
978
|
+
def _convert_hunyuan_video_lora_to_diffusers(original_state_dict):
|
979
|
+
converted_state_dict = {k: original_state_dict.pop(k) for k in list(original_state_dict.keys())}
|
980
|
+
|
981
|
+
def remap_norm_scale_shift_(key, state_dict):
|
982
|
+
weight = state_dict.pop(key)
|
983
|
+
shift, scale = weight.chunk(2, dim=0)
|
984
|
+
new_weight = torch.cat([scale, shift], dim=0)
|
985
|
+
state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight
|
986
|
+
|
987
|
+
def remap_txt_in_(key, state_dict):
|
988
|
+
def rename_key(key):
|
989
|
+
new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks")
|
990
|
+
new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear")
|
991
|
+
new_key = new_key.replace("txt_in", "context_embedder")
|
992
|
+
new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1")
|
993
|
+
new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2")
|
994
|
+
new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder")
|
995
|
+
new_key = new_key.replace("mlp", "ff")
|
996
|
+
return new_key
|
997
|
+
|
998
|
+
if "self_attn_qkv" in key:
|
999
|
+
weight = state_dict.pop(key)
|
1000
|
+
to_q, to_k, to_v = weight.chunk(3, dim=0)
|
1001
|
+
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q
|
1002
|
+
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k
|
1003
|
+
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v
|
1004
|
+
else:
|
1005
|
+
state_dict[rename_key(key)] = state_dict.pop(key)
|
1006
|
+
|
1007
|
+
def remap_img_attn_qkv_(key, state_dict):
|
1008
|
+
weight = state_dict.pop(key)
|
1009
|
+
if "lora_A" in key:
|
1010
|
+
state_dict[key.replace("img_attn_qkv", "attn.to_q")] = weight
|
1011
|
+
state_dict[key.replace("img_attn_qkv", "attn.to_k")] = weight
|
1012
|
+
state_dict[key.replace("img_attn_qkv", "attn.to_v")] = weight
|
1013
|
+
else:
|
1014
|
+
to_q, to_k, to_v = weight.chunk(3, dim=0)
|
1015
|
+
state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q
|
1016
|
+
state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k
|
1017
|
+
state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v
|
1018
|
+
|
1019
|
+
def remap_txt_attn_qkv_(key, state_dict):
|
1020
|
+
weight = state_dict.pop(key)
|
1021
|
+
if "lora_A" in key:
|
1022
|
+
state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = weight
|
1023
|
+
state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = weight
|
1024
|
+
state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = weight
|
1025
|
+
else:
|
1026
|
+
to_q, to_k, to_v = weight.chunk(3, dim=0)
|
1027
|
+
state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q
|
1028
|
+
state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k
|
1029
|
+
state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v
|
1030
|
+
|
1031
|
+
def remap_single_transformer_blocks_(key, state_dict):
|
1032
|
+
hidden_size = 3072
|
1033
|
+
|
1034
|
+
if "linear1.lora_A.weight" in key or "linear1.lora_B.weight" in key:
|
1035
|
+
linear1_weight = state_dict.pop(key)
|
1036
|
+
if "lora_A" in key:
|
1037
|
+
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
|
1038
|
+
".linear1.lora_A.weight"
|
1039
|
+
)
|
1040
|
+
state_dict[f"{new_key}.attn.to_q.lora_A.weight"] = linear1_weight
|
1041
|
+
state_dict[f"{new_key}.attn.to_k.lora_A.weight"] = linear1_weight
|
1042
|
+
state_dict[f"{new_key}.attn.to_v.lora_A.weight"] = linear1_weight
|
1043
|
+
state_dict[f"{new_key}.proj_mlp.lora_A.weight"] = linear1_weight
|
1044
|
+
else:
|
1045
|
+
split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size)
|
1046
|
+
q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0)
|
1047
|
+
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
|
1048
|
+
".linear1.lora_B.weight"
|
1049
|
+
)
|
1050
|
+
state_dict[f"{new_key}.attn.to_q.lora_B.weight"] = q
|
1051
|
+
state_dict[f"{new_key}.attn.to_k.lora_B.weight"] = k
|
1052
|
+
state_dict[f"{new_key}.attn.to_v.lora_B.weight"] = v
|
1053
|
+
state_dict[f"{new_key}.proj_mlp.lora_B.weight"] = mlp
|
1054
|
+
|
1055
|
+
elif "linear1.lora_A.bias" in key or "linear1.lora_B.bias" in key:
|
1056
|
+
linear1_bias = state_dict.pop(key)
|
1057
|
+
if "lora_A" in key:
|
1058
|
+
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
|
1059
|
+
".linear1.lora_A.bias"
|
1060
|
+
)
|
1061
|
+
state_dict[f"{new_key}.attn.to_q.lora_A.bias"] = linear1_bias
|
1062
|
+
state_dict[f"{new_key}.attn.to_k.lora_A.bias"] = linear1_bias
|
1063
|
+
state_dict[f"{new_key}.attn.to_v.lora_A.bias"] = linear1_bias
|
1064
|
+
state_dict[f"{new_key}.proj_mlp.lora_A.bias"] = linear1_bias
|
1065
|
+
else:
|
1066
|
+
split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size)
|
1067
|
+
q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0)
|
1068
|
+
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
|
1069
|
+
".linear1.lora_B.bias"
|
1070
|
+
)
|
1071
|
+
state_dict[f"{new_key}.attn.to_q.lora_B.bias"] = q_bias
|
1072
|
+
state_dict[f"{new_key}.attn.to_k.lora_B.bias"] = k_bias
|
1073
|
+
state_dict[f"{new_key}.attn.to_v.lora_B.bias"] = v_bias
|
1074
|
+
state_dict[f"{new_key}.proj_mlp.lora_B.bias"] = mlp_bias
|
1075
|
+
|
1076
|
+
else:
|
1077
|
+
new_key = key.replace("single_blocks", "single_transformer_blocks")
|
1078
|
+
new_key = new_key.replace("linear2", "proj_out")
|
1079
|
+
new_key = new_key.replace("q_norm", "attn.norm_q")
|
1080
|
+
new_key = new_key.replace("k_norm", "attn.norm_k")
|
1081
|
+
state_dict[new_key] = state_dict.pop(key)
|
1082
|
+
|
1083
|
+
TRANSFORMER_KEYS_RENAME_DICT = {
|
1084
|
+
"img_in": "x_embedder",
|
1085
|
+
"time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1",
|
1086
|
+
"time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2",
|
1087
|
+
"guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1",
|
1088
|
+
"guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2",
|
1089
|
+
"vector_in.in_layer": "time_text_embed.text_embedder.linear_1",
|
1090
|
+
"vector_in.out_layer": "time_text_embed.text_embedder.linear_2",
|
1091
|
+
"double_blocks": "transformer_blocks",
|
1092
|
+
"img_attn_q_norm": "attn.norm_q",
|
1093
|
+
"img_attn_k_norm": "attn.norm_k",
|
1094
|
+
"img_attn_proj": "attn.to_out.0",
|
1095
|
+
"txt_attn_q_norm": "attn.norm_added_q",
|
1096
|
+
"txt_attn_k_norm": "attn.norm_added_k",
|
1097
|
+
"txt_attn_proj": "attn.to_add_out",
|
1098
|
+
"img_mod.linear": "norm1.linear",
|
1099
|
+
"img_norm1": "norm1.norm",
|
1100
|
+
"img_norm2": "norm2",
|
1101
|
+
"img_mlp": "ff",
|
1102
|
+
"txt_mod.linear": "norm1_context.linear",
|
1103
|
+
"txt_norm1": "norm1.norm",
|
1104
|
+
"txt_norm2": "norm2_context",
|
1105
|
+
"txt_mlp": "ff_context",
|
1106
|
+
"self_attn_proj": "attn.to_out.0",
|
1107
|
+
"modulation.linear": "norm.linear",
|
1108
|
+
"pre_norm": "norm.norm",
|
1109
|
+
"final_layer.norm_final": "norm_out.norm",
|
1110
|
+
"final_layer.linear": "proj_out",
|
1111
|
+
"fc1": "net.0.proj",
|
1112
|
+
"fc2": "net.2",
|
1113
|
+
"input_embedder": "proj_in",
|
1114
|
+
}
|
1115
|
+
|
1116
|
+
TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
1117
|
+
"txt_in": remap_txt_in_,
|
1118
|
+
"img_attn_qkv": remap_img_attn_qkv_,
|
1119
|
+
"txt_attn_qkv": remap_txt_attn_qkv_,
|
1120
|
+
"single_blocks": remap_single_transformer_blocks_,
|
1121
|
+
"final_layer.adaLN_modulation.1": remap_norm_scale_shift_,
|
1122
|
+
}
|
1123
|
+
|
1124
|
+
# Some folks attempt to make their state dict compatible with diffusers by adding "transformer." prefix to all keys
|
1125
|
+
# and use their custom code. To make sure both "original" and "attempted diffusers" loras work as expected, we make
|
1126
|
+
# sure that both follow the same initial format by stripping off the "transformer." prefix.
|
1127
|
+
for key in list(converted_state_dict.keys()):
|
1128
|
+
if key.startswith("transformer."):
|
1129
|
+
converted_state_dict[key[len("transformer.") :]] = converted_state_dict.pop(key)
|
1130
|
+
if key.startswith("diffusion_model."):
|
1131
|
+
converted_state_dict[key[len("diffusion_model.") :]] = converted_state_dict.pop(key)
|
1132
|
+
|
1133
|
+
# Rename and remap the state dict keys
|
1134
|
+
for key in list(converted_state_dict.keys()):
|
1135
|
+
new_key = key[:]
|
1136
|
+
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
1137
|
+
new_key = new_key.replace(replace_key, rename_key)
|
1138
|
+
converted_state_dict[new_key] = converted_state_dict.pop(key)
|
1139
|
+
|
1140
|
+
for key in list(converted_state_dict.keys()):
|
1141
|
+
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
|
1142
|
+
if special_key not in key:
|
1143
|
+
continue
|
1144
|
+
handler_fn_inplace(key, converted_state_dict)
|
1145
|
+
|
1146
|
+
# Add back the "transformer." prefix
|
1147
|
+
for key in list(converted_state_dict.keys()):
|
1148
|
+
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
|
1149
|
+
|
1150
|
+
return converted_state_dict
|