diffusers 0.32.2__py3-none-any.whl → 0.33.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +186 -3
- diffusers/configuration_utils.py +40 -12
- diffusers/dependency_versions_table.py +9 -2
- diffusers/hooks/__init__.py +9 -0
- diffusers/hooks/faster_cache.py +653 -0
- diffusers/hooks/group_offloading.py +793 -0
- diffusers/hooks/hooks.py +236 -0
- diffusers/hooks/layerwise_casting.py +245 -0
- diffusers/hooks/pyramid_attention_broadcast.py +311 -0
- diffusers/loaders/__init__.py +6 -0
- diffusers/loaders/ip_adapter.py +38 -30
- diffusers/loaders/lora_base.py +121 -86
- diffusers/loaders/lora_conversion_utils.py +504 -44
- diffusers/loaders/lora_pipeline.py +1769 -181
- diffusers/loaders/peft.py +167 -57
- diffusers/loaders/single_file.py +17 -2
- diffusers/loaders/single_file_model.py +53 -5
- diffusers/loaders/single_file_utils.py +646 -72
- diffusers/loaders/textual_inversion.py +9 -9
- diffusers/loaders/transformer_flux.py +8 -9
- diffusers/loaders/transformer_sd3.py +120 -39
- diffusers/loaders/unet.py +20 -7
- diffusers/models/__init__.py +22 -0
- diffusers/models/activations.py +9 -9
- diffusers/models/attention.py +0 -1
- diffusers/models/attention_processor.py +163 -25
- diffusers/models/auto_model.py +169 -0
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
- diffusers/models/autoencoders/autoencoder_dc.py +106 -4
- diffusers/models/autoencoders/autoencoder_kl.py +0 -4
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
- diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
- diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
- diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
- diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
- diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
- diffusers/models/autoencoders/vae.py +31 -141
- diffusers/models/autoencoders/vq_model.py +3 -0
- diffusers/models/cache_utils.py +108 -0
- diffusers/models/controlnets/__init__.py +1 -0
- diffusers/models/controlnets/controlnet.py +3 -8
- diffusers/models/controlnets/controlnet_flux.py +14 -42
- diffusers/models/controlnets/controlnet_sd3.py +58 -34
- diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
- diffusers/models/controlnets/controlnet_union.py +27 -18
- diffusers/models/controlnets/controlnet_xs.py +7 -46
- diffusers/models/controlnets/multicontrolnet_union.py +196 -0
- diffusers/models/embeddings.py +18 -7
- diffusers/models/model_loading_utils.py +122 -80
- diffusers/models/modeling_flax_pytorch_utils.py +1 -1
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +617 -272
- diffusers/models/normalization.py +67 -14
- diffusers/models/resnet.py +1 -1
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
- diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
- diffusers/models/transformers/consisid_transformer_3d.py +789 -0
- diffusers/models/transformers/dit_transformer_2d.py +5 -19
- diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
- diffusers/models/transformers/latte_transformer_3d.py +20 -15
- diffusers/models/transformers/lumina_nextdit2d.py +3 -1
- diffusers/models/transformers/pixart_transformer_2d.py +4 -19
- diffusers/models/transformers/prior_transformer.py +5 -1
- diffusers/models/transformers/sana_transformer.py +144 -40
- diffusers/models/transformers/stable_audio_transformer.py +5 -20
- diffusers/models/transformers/transformer_2d.py +7 -22
- diffusers/models/transformers/transformer_allegro.py +9 -17
- diffusers/models/transformers/transformer_cogview3plus.py +6 -17
- diffusers/models/transformers/transformer_cogview4.py +462 -0
- diffusers/models/transformers/transformer_easyanimate.py +527 -0
- diffusers/models/transformers/transformer_flux.py +68 -110
- diffusers/models/transformers/transformer_hunyuan_video.py +404 -46
- diffusers/models/transformers/transformer_ltx.py +53 -35
- diffusers/models/transformers/transformer_lumina2.py +548 -0
- diffusers/models/transformers/transformer_mochi.py +6 -17
- diffusers/models/transformers/transformer_omnigen.py +469 -0
- diffusers/models/transformers/transformer_sd3.py +56 -86
- diffusers/models/transformers/transformer_temporal.py +5 -11
- diffusers/models/transformers/transformer_wan.py +469 -0
- diffusers/models/unets/unet_1d.py +3 -1
- diffusers/models/unets/unet_2d.py +21 -20
- diffusers/models/unets/unet_2d_blocks.py +19 -243
- diffusers/models/unets/unet_2d_condition.py +4 -6
- diffusers/models/unets/unet_3d_blocks.py +14 -127
- diffusers/models/unets/unet_3d_condition.py +8 -12
- diffusers/models/unets/unet_i2vgen_xl.py +5 -13
- diffusers/models/unets/unet_kandinsky3.py +0 -4
- diffusers/models/unets/unet_motion_model.py +20 -114
- diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
- diffusers/models/unets/unet_stable_cascade.py +8 -35
- diffusers/models/unets/uvit_2d.py +1 -4
- diffusers/optimization.py +2 -2
- diffusers/pipelines/__init__.py +57 -8
- diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
- diffusers/pipelines/amused/pipeline_amused.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
- diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
- diffusers/pipelines/auto_pipeline.py +35 -14
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
- diffusers/pipelines/cogview4/__init__.py +49 -0
- diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
- diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
- diffusers/pipelines/cogview4/pipeline_output.py +21 -0
- diffusers/pipelines/consisid/__init__.py +49 -0
- diffusers/pipelines/consisid/consisid_utils.py +357 -0
- diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
- diffusers/pipelines/consisid/pipeline_output.py +20 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
- diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
- diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +15 -2
- diffusers/pipelines/easyanimate/__init__.py +52 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
- diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
- diffusers/pipelines/flux/pipeline_flux.py +53 -21
- diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
- diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
- diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
- diffusers/pipelines/free_noise_utils.py +3 -3
- diffusers/pipelines/hunyuan_video/__init__.py +4 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
- diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
- diffusers/pipelines/kolors/text_encoder.py +7 -34
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
- diffusers/pipelines/latte/pipeline_latte.py +36 -7
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
- diffusers/pipelines/ltx/__init__.py +2 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
- diffusers/pipelines/lumina/__init__.py +2 -2
- diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
- diffusers/pipelines/lumina2/__init__.py +48 -0
- diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
- diffusers/pipelines/marigold/__init__.py +2 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
- diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
- diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
- diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
- diffusers/pipelines/omnigen/__init__.py +50 -0
- diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
- diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
- diffusers/pipelines/onnx_utils.py +5 -3
- diffusers/pipelines/pag/pag_utils.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
- diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
- diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
- diffusers/pipelines/pia/pipeline_pia.py +13 -1
- diffusers/pipelines/pipeline_flax_utils.py +7 -7
- diffusers/pipelines/pipeline_loading_utils.py +193 -83
- diffusers/pipelines/pipeline_utils.py +221 -106
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
- diffusers/pipelines/sana/__init__.py +2 -0
- diffusers/pipelines/sana/pipeline_sana.py +183 -58
- diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
- diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
- diffusers/pipelines/shap_e/renderer.py +6 -6
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
- diffusers/pipelines/transformers_loading_utils.py +121 -0
- diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
- diffusers/pipelines/wan/__init__.py +51 -0
- diffusers/pipelines/wan/pipeline_output.py +20 -0
- diffusers/pipelines/wan/pipeline_wan.py +595 -0
- diffusers/pipelines/wan/pipeline_wan_i2v.py +724 -0
- diffusers/pipelines/wan/pipeline_wan_video2video.py +727 -0
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
- diffusers/quantizers/auto.py +5 -1
- diffusers/quantizers/base.py +5 -9
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
- diffusers/quantizers/bitsandbytes/utils.py +30 -20
- diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
- diffusers/quantizers/gguf/utils.py +4 -2
- diffusers/quantizers/quantization_config.py +59 -4
- diffusers/quantizers/quanto/__init__.py +1 -0
- diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
- diffusers/quantizers/quanto/utils.py +60 -0
- diffusers/quantizers/torchao/__init__.py +1 -1
- diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
- diffusers/schedulers/__init__.py +2 -1
- diffusers/schedulers/scheduling_consistency_models.py +1 -2
- diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
- diffusers/schedulers/scheduling_ddpm.py +2 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
- diffusers/schedulers/scheduling_edm_euler.py +45 -10
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
- diffusers/schedulers/scheduling_heun_discrete.py +1 -1
- diffusers/schedulers/scheduling_lcm.py +1 -2
- diffusers/schedulers/scheduling_lms_discrete.py +1 -1
- diffusers/schedulers/scheduling_repaint.py +5 -1
- diffusers/schedulers/scheduling_scm.py +265 -0
- diffusers/schedulers/scheduling_tcd.py +1 -2
- diffusers/schedulers/scheduling_utils.py +2 -1
- diffusers/training_utils.py +14 -7
- diffusers/utils/__init__.py +9 -1
- diffusers/utils/constants.py +13 -1
- diffusers/utils/deprecation_utils.py +1 -1
- diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
- diffusers/utils/dummy_gguf_objects.py +17 -0
- diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
- diffusers/utils/dummy_pt_objects.py +233 -0
- diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
- diffusers/utils/dummy_torchao_objects.py +17 -0
- diffusers/utils/dynamic_modules_utils.py +1 -1
- diffusers/utils/export_utils.py +28 -3
- diffusers/utils/hub_utils.py +52 -102
- diffusers/utils/import_utils.py +121 -221
- diffusers/utils/loading_utils.py +2 -1
- diffusers/utils/logging.py +1 -2
- diffusers/utils/peft_utils.py +6 -14
- diffusers/utils/remote_utils.py +425 -0
- diffusers/utils/source_code_parsing_utils.py +52 -0
- diffusers/utils/state_dict_utils.py +15 -1
- diffusers/utils/testing_utils.py +243 -13
- diffusers/utils/torch_utils.py +10 -0
- diffusers/utils/typing_utils.py +91 -0
- diffusers/video_processor.py +1 -1
- {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/METADATA +21 -4
- diffusers-0.33.1.dist-info/RECORD +608 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/WHEEL +1 -1
- diffusers-0.32.2.dist-info/RECORD +0 -550
- {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/LICENSE +0 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
|
|
1
1
|
# coding=utf-8
|
2
|
-
# Copyright
|
2
|
+
# Copyright 2025 The HuggingFace Inc. team.
|
3
3
|
#
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
5
|
# you may not use this file except in compliance with the License.
|
@@ -44,6 +44,7 @@ from ..utils import (
|
|
44
44
|
is_transformers_available,
|
45
45
|
logging,
|
46
46
|
)
|
47
|
+
from ..utils.constants import DIFFUSERS_REQUEST_TIMEOUT
|
47
48
|
from ..utils.hub_utils import _get_model_file
|
48
49
|
|
49
50
|
|
@@ -94,6 +95,12 @@ CHECKPOINT_KEY_NAMES = {
|
|
94
95
|
"animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
|
95
96
|
"animatediff_scribble": "controlnet_cond_embedding.conv_in.weight",
|
96
97
|
"animatediff_rgb": "controlnet_cond_embedding.weight",
|
98
|
+
"auraflow": [
|
99
|
+
"double_layers.0.attn.w2q.weight",
|
100
|
+
"double_layers.0.attn.w1q.weight",
|
101
|
+
"cond_seq_linear.weight",
|
102
|
+
"t_embedder.mlp.0.weight",
|
103
|
+
],
|
97
104
|
"flux": [
|
98
105
|
"double_blocks.0.img_attn.norm.key_norm.scale",
|
99
106
|
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
|
@@ -109,6 +116,16 @@ CHECKPOINT_KEY_NAMES = {
|
|
109
116
|
"autoencoder-dc-sana": "encoder.project_in.conv.bias",
|
110
117
|
"mochi-1-preview": ["model.diffusion_model.blocks.0.attn.qkv_x.weight", "blocks.0.attn.qkv_x.weight"],
|
111
118
|
"hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias",
|
119
|
+
"instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight",
|
120
|
+
"lumina2": ["model.diffusion_model.cap_embedder.0.weight", "cap_embedder.0.weight"],
|
121
|
+
"sana": [
|
122
|
+
"blocks.0.cross_attn.q_linear.weight",
|
123
|
+
"blocks.0.cross_attn.q_linear.bias",
|
124
|
+
"blocks.0.cross_attn.kv_linear.weight",
|
125
|
+
"blocks.0.cross_attn.kv_linear.bias",
|
126
|
+
],
|
127
|
+
"wan": ["model.diffusion_model.head.modulation", "head.modulation"],
|
128
|
+
"wan_vae": "decoder.middle.0.residual.0.gamma",
|
112
129
|
}
|
113
130
|
|
114
131
|
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
@@ -153,6 +170,7 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
|
153
170
|
"animatediff_sdxl_beta": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-sdxl-beta"},
|
154
171
|
"animatediff_scribble": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-scribble"},
|
155
172
|
"animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"},
|
173
|
+
"auraflow": {"pretrained_model_name_or_path": "fal/AuraFlow-v0.3"},
|
156
174
|
"flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"},
|
157
175
|
"flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"},
|
158
176
|
"flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"},
|
@@ -165,6 +183,12 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
|
165
183
|
"autoencoder-dc-f32c32-sana": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"},
|
166
184
|
"mochi-1-preview": {"pretrained_model_name_or_path": "genmo/mochi-1-preview"},
|
167
185
|
"hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"},
|
186
|
+
"instruct-pix2pix": {"pretrained_model_name_or_path": "timbrooks/instruct-pix2pix"},
|
187
|
+
"lumina2": {"pretrained_model_name_or_path": "Alpha-VLLM/Lumina-Image-2.0"},
|
188
|
+
"sana": {"pretrained_model_name_or_path": "Efficient-Large-Model/Sana_1600M_1024px_diffusers"},
|
189
|
+
"wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"},
|
190
|
+
"wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"},
|
191
|
+
"wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"},
|
168
192
|
}
|
169
193
|
|
170
194
|
# Use to configure model sample size when original config is provided
|
@@ -177,6 +201,7 @@ DIFFUSERS_TO_LDM_DEFAULT_IMAGE_SIZE_MAP = {
|
|
177
201
|
"inpainting": 512,
|
178
202
|
"inpainting_v2": 512,
|
179
203
|
"controlnet": 512,
|
204
|
+
"instruct-pix2pix": 512,
|
180
205
|
"v2": 768,
|
181
206
|
"v1": 512,
|
182
207
|
}
|
@@ -378,12 +403,14 @@ def load_single_file_checkpoint(
|
|
378
403
|
cache_dir=None,
|
379
404
|
local_files_only=None,
|
380
405
|
revision=None,
|
406
|
+
disable_mmap=False,
|
381
407
|
):
|
382
408
|
if os.path.isfile(pretrained_model_link_or_path):
|
383
409
|
pretrained_model_link_or_path = pretrained_model_link_or_path
|
384
410
|
|
385
411
|
else:
|
386
412
|
repo_id, weights_name = _extract_repo_id_and_weights_name(pretrained_model_link_or_path)
|
413
|
+
user_agent = {"file_type": "single_file", "framework": "pytorch"}
|
387
414
|
pretrained_model_link_or_path = _get_model_file(
|
388
415
|
repo_id,
|
389
416
|
weights_name=weights_name,
|
@@ -393,9 +420,10 @@ def load_single_file_checkpoint(
|
|
393
420
|
local_files_only=local_files_only,
|
394
421
|
token=token,
|
395
422
|
revision=revision,
|
423
|
+
user_agent=user_agent,
|
396
424
|
)
|
397
425
|
|
398
|
-
checkpoint = load_state_dict(pretrained_model_link_or_path)
|
426
|
+
checkpoint = load_state_dict(pretrained_model_link_or_path, disable_mmap=disable_mmap)
|
399
427
|
|
400
428
|
# some checkpoints contain the model state dict under a "state_dict" key
|
401
429
|
while "state_dict" in checkpoint:
|
@@ -416,7 +444,7 @@ def fetch_original_config(original_config_file, local_files_only=False):
|
|
416
444
|
"Please provide a valid local file path."
|
417
445
|
)
|
418
446
|
|
419
|
-
original_config_file = BytesIO(requests.get(original_config_file).content)
|
447
|
+
original_config_file = BytesIO(requests.get(original_config_file, timeout=DIFFUSERS_REQUEST_TIMEOUT).content)
|
420
448
|
|
421
449
|
else:
|
422
450
|
raise ValueError("Invalid `original_config_file` provided. Please set it to a valid file path or URL.")
|
@@ -637,6 +665,36 @@ def infer_diffusers_model_type(checkpoint):
|
|
637
665
|
elif CHECKPOINT_KEY_NAMES["hunyuan-video"] in checkpoint:
|
638
666
|
model_type = "hunyuan-video"
|
639
667
|
|
668
|
+
elif all(key in checkpoint for key in CHECKPOINT_KEY_NAMES["auraflow"]):
|
669
|
+
model_type = "auraflow"
|
670
|
+
|
671
|
+
elif (
|
672
|
+
CHECKPOINT_KEY_NAMES["instruct-pix2pix"] in checkpoint
|
673
|
+
and checkpoint[CHECKPOINT_KEY_NAMES["instruct-pix2pix"]].shape[1] == 8
|
674
|
+
):
|
675
|
+
model_type = "instruct-pix2pix"
|
676
|
+
|
677
|
+
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["lumina2"]):
|
678
|
+
model_type = "lumina2"
|
679
|
+
|
680
|
+
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sana"]):
|
681
|
+
model_type = "sana"
|
682
|
+
|
683
|
+
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["wan"]):
|
684
|
+
if "model.diffusion_model.patch_embedding.weight" in checkpoint:
|
685
|
+
target_key = "model.diffusion_model.patch_embedding.weight"
|
686
|
+
else:
|
687
|
+
target_key = "patch_embedding.weight"
|
688
|
+
|
689
|
+
if checkpoint[target_key].shape[0] == 1536:
|
690
|
+
model_type = "wan-t2v-1.3B"
|
691
|
+
elif checkpoint[target_key].shape[0] == 5120 and checkpoint[target_key].shape[1] == 16:
|
692
|
+
model_type = "wan-t2v-14B"
|
693
|
+
else:
|
694
|
+
model_type = "wan-i2v-14B"
|
695
|
+
elif CHECKPOINT_KEY_NAMES["wan_vae"] in checkpoint:
|
696
|
+
# All Wan models use the same VAE so we can use the same default model repo to fetch the config
|
697
|
+
model_type = "wan-t2v-14B"
|
640
698
|
else:
|
641
699
|
model_type = "v1"
|
642
700
|
|
@@ -1423,8 +1481,8 @@ def convert_open_clip_checkpoint(
|
|
1423
1481
|
|
1424
1482
|
if text_proj_key in checkpoint:
|
1425
1483
|
text_proj_dim = int(checkpoint[text_proj_key].shape[0])
|
1426
|
-
elif hasattr(text_model.config, "
|
1427
|
-
text_proj_dim = text_model.config.
|
1484
|
+
elif hasattr(text_model.config, "hidden_size"):
|
1485
|
+
text_proj_dim = text_model.config.hidden_size
|
1428
1486
|
else:
|
1429
1487
|
text_proj_dim = LDM_OPEN_CLIP_TEXT_PROJECTION_DIM
|
1430
1488
|
|
@@ -1568,18 +1626,9 @@ def create_diffusers_clip_model_from_ldm(
|
|
1568
1626
|
raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.")
|
1569
1627
|
|
1570
1628
|
if is_accelerate_available():
|
1571
|
-
|
1629
|
+
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
|
1572
1630
|
else:
|
1573
|
-
|
1574
|
-
|
1575
|
-
if model._keys_to_ignore_on_load_unexpected is not None:
|
1576
|
-
for pat in model._keys_to_ignore_on_load_unexpected:
|
1577
|
-
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
1578
|
-
|
1579
|
-
if len(unexpected_keys) > 0:
|
1580
|
-
logger.warning(
|
1581
|
-
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
1582
|
-
)
|
1631
|
+
model.load_state_dict(diffusers_format_checkpoint, strict=False)
|
1583
1632
|
|
1584
1633
|
if torch_dtype is not None:
|
1585
1634
|
model.to(torch_dtype)
|
@@ -2036,16 +2085,7 @@ def create_diffusers_t5_model_from_checkpoint(
|
|
2036
2085
|
diffusers_format_checkpoint = convert_sd3_t5_checkpoint_to_diffusers(checkpoint)
|
2037
2086
|
|
2038
2087
|
if is_accelerate_available():
|
2039
|
-
|
2040
|
-
if model._keys_to_ignore_on_load_unexpected is not None:
|
2041
|
-
for pat in model._keys_to_ignore_on_load_unexpected:
|
2042
|
-
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
2043
|
-
|
2044
|
-
if len(unexpected_keys) > 0:
|
2045
|
-
logger.warning(
|
2046
|
-
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
2047
|
-
)
|
2048
|
-
|
2088
|
+
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
|
2049
2089
|
else:
|
2050
2090
|
model.load_state_dict(diffusers_format_checkpoint)
|
2051
2091
|
|
@@ -2086,6 +2126,7 @@ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs):
|
|
2086
2126
|
def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
2087
2127
|
converted_state_dict = {}
|
2088
2128
|
keys = list(checkpoint.keys())
|
2129
|
+
|
2089
2130
|
for k in keys:
|
2090
2131
|
if "model.diffusion_model." in k:
|
2091
2132
|
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
|
@@ -2366,7 +2407,6 @@ def convert_ltx_vae_checkpoint_to_diffusers(checkpoint, **kwargs):
|
|
2366
2407
|
"per_channel_statistics.channel": remove_keys_,
|
2367
2408
|
"per_channel_statistics.mean-of-means": remove_keys_,
|
2368
2409
|
"per_channel_statistics.mean-of-stds": remove_keys_,
|
2369
|
-
"timestep_scale_multiplier": remove_keys_,
|
2370
2410
|
}
|
2371
2411
|
|
2372
2412
|
if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in converted_state_dict:
|
@@ -2460,7 +2500,7 @@ def convert_autoencoder_dc_checkpoint_to_diffusers(checkpoint, **kwargs):
|
|
2460
2500
|
|
2461
2501
|
|
2462
2502
|
def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
2463
|
-
|
2503
|
+
converted_state_dict = {}
|
2464
2504
|
|
2465
2505
|
# Comfy checkpoints add this prefix
|
2466
2506
|
keys = list(checkpoint.keys())
|
@@ -2469,22 +2509,22 @@ def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
|
2469
2509
|
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
|
2470
2510
|
|
2471
2511
|
# Convert patch_embed
|
2472
|
-
|
2473
|
-
|
2512
|
+
converted_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
|
2513
|
+
converted_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
|
2474
2514
|
|
2475
2515
|
# Convert time_embed
|
2476
|
-
|
2477
|
-
|
2478
|
-
|
2479
|
-
|
2480
|
-
|
2481
|
-
|
2482
|
-
|
2483
|
-
|
2484
|
-
|
2485
|
-
|
2486
|
-
|
2487
|
-
|
2516
|
+
converted_state_dict["time_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight")
|
2517
|
+
converted_state_dict["time_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
|
2518
|
+
converted_state_dict["time_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight")
|
2519
|
+
converted_state_dict["time_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
|
2520
|
+
converted_state_dict["time_embed.pooler.to_kv.weight"] = checkpoint.pop("t5_y_embedder.to_kv.weight")
|
2521
|
+
converted_state_dict["time_embed.pooler.to_kv.bias"] = checkpoint.pop("t5_y_embedder.to_kv.bias")
|
2522
|
+
converted_state_dict["time_embed.pooler.to_q.weight"] = checkpoint.pop("t5_y_embedder.to_q.weight")
|
2523
|
+
converted_state_dict["time_embed.pooler.to_q.bias"] = checkpoint.pop("t5_y_embedder.to_q.bias")
|
2524
|
+
converted_state_dict["time_embed.pooler.to_out.weight"] = checkpoint.pop("t5_y_embedder.to_out.weight")
|
2525
|
+
converted_state_dict["time_embed.pooler.to_out.bias"] = checkpoint.pop("t5_y_embedder.to_out.bias")
|
2526
|
+
converted_state_dict["time_embed.caption_proj.weight"] = checkpoint.pop("t5_yproj.weight")
|
2527
|
+
converted_state_dict["time_embed.caption_proj.bias"] = checkpoint.pop("t5_yproj.bias")
|
2488
2528
|
|
2489
2529
|
# Convert transformer blocks
|
2490
2530
|
num_layers = 48
|
@@ -2493,68 +2533,84 @@ def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
|
2493
2533
|
old_prefix = f"blocks.{i}."
|
2494
2534
|
|
2495
2535
|
# norm1
|
2496
|
-
|
2497
|
-
|
2536
|
+
converted_state_dict[block_prefix + "norm1.linear.weight"] = checkpoint.pop(old_prefix + "mod_x.weight")
|
2537
|
+
converted_state_dict[block_prefix + "norm1.linear.bias"] = checkpoint.pop(old_prefix + "mod_x.bias")
|
2498
2538
|
if i < num_layers - 1:
|
2499
|
-
|
2500
|
-
|
2539
|
+
converted_state_dict[block_prefix + "norm1_context.linear.weight"] = checkpoint.pop(
|
2540
|
+
old_prefix + "mod_y.weight"
|
2541
|
+
)
|
2542
|
+
converted_state_dict[block_prefix + "norm1_context.linear.bias"] = checkpoint.pop(
|
2543
|
+
old_prefix + "mod_y.bias"
|
2544
|
+
)
|
2501
2545
|
else:
|
2502
|
-
|
2546
|
+
converted_state_dict[block_prefix + "norm1_context.linear_1.weight"] = checkpoint.pop(
|
2503
2547
|
old_prefix + "mod_y.weight"
|
2504
2548
|
)
|
2505
|
-
|
2549
|
+
converted_state_dict[block_prefix + "norm1_context.linear_1.bias"] = checkpoint.pop(
|
2550
|
+
old_prefix + "mod_y.bias"
|
2551
|
+
)
|
2506
2552
|
|
2507
2553
|
# Visual attention
|
2508
2554
|
qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_x.weight")
|
2509
2555
|
q, k, v = qkv_weight.chunk(3, dim=0)
|
2510
2556
|
|
2511
|
-
|
2512
|
-
|
2513
|
-
|
2514
|
-
|
2515
|
-
|
2516
|
-
|
2517
|
-
|
2557
|
+
converted_state_dict[block_prefix + "attn1.to_q.weight"] = q
|
2558
|
+
converted_state_dict[block_prefix + "attn1.to_k.weight"] = k
|
2559
|
+
converted_state_dict[block_prefix + "attn1.to_v.weight"] = v
|
2560
|
+
converted_state_dict[block_prefix + "attn1.norm_q.weight"] = checkpoint.pop(
|
2561
|
+
old_prefix + "attn.q_norm_x.weight"
|
2562
|
+
)
|
2563
|
+
converted_state_dict[block_prefix + "attn1.norm_k.weight"] = checkpoint.pop(
|
2564
|
+
old_prefix + "attn.k_norm_x.weight"
|
2565
|
+
)
|
2566
|
+
converted_state_dict[block_prefix + "attn1.to_out.0.weight"] = checkpoint.pop(
|
2567
|
+
old_prefix + "attn.proj_x.weight"
|
2568
|
+
)
|
2569
|
+
converted_state_dict[block_prefix + "attn1.to_out.0.bias"] = checkpoint.pop(old_prefix + "attn.proj_x.bias")
|
2518
2570
|
|
2519
2571
|
# Context attention
|
2520
2572
|
qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_y.weight")
|
2521
2573
|
q, k, v = qkv_weight.chunk(3, dim=0)
|
2522
2574
|
|
2523
|
-
|
2524
|
-
|
2525
|
-
|
2526
|
-
|
2575
|
+
converted_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q
|
2576
|
+
converted_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k
|
2577
|
+
converted_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v
|
2578
|
+
converted_state_dict[block_prefix + "attn1.norm_added_q.weight"] = checkpoint.pop(
|
2527
2579
|
old_prefix + "attn.q_norm_y.weight"
|
2528
2580
|
)
|
2529
|
-
|
2581
|
+
converted_state_dict[block_prefix + "attn1.norm_added_k.weight"] = checkpoint.pop(
|
2530
2582
|
old_prefix + "attn.k_norm_y.weight"
|
2531
2583
|
)
|
2532
2584
|
if i < num_layers - 1:
|
2533
|
-
|
2585
|
+
converted_state_dict[block_prefix + "attn1.to_add_out.weight"] = checkpoint.pop(
|
2534
2586
|
old_prefix + "attn.proj_y.weight"
|
2535
2587
|
)
|
2536
|
-
|
2588
|
+
converted_state_dict[block_prefix + "attn1.to_add_out.bias"] = checkpoint.pop(
|
2589
|
+
old_prefix + "attn.proj_y.bias"
|
2590
|
+
)
|
2537
2591
|
|
2538
2592
|
# MLP
|
2539
|
-
|
2593
|
+
converted_state_dict[block_prefix + "ff.net.0.proj.weight"] = swap_proj_gate(
|
2540
2594
|
checkpoint.pop(old_prefix + "mlp_x.w1.weight")
|
2541
2595
|
)
|
2542
|
-
|
2596
|
+
converted_state_dict[block_prefix + "ff.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_x.w2.weight")
|
2543
2597
|
if i < num_layers - 1:
|
2544
|
-
|
2598
|
+
converted_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = swap_proj_gate(
|
2545
2599
|
checkpoint.pop(old_prefix + "mlp_y.w1.weight")
|
2546
2600
|
)
|
2547
|
-
|
2601
|
+
converted_state_dict[block_prefix + "ff_context.net.2.weight"] = checkpoint.pop(
|
2602
|
+
old_prefix + "mlp_y.w2.weight"
|
2603
|
+
)
|
2548
2604
|
|
2549
2605
|
# Output layers
|
2550
|
-
|
2551
|
-
|
2552
|
-
|
2553
|
-
|
2606
|
+
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(checkpoint.pop("final_layer.mod.weight"), dim=0)
|
2607
|
+
converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(checkpoint.pop("final_layer.mod.bias"), dim=0)
|
2608
|
+
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
|
2609
|
+
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
|
2554
2610
|
|
2555
|
-
|
2611
|
+
converted_state_dict["pos_frequencies"] = checkpoint.pop("pos_frequencies")
|
2556
2612
|
|
2557
|
-
return
|
2613
|
+
return converted_state_dict
|
2558
2614
|
|
2559
2615
|
|
2560
2616
|
def convert_hunyuan_video_transformer_to_diffusers(checkpoint, **kwargs):
|
@@ -2685,3 +2741,521 @@ def convert_hunyuan_video_transformer_to_diffusers(checkpoint, **kwargs):
|
|
2685
2741
|
handler_fn_inplace(key, checkpoint)
|
2686
2742
|
|
2687
2743
|
return checkpoint
|
2744
|
+
|
2745
|
+
|
2746
|
+
def convert_auraflow_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
2747
|
+
converted_state_dict = {}
|
2748
|
+
state_dict_keys = list(checkpoint.keys())
|
2749
|
+
|
2750
|
+
# Handle register tokens and positional embeddings
|
2751
|
+
converted_state_dict["register_tokens"] = checkpoint.pop("register_tokens", None)
|
2752
|
+
|
2753
|
+
# Handle time step projection
|
2754
|
+
converted_state_dict["time_step_proj.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight", None)
|
2755
|
+
converted_state_dict["time_step_proj.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias", None)
|
2756
|
+
converted_state_dict["time_step_proj.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight", None)
|
2757
|
+
converted_state_dict["time_step_proj.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias", None)
|
2758
|
+
|
2759
|
+
# Handle context embedder
|
2760
|
+
converted_state_dict["context_embedder.weight"] = checkpoint.pop("cond_seq_linear.weight", None)
|
2761
|
+
|
2762
|
+
# Calculate the number of layers
|
2763
|
+
def calculate_layers(keys, key_prefix):
|
2764
|
+
layers = set()
|
2765
|
+
for k in keys:
|
2766
|
+
if key_prefix in k:
|
2767
|
+
layer_num = int(k.split(".")[1]) # get the layer number
|
2768
|
+
layers.add(layer_num)
|
2769
|
+
return len(layers)
|
2770
|
+
|
2771
|
+
mmdit_layers = calculate_layers(state_dict_keys, key_prefix="double_layers")
|
2772
|
+
single_dit_layers = calculate_layers(state_dict_keys, key_prefix="single_layers")
|
2773
|
+
|
2774
|
+
# MMDiT blocks
|
2775
|
+
for i in range(mmdit_layers):
|
2776
|
+
# Feed-forward
|
2777
|
+
path_mapping = {"mlpX": "ff", "mlpC": "ff_context"}
|
2778
|
+
weight_mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"}
|
2779
|
+
for orig_k, diffuser_k in path_mapping.items():
|
2780
|
+
for k, v in weight_mapping.items():
|
2781
|
+
converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.{v}.weight"] = checkpoint.pop(
|
2782
|
+
f"double_layers.{i}.{orig_k}.{k}.weight", None
|
2783
|
+
)
|
2784
|
+
|
2785
|
+
# Norms
|
2786
|
+
path_mapping = {"modX": "norm1", "modC": "norm1_context"}
|
2787
|
+
for orig_k, diffuser_k in path_mapping.items():
|
2788
|
+
converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.linear.weight"] = checkpoint.pop(
|
2789
|
+
f"double_layers.{i}.{orig_k}.1.weight", None
|
2790
|
+
)
|
2791
|
+
|
2792
|
+
# Attentions
|
2793
|
+
x_attn_mapping = {"w2q": "to_q", "w2k": "to_k", "w2v": "to_v", "w2o": "to_out.0"}
|
2794
|
+
context_attn_mapping = {"w1q": "add_q_proj", "w1k": "add_k_proj", "w1v": "add_v_proj", "w1o": "to_add_out"}
|
2795
|
+
for attn_mapping in [x_attn_mapping, context_attn_mapping]:
|
2796
|
+
for k, v in attn_mapping.items():
|
2797
|
+
converted_state_dict[f"joint_transformer_blocks.{i}.attn.{v}.weight"] = checkpoint.pop(
|
2798
|
+
f"double_layers.{i}.attn.{k}.weight", None
|
2799
|
+
)
|
2800
|
+
|
2801
|
+
# Single-DiT blocks
|
2802
|
+
for i in range(single_dit_layers):
|
2803
|
+
# Feed-forward
|
2804
|
+
mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"}
|
2805
|
+
for k, v in mapping.items():
|
2806
|
+
converted_state_dict[f"single_transformer_blocks.{i}.ff.{v}.weight"] = checkpoint.pop(
|
2807
|
+
f"single_layers.{i}.mlp.{k}.weight", None
|
2808
|
+
)
|
2809
|
+
|
2810
|
+
# Norms
|
2811
|
+
converted_state_dict[f"single_transformer_blocks.{i}.norm1.linear.weight"] = checkpoint.pop(
|
2812
|
+
f"single_layers.{i}.modCX.1.weight", None
|
2813
|
+
)
|
2814
|
+
|
2815
|
+
# Attentions
|
2816
|
+
x_attn_mapping = {"w1q": "to_q", "w1k": "to_k", "w1v": "to_v", "w1o": "to_out.0"}
|
2817
|
+
for k, v in x_attn_mapping.items():
|
2818
|
+
converted_state_dict[f"single_transformer_blocks.{i}.attn.{v}.weight"] = checkpoint.pop(
|
2819
|
+
f"single_layers.{i}.attn.{k}.weight", None
|
2820
|
+
)
|
2821
|
+
# Final blocks
|
2822
|
+
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_linear.weight", None)
|
2823
|
+
|
2824
|
+
# Handle the final norm layer
|
2825
|
+
norm_weight = checkpoint.pop("modF.1.weight", None)
|
2826
|
+
if norm_weight is not None:
|
2827
|
+
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(norm_weight, dim=None)
|
2828
|
+
else:
|
2829
|
+
converted_state_dict["norm_out.linear.weight"] = None
|
2830
|
+
|
2831
|
+
converted_state_dict["pos_embed.pos_embed"] = checkpoint.pop("positional_encoding")
|
2832
|
+
converted_state_dict["pos_embed.proj.weight"] = checkpoint.pop("init_x_linear.weight")
|
2833
|
+
converted_state_dict["pos_embed.proj.bias"] = checkpoint.pop("init_x_linear.bias")
|
2834
|
+
|
2835
|
+
return converted_state_dict
|
2836
|
+
|
2837
|
+
|
2838
|
+
def convert_lumina2_to_diffusers(checkpoint, **kwargs):
|
2839
|
+
converted_state_dict = {}
|
2840
|
+
|
2841
|
+
# Original Lumina-Image-2 has an extra norm paramter that is unused
|
2842
|
+
# We just remove it here
|
2843
|
+
checkpoint.pop("norm_final.weight", None)
|
2844
|
+
|
2845
|
+
# Comfy checkpoints add this prefix
|
2846
|
+
keys = list(checkpoint.keys())
|
2847
|
+
for k in keys:
|
2848
|
+
if "model.diffusion_model." in k:
|
2849
|
+
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
|
2850
|
+
|
2851
|
+
LUMINA_KEY_MAP = {
|
2852
|
+
"cap_embedder": "time_caption_embed.caption_embedder",
|
2853
|
+
"t_embedder.mlp.0": "time_caption_embed.timestep_embedder.linear_1",
|
2854
|
+
"t_embedder.mlp.2": "time_caption_embed.timestep_embedder.linear_2",
|
2855
|
+
"attention": "attn",
|
2856
|
+
".out.": ".to_out.0.",
|
2857
|
+
"k_norm": "norm_k",
|
2858
|
+
"q_norm": "norm_q",
|
2859
|
+
"w1": "linear_1",
|
2860
|
+
"w2": "linear_2",
|
2861
|
+
"w3": "linear_3",
|
2862
|
+
"adaLN_modulation.1": "norm1.linear",
|
2863
|
+
}
|
2864
|
+
ATTENTION_NORM_MAP = {
|
2865
|
+
"attention_norm1": "norm1.norm",
|
2866
|
+
"attention_norm2": "norm2",
|
2867
|
+
}
|
2868
|
+
CONTEXT_REFINER_MAP = {
|
2869
|
+
"context_refiner.0.attention_norm1": "context_refiner.0.norm1",
|
2870
|
+
"context_refiner.0.attention_norm2": "context_refiner.0.norm2",
|
2871
|
+
"context_refiner.1.attention_norm1": "context_refiner.1.norm1",
|
2872
|
+
"context_refiner.1.attention_norm2": "context_refiner.1.norm2",
|
2873
|
+
}
|
2874
|
+
FINAL_LAYER_MAP = {
|
2875
|
+
"final_layer.adaLN_modulation.1": "norm_out.linear_1",
|
2876
|
+
"final_layer.linear": "norm_out.linear_2",
|
2877
|
+
}
|
2878
|
+
|
2879
|
+
def convert_lumina_attn_to_diffusers(tensor, diffusers_key):
|
2880
|
+
q_dim = 2304
|
2881
|
+
k_dim = v_dim = 768
|
2882
|
+
|
2883
|
+
to_q, to_k, to_v = torch.split(tensor, [q_dim, k_dim, v_dim], dim=0)
|
2884
|
+
|
2885
|
+
return {
|
2886
|
+
diffusers_key.replace("qkv", "to_q"): to_q,
|
2887
|
+
diffusers_key.replace("qkv", "to_k"): to_k,
|
2888
|
+
diffusers_key.replace("qkv", "to_v"): to_v,
|
2889
|
+
}
|
2890
|
+
|
2891
|
+
for key in keys:
|
2892
|
+
diffusers_key = key
|
2893
|
+
for k, v in CONTEXT_REFINER_MAP.items():
|
2894
|
+
diffusers_key = diffusers_key.replace(k, v)
|
2895
|
+
for k, v in FINAL_LAYER_MAP.items():
|
2896
|
+
diffusers_key = diffusers_key.replace(k, v)
|
2897
|
+
for k, v in ATTENTION_NORM_MAP.items():
|
2898
|
+
diffusers_key = diffusers_key.replace(k, v)
|
2899
|
+
for k, v in LUMINA_KEY_MAP.items():
|
2900
|
+
diffusers_key = diffusers_key.replace(k, v)
|
2901
|
+
|
2902
|
+
if "qkv" in diffusers_key:
|
2903
|
+
converted_state_dict.update(convert_lumina_attn_to_diffusers(checkpoint.pop(key), diffusers_key))
|
2904
|
+
else:
|
2905
|
+
converted_state_dict[diffusers_key] = checkpoint.pop(key)
|
2906
|
+
|
2907
|
+
return converted_state_dict
|
2908
|
+
|
2909
|
+
|
2910
|
+
def convert_sana_transformer_to_diffusers(checkpoint, **kwargs):
|
2911
|
+
converted_state_dict = {}
|
2912
|
+
keys = list(checkpoint.keys())
|
2913
|
+
for k in keys:
|
2914
|
+
if "model.diffusion_model." in k:
|
2915
|
+
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
|
2916
|
+
|
2917
|
+
num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "blocks" in k))[-1] + 1 # noqa: C401
|
2918
|
+
|
2919
|
+
# Positional and patch embeddings.
|
2920
|
+
checkpoint.pop("pos_embed")
|
2921
|
+
converted_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
|
2922
|
+
converted_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
|
2923
|
+
|
2924
|
+
# Timestep embeddings.
|
2925
|
+
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = checkpoint.pop(
|
2926
|
+
"t_embedder.mlp.0.weight"
|
2927
|
+
)
|
2928
|
+
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
|
2929
|
+
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = checkpoint.pop(
|
2930
|
+
"t_embedder.mlp.2.weight"
|
2931
|
+
)
|
2932
|
+
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
|
2933
|
+
converted_state_dict["time_embed.linear.weight"] = checkpoint.pop("t_block.1.weight")
|
2934
|
+
converted_state_dict["time_embed.linear.bias"] = checkpoint.pop("t_block.1.bias")
|
2935
|
+
|
2936
|
+
# Caption Projection.
|
2937
|
+
checkpoint.pop("y_embedder.y_embedding")
|
2938
|
+
converted_state_dict["caption_projection.linear_1.weight"] = checkpoint.pop("y_embedder.y_proj.fc1.weight")
|
2939
|
+
converted_state_dict["caption_projection.linear_1.bias"] = checkpoint.pop("y_embedder.y_proj.fc1.bias")
|
2940
|
+
converted_state_dict["caption_projection.linear_2.weight"] = checkpoint.pop("y_embedder.y_proj.fc2.weight")
|
2941
|
+
converted_state_dict["caption_projection.linear_2.bias"] = checkpoint.pop("y_embedder.y_proj.fc2.bias")
|
2942
|
+
converted_state_dict["caption_norm.weight"] = checkpoint.pop("attention_y_norm.weight")
|
2943
|
+
|
2944
|
+
for i in range(num_layers):
|
2945
|
+
converted_state_dict[f"transformer_blocks.{i}.scale_shift_table"] = checkpoint.pop(
|
2946
|
+
f"blocks.{i}.scale_shift_table"
|
2947
|
+
)
|
2948
|
+
|
2949
|
+
# Self-Attention
|
2950
|
+
sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"blocks.{i}.attn.qkv.weight"), 3, dim=0)
|
2951
|
+
converted_state_dict[f"transformer_blocks.{i}.attn1.to_q.weight"] = torch.cat([sample_q])
|
2952
|
+
converted_state_dict[f"transformer_blocks.{i}.attn1.to_k.weight"] = torch.cat([sample_k])
|
2953
|
+
converted_state_dict[f"transformer_blocks.{i}.attn1.to_v.weight"] = torch.cat([sample_v])
|
2954
|
+
|
2955
|
+
# Output Projections
|
2956
|
+
converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.weight"] = checkpoint.pop(
|
2957
|
+
f"blocks.{i}.attn.proj.weight"
|
2958
|
+
)
|
2959
|
+
converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.bias"] = checkpoint.pop(
|
2960
|
+
f"blocks.{i}.attn.proj.bias"
|
2961
|
+
)
|
2962
|
+
|
2963
|
+
# Cross-Attention
|
2964
|
+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = checkpoint.pop(
|
2965
|
+
f"blocks.{i}.cross_attn.q_linear.weight"
|
2966
|
+
)
|
2967
|
+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = checkpoint.pop(
|
2968
|
+
f"blocks.{i}.cross_attn.q_linear.bias"
|
2969
|
+
)
|
2970
|
+
|
2971
|
+
linear_sample_k, linear_sample_v = torch.chunk(
|
2972
|
+
checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.weight"), 2, dim=0
|
2973
|
+
)
|
2974
|
+
linear_sample_k_bias, linear_sample_v_bias = torch.chunk(
|
2975
|
+
checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.bias"), 2, dim=0
|
2976
|
+
)
|
2977
|
+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.weight"] = linear_sample_k
|
2978
|
+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.weight"] = linear_sample_v
|
2979
|
+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.bias"] = linear_sample_k_bias
|
2980
|
+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.bias"] = linear_sample_v_bias
|
2981
|
+
|
2982
|
+
# Output Projections
|
2983
|
+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = checkpoint.pop(
|
2984
|
+
f"blocks.{i}.cross_attn.proj.weight"
|
2985
|
+
)
|
2986
|
+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = checkpoint.pop(
|
2987
|
+
f"blocks.{i}.cross_attn.proj.bias"
|
2988
|
+
)
|
2989
|
+
|
2990
|
+
# MLP
|
2991
|
+
converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.weight"] = checkpoint.pop(
|
2992
|
+
f"blocks.{i}.mlp.inverted_conv.conv.weight"
|
2993
|
+
)
|
2994
|
+
converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.bias"] = checkpoint.pop(
|
2995
|
+
f"blocks.{i}.mlp.inverted_conv.conv.bias"
|
2996
|
+
)
|
2997
|
+
converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.weight"] = checkpoint.pop(
|
2998
|
+
f"blocks.{i}.mlp.depth_conv.conv.weight"
|
2999
|
+
)
|
3000
|
+
converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.bias"] = checkpoint.pop(
|
3001
|
+
f"blocks.{i}.mlp.depth_conv.conv.bias"
|
3002
|
+
)
|
3003
|
+
converted_state_dict[f"transformer_blocks.{i}.ff.conv_point.weight"] = checkpoint.pop(
|
3004
|
+
f"blocks.{i}.mlp.point_conv.conv.weight"
|
3005
|
+
)
|
3006
|
+
|
3007
|
+
# Final layer
|
3008
|
+
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
|
3009
|
+
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
|
3010
|
+
converted_state_dict["scale_shift_table"] = checkpoint.pop("final_layer.scale_shift_table")
|
3011
|
+
|
3012
|
+
return converted_state_dict
|
3013
|
+
|
3014
|
+
|
3015
|
+
def convert_wan_transformer_to_diffusers(checkpoint, **kwargs):
|
3016
|
+
converted_state_dict = {}
|
3017
|
+
|
3018
|
+
keys = list(checkpoint.keys())
|
3019
|
+
for k in keys:
|
3020
|
+
if "model.diffusion_model." in k:
|
3021
|
+
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
|
3022
|
+
|
3023
|
+
TRANSFORMER_KEYS_RENAME_DICT = {
|
3024
|
+
"time_embedding.0": "condition_embedder.time_embedder.linear_1",
|
3025
|
+
"time_embedding.2": "condition_embedder.time_embedder.linear_2",
|
3026
|
+
"text_embedding.0": "condition_embedder.text_embedder.linear_1",
|
3027
|
+
"text_embedding.2": "condition_embedder.text_embedder.linear_2",
|
3028
|
+
"time_projection.1": "condition_embedder.time_proj",
|
3029
|
+
"cross_attn": "attn2",
|
3030
|
+
"self_attn": "attn1",
|
3031
|
+
".o.": ".to_out.0.",
|
3032
|
+
".q.": ".to_q.",
|
3033
|
+
".k.": ".to_k.",
|
3034
|
+
".v.": ".to_v.",
|
3035
|
+
".k_img.": ".add_k_proj.",
|
3036
|
+
".v_img.": ".add_v_proj.",
|
3037
|
+
".norm_k_img.": ".norm_added_k.",
|
3038
|
+
"head.modulation": "scale_shift_table",
|
3039
|
+
"head.head": "proj_out",
|
3040
|
+
"modulation": "scale_shift_table",
|
3041
|
+
"ffn.0": "ffn.net.0.proj",
|
3042
|
+
"ffn.2": "ffn.net.2",
|
3043
|
+
# Hack to swap the layer names
|
3044
|
+
# The original model calls the norms in following order: norm1, norm3, norm2
|
3045
|
+
# We convert it to: norm1, norm2, norm3
|
3046
|
+
"norm2": "norm__placeholder",
|
3047
|
+
"norm3": "norm2",
|
3048
|
+
"norm__placeholder": "norm3",
|
3049
|
+
# For the I2V model
|
3050
|
+
"img_emb.proj.0": "condition_embedder.image_embedder.norm1",
|
3051
|
+
"img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
|
3052
|
+
"img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
|
3053
|
+
"img_emb.proj.4": "condition_embedder.image_embedder.norm2",
|
3054
|
+
}
|
3055
|
+
|
3056
|
+
for key in list(checkpoint.keys()):
|
3057
|
+
new_key = key[:]
|
3058
|
+
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
3059
|
+
new_key = new_key.replace(replace_key, rename_key)
|
3060
|
+
|
3061
|
+
converted_state_dict[new_key] = checkpoint.pop(key)
|
3062
|
+
|
3063
|
+
return converted_state_dict
|
3064
|
+
|
3065
|
+
|
3066
|
+
def convert_wan_vae_to_diffusers(checkpoint, **kwargs):
|
3067
|
+
converted_state_dict = {}
|
3068
|
+
|
3069
|
+
# Create mappings for specific components
|
3070
|
+
middle_key_mapping = {
|
3071
|
+
# Encoder middle block
|
3072
|
+
"encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma",
|
3073
|
+
"encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias",
|
3074
|
+
"encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight",
|
3075
|
+
"encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma",
|
3076
|
+
"encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias",
|
3077
|
+
"encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight",
|
3078
|
+
"encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma",
|
3079
|
+
"encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias",
|
3080
|
+
"encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight",
|
3081
|
+
"encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma",
|
3082
|
+
"encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias",
|
3083
|
+
"encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight",
|
3084
|
+
# Decoder middle block
|
3085
|
+
"decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma",
|
3086
|
+
"decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias",
|
3087
|
+
"decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight",
|
3088
|
+
"decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma",
|
3089
|
+
"decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias",
|
3090
|
+
"decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight",
|
3091
|
+
"decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma",
|
3092
|
+
"decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias",
|
3093
|
+
"decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight",
|
3094
|
+
"decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma",
|
3095
|
+
"decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias",
|
3096
|
+
"decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight",
|
3097
|
+
}
|
3098
|
+
|
3099
|
+
# Create a mapping for attention blocks
|
3100
|
+
attention_mapping = {
|
3101
|
+
# Encoder middle attention
|
3102
|
+
"encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma",
|
3103
|
+
"encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight",
|
3104
|
+
"encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias",
|
3105
|
+
"encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight",
|
3106
|
+
"encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias",
|
3107
|
+
# Decoder middle attention
|
3108
|
+
"decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma",
|
3109
|
+
"decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight",
|
3110
|
+
"decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias",
|
3111
|
+
"decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight",
|
3112
|
+
"decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias",
|
3113
|
+
}
|
3114
|
+
|
3115
|
+
# Create a mapping for the head components
|
3116
|
+
head_mapping = {
|
3117
|
+
# Encoder head
|
3118
|
+
"encoder.head.0.gamma": "encoder.norm_out.gamma",
|
3119
|
+
"encoder.head.2.bias": "encoder.conv_out.bias",
|
3120
|
+
"encoder.head.2.weight": "encoder.conv_out.weight",
|
3121
|
+
# Decoder head
|
3122
|
+
"decoder.head.0.gamma": "decoder.norm_out.gamma",
|
3123
|
+
"decoder.head.2.bias": "decoder.conv_out.bias",
|
3124
|
+
"decoder.head.2.weight": "decoder.conv_out.weight",
|
3125
|
+
}
|
3126
|
+
|
3127
|
+
# Create a mapping for the quant components
|
3128
|
+
quant_mapping = {
|
3129
|
+
"conv1.weight": "quant_conv.weight",
|
3130
|
+
"conv1.bias": "quant_conv.bias",
|
3131
|
+
"conv2.weight": "post_quant_conv.weight",
|
3132
|
+
"conv2.bias": "post_quant_conv.bias",
|
3133
|
+
}
|
3134
|
+
|
3135
|
+
# Process each key in the state dict
|
3136
|
+
for key, value in checkpoint.items():
|
3137
|
+
# Handle middle block keys using the mapping
|
3138
|
+
if key in middle_key_mapping:
|
3139
|
+
new_key = middle_key_mapping[key]
|
3140
|
+
converted_state_dict[new_key] = value
|
3141
|
+
# Handle attention blocks using the mapping
|
3142
|
+
elif key in attention_mapping:
|
3143
|
+
new_key = attention_mapping[key]
|
3144
|
+
converted_state_dict[new_key] = value
|
3145
|
+
# Handle head keys using the mapping
|
3146
|
+
elif key in head_mapping:
|
3147
|
+
new_key = head_mapping[key]
|
3148
|
+
converted_state_dict[new_key] = value
|
3149
|
+
# Handle quant keys using the mapping
|
3150
|
+
elif key in quant_mapping:
|
3151
|
+
new_key = quant_mapping[key]
|
3152
|
+
converted_state_dict[new_key] = value
|
3153
|
+
# Handle encoder conv1
|
3154
|
+
elif key == "encoder.conv1.weight":
|
3155
|
+
converted_state_dict["encoder.conv_in.weight"] = value
|
3156
|
+
elif key == "encoder.conv1.bias":
|
3157
|
+
converted_state_dict["encoder.conv_in.bias"] = value
|
3158
|
+
# Handle decoder conv1
|
3159
|
+
elif key == "decoder.conv1.weight":
|
3160
|
+
converted_state_dict["decoder.conv_in.weight"] = value
|
3161
|
+
elif key == "decoder.conv1.bias":
|
3162
|
+
converted_state_dict["decoder.conv_in.bias"] = value
|
3163
|
+
# Handle encoder downsamples
|
3164
|
+
elif key.startswith("encoder.downsamples."):
|
3165
|
+
# Convert to down_blocks
|
3166
|
+
new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.")
|
3167
|
+
|
3168
|
+
# Convert residual block naming but keep the original structure
|
3169
|
+
if ".residual.0.gamma" in new_key:
|
3170
|
+
new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma")
|
3171
|
+
elif ".residual.2.bias" in new_key:
|
3172
|
+
new_key = new_key.replace(".residual.2.bias", ".conv1.bias")
|
3173
|
+
elif ".residual.2.weight" in new_key:
|
3174
|
+
new_key = new_key.replace(".residual.2.weight", ".conv1.weight")
|
3175
|
+
elif ".residual.3.gamma" in new_key:
|
3176
|
+
new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma")
|
3177
|
+
elif ".residual.6.bias" in new_key:
|
3178
|
+
new_key = new_key.replace(".residual.6.bias", ".conv2.bias")
|
3179
|
+
elif ".residual.6.weight" in new_key:
|
3180
|
+
new_key = new_key.replace(".residual.6.weight", ".conv2.weight")
|
3181
|
+
elif ".shortcut.bias" in new_key:
|
3182
|
+
new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias")
|
3183
|
+
elif ".shortcut.weight" in new_key:
|
3184
|
+
new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight")
|
3185
|
+
|
3186
|
+
converted_state_dict[new_key] = value
|
3187
|
+
|
3188
|
+
# Handle decoder upsamples
|
3189
|
+
elif key.startswith("decoder.upsamples."):
|
3190
|
+
# Convert to up_blocks
|
3191
|
+
parts = key.split(".")
|
3192
|
+
block_idx = int(parts[2])
|
3193
|
+
|
3194
|
+
# Group residual blocks
|
3195
|
+
if "residual" in key:
|
3196
|
+
if block_idx in [0, 1, 2]:
|
3197
|
+
new_block_idx = 0
|
3198
|
+
resnet_idx = block_idx
|
3199
|
+
elif block_idx in [4, 5, 6]:
|
3200
|
+
new_block_idx = 1
|
3201
|
+
resnet_idx = block_idx - 4
|
3202
|
+
elif block_idx in [8, 9, 10]:
|
3203
|
+
new_block_idx = 2
|
3204
|
+
resnet_idx = block_idx - 8
|
3205
|
+
elif block_idx in [12, 13, 14]:
|
3206
|
+
new_block_idx = 3
|
3207
|
+
resnet_idx = block_idx - 12
|
3208
|
+
else:
|
3209
|
+
# Keep as is for other blocks
|
3210
|
+
converted_state_dict[key] = value
|
3211
|
+
continue
|
3212
|
+
|
3213
|
+
# Convert residual block naming
|
3214
|
+
if ".residual.0.gamma" in key:
|
3215
|
+
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm1.gamma"
|
3216
|
+
elif ".residual.2.bias" in key:
|
3217
|
+
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.bias"
|
3218
|
+
elif ".residual.2.weight" in key:
|
3219
|
+
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.weight"
|
3220
|
+
elif ".residual.3.gamma" in key:
|
3221
|
+
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm2.gamma"
|
3222
|
+
elif ".residual.6.bias" in key:
|
3223
|
+
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.bias"
|
3224
|
+
elif ".residual.6.weight" in key:
|
3225
|
+
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.weight"
|
3226
|
+
else:
|
3227
|
+
new_key = key
|
3228
|
+
|
3229
|
+
converted_state_dict[new_key] = value
|
3230
|
+
|
3231
|
+
# Handle shortcut connections
|
3232
|
+
elif ".shortcut." in key:
|
3233
|
+
if block_idx == 4:
|
3234
|
+
new_key = key.replace(".shortcut.", ".resnets.0.conv_shortcut.")
|
3235
|
+
new_key = new_key.replace("decoder.upsamples.4", "decoder.up_blocks.1")
|
3236
|
+
else:
|
3237
|
+
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
|
3238
|
+
new_key = new_key.replace(".shortcut.", ".conv_shortcut.")
|
3239
|
+
|
3240
|
+
converted_state_dict[new_key] = value
|
3241
|
+
|
3242
|
+
# Handle upsamplers
|
3243
|
+
elif ".resample." in key or ".time_conv." in key:
|
3244
|
+
if block_idx == 3:
|
3245
|
+
new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.0.upsamplers.0")
|
3246
|
+
elif block_idx == 7:
|
3247
|
+
new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.1.upsamplers.0")
|
3248
|
+
elif block_idx == 11:
|
3249
|
+
new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.2.upsamplers.0")
|
3250
|
+
else:
|
3251
|
+
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
|
3252
|
+
|
3253
|
+
converted_state_dict[new_key] = value
|
3254
|
+
else:
|
3255
|
+
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
|
3256
|
+
converted_state_dict[new_key] = value
|
3257
|
+
else:
|
3258
|
+
# Keep other keys unchanged
|
3259
|
+
converted_state_dict[key] = value
|
3260
|
+
|
3261
|
+
return converted_state_dict
|