diffusers 0.32.1__py3-none-any.whl → 0.33.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +186 -3
- diffusers/configuration_utils.py +40 -12
- diffusers/dependency_versions_table.py +9 -2
- diffusers/hooks/__init__.py +9 -0
- diffusers/hooks/faster_cache.py +653 -0
- diffusers/hooks/group_offloading.py +793 -0
- diffusers/hooks/hooks.py +236 -0
- diffusers/hooks/layerwise_casting.py +245 -0
- diffusers/hooks/pyramid_attention_broadcast.py +311 -0
- diffusers/loaders/__init__.py +6 -0
- diffusers/loaders/ip_adapter.py +38 -30
- diffusers/loaders/lora_base.py +198 -28
- diffusers/loaders/lora_conversion_utils.py +679 -44
- diffusers/loaders/lora_pipeline.py +1963 -801
- diffusers/loaders/peft.py +169 -84
- diffusers/loaders/single_file.py +17 -2
- diffusers/loaders/single_file_model.py +53 -5
- diffusers/loaders/single_file_utils.py +653 -75
- diffusers/loaders/textual_inversion.py +9 -9
- diffusers/loaders/transformer_flux.py +8 -9
- diffusers/loaders/transformer_sd3.py +120 -39
- diffusers/loaders/unet.py +22 -32
- diffusers/models/__init__.py +22 -0
- diffusers/models/activations.py +9 -9
- diffusers/models/attention.py +0 -1
- diffusers/models/attention_processor.py +163 -25
- diffusers/models/auto_model.py +169 -0
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
- diffusers/models/autoencoders/autoencoder_dc.py +106 -4
- diffusers/models/autoencoders/autoencoder_kl.py +0 -4
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
- diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
- diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
- diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
- diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
- diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
- diffusers/models/autoencoders/vae.py +31 -141
- diffusers/models/autoencoders/vq_model.py +3 -0
- diffusers/models/cache_utils.py +108 -0
- diffusers/models/controlnets/__init__.py +1 -0
- diffusers/models/controlnets/controlnet.py +3 -8
- diffusers/models/controlnets/controlnet_flux.py +14 -42
- diffusers/models/controlnets/controlnet_sd3.py +58 -34
- diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
- diffusers/models/controlnets/controlnet_union.py +27 -18
- diffusers/models/controlnets/controlnet_xs.py +7 -46
- diffusers/models/controlnets/multicontrolnet_union.py +196 -0
- diffusers/models/embeddings.py +18 -7
- diffusers/models/model_loading_utils.py +122 -80
- diffusers/models/modeling_flax_pytorch_utils.py +1 -1
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +617 -272
- diffusers/models/normalization.py +67 -14
- diffusers/models/resnet.py +1 -1
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
- diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
- diffusers/models/transformers/consisid_transformer_3d.py +789 -0
- diffusers/models/transformers/dit_transformer_2d.py +5 -19
- diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
- diffusers/models/transformers/latte_transformer_3d.py +20 -15
- diffusers/models/transformers/lumina_nextdit2d.py +3 -1
- diffusers/models/transformers/pixart_transformer_2d.py +4 -19
- diffusers/models/transformers/prior_transformer.py +5 -1
- diffusers/models/transformers/sana_transformer.py +144 -40
- diffusers/models/transformers/stable_audio_transformer.py +5 -20
- diffusers/models/transformers/transformer_2d.py +7 -22
- diffusers/models/transformers/transformer_allegro.py +9 -17
- diffusers/models/transformers/transformer_cogview3plus.py +6 -17
- diffusers/models/transformers/transformer_cogview4.py +462 -0
- diffusers/models/transformers/transformer_easyanimate.py +527 -0
- diffusers/models/transformers/transformer_flux.py +68 -110
- diffusers/models/transformers/transformer_hunyuan_video.py +409 -49
- diffusers/models/transformers/transformer_ltx.py +53 -35
- diffusers/models/transformers/transformer_lumina2.py +548 -0
- diffusers/models/transformers/transformer_mochi.py +6 -17
- diffusers/models/transformers/transformer_omnigen.py +469 -0
- diffusers/models/transformers/transformer_sd3.py +56 -86
- diffusers/models/transformers/transformer_temporal.py +5 -11
- diffusers/models/transformers/transformer_wan.py +469 -0
- diffusers/models/unets/unet_1d.py +3 -1
- diffusers/models/unets/unet_2d.py +21 -20
- diffusers/models/unets/unet_2d_blocks.py +19 -243
- diffusers/models/unets/unet_2d_condition.py +4 -6
- diffusers/models/unets/unet_3d_blocks.py +14 -127
- diffusers/models/unets/unet_3d_condition.py +8 -12
- diffusers/models/unets/unet_i2vgen_xl.py +5 -13
- diffusers/models/unets/unet_kandinsky3.py +0 -4
- diffusers/models/unets/unet_motion_model.py +20 -114
- diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
- diffusers/models/unets/unet_stable_cascade.py +8 -35
- diffusers/models/unets/uvit_2d.py +1 -4
- diffusers/optimization.py +2 -2
- diffusers/pipelines/__init__.py +57 -8
- diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
- diffusers/pipelines/amused/pipeline_amused.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
- diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
- diffusers/pipelines/auto_pipeline.py +35 -14
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
- diffusers/pipelines/cogview4/__init__.py +49 -0
- diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
- diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
- diffusers/pipelines/cogview4/pipeline_output.py +21 -0
- diffusers/pipelines/consisid/__init__.py +49 -0
- diffusers/pipelines/consisid/consisid_utils.py +357 -0
- diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
- diffusers/pipelines/consisid/pipeline_output.py +20 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
- diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
- diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +15 -2
- diffusers/pipelines/easyanimate/__init__.py +52 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
- diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
- diffusers/pipelines/flux/pipeline_flux.py +53 -21
- diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
- diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
- diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
- diffusers/pipelines/free_noise_utils.py +3 -3
- diffusers/pipelines/hunyuan_video/__init__.py +4 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
- diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
- diffusers/pipelines/kolors/text_encoder.py +7 -34
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
- diffusers/pipelines/latte/pipeline_latte.py +36 -7
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
- diffusers/pipelines/ltx/__init__.py +2 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
- diffusers/pipelines/lumina/__init__.py +2 -2
- diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
- diffusers/pipelines/lumina2/__init__.py +48 -0
- diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
- diffusers/pipelines/marigold/__init__.py +2 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
- diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
- diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
- diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
- diffusers/pipelines/omnigen/__init__.py +50 -0
- diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
- diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
- diffusers/pipelines/onnx_utils.py +5 -3
- diffusers/pipelines/pag/pag_utils.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
- diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
- diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
- diffusers/pipelines/pia/pipeline_pia.py +13 -1
- diffusers/pipelines/pipeline_flax_utils.py +7 -7
- diffusers/pipelines/pipeline_loading_utils.py +193 -83
- diffusers/pipelines/pipeline_utils.py +221 -106
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
- diffusers/pipelines/sana/__init__.py +2 -0
- diffusers/pipelines/sana/pipeline_sana.py +183 -58
- diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
- diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
- diffusers/pipelines/shap_e/renderer.py +6 -6
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
- diffusers/pipelines/transformers_loading_utils.py +121 -0
- diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
- diffusers/pipelines/wan/__init__.py +51 -0
- diffusers/pipelines/wan/pipeline_output.py +20 -0
- diffusers/pipelines/wan/pipeline_wan.py +593 -0
- diffusers/pipelines/wan/pipeline_wan_i2v.py +722 -0
- diffusers/pipelines/wan/pipeline_wan_video2video.py +725 -0
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
- diffusers/quantizers/auto.py +5 -1
- diffusers/quantizers/base.py +5 -9
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
- diffusers/quantizers/bitsandbytes/utils.py +30 -20
- diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
- diffusers/quantizers/gguf/utils.py +4 -2
- diffusers/quantizers/quantization_config.py +59 -4
- diffusers/quantizers/quanto/__init__.py +1 -0
- diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
- diffusers/quantizers/quanto/utils.py +60 -0
- diffusers/quantizers/torchao/__init__.py +1 -1
- diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
- diffusers/schedulers/__init__.py +2 -1
- diffusers/schedulers/scheduling_consistency_models.py +1 -2
- diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
- diffusers/schedulers/scheduling_ddpm.py +2 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
- diffusers/schedulers/scheduling_edm_euler.py +45 -10
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
- diffusers/schedulers/scheduling_heun_discrete.py +1 -1
- diffusers/schedulers/scheduling_lcm.py +1 -2
- diffusers/schedulers/scheduling_lms_discrete.py +1 -1
- diffusers/schedulers/scheduling_repaint.py +5 -1
- diffusers/schedulers/scheduling_scm.py +265 -0
- diffusers/schedulers/scheduling_tcd.py +1 -2
- diffusers/schedulers/scheduling_utils.py +2 -1
- diffusers/training_utils.py +14 -7
- diffusers/utils/__init__.py +10 -2
- diffusers/utils/constants.py +13 -1
- diffusers/utils/deprecation_utils.py +1 -1
- diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
- diffusers/utils/dummy_gguf_objects.py +17 -0
- diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
- diffusers/utils/dummy_pt_objects.py +233 -0
- diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
- diffusers/utils/dummy_torchao_objects.py +17 -0
- diffusers/utils/dynamic_modules_utils.py +1 -1
- diffusers/utils/export_utils.py +28 -3
- diffusers/utils/hub_utils.py +52 -102
- diffusers/utils/import_utils.py +121 -221
- diffusers/utils/loading_utils.py +14 -1
- diffusers/utils/logging.py +1 -2
- diffusers/utils/peft_utils.py +6 -14
- diffusers/utils/remote_utils.py +425 -0
- diffusers/utils/source_code_parsing_utils.py +52 -0
- diffusers/utils/state_dict_utils.py +15 -1
- diffusers/utils/testing_utils.py +243 -13
- diffusers/utils/torch_utils.py +10 -0
- diffusers/utils/typing_utils.py +91 -0
- diffusers/video_processor.py +1 -1
- {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/METADATA +76 -44
- diffusers-0.33.0.dist-info/RECORD +608 -0
- {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/WHEEL +1 -1
- diffusers-0.32.1.dist-info/RECORD +0 -550
- {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/LICENSE +0 -0
- {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/top_level.txt +0 -0
@@ -22,17 +22,20 @@ from diffusers.loaders import FromOriginalModelMixin
|
|
22
22
|
|
23
23
|
from ...configuration_utils import ConfigMixin, register_to_config
|
24
24
|
from ...loaders import PeftAdapterMixin
|
25
|
-
from ...utils import USE_PEFT_BACKEND,
|
25
|
+
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
26
26
|
from ..attention import FeedForward
|
27
27
|
from ..attention_processor import Attention, AttentionProcessor
|
28
|
+
from ..cache_utils import CacheMixin
|
28
29
|
from ..embeddings import (
|
29
|
-
CombinedTimestepGuidanceTextProjEmbeddings,
|
30
30
|
CombinedTimestepTextProjEmbeddings,
|
31
|
+
PixArtAlphaTextProjection,
|
32
|
+
TimestepEmbedding,
|
33
|
+
Timesteps,
|
31
34
|
get_1d_rotary_pos_embed,
|
32
35
|
)
|
33
36
|
from ..modeling_outputs import Transformer2DModelOutput
|
34
37
|
from ..modeling_utils import ModelMixin
|
35
|
-
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
|
38
|
+
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, FP32LayerNorm
|
36
39
|
|
37
40
|
|
38
41
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
@@ -172,6 +175,141 @@ class HunyuanVideoAdaNorm(nn.Module):
|
|
172
175
|
return gate_msa, gate_mlp
|
173
176
|
|
174
177
|
|
178
|
+
class HunyuanVideoTokenReplaceAdaLayerNormZero(nn.Module):
|
179
|
+
def __init__(self, embedding_dim: int, norm_type: str = "layer_norm", bias: bool = True):
|
180
|
+
super().__init__()
|
181
|
+
|
182
|
+
self.silu = nn.SiLU()
|
183
|
+
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
|
184
|
+
|
185
|
+
if norm_type == "layer_norm":
|
186
|
+
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
187
|
+
elif norm_type == "fp32_layer_norm":
|
188
|
+
self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False)
|
189
|
+
else:
|
190
|
+
raise ValueError(
|
191
|
+
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
|
192
|
+
)
|
193
|
+
|
194
|
+
def forward(
|
195
|
+
self,
|
196
|
+
hidden_states: torch.Tensor,
|
197
|
+
emb: torch.Tensor,
|
198
|
+
token_replace_emb: torch.Tensor,
|
199
|
+
first_frame_num_tokens: int,
|
200
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
201
|
+
emb = self.linear(self.silu(emb))
|
202
|
+
token_replace_emb = self.linear(self.silu(token_replace_emb))
|
203
|
+
|
204
|
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
|
205
|
+
tr_shift_msa, tr_scale_msa, tr_gate_msa, tr_shift_mlp, tr_scale_mlp, tr_gate_mlp = token_replace_emb.chunk(
|
206
|
+
6, dim=1
|
207
|
+
)
|
208
|
+
|
209
|
+
norm_hidden_states = self.norm(hidden_states)
|
210
|
+
hidden_states_zero = (
|
211
|
+
norm_hidden_states[:, :first_frame_num_tokens] * (1 + tr_scale_msa[:, None]) + tr_shift_msa[:, None]
|
212
|
+
)
|
213
|
+
hidden_states_orig = (
|
214
|
+
norm_hidden_states[:, first_frame_num_tokens:] * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
215
|
+
)
|
216
|
+
hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
|
217
|
+
|
218
|
+
return (
|
219
|
+
hidden_states,
|
220
|
+
gate_msa,
|
221
|
+
shift_mlp,
|
222
|
+
scale_mlp,
|
223
|
+
gate_mlp,
|
224
|
+
tr_gate_msa,
|
225
|
+
tr_shift_mlp,
|
226
|
+
tr_scale_mlp,
|
227
|
+
tr_gate_mlp,
|
228
|
+
)
|
229
|
+
|
230
|
+
|
231
|
+
class HunyuanVideoTokenReplaceAdaLayerNormZeroSingle(nn.Module):
|
232
|
+
def __init__(self, embedding_dim: int, norm_type: str = "layer_norm", bias: bool = True):
|
233
|
+
super().__init__()
|
234
|
+
|
235
|
+
self.silu = nn.SiLU()
|
236
|
+
self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
|
237
|
+
|
238
|
+
if norm_type == "layer_norm":
|
239
|
+
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
240
|
+
else:
|
241
|
+
raise ValueError(
|
242
|
+
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
|
243
|
+
)
|
244
|
+
|
245
|
+
def forward(
|
246
|
+
self,
|
247
|
+
hidden_states: torch.Tensor,
|
248
|
+
emb: torch.Tensor,
|
249
|
+
token_replace_emb: torch.Tensor,
|
250
|
+
first_frame_num_tokens: int,
|
251
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
252
|
+
emb = self.linear(self.silu(emb))
|
253
|
+
token_replace_emb = self.linear(self.silu(token_replace_emb))
|
254
|
+
|
255
|
+
shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
|
256
|
+
tr_shift_msa, tr_scale_msa, tr_gate_msa = token_replace_emb.chunk(3, dim=1)
|
257
|
+
|
258
|
+
norm_hidden_states = self.norm(hidden_states)
|
259
|
+
hidden_states_zero = (
|
260
|
+
norm_hidden_states[:, :first_frame_num_tokens] * (1 + tr_scale_msa[:, None]) + tr_shift_msa[:, None]
|
261
|
+
)
|
262
|
+
hidden_states_orig = (
|
263
|
+
norm_hidden_states[:, first_frame_num_tokens:] * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
264
|
+
)
|
265
|
+
hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
|
266
|
+
|
267
|
+
return hidden_states, gate_msa, tr_gate_msa
|
268
|
+
|
269
|
+
|
270
|
+
class HunyuanVideoConditionEmbedding(nn.Module):
|
271
|
+
def __init__(
|
272
|
+
self,
|
273
|
+
embedding_dim: int,
|
274
|
+
pooled_projection_dim: int,
|
275
|
+
guidance_embeds: bool,
|
276
|
+
image_condition_type: Optional[str] = None,
|
277
|
+
):
|
278
|
+
super().__init__()
|
279
|
+
|
280
|
+
self.image_condition_type = image_condition_type
|
281
|
+
|
282
|
+
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
283
|
+
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
284
|
+
self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
|
285
|
+
|
286
|
+
self.guidance_embedder = None
|
287
|
+
if guidance_embeds:
|
288
|
+
self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
289
|
+
|
290
|
+
def forward(
|
291
|
+
self, timestep: torch.Tensor, pooled_projection: torch.Tensor, guidance: Optional[torch.Tensor] = None
|
292
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
293
|
+
timesteps_proj = self.time_proj(timestep)
|
294
|
+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
|
295
|
+
pooled_projections = self.text_embedder(pooled_projection)
|
296
|
+
conditioning = timesteps_emb + pooled_projections
|
297
|
+
|
298
|
+
token_replace_emb = None
|
299
|
+
if self.image_condition_type == "token_replace":
|
300
|
+
token_replace_timestep = torch.zeros_like(timestep)
|
301
|
+
token_replace_proj = self.time_proj(token_replace_timestep)
|
302
|
+
token_replace_emb = self.timestep_embedder(token_replace_proj.to(dtype=pooled_projection.dtype))
|
303
|
+
token_replace_emb = token_replace_emb + pooled_projections
|
304
|
+
|
305
|
+
if self.guidance_embedder is not None:
|
306
|
+
guidance_proj = self.time_proj(guidance)
|
307
|
+
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype))
|
308
|
+
conditioning = conditioning + guidance_emb
|
309
|
+
|
310
|
+
return conditioning, token_replace_emb
|
311
|
+
|
312
|
+
|
175
313
|
class HunyuanVideoIndividualTokenRefinerBlock(nn.Module):
|
176
314
|
def __init__(
|
177
315
|
self,
|
@@ -389,6 +527,8 @@ class HunyuanVideoSingleTransformerBlock(nn.Module):
|
|
389
527
|
temb: torch.Tensor,
|
390
528
|
attention_mask: Optional[torch.Tensor] = None,
|
391
529
|
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
530
|
+
*args,
|
531
|
+
**kwargs,
|
392
532
|
) -> torch.Tensor:
|
393
533
|
text_seq_length = encoder_hidden_states.shape[1]
|
394
534
|
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
@@ -467,6 +607,8 @@ class HunyuanVideoTransformerBlock(nn.Module):
|
|
467
607
|
temb: torch.Tensor,
|
468
608
|
attention_mask: Optional[torch.Tensor] = None,
|
469
609
|
freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
610
|
+
*args,
|
611
|
+
**kwargs,
|
470
612
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
471
613
|
# 1. Input normalization
|
472
614
|
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
@@ -502,7 +644,182 @@ class HunyuanVideoTransformerBlock(nn.Module):
|
|
502
644
|
return hidden_states, encoder_hidden_states
|
503
645
|
|
504
646
|
|
505
|
-
class
|
647
|
+
class HunyuanVideoTokenReplaceSingleTransformerBlock(nn.Module):
|
648
|
+
def __init__(
|
649
|
+
self,
|
650
|
+
num_attention_heads: int,
|
651
|
+
attention_head_dim: int,
|
652
|
+
mlp_ratio: float = 4.0,
|
653
|
+
qk_norm: str = "rms_norm",
|
654
|
+
) -> None:
|
655
|
+
super().__init__()
|
656
|
+
|
657
|
+
hidden_size = num_attention_heads * attention_head_dim
|
658
|
+
mlp_dim = int(hidden_size * mlp_ratio)
|
659
|
+
|
660
|
+
self.attn = Attention(
|
661
|
+
query_dim=hidden_size,
|
662
|
+
cross_attention_dim=None,
|
663
|
+
dim_head=attention_head_dim,
|
664
|
+
heads=num_attention_heads,
|
665
|
+
out_dim=hidden_size,
|
666
|
+
bias=True,
|
667
|
+
processor=HunyuanVideoAttnProcessor2_0(),
|
668
|
+
qk_norm=qk_norm,
|
669
|
+
eps=1e-6,
|
670
|
+
pre_only=True,
|
671
|
+
)
|
672
|
+
|
673
|
+
self.norm = HunyuanVideoTokenReplaceAdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm")
|
674
|
+
self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
|
675
|
+
self.act_mlp = nn.GELU(approximate="tanh")
|
676
|
+
self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
|
677
|
+
|
678
|
+
def forward(
|
679
|
+
self,
|
680
|
+
hidden_states: torch.Tensor,
|
681
|
+
encoder_hidden_states: torch.Tensor,
|
682
|
+
temb: torch.Tensor,
|
683
|
+
attention_mask: Optional[torch.Tensor] = None,
|
684
|
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
685
|
+
token_replace_emb: torch.Tensor = None,
|
686
|
+
num_tokens: int = None,
|
687
|
+
) -> torch.Tensor:
|
688
|
+
text_seq_length = encoder_hidden_states.shape[1]
|
689
|
+
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
690
|
+
|
691
|
+
residual = hidden_states
|
692
|
+
|
693
|
+
# 1. Input normalization
|
694
|
+
norm_hidden_states, gate, tr_gate = self.norm(hidden_states, temb, token_replace_emb, num_tokens)
|
695
|
+
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
696
|
+
|
697
|
+
norm_hidden_states, norm_encoder_hidden_states = (
|
698
|
+
norm_hidden_states[:, :-text_seq_length, :],
|
699
|
+
norm_hidden_states[:, -text_seq_length:, :],
|
700
|
+
)
|
701
|
+
|
702
|
+
# 2. Attention
|
703
|
+
attn_output, context_attn_output = self.attn(
|
704
|
+
hidden_states=norm_hidden_states,
|
705
|
+
encoder_hidden_states=norm_encoder_hidden_states,
|
706
|
+
attention_mask=attention_mask,
|
707
|
+
image_rotary_emb=image_rotary_emb,
|
708
|
+
)
|
709
|
+
attn_output = torch.cat([attn_output, context_attn_output], dim=1)
|
710
|
+
|
711
|
+
# 3. Modulation and residual connection
|
712
|
+
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
713
|
+
|
714
|
+
proj_output = self.proj_out(hidden_states)
|
715
|
+
hidden_states_zero = proj_output[:, :num_tokens] * tr_gate.unsqueeze(1)
|
716
|
+
hidden_states_orig = proj_output[:, num_tokens:] * gate.unsqueeze(1)
|
717
|
+
hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
|
718
|
+
hidden_states = hidden_states + residual
|
719
|
+
|
720
|
+
hidden_states, encoder_hidden_states = (
|
721
|
+
hidden_states[:, :-text_seq_length, :],
|
722
|
+
hidden_states[:, -text_seq_length:, :],
|
723
|
+
)
|
724
|
+
return hidden_states, encoder_hidden_states
|
725
|
+
|
726
|
+
|
727
|
+
class HunyuanVideoTokenReplaceTransformerBlock(nn.Module):
|
728
|
+
def __init__(
|
729
|
+
self,
|
730
|
+
num_attention_heads: int,
|
731
|
+
attention_head_dim: int,
|
732
|
+
mlp_ratio: float,
|
733
|
+
qk_norm: str = "rms_norm",
|
734
|
+
) -> None:
|
735
|
+
super().__init__()
|
736
|
+
|
737
|
+
hidden_size = num_attention_heads * attention_head_dim
|
738
|
+
|
739
|
+
self.norm1 = HunyuanVideoTokenReplaceAdaLayerNormZero(hidden_size, norm_type="layer_norm")
|
740
|
+
self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
|
741
|
+
|
742
|
+
self.attn = Attention(
|
743
|
+
query_dim=hidden_size,
|
744
|
+
cross_attention_dim=None,
|
745
|
+
added_kv_proj_dim=hidden_size,
|
746
|
+
dim_head=attention_head_dim,
|
747
|
+
heads=num_attention_heads,
|
748
|
+
out_dim=hidden_size,
|
749
|
+
context_pre_only=False,
|
750
|
+
bias=True,
|
751
|
+
processor=HunyuanVideoAttnProcessor2_0(),
|
752
|
+
qk_norm=qk_norm,
|
753
|
+
eps=1e-6,
|
754
|
+
)
|
755
|
+
|
756
|
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
757
|
+
self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
|
758
|
+
|
759
|
+
self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
760
|
+
self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
|
761
|
+
|
762
|
+
def forward(
|
763
|
+
self,
|
764
|
+
hidden_states: torch.Tensor,
|
765
|
+
encoder_hidden_states: torch.Tensor,
|
766
|
+
temb: torch.Tensor,
|
767
|
+
attention_mask: Optional[torch.Tensor] = None,
|
768
|
+
freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
769
|
+
token_replace_emb: torch.Tensor = None,
|
770
|
+
num_tokens: int = None,
|
771
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
772
|
+
# 1. Input normalization
|
773
|
+
(
|
774
|
+
norm_hidden_states,
|
775
|
+
gate_msa,
|
776
|
+
shift_mlp,
|
777
|
+
scale_mlp,
|
778
|
+
gate_mlp,
|
779
|
+
tr_gate_msa,
|
780
|
+
tr_shift_mlp,
|
781
|
+
tr_scale_mlp,
|
782
|
+
tr_gate_mlp,
|
783
|
+
) = self.norm1(hidden_states, temb, token_replace_emb, num_tokens)
|
784
|
+
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
785
|
+
encoder_hidden_states, emb=temb
|
786
|
+
)
|
787
|
+
|
788
|
+
# 2. Joint attention
|
789
|
+
attn_output, context_attn_output = self.attn(
|
790
|
+
hidden_states=norm_hidden_states,
|
791
|
+
encoder_hidden_states=norm_encoder_hidden_states,
|
792
|
+
attention_mask=attention_mask,
|
793
|
+
image_rotary_emb=freqs_cis,
|
794
|
+
)
|
795
|
+
|
796
|
+
# 3. Modulation and residual connection
|
797
|
+
hidden_states_zero = hidden_states[:, :num_tokens] + attn_output[:, :num_tokens] * tr_gate_msa.unsqueeze(1)
|
798
|
+
hidden_states_orig = hidden_states[:, num_tokens:] + attn_output[:, num_tokens:] * gate_msa.unsqueeze(1)
|
799
|
+
hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
|
800
|
+
encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1)
|
801
|
+
|
802
|
+
norm_hidden_states = self.norm2(hidden_states)
|
803
|
+
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
804
|
+
|
805
|
+
hidden_states_zero = norm_hidden_states[:, :num_tokens] * (1 + tr_scale_mlp[:, None]) + tr_shift_mlp[:, None]
|
806
|
+
hidden_states_orig = norm_hidden_states[:, num_tokens:] * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
807
|
+
norm_hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
|
808
|
+
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
809
|
+
|
810
|
+
# 4. Feed-forward
|
811
|
+
ff_output = self.ff(norm_hidden_states)
|
812
|
+
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
813
|
+
|
814
|
+
hidden_states_zero = hidden_states[:, :num_tokens] + ff_output[:, :num_tokens] * tr_gate_mlp.unsqueeze(1)
|
815
|
+
hidden_states_orig = hidden_states[:, num_tokens:] + ff_output[:, num_tokens:] * gate_mlp.unsqueeze(1)
|
816
|
+
hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
|
817
|
+
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
818
|
+
|
819
|
+
return hidden_states, encoder_hidden_states
|
820
|
+
|
821
|
+
|
822
|
+
class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
|
506
823
|
r"""
|
507
824
|
A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo).
|
508
825
|
|
@@ -539,9 +856,20 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
|
539
856
|
The value of theta to use in the RoPE layer.
|
540
857
|
rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`):
|
541
858
|
The dimensions of the axes to use in the RoPE layer.
|
859
|
+
image_condition_type (`str`, *optional*, defaults to `None`):
|
860
|
+
The type of image conditioning to use. If `None`, no image conditioning is used. If `latent_concat`, the
|
861
|
+
image is concatenated to the latent stream. If `token_replace`, the image is used to replace first-frame
|
862
|
+
tokens in the latent stream and apply conditioning.
|
542
863
|
"""
|
543
864
|
|
544
865
|
_supports_gradient_checkpointing = True
|
866
|
+
_skip_layerwise_casting_patterns = ["x_embedder", "context_embedder", "norm"]
|
867
|
+
_no_split_modules = [
|
868
|
+
"HunyuanVideoTransformerBlock",
|
869
|
+
"HunyuanVideoSingleTransformerBlock",
|
870
|
+
"HunyuanVideoPatchEmbed",
|
871
|
+
"HunyuanVideoTokenRefiner",
|
872
|
+
]
|
545
873
|
|
546
874
|
@register_to_config
|
547
875
|
def __init__(
|
@@ -562,9 +890,16 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
|
562
890
|
pooled_projection_dim: int = 768,
|
563
891
|
rope_theta: float = 256.0,
|
564
892
|
rope_axes_dim: Tuple[int] = (16, 56, 56),
|
893
|
+
image_condition_type: Optional[str] = None,
|
565
894
|
) -> None:
|
566
895
|
super().__init__()
|
567
896
|
|
897
|
+
supported_image_condition_types = ["latent_concat", "token_replace"]
|
898
|
+
if image_condition_type is not None and image_condition_type not in supported_image_condition_types:
|
899
|
+
raise ValueError(
|
900
|
+
f"Invalid `image_condition_type` ({image_condition_type}). Supported ones are: {supported_image_condition_types}"
|
901
|
+
)
|
902
|
+
|
568
903
|
inner_dim = num_attention_heads * attention_head_dim
|
569
904
|
out_channels = out_channels or in_channels
|
570
905
|
|
@@ -573,30 +908,53 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
|
573
908
|
self.context_embedder = HunyuanVideoTokenRefiner(
|
574
909
|
text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
|
575
910
|
)
|
576
|
-
|
911
|
+
|
912
|
+
self.time_text_embed = HunyuanVideoConditionEmbedding(
|
913
|
+
inner_dim, pooled_projection_dim, guidance_embeds, image_condition_type
|
914
|
+
)
|
577
915
|
|
578
916
|
# 2. RoPE
|
579
917
|
self.rope = HunyuanVideoRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta)
|
580
918
|
|
581
919
|
# 3. Dual stream transformer blocks
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
920
|
+
if image_condition_type == "token_replace":
|
921
|
+
self.transformer_blocks = nn.ModuleList(
|
922
|
+
[
|
923
|
+
HunyuanVideoTokenReplaceTransformerBlock(
|
924
|
+
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
|
925
|
+
)
|
926
|
+
for _ in range(num_layers)
|
927
|
+
]
|
928
|
+
)
|
929
|
+
else:
|
930
|
+
self.transformer_blocks = nn.ModuleList(
|
931
|
+
[
|
932
|
+
HunyuanVideoTransformerBlock(
|
933
|
+
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
|
934
|
+
)
|
935
|
+
for _ in range(num_layers)
|
936
|
+
]
|
937
|
+
)
|
590
938
|
|
591
939
|
# 4. Single stream transformer blocks
|
592
|
-
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
940
|
+
if image_condition_type == "token_replace":
|
941
|
+
self.single_transformer_blocks = nn.ModuleList(
|
942
|
+
[
|
943
|
+
HunyuanVideoTokenReplaceSingleTransformerBlock(
|
944
|
+
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
|
945
|
+
)
|
946
|
+
for _ in range(num_single_layers)
|
947
|
+
]
|
948
|
+
)
|
949
|
+
else:
|
950
|
+
self.single_transformer_blocks = nn.ModuleList(
|
951
|
+
[
|
952
|
+
HunyuanVideoSingleTransformerBlock(
|
953
|
+
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
|
954
|
+
)
|
955
|
+
for _ in range(num_single_layers)
|
956
|
+
]
|
957
|
+
)
|
600
958
|
|
601
959
|
# 5. Output projection
|
602
960
|
self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
|
@@ -664,10 +1022,6 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
|
664
1022
|
for name, module in self.named_children():
|
665
1023
|
fn_recursive_attn_processor(name, module, processor)
|
666
1024
|
|
667
|
-
def _set_gradient_checkpointing(self, module, value=False):
|
668
|
-
if hasattr(module, "gradient_checkpointing"):
|
669
|
-
module.gradient_checkpointing = value
|
670
|
-
|
671
1025
|
def forward(
|
672
1026
|
self,
|
673
1027
|
hidden_states: torch.Tensor,
|
@@ -699,12 +1053,14 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
|
699
1053
|
post_patch_num_frames = num_frames // p_t
|
700
1054
|
post_patch_height = height // p
|
701
1055
|
post_patch_width = width // p
|
1056
|
+
first_frame_num_tokens = 1 * post_patch_height * post_patch_width
|
702
1057
|
|
703
1058
|
# 1. RoPE
|
704
1059
|
image_rotary_emb = self.rope(hidden_states)
|
705
1060
|
|
706
1061
|
# 2. Conditional embeddings
|
707
|
-
temb = self.time_text_embed(timestep,
|
1062
|
+
temb, token_replace_emb = self.time_text_embed(timestep, pooled_projections, guidance)
|
1063
|
+
|
708
1064
|
hidden_states = self.x_embedder(hidden_states)
|
709
1065
|
encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask)
|
710
1066
|
|
@@ -713,60 +1069,64 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
|
713
1069
|
condition_sequence_length = encoder_hidden_states.shape[1]
|
714
1070
|
sequence_length = latent_sequence_length + condition_sequence_length
|
715
1071
|
attention_mask = torch.zeros(
|
716
|
-
batch_size, sequence_length,
|
717
|
-
) # [B, N
|
1072
|
+
batch_size, sequence_length, device=hidden_states.device, dtype=torch.bool
|
1073
|
+
) # [B, N]
|
718
1074
|
|
719
1075
|
effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,]
|
720
1076
|
effective_sequence_length = latent_sequence_length + effective_condition_sequence_length
|
721
1077
|
|
722
1078
|
for i in range(batch_size):
|
723
|
-
attention_mask[i, : effective_sequence_length[i]
|
1079
|
+
attention_mask[i, : effective_sequence_length[i]] = True
|
1080
|
+
# [B, 1, 1, N], for broadcasting across attention heads
|
1081
|
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
|
724
1082
|
|
725
1083
|
# 4. Transformer blocks
|
726
1084
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
727
|
-
|
728
|
-
def create_custom_forward(module, return_dict=None):
|
729
|
-
def custom_forward(*inputs):
|
730
|
-
if return_dict is not None:
|
731
|
-
return module(*inputs, return_dict=return_dict)
|
732
|
-
else:
|
733
|
-
return module(*inputs)
|
734
|
-
|
735
|
-
return custom_forward
|
736
|
-
|
737
|
-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
738
|
-
|
739
1085
|
for block in self.transformer_blocks:
|
740
|
-
hidden_states, encoder_hidden_states =
|
741
|
-
|
1086
|
+
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
|
1087
|
+
block,
|
742
1088
|
hidden_states,
|
743
1089
|
encoder_hidden_states,
|
744
1090
|
temb,
|
745
1091
|
attention_mask,
|
746
1092
|
image_rotary_emb,
|
747
|
-
|
1093
|
+
token_replace_emb,
|
1094
|
+
first_frame_num_tokens,
|
748
1095
|
)
|
749
1096
|
|
750
1097
|
for block in self.single_transformer_blocks:
|
751
|
-
hidden_states, encoder_hidden_states =
|
752
|
-
|
1098
|
+
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
|
1099
|
+
block,
|
753
1100
|
hidden_states,
|
754
1101
|
encoder_hidden_states,
|
755
1102
|
temb,
|
756
1103
|
attention_mask,
|
757
1104
|
image_rotary_emb,
|
758
|
-
|
1105
|
+
token_replace_emb,
|
1106
|
+
first_frame_num_tokens,
|
759
1107
|
)
|
760
1108
|
|
761
1109
|
else:
|
762
1110
|
for block in self.transformer_blocks:
|
763
1111
|
hidden_states, encoder_hidden_states = block(
|
764
|
-
hidden_states,
|
1112
|
+
hidden_states,
|
1113
|
+
encoder_hidden_states,
|
1114
|
+
temb,
|
1115
|
+
attention_mask,
|
1116
|
+
image_rotary_emb,
|
1117
|
+
token_replace_emb,
|
1118
|
+
first_frame_num_tokens,
|
765
1119
|
)
|
766
1120
|
|
767
1121
|
for block in self.single_transformer_blocks:
|
768
1122
|
hidden_states, encoder_hidden_states = block(
|
769
|
-
hidden_states,
|
1123
|
+
hidden_states,
|
1124
|
+
encoder_hidden_states,
|
1125
|
+
temb,
|
1126
|
+
attention_mask,
|
1127
|
+
image_rotary_emb,
|
1128
|
+
token_replace_emb,
|
1129
|
+
first_frame_num_tokens,
|
770
1130
|
)
|
771
1131
|
|
772
1132
|
# 5. Output projection
|