diffusers 0.32.1__py3-none-any.whl → 0.33.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +186 -3
- diffusers/configuration_utils.py +40 -12
- diffusers/dependency_versions_table.py +9 -2
- diffusers/hooks/__init__.py +9 -0
- diffusers/hooks/faster_cache.py +653 -0
- diffusers/hooks/group_offloading.py +793 -0
- diffusers/hooks/hooks.py +236 -0
- diffusers/hooks/layerwise_casting.py +245 -0
- diffusers/hooks/pyramid_attention_broadcast.py +311 -0
- diffusers/loaders/__init__.py +6 -0
- diffusers/loaders/ip_adapter.py +38 -30
- diffusers/loaders/lora_base.py +198 -28
- diffusers/loaders/lora_conversion_utils.py +679 -44
- diffusers/loaders/lora_pipeline.py +1963 -801
- diffusers/loaders/peft.py +169 -84
- diffusers/loaders/single_file.py +17 -2
- diffusers/loaders/single_file_model.py +53 -5
- diffusers/loaders/single_file_utils.py +653 -75
- diffusers/loaders/textual_inversion.py +9 -9
- diffusers/loaders/transformer_flux.py +8 -9
- diffusers/loaders/transformer_sd3.py +120 -39
- diffusers/loaders/unet.py +22 -32
- diffusers/models/__init__.py +22 -0
- diffusers/models/activations.py +9 -9
- diffusers/models/attention.py +0 -1
- diffusers/models/attention_processor.py +163 -25
- diffusers/models/auto_model.py +169 -0
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
- diffusers/models/autoencoders/autoencoder_dc.py +106 -4
- diffusers/models/autoencoders/autoencoder_kl.py +0 -4
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
- diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
- diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
- diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
- diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
- diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
- diffusers/models/autoencoders/vae.py +31 -141
- diffusers/models/autoencoders/vq_model.py +3 -0
- diffusers/models/cache_utils.py +108 -0
- diffusers/models/controlnets/__init__.py +1 -0
- diffusers/models/controlnets/controlnet.py +3 -8
- diffusers/models/controlnets/controlnet_flux.py +14 -42
- diffusers/models/controlnets/controlnet_sd3.py +58 -34
- diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
- diffusers/models/controlnets/controlnet_union.py +27 -18
- diffusers/models/controlnets/controlnet_xs.py +7 -46
- diffusers/models/controlnets/multicontrolnet_union.py +196 -0
- diffusers/models/embeddings.py +18 -7
- diffusers/models/model_loading_utils.py +122 -80
- diffusers/models/modeling_flax_pytorch_utils.py +1 -1
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +617 -272
- diffusers/models/normalization.py +67 -14
- diffusers/models/resnet.py +1 -1
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
- diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
- diffusers/models/transformers/consisid_transformer_3d.py +789 -0
- diffusers/models/transformers/dit_transformer_2d.py +5 -19
- diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
- diffusers/models/transformers/latte_transformer_3d.py +20 -15
- diffusers/models/transformers/lumina_nextdit2d.py +3 -1
- diffusers/models/transformers/pixart_transformer_2d.py +4 -19
- diffusers/models/transformers/prior_transformer.py +5 -1
- diffusers/models/transformers/sana_transformer.py +144 -40
- diffusers/models/transformers/stable_audio_transformer.py +5 -20
- diffusers/models/transformers/transformer_2d.py +7 -22
- diffusers/models/transformers/transformer_allegro.py +9 -17
- diffusers/models/transformers/transformer_cogview3plus.py +6 -17
- diffusers/models/transformers/transformer_cogview4.py +462 -0
- diffusers/models/transformers/transformer_easyanimate.py +527 -0
- diffusers/models/transformers/transformer_flux.py +68 -110
- diffusers/models/transformers/transformer_hunyuan_video.py +409 -49
- diffusers/models/transformers/transformer_ltx.py +53 -35
- diffusers/models/transformers/transformer_lumina2.py +548 -0
- diffusers/models/transformers/transformer_mochi.py +6 -17
- diffusers/models/transformers/transformer_omnigen.py +469 -0
- diffusers/models/transformers/transformer_sd3.py +56 -86
- diffusers/models/transformers/transformer_temporal.py +5 -11
- diffusers/models/transformers/transformer_wan.py +469 -0
- diffusers/models/unets/unet_1d.py +3 -1
- diffusers/models/unets/unet_2d.py +21 -20
- diffusers/models/unets/unet_2d_blocks.py +19 -243
- diffusers/models/unets/unet_2d_condition.py +4 -6
- diffusers/models/unets/unet_3d_blocks.py +14 -127
- diffusers/models/unets/unet_3d_condition.py +8 -12
- diffusers/models/unets/unet_i2vgen_xl.py +5 -13
- diffusers/models/unets/unet_kandinsky3.py +0 -4
- diffusers/models/unets/unet_motion_model.py +20 -114
- diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
- diffusers/models/unets/unet_stable_cascade.py +8 -35
- diffusers/models/unets/uvit_2d.py +1 -4
- diffusers/optimization.py +2 -2
- diffusers/pipelines/__init__.py +57 -8
- diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
- diffusers/pipelines/amused/pipeline_amused.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
- diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
- diffusers/pipelines/auto_pipeline.py +35 -14
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
- diffusers/pipelines/cogview4/__init__.py +49 -0
- diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
- diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
- diffusers/pipelines/cogview4/pipeline_output.py +21 -0
- diffusers/pipelines/consisid/__init__.py +49 -0
- diffusers/pipelines/consisid/consisid_utils.py +357 -0
- diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
- diffusers/pipelines/consisid/pipeline_output.py +20 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
- diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
- diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +15 -2
- diffusers/pipelines/easyanimate/__init__.py +52 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
- diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
- diffusers/pipelines/flux/pipeline_flux.py +53 -21
- diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
- diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
- diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
- diffusers/pipelines/free_noise_utils.py +3 -3
- diffusers/pipelines/hunyuan_video/__init__.py +4 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
- diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
- diffusers/pipelines/kolors/text_encoder.py +7 -34
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
- diffusers/pipelines/latte/pipeline_latte.py +36 -7
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
- diffusers/pipelines/ltx/__init__.py +2 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
- diffusers/pipelines/lumina/__init__.py +2 -2
- diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
- diffusers/pipelines/lumina2/__init__.py +48 -0
- diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
- diffusers/pipelines/marigold/__init__.py +2 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
- diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
- diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
- diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
- diffusers/pipelines/omnigen/__init__.py +50 -0
- diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
- diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
- diffusers/pipelines/onnx_utils.py +5 -3
- diffusers/pipelines/pag/pag_utils.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
- diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
- diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
- diffusers/pipelines/pia/pipeline_pia.py +13 -1
- diffusers/pipelines/pipeline_flax_utils.py +7 -7
- diffusers/pipelines/pipeline_loading_utils.py +193 -83
- diffusers/pipelines/pipeline_utils.py +221 -106
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
- diffusers/pipelines/sana/__init__.py +2 -0
- diffusers/pipelines/sana/pipeline_sana.py +183 -58
- diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
- diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
- diffusers/pipelines/shap_e/renderer.py +6 -6
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
- diffusers/pipelines/transformers_loading_utils.py +121 -0
- diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
- diffusers/pipelines/wan/__init__.py +51 -0
- diffusers/pipelines/wan/pipeline_output.py +20 -0
- diffusers/pipelines/wan/pipeline_wan.py +593 -0
- diffusers/pipelines/wan/pipeline_wan_i2v.py +722 -0
- diffusers/pipelines/wan/pipeline_wan_video2video.py +725 -0
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
- diffusers/quantizers/auto.py +5 -1
- diffusers/quantizers/base.py +5 -9
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
- diffusers/quantizers/bitsandbytes/utils.py +30 -20
- diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
- diffusers/quantizers/gguf/utils.py +4 -2
- diffusers/quantizers/quantization_config.py +59 -4
- diffusers/quantizers/quanto/__init__.py +1 -0
- diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
- diffusers/quantizers/quanto/utils.py +60 -0
- diffusers/quantizers/torchao/__init__.py +1 -1
- diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
- diffusers/schedulers/__init__.py +2 -1
- diffusers/schedulers/scheduling_consistency_models.py +1 -2
- diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
- diffusers/schedulers/scheduling_ddpm.py +2 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
- diffusers/schedulers/scheduling_edm_euler.py +45 -10
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
- diffusers/schedulers/scheduling_heun_discrete.py +1 -1
- diffusers/schedulers/scheduling_lcm.py +1 -2
- diffusers/schedulers/scheduling_lms_discrete.py +1 -1
- diffusers/schedulers/scheduling_repaint.py +5 -1
- diffusers/schedulers/scheduling_scm.py +265 -0
- diffusers/schedulers/scheduling_tcd.py +1 -2
- diffusers/schedulers/scheduling_utils.py +2 -1
- diffusers/training_utils.py +14 -7
- diffusers/utils/__init__.py +10 -2
- diffusers/utils/constants.py +13 -1
- diffusers/utils/deprecation_utils.py +1 -1
- diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
- diffusers/utils/dummy_gguf_objects.py +17 -0
- diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
- diffusers/utils/dummy_pt_objects.py +233 -0
- diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
- diffusers/utils/dummy_torchao_objects.py +17 -0
- diffusers/utils/dynamic_modules_utils.py +1 -1
- diffusers/utils/export_utils.py +28 -3
- diffusers/utils/hub_utils.py +52 -102
- diffusers/utils/import_utils.py +121 -221
- diffusers/utils/loading_utils.py +14 -1
- diffusers/utils/logging.py +1 -2
- diffusers/utils/peft_utils.py +6 -14
- diffusers/utils/remote_utils.py +425 -0
- diffusers/utils/source_code_parsing_utils.py +52 -0
- diffusers/utils/state_dict_utils.py +15 -1
- diffusers/utils/testing_utils.py +243 -13
- diffusers/utils/torch_utils.py +10 -0
- diffusers/utils/typing_utils.py +91 -0
- diffusers/video_processor.py +1 -1
- {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/METADATA +76 -44
- diffusers-0.33.0.dist-info/RECORD +608 -0
- {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/WHEEL +1 -1
- diffusers-0.32.1.dist-info/RECORD +0 -550
- {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/LICENSE +0 -0
- {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/top_level.txt +0 -0
@@ -20,7 +20,7 @@ import torch
|
|
20
20
|
import torch.nn as nn
|
21
21
|
import torch.nn.functional as F
|
22
22
|
|
23
|
-
from ..utils import is_torch_version
|
23
|
+
from ..utils import is_torch_npu_available, is_torch_version
|
24
24
|
from .activations import get_activation
|
25
25
|
from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings
|
26
26
|
|
@@ -71,7 +71,7 @@ class AdaLayerNorm(nn.Module):
|
|
71
71
|
|
72
72
|
if self.chunk_dim == 1:
|
73
73
|
# This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the
|
74
|
-
# other if-branch. This branch is specific to CogVideoX for now.
|
74
|
+
# other if-branch. This branch is specific to CogVideoX and OmniGen for now.
|
75
75
|
shift, scale = temb.chunk(2, dim=1)
|
76
76
|
shift = shift[:, None, :]
|
77
77
|
scale = scale[:, None, :]
|
@@ -219,14 +219,13 @@ class LuminaRMSNormZero(nn.Module):
|
|
219
219
|
4 * embedding_dim,
|
220
220
|
bias=True,
|
221
221
|
)
|
222
|
-
self.norm = RMSNorm(embedding_dim, eps=norm_eps
|
222
|
+
self.norm = RMSNorm(embedding_dim, eps=norm_eps)
|
223
223
|
|
224
224
|
def forward(
|
225
225
|
self,
|
226
226
|
x: torch.Tensor,
|
227
227
|
emb: Optional[torch.Tensor] = None,
|
228
228
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
229
|
-
# emb = self.emb(timestep, encoder_hidden_states, encoder_mask)
|
230
229
|
emb = self.linear(self.silu(emb))
|
231
230
|
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
|
232
231
|
x = self.norm(x) * (1 + scale_msa[:, None])
|
@@ -307,6 +306,20 @@ class AdaGroupNorm(nn.Module):
|
|
307
306
|
|
308
307
|
|
309
308
|
class AdaLayerNormContinuous(nn.Module):
|
309
|
+
r"""
|
310
|
+
Adaptive normalization layer with a norm layer (layer_norm or rms_norm).
|
311
|
+
|
312
|
+
Args:
|
313
|
+
embedding_dim (`int`): Embedding dimension to use during projection.
|
314
|
+
conditioning_embedding_dim (`int`): Dimension of the input condition.
|
315
|
+
elementwise_affine (`bool`, defaults to `True`):
|
316
|
+
Boolean flag to denote if affine transformation should be applied.
|
317
|
+
eps (`float`, defaults to 1e-5): Epsilon factor.
|
318
|
+
bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use.
|
319
|
+
norm_type (`str`, defaults to `"layer_norm"`):
|
320
|
+
Normalization layer to use. Values supported: "layer_norm", "rms_norm".
|
321
|
+
"""
|
322
|
+
|
310
323
|
def __init__(
|
311
324
|
self,
|
312
325
|
embedding_dim: int,
|
@@ -463,6 +476,17 @@ else:
|
|
463
476
|
# Has optional bias parameter compared to torch layer norm
|
464
477
|
# TODO: replace with torch layernorm once min required torch version >= 2.1
|
465
478
|
class LayerNorm(nn.Module):
|
479
|
+
r"""
|
480
|
+
LayerNorm with the bias parameter.
|
481
|
+
|
482
|
+
Args:
|
483
|
+
dim (`int`): Dimensionality to use for the parameters.
|
484
|
+
eps (`float`, defaults to 1e-5): Epsilon factor.
|
485
|
+
elementwise_affine (`bool`, defaults to `True`):
|
486
|
+
Boolean flag to denote if affine transformation should be applied.
|
487
|
+
bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use.
|
488
|
+
"""
|
489
|
+
|
466
490
|
def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True):
|
467
491
|
super().__init__()
|
468
492
|
|
@@ -485,6 +509,17 @@ else:
|
|
485
509
|
|
486
510
|
|
487
511
|
class RMSNorm(nn.Module):
|
512
|
+
r"""
|
513
|
+
RMS Norm as introduced in https://arxiv.org/abs/1910.07467 by Zhang et al.
|
514
|
+
|
515
|
+
Args:
|
516
|
+
dim (`int`): Number of dimensions to use for `weights`. Only effective when `elementwise_affine` is True.
|
517
|
+
eps (`float`): Small value to use when calculating the reciprocal of the square-root.
|
518
|
+
elementwise_affine (`bool`, defaults to `True`):
|
519
|
+
Boolean flag to denote if affine transformation should be applied.
|
520
|
+
bias (`bool`, defaults to False): If also training the `bias` param.
|
521
|
+
"""
|
522
|
+
|
488
523
|
def __init__(self, dim, eps: float, elementwise_affine: bool = True, bias: bool = False):
|
489
524
|
super().__init__()
|
490
525
|
|
@@ -505,19 +540,30 @@ class RMSNorm(nn.Module):
|
|
505
540
|
self.bias = nn.Parameter(torch.zeros(dim))
|
506
541
|
|
507
542
|
def forward(self, hidden_states):
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
hidden_states = hidden_states * self.weight
|
543
|
+
if is_torch_npu_available():
|
544
|
+
import torch_npu
|
545
|
+
|
546
|
+
if self.weight is not None:
|
547
|
+
# convert into half-precision if necessary
|
548
|
+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
549
|
+
hidden_states = hidden_states.to(self.weight.dtype)
|
550
|
+
hidden_states = torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.eps)[0]
|
517
551
|
if self.bias is not None:
|
518
552
|
hidden_states = hidden_states + self.bias
|
519
553
|
else:
|
520
|
-
|
554
|
+
input_dtype = hidden_states.dtype
|
555
|
+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
556
|
+
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
557
|
+
|
558
|
+
if self.weight is not None:
|
559
|
+
# convert into half-precision if necessary
|
560
|
+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
561
|
+
hidden_states = hidden_states.to(self.weight.dtype)
|
562
|
+
hidden_states = hidden_states * self.weight
|
563
|
+
if self.bias is not None:
|
564
|
+
hidden_states = hidden_states + self.bias
|
565
|
+
else:
|
566
|
+
hidden_states = hidden_states.to(input_dtype)
|
521
567
|
|
522
568
|
return hidden_states
|
523
569
|
|
@@ -553,6 +599,13 @@ class MochiRMSNorm(nn.Module):
|
|
553
599
|
|
554
600
|
|
555
601
|
class GlobalResponseNorm(nn.Module):
|
602
|
+
r"""
|
603
|
+
Global response normalization as introduced in ConvNeXt-v2 (https://arxiv.org/abs/2301.00808).
|
604
|
+
|
605
|
+
Args:
|
606
|
+
dim (`int`): Number of dimensions to use for the `gamma` and `beta`.
|
607
|
+
"""
|
608
|
+
|
556
609
|
# Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105
|
557
610
|
def __init__(self, dim):
|
558
611
|
super().__init__()
|
diffusers/models/resnet.py
CHANGED
@@ -366,7 +366,7 @@ class ResnetBlock2D(nn.Module):
|
|
366
366
|
hidden_states = self.conv2(hidden_states)
|
367
367
|
|
368
368
|
if self.conv_shortcut is not None:
|
369
|
-
input_tensor = self.conv_shortcut(input_tensor)
|
369
|
+
input_tensor = self.conv_shortcut(input_tensor.contiguous())
|
370
370
|
|
371
371
|
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
372
372
|
|
@@ -4,6 +4,7 @@ from ...utils import is_torch_available
|
|
4
4
|
if is_torch_available():
|
5
5
|
from .auraflow_transformer_2d import AuraFlowTransformer2DModel
|
6
6
|
from .cogvideox_transformer_3d import CogVideoXTransformer3DModel
|
7
|
+
from .consisid_transformer_3d import ConsisIDTransformer3DModel
|
7
8
|
from .dit_transformer_2d import DiTTransformer2DModel
|
8
9
|
from .dual_transformer_2d import DualTransformer2DModel
|
9
10
|
from .hunyuan_transformer_2d import HunyuanDiT2DModel
|
@@ -17,9 +18,14 @@ if is_torch_available():
|
|
17
18
|
from .transformer_2d import Transformer2DModel
|
18
19
|
from .transformer_allegro import AllegroTransformer3DModel
|
19
20
|
from .transformer_cogview3plus import CogView3PlusTransformer2DModel
|
21
|
+
from .transformer_cogview4 import CogView4Transformer2DModel
|
22
|
+
from .transformer_easyanimate import EasyAnimateTransformer3DModel
|
20
23
|
from .transformer_flux import FluxTransformer2DModel
|
21
24
|
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
|
22
25
|
from .transformer_ltx import LTXVideoTransformer3DModel
|
26
|
+
from .transformer_lumina2 import Lumina2Transformer2DModel
|
23
27
|
from .transformer_mochi import MochiTransformer3DModel
|
28
|
+
from .transformer_omnigen import OmniGenTransformer2DModel
|
24
29
|
from .transformer_sd3 import SD3Transformer2DModel
|
25
30
|
from .transformer_temporal import TransformerTemporalModel
|
31
|
+
from .transformer_wan import WanTransformer3DModel
|
@@ -13,14 +13,15 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
|
16
|
-
from typing import
|
16
|
+
from typing import Dict, Union
|
17
17
|
|
18
18
|
import torch
|
19
19
|
import torch.nn as nn
|
20
20
|
import torch.nn.functional as F
|
21
21
|
|
22
22
|
from ...configuration_utils import ConfigMixin, register_to_config
|
23
|
-
from ...
|
23
|
+
from ...loaders import FromOriginalModelMixin
|
24
|
+
from ...utils import logging
|
24
25
|
from ...utils.torch_utils import maybe_allow_in_graph
|
25
26
|
from ..attention_processor import (
|
26
27
|
Attention,
|
@@ -253,7 +254,7 @@ class AuraFlowJointTransformerBlock(nn.Module):
|
|
253
254
|
return encoder_hidden_states, hidden_states
|
254
255
|
|
255
256
|
|
256
|
-
class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
|
257
|
+
class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
257
258
|
r"""
|
258
259
|
A 2D Transformer model as introduced in AuraFlow (https://blog.fal.ai/auraflow/).
|
259
260
|
|
@@ -275,6 +276,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
|
|
275
276
|
"""
|
276
277
|
|
277
278
|
_no_split_modules = ["AuraFlowJointTransformerBlock", "AuraFlowSingleTransformerBlock", "AuraFlowPatchEmbed"]
|
279
|
+
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
278
280
|
_supports_gradient_checkpointing = True
|
279
281
|
|
280
282
|
@register_to_config
|
@@ -442,10 +444,6 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
|
|
442
444
|
if self.original_attn_processors is not None:
|
443
445
|
self.set_attn_processor(self.original_attn_processors)
|
444
446
|
|
445
|
-
def _set_gradient_checkpointing(self, module, value=False):
|
446
|
-
if hasattr(module, "gradient_checkpointing"):
|
447
|
-
module.gradient_checkpointing = value
|
448
|
-
|
449
447
|
def forward(
|
450
448
|
self,
|
451
449
|
hidden_states: torch.FloatTensor,
|
@@ -467,23 +465,11 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
|
|
467
465
|
# MMDiT blocks.
|
468
466
|
for index_block, block in enumerate(self.joint_transformer_blocks):
|
469
467
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
470
|
-
|
471
|
-
|
472
|
-
def custom_forward(*inputs):
|
473
|
-
if return_dict is not None:
|
474
|
-
return module(*inputs, return_dict=return_dict)
|
475
|
-
else:
|
476
|
-
return module(*inputs)
|
477
|
-
|
478
|
-
return custom_forward
|
479
|
-
|
480
|
-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
481
|
-
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
482
|
-
create_custom_forward(block),
|
468
|
+
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
469
|
+
block,
|
483
470
|
hidden_states,
|
484
471
|
encoder_hidden_states,
|
485
472
|
temb,
|
486
|
-
**ckpt_kwargs,
|
487
473
|
)
|
488
474
|
|
489
475
|
else:
|
@@ -498,22 +484,10 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
|
|
498
484
|
|
499
485
|
for index_block, block in enumerate(self.single_transformer_blocks):
|
500
486
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
501
|
-
|
502
|
-
|
503
|
-
def custom_forward(*inputs):
|
504
|
-
if return_dict is not None:
|
505
|
-
return module(*inputs, return_dict=return_dict)
|
506
|
-
else:
|
507
|
-
return module(*inputs)
|
508
|
-
|
509
|
-
return custom_forward
|
510
|
-
|
511
|
-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
512
|
-
combined_hidden_states = torch.utils.checkpoint.checkpoint(
|
513
|
-
create_custom_forward(block),
|
487
|
+
combined_hidden_states = self._gradient_checkpointing_func(
|
488
|
+
block,
|
514
489
|
combined_hidden_states,
|
515
490
|
temb,
|
516
|
-
**ckpt_kwargs,
|
517
491
|
)
|
518
492
|
|
519
493
|
else:
|
@@ -20,10 +20,11 @@ from torch import nn
|
|
20
20
|
|
21
21
|
from ...configuration_utils import ConfigMixin, register_to_config
|
22
22
|
from ...loaders import PeftAdapterMixin
|
23
|
-
from ...utils import USE_PEFT_BACKEND,
|
23
|
+
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
24
24
|
from ...utils.torch_utils import maybe_allow_in_graph
|
25
25
|
from ..attention import Attention, FeedForward
|
26
26
|
from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
|
27
|
+
from ..cache_utils import CacheMixin
|
27
28
|
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
|
28
29
|
from ..modeling_outputs import Transformer2DModelOutput
|
29
30
|
from ..modeling_utils import ModelMixin
|
@@ -120,8 +121,10 @@ class CogVideoXBlock(nn.Module):
|
|
120
121
|
encoder_hidden_states: torch.Tensor,
|
121
122
|
temb: torch.Tensor,
|
122
123
|
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
124
|
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
123
125
|
) -> torch.Tensor:
|
124
126
|
text_seq_length = encoder_hidden_states.size(1)
|
127
|
+
attention_kwargs = attention_kwargs or {}
|
125
128
|
|
126
129
|
# norm & modulate
|
127
130
|
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
|
@@ -133,6 +136,7 @@ class CogVideoXBlock(nn.Module):
|
|
133
136
|
hidden_states=norm_hidden_states,
|
134
137
|
encoder_hidden_states=norm_encoder_hidden_states,
|
135
138
|
image_rotary_emb=image_rotary_emb,
|
139
|
+
**attention_kwargs,
|
136
140
|
)
|
137
141
|
|
138
142
|
hidden_states = hidden_states + gate_msa * attn_hidden_states
|
@@ -153,7 +157,7 @@ class CogVideoXBlock(nn.Module):
|
|
153
157
|
return hidden_states, encoder_hidden_states
|
154
158
|
|
155
159
|
|
156
|
-
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
160
|
+
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):
|
157
161
|
"""
|
158
162
|
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
|
159
163
|
|
@@ -209,7 +213,9 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|
209
213
|
Scaling factor to apply in 3D positional embeddings across temporal dimensions.
|
210
214
|
"""
|
211
215
|
|
216
|
+
_skip_layerwise_casting_patterns = ["patch_embed", "norm"]
|
212
217
|
_supports_gradient_checkpointing = True
|
218
|
+
_no_split_modules = ["CogVideoXBlock", "CogVideoXPatchEmbed"]
|
213
219
|
|
214
220
|
@register_to_config
|
215
221
|
def __init__(
|
@@ -325,9 +331,6 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|
325
331
|
|
326
332
|
self.gradient_checkpointing = False
|
327
333
|
|
328
|
-
def _set_gradient_checkpointing(self, module, value=False):
|
329
|
-
self.gradient_checkpointing = value
|
330
|
-
|
331
334
|
@property
|
332
335
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
333
336
|
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
@@ -483,21 +486,13 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|
483
486
|
# 3. Transformer blocks
|
484
487
|
for i, block in enumerate(self.transformer_blocks):
|
485
488
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
486
|
-
|
487
|
-
|
488
|
-
def custom_forward(*inputs):
|
489
|
-
return module(*inputs)
|
490
|
-
|
491
|
-
return custom_forward
|
492
|
-
|
493
|
-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
494
|
-
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
495
|
-
create_custom_forward(block),
|
489
|
+
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
|
490
|
+
block,
|
496
491
|
hidden_states,
|
497
492
|
encoder_hidden_states,
|
498
493
|
emb,
|
499
494
|
image_rotary_emb,
|
500
|
-
|
495
|
+
attention_kwargs,
|
501
496
|
)
|
502
497
|
else:
|
503
498
|
hidden_states, encoder_hidden_states = block(
|
@@ -505,16 +500,10 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|
505
500
|
encoder_hidden_states=encoder_hidden_states,
|
506
501
|
temb=emb,
|
507
502
|
image_rotary_emb=image_rotary_emb,
|
503
|
+
attention_kwargs=attention_kwargs,
|
508
504
|
)
|
509
505
|
|
510
|
-
|
511
|
-
# CogVideoX-2B
|
512
|
-
hidden_states = self.norm_final(hidden_states)
|
513
|
-
else:
|
514
|
-
# CogVideoX-5B
|
515
|
-
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
516
|
-
hidden_states = self.norm_final(hidden_states)
|
517
|
-
hidden_states = hidden_states[:, text_seq_length:]
|
506
|
+
hidden_states = self.norm_final(hidden_states)
|
518
507
|
|
519
508
|
# 4. Final block
|
520
509
|
hidden_states = self.norm_out(hidden_states, temb=emb)
|