diffusers 0.32.2__py3-none-any.whl → 0.33.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +186 -3
- diffusers/configuration_utils.py +40 -12
- diffusers/dependency_versions_table.py +9 -2
- diffusers/hooks/__init__.py +9 -0
- diffusers/hooks/faster_cache.py +653 -0
- diffusers/hooks/group_offloading.py +793 -0
- diffusers/hooks/hooks.py +236 -0
- diffusers/hooks/layerwise_casting.py +245 -0
- diffusers/hooks/pyramid_attention_broadcast.py +311 -0
- diffusers/loaders/__init__.py +6 -0
- diffusers/loaders/ip_adapter.py +38 -30
- diffusers/loaders/lora_base.py +121 -86
- diffusers/loaders/lora_conversion_utils.py +504 -44
- diffusers/loaders/lora_pipeline.py +1769 -181
- diffusers/loaders/peft.py +167 -57
- diffusers/loaders/single_file.py +17 -2
- diffusers/loaders/single_file_model.py +53 -5
- diffusers/loaders/single_file_utils.py +646 -72
- diffusers/loaders/textual_inversion.py +9 -9
- diffusers/loaders/transformer_flux.py +8 -9
- diffusers/loaders/transformer_sd3.py +120 -39
- diffusers/loaders/unet.py +20 -7
- diffusers/models/__init__.py +22 -0
- diffusers/models/activations.py +9 -9
- diffusers/models/attention.py +0 -1
- diffusers/models/attention_processor.py +163 -25
- diffusers/models/auto_model.py +169 -0
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
- diffusers/models/autoencoders/autoencoder_dc.py +106 -4
- diffusers/models/autoencoders/autoencoder_kl.py +0 -4
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
- diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
- diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
- diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
- diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
- diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
- diffusers/models/autoencoders/vae.py +31 -141
- diffusers/models/autoencoders/vq_model.py +3 -0
- diffusers/models/cache_utils.py +108 -0
- diffusers/models/controlnets/__init__.py +1 -0
- diffusers/models/controlnets/controlnet.py +3 -8
- diffusers/models/controlnets/controlnet_flux.py +14 -42
- diffusers/models/controlnets/controlnet_sd3.py +58 -34
- diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
- diffusers/models/controlnets/controlnet_union.py +27 -18
- diffusers/models/controlnets/controlnet_xs.py +7 -46
- diffusers/models/controlnets/multicontrolnet_union.py +196 -0
- diffusers/models/embeddings.py +18 -7
- diffusers/models/model_loading_utils.py +122 -80
- diffusers/models/modeling_flax_pytorch_utils.py +1 -1
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +617 -272
- diffusers/models/normalization.py +67 -14
- diffusers/models/resnet.py +1 -1
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
- diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
- diffusers/models/transformers/consisid_transformer_3d.py +789 -0
- diffusers/models/transformers/dit_transformer_2d.py +5 -19
- diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
- diffusers/models/transformers/latte_transformer_3d.py +20 -15
- diffusers/models/transformers/lumina_nextdit2d.py +3 -1
- diffusers/models/transformers/pixart_transformer_2d.py +4 -19
- diffusers/models/transformers/prior_transformer.py +5 -1
- diffusers/models/transformers/sana_transformer.py +144 -40
- diffusers/models/transformers/stable_audio_transformer.py +5 -20
- diffusers/models/transformers/transformer_2d.py +7 -22
- diffusers/models/transformers/transformer_allegro.py +9 -17
- diffusers/models/transformers/transformer_cogview3plus.py +6 -17
- diffusers/models/transformers/transformer_cogview4.py +462 -0
- diffusers/models/transformers/transformer_easyanimate.py +527 -0
- diffusers/models/transformers/transformer_flux.py +68 -110
- diffusers/models/transformers/transformer_hunyuan_video.py +404 -46
- diffusers/models/transformers/transformer_ltx.py +53 -35
- diffusers/models/transformers/transformer_lumina2.py +548 -0
- diffusers/models/transformers/transformer_mochi.py +6 -17
- diffusers/models/transformers/transformer_omnigen.py +469 -0
- diffusers/models/transformers/transformer_sd3.py +56 -86
- diffusers/models/transformers/transformer_temporal.py +5 -11
- diffusers/models/transformers/transformer_wan.py +469 -0
- diffusers/models/unets/unet_1d.py +3 -1
- diffusers/models/unets/unet_2d.py +21 -20
- diffusers/models/unets/unet_2d_blocks.py +19 -243
- diffusers/models/unets/unet_2d_condition.py +4 -6
- diffusers/models/unets/unet_3d_blocks.py +14 -127
- diffusers/models/unets/unet_3d_condition.py +8 -12
- diffusers/models/unets/unet_i2vgen_xl.py +5 -13
- diffusers/models/unets/unet_kandinsky3.py +0 -4
- diffusers/models/unets/unet_motion_model.py +20 -114
- diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
- diffusers/models/unets/unet_stable_cascade.py +8 -35
- diffusers/models/unets/uvit_2d.py +1 -4
- diffusers/optimization.py +2 -2
- diffusers/pipelines/__init__.py +57 -8
- diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
- diffusers/pipelines/amused/pipeline_amused.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
- diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
- diffusers/pipelines/auto_pipeline.py +35 -14
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
- diffusers/pipelines/cogview4/__init__.py +49 -0
- diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
- diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
- diffusers/pipelines/cogview4/pipeline_output.py +21 -0
- diffusers/pipelines/consisid/__init__.py +49 -0
- diffusers/pipelines/consisid/consisid_utils.py +357 -0
- diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
- diffusers/pipelines/consisid/pipeline_output.py +20 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
- diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
- diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +15 -2
- diffusers/pipelines/easyanimate/__init__.py +52 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
- diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
- diffusers/pipelines/flux/pipeline_flux.py +53 -21
- diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
- diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
- diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
- diffusers/pipelines/free_noise_utils.py +3 -3
- diffusers/pipelines/hunyuan_video/__init__.py +4 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
- diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
- diffusers/pipelines/kolors/text_encoder.py +7 -34
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
- diffusers/pipelines/latte/pipeline_latte.py +36 -7
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
- diffusers/pipelines/ltx/__init__.py +2 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
- diffusers/pipelines/lumina/__init__.py +2 -2
- diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
- diffusers/pipelines/lumina2/__init__.py +48 -0
- diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
- diffusers/pipelines/marigold/__init__.py +2 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
- diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
- diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
- diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
- diffusers/pipelines/omnigen/__init__.py +50 -0
- diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
- diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
- diffusers/pipelines/onnx_utils.py +5 -3
- diffusers/pipelines/pag/pag_utils.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
- diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
- diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
- diffusers/pipelines/pia/pipeline_pia.py +13 -1
- diffusers/pipelines/pipeline_flax_utils.py +7 -7
- diffusers/pipelines/pipeline_loading_utils.py +193 -83
- diffusers/pipelines/pipeline_utils.py +221 -106
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
- diffusers/pipelines/sana/__init__.py +2 -0
- diffusers/pipelines/sana/pipeline_sana.py +183 -58
- diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
- diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
- diffusers/pipelines/shap_e/renderer.py +6 -6
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
- diffusers/pipelines/transformers_loading_utils.py +121 -0
- diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
- diffusers/pipelines/wan/__init__.py +51 -0
- diffusers/pipelines/wan/pipeline_output.py +20 -0
- diffusers/pipelines/wan/pipeline_wan.py +595 -0
- diffusers/pipelines/wan/pipeline_wan_i2v.py +724 -0
- diffusers/pipelines/wan/pipeline_wan_video2video.py +727 -0
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
- diffusers/quantizers/auto.py +5 -1
- diffusers/quantizers/base.py +5 -9
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
- diffusers/quantizers/bitsandbytes/utils.py +30 -20
- diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
- diffusers/quantizers/gguf/utils.py +4 -2
- diffusers/quantizers/quantization_config.py +59 -4
- diffusers/quantizers/quanto/__init__.py +1 -0
- diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
- diffusers/quantizers/quanto/utils.py +60 -0
- diffusers/quantizers/torchao/__init__.py +1 -1
- diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
- diffusers/schedulers/__init__.py +2 -1
- diffusers/schedulers/scheduling_consistency_models.py +1 -2
- diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
- diffusers/schedulers/scheduling_ddpm.py +2 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
- diffusers/schedulers/scheduling_edm_euler.py +45 -10
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
- diffusers/schedulers/scheduling_heun_discrete.py +1 -1
- diffusers/schedulers/scheduling_lcm.py +1 -2
- diffusers/schedulers/scheduling_lms_discrete.py +1 -1
- diffusers/schedulers/scheduling_repaint.py +5 -1
- diffusers/schedulers/scheduling_scm.py +265 -0
- diffusers/schedulers/scheduling_tcd.py +1 -2
- diffusers/schedulers/scheduling_utils.py +2 -1
- diffusers/training_utils.py +14 -7
- diffusers/utils/__init__.py +9 -1
- diffusers/utils/constants.py +13 -1
- diffusers/utils/deprecation_utils.py +1 -1
- diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
- diffusers/utils/dummy_gguf_objects.py +17 -0
- diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
- diffusers/utils/dummy_pt_objects.py +233 -0
- diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
- diffusers/utils/dummy_torchao_objects.py +17 -0
- diffusers/utils/dynamic_modules_utils.py +1 -1
- diffusers/utils/export_utils.py +28 -3
- diffusers/utils/hub_utils.py +52 -102
- diffusers/utils/import_utils.py +121 -221
- diffusers/utils/loading_utils.py +2 -1
- diffusers/utils/logging.py +1 -2
- diffusers/utils/peft_utils.py +6 -14
- diffusers/utils/remote_utils.py +425 -0
- diffusers/utils/source_code_parsing_utils.py +52 -0
- diffusers/utils/state_dict_utils.py +15 -1
- diffusers/utils/testing_utils.py +243 -13
- diffusers/utils/torch_utils.py +10 -0
- diffusers/utils/typing_utils.py +91 -0
- diffusers/video_processor.py +1 -1
- {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/METADATA +21 -4
- diffusers-0.33.1.dist-info/RECORD +608 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/WHEEL +1 -1
- diffusers-0.32.2.dist-info/RECORD +0 -550
- {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/LICENSE +0 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/top_level.txt +0 -0
@@ -138,10 +138,6 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapter
|
|
138
138
|
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
|
139
139
|
self.tile_overlap_factor = 0.25
|
140
140
|
|
141
|
-
def _set_gradient_checkpointing(self, module, value=False):
|
142
|
-
if isinstance(module, (Encoder, Decoder)):
|
143
|
-
module.gradient_checkpointing = value
|
144
|
-
|
145
141
|
def enable_tiling(self, use_tiling: bool = True):
|
146
142
|
r"""
|
147
143
|
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
@@ -103,7 +103,7 @@ class AllegroTemporalConvLayer(nn.Module):
|
|
103
103
|
if self.down_sample:
|
104
104
|
identity = hidden_states[:, :, ::2]
|
105
105
|
elif self.up_sample:
|
106
|
-
identity = hidden_states.repeat_interleave(2, dim=2)
|
106
|
+
identity = hidden_states.repeat_interleave(2, dim=2, output_size=hidden_states.shape[2] * 2)
|
107
107
|
else:
|
108
108
|
identity = hidden_states
|
109
109
|
|
@@ -507,19 +507,12 @@ class AllegroEncoder3D(nn.Module):
|
|
507
507
|
sample = sample + residual
|
508
508
|
|
509
509
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
510
|
-
|
511
|
-
def create_custom_forward(module):
|
512
|
-
def custom_forward(*inputs):
|
513
|
-
return module(*inputs)
|
514
|
-
|
515
|
-
return custom_forward
|
516
|
-
|
517
510
|
# Down blocks
|
518
511
|
for down_block in self.down_blocks:
|
519
|
-
sample =
|
512
|
+
sample = self._gradient_checkpointing_func(down_block, sample)
|
520
513
|
|
521
514
|
# Mid block
|
522
|
-
sample =
|
515
|
+
sample = self._gradient_checkpointing_func(self.mid_block, sample)
|
523
516
|
else:
|
524
517
|
# Down blocks
|
525
518
|
for down_block in self.down_blocks:
|
@@ -647,19 +640,12 @@ class AllegroDecoder3D(nn.Module):
|
|
647
640
|
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
648
641
|
|
649
642
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
650
|
-
|
651
|
-
def create_custom_forward(module):
|
652
|
-
def custom_forward(*inputs):
|
653
|
-
return module(*inputs)
|
654
|
-
|
655
|
-
return custom_forward
|
656
|
-
|
657
643
|
# Mid block
|
658
|
-
sample =
|
644
|
+
sample = self._gradient_checkpointing_func(self.mid_block, sample)
|
659
645
|
|
660
646
|
# Up blocks
|
661
647
|
for up_block in self.up_blocks:
|
662
|
-
sample =
|
648
|
+
sample = self._gradient_checkpointing_func(up_block, sample)
|
663
649
|
|
664
650
|
else:
|
665
651
|
# Mid block
|
@@ -809,10 +795,6 @@ class AutoencoderKLAllegro(ModelMixin, ConfigMixin):
|
|
809
795
|
sample_size - self.tile_overlap_w,
|
810
796
|
)
|
811
797
|
|
812
|
-
def _set_gradient_checkpointing(self, module, value=False):
|
813
|
-
if isinstance(module, (AllegroEncoder3D, AllegroDecoder3D)):
|
814
|
-
module.gradient_checkpointing = value
|
815
|
-
|
816
798
|
def enable_tiling(self) -> None:
|
817
799
|
r"""
|
818
800
|
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
@@ -105,6 +105,7 @@ class CogVideoXCausalConv3d(nn.Module):
|
|
105
105
|
self.width_pad = width_pad
|
106
106
|
self.time_pad = time_pad
|
107
107
|
self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
|
108
|
+
self.const_padding_conv3d = (0, self.width_pad, self.height_pad)
|
108
109
|
|
109
110
|
self.temporal_dim = 2
|
110
111
|
self.time_kernel_size = time_kernel_size
|
@@ -117,6 +118,8 @@ class CogVideoXCausalConv3d(nn.Module):
|
|
117
118
|
kernel_size=kernel_size,
|
118
119
|
stride=stride,
|
119
120
|
dilation=dilation,
|
121
|
+
padding=0 if self.pad_mode == "replicate" else self.const_padding_conv3d,
|
122
|
+
padding_mode="zeros",
|
120
123
|
)
|
121
124
|
|
122
125
|
def fake_context_parallel_forward(
|
@@ -137,9 +140,7 @@ class CogVideoXCausalConv3d(nn.Module):
|
|
137
140
|
if self.pad_mode == "replicate":
|
138
141
|
conv_cache = None
|
139
142
|
else:
|
140
|
-
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
|
141
143
|
conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
|
142
|
-
inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
|
143
144
|
|
144
145
|
output = self.conv(inputs)
|
145
146
|
return output, conv_cache
|
@@ -421,15 +422,8 @@ class CogVideoXDownBlock3D(nn.Module):
|
|
421
422
|
conv_cache_key = f"resnet_{i}"
|
422
423
|
|
423
424
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
424
|
-
|
425
|
-
|
426
|
-
def create_forward(*inputs):
|
427
|
-
return module(*inputs)
|
428
|
-
|
429
|
-
return create_forward
|
430
|
-
|
431
|
-
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
432
|
-
create_custom_forward(resnet),
|
425
|
+
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
|
426
|
+
resnet,
|
433
427
|
hidden_states,
|
434
428
|
temb,
|
435
429
|
zq,
|
@@ -523,15 +517,8 @@ class CogVideoXMidBlock3D(nn.Module):
|
|
523
517
|
conv_cache_key = f"resnet_{i}"
|
524
518
|
|
525
519
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
526
|
-
|
527
|
-
|
528
|
-
def create_forward(*inputs):
|
529
|
-
return module(*inputs)
|
530
|
-
|
531
|
-
return create_forward
|
532
|
-
|
533
|
-
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
534
|
-
create_custom_forward(resnet), hidden_states, temb, zq, conv_cache.get(conv_cache_key)
|
520
|
+
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
|
521
|
+
resnet, hidden_states, temb, zq, conv_cache.get(conv_cache_key)
|
535
522
|
)
|
536
523
|
else:
|
537
524
|
hidden_states, new_conv_cache[conv_cache_key] = resnet(
|
@@ -637,15 +624,8 @@ class CogVideoXUpBlock3D(nn.Module):
|
|
637
624
|
conv_cache_key = f"resnet_{i}"
|
638
625
|
|
639
626
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
640
|
-
|
641
|
-
|
642
|
-
def create_forward(*inputs):
|
643
|
-
return module(*inputs)
|
644
|
-
|
645
|
-
return create_forward
|
646
|
-
|
647
|
-
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
648
|
-
create_custom_forward(resnet),
|
627
|
+
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
|
628
|
+
resnet,
|
649
629
|
hidden_states,
|
650
630
|
temb,
|
651
631
|
zq,
|
@@ -774,18 +754,11 @@ class CogVideoXEncoder3D(nn.Module):
|
|
774
754
|
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
|
775
755
|
|
776
756
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
777
|
-
|
778
|
-
def create_custom_forward(module):
|
779
|
-
def custom_forward(*inputs):
|
780
|
-
return module(*inputs)
|
781
|
-
|
782
|
-
return custom_forward
|
783
|
-
|
784
757
|
# 1. Down
|
785
758
|
for i, down_block in enumerate(self.down_blocks):
|
786
759
|
conv_cache_key = f"down_block_{i}"
|
787
|
-
hidden_states, new_conv_cache[conv_cache_key] =
|
788
|
-
|
760
|
+
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
|
761
|
+
down_block,
|
789
762
|
hidden_states,
|
790
763
|
temb,
|
791
764
|
None,
|
@@ -793,8 +766,8 @@ class CogVideoXEncoder3D(nn.Module):
|
|
793
766
|
)
|
794
767
|
|
795
768
|
# 2. Mid
|
796
|
-
hidden_states, new_conv_cache["mid_block"] =
|
797
|
-
|
769
|
+
hidden_states, new_conv_cache["mid_block"] = self._gradient_checkpointing_func(
|
770
|
+
self.mid_block,
|
798
771
|
hidden_states,
|
799
772
|
temb,
|
800
773
|
None,
|
@@ -940,16 +913,9 @@ class CogVideoXDecoder3D(nn.Module):
|
|
940
913
|
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
|
941
914
|
|
942
915
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
943
|
-
|
944
|
-
def create_custom_forward(module):
|
945
|
-
def custom_forward(*inputs):
|
946
|
-
return module(*inputs)
|
947
|
-
|
948
|
-
return custom_forward
|
949
|
-
|
950
916
|
# 1. Mid
|
951
|
-
hidden_states, new_conv_cache["mid_block"] =
|
952
|
-
|
917
|
+
hidden_states, new_conv_cache["mid_block"] = self._gradient_checkpointing_func(
|
918
|
+
self.mid_block,
|
953
919
|
hidden_states,
|
954
920
|
temb,
|
955
921
|
sample,
|
@@ -959,8 +925,8 @@ class CogVideoXDecoder3D(nn.Module):
|
|
959
925
|
# 2. Up
|
960
926
|
for i, up_block in enumerate(self.up_blocks):
|
961
927
|
conv_cache_key = f"up_block_{i}"
|
962
|
-
hidden_states, new_conv_cache[conv_cache_key] =
|
963
|
-
|
928
|
+
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
|
929
|
+
up_block,
|
964
930
|
hidden_states,
|
965
931
|
temb,
|
966
932
|
sample,
|
@@ -1122,10 +1088,6 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
1122
1088
|
self.tile_overlap_factor_height = 1 / 6
|
1123
1089
|
self.tile_overlap_factor_width = 1 / 5
|
1124
1090
|
|
1125
|
-
def _set_gradient_checkpointing(self, module, value=False):
|
1126
|
-
if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
|
1127
|
-
module.gradient_checkpointing = value
|
1128
|
-
|
1129
1091
|
def enable_tiling(
|
1130
1092
|
self,
|
1131
1093
|
tile_sample_min_height: Optional[int] = None,
|
@@ -12,7 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import
|
15
|
+
from typing import Optional, Tuple, Union
|
16
16
|
|
17
17
|
import numpy as np
|
18
18
|
import torch
|
@@ -21,7 +21,7 @@ import torch.nn.functional as F
|
|
21
21
|
import torch.utils.checkpoint
|
22
22
|
|
23
23
|
from ...configuration_utils import ConfigMixin, register_to_config
|
24
|
-
from ...utils import
|
24
|
+
from ...utils import logging
|
25
25
|
from ...utils.accelerate_utils import apply_forward_hook
|
26
26
|
from ..activations import get_activation
|
27
27
|
from ..attention_processor import Attention
|
@@ -36,11 +36,11 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|
36
36
|
def prepare_causal_attention_mask(
|
37
37
|
num_frames: int, height_width: int, dtype: torch.dtype, device: torch.device, batch_size: int = None
|
38
38
|
) -> torch.Tensor:
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
39
|
+
indices = torch.arange(1, num_frames + 1, dtype=torch.int32, device=device)
|
40
|
+
indices_blocks = indices.repeat_interleave(height_width)
|
41
|
+
x, y = torch.meshgrid(indices_blocks, indices_blocks, indexing="xy")
|
42
|
+
mask = torch.where(x <= y, 0, -float("inf")).to(dtype=dtype)
|
43
|
+
|
44
44
|
if batch_size is not None:
|
45
45
|
mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
|
46
46
|
return mask
|
@@ -252,21 +252,7 @@ class HunyuanVideoMidBlock3D(nn.Module):
|
|
252
252
|
|
253
253
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
254
254
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
255
|
-
|
256
|
-
def create_custom_forward(module, return_dict=None):
|
257
|
-
def custom_forward(*inputs):
|
258
|
-
if return_dict is not None:
|
259
|
-
return module(*inputs, return_dict=return_dict)
|
260
|
-
else:
|
261
|
-
return module(*inputs)
|
262
|
-
|
263
|
-
return custom_forward
|
264
|
-
|
265
|
-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
266
|
-
|
267
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
268
|
-
create_custom_forward(self.resnets[0]), hidden_states, **ckpt_kwargs
|
269
|
-
)
|
255
|
+
hidden_states = self._gradient_checkpointing_func(self.resnets[0], hidden_states)
|
270
256
|
|
271
257
|
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
272
258
|
if attn is not None:
|
@@ -278,9 +264,7 @@ class HunyuanVideoMidBlock3D(nn.Module):
|
|
278
264
|
hidden_states = attn(hidden_states, attention_mask=attention_mask)
|
279
265
|
hidden_states = hidden_states.unflatten(1, (num_frames, height, width)).permute(0, 4, 1, 2, 3)
|
280
266
|
|
281
|
-
hidden_states =
|
282
|
-
create_custom_forward(resnet), hidden_states, **ckpt_kwargs
|
283
|
-
)
|
267
|
+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
|
284
268
|
|
285
269
|
else:
|
286
270
|
hidden_states = self.resnets[0](hidden_states)
|
@@ -350,22 +334,8 @@ class HunyuanVideoDownBlock3D(nn.Module):
|
|
350
334
|
|
351
335
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
352
336
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
353
|
-
|
354
|
-
def create_custom_forward(module, return_dict=None):
|
355
|
-
def custom_forward(*inputs):
|
356
|
-
if return_dict is not None:
|
357
|
-
return module(*inputs, return_dict=return_dict)
|
358
|
-
else:
|
359
|
-
return module(*inputs)
|
360
|
-
|
361
|
-
return custom_forward
|
362
|
-
|
363
|
-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
364
|
-
|
365
337
|
for resnet in self.resnets:
|
366
|
-
hidden_states =
|
367
|
-
create_custom_forward(resnet), hidden_states, **ckpt_kwargs
|
368
|
-
)
|
338
|
+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
|
369
339
|
else:
|
370
340
|
for resnet in self.resnets:
|
371
341
|
hidden_states = resnet(hidden_states)
|
@@ -426,22 +396,8 @@ class HunyuanVideoUpBlock3D(nn.Module):
|
|
426
396
|
|
427
397
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
428
398
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
429
|
-
|
430
|
-
def create_custom_forward(module, return_dict=None):
|
431
|
-
def custom_forward(*inputs):
|
432
|
-
if return_dict is not None:
|
433
|
-
return module(*inputs, return_dict=return_dict)
|
434
|
-
else:
|
435
|
-
return module(*inputs)
|
436
|
-
|
437
|
-
return custom_forward
|
438
|
-
|
439
|
-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
440
|
-
|
441
399
|
for resnet in self.resnets:
|
442
|
-
hidden_states =
|
443
|
-
create_custom_forward(resnet), hidden_states, **ckpt_kwargs
|
444
|
-
)
|
400
|
+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
|
445
401
|
|
446
402
|
else:
|
447
403
|
for resnet in self.resnets:
|
@@ -545,26 +501,10 @@ class HunyuanVideoEncoder3D(nn.Module):
|
|
545
501
|
hidden_states = self.conv_in(hidden_states)
|
546
502
|
|
547
503
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
548
|
-
|
549
|
-
def create_custom_forward(module, return_dict=None):
|
550
|
-
def custom_forward(*inputs):
|
551
|
-
if return_dict is not None:
|
552
|
-
return module(*inputs, return_dict=return_dict)
|
553
|
-
else:
|
554
|
-
return module(*inputs)
|
555
|
-
|
556
|
-
return custom_forward
|
557
|
-
|
558
|
-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
559
|
-
|
560
504
|
for down_block in self.down_blocks:
|
561
|
-
hidden_states =
|
562
|
-
create_custom_forward(down_block), hidden_states, **ckpt_kwargs
|
563
|
-
)
|
505
|
+
hidden_states = self._gradient_checkpointing_func(down_block, hidden_states)
|
564
506
|
|
565
|
-
hidden_states =
|
566
|
-
create_custom_forward(self.mid_block), hidden_states, **ckpt_kwargs
|
567
|
-
)
|
507
|
+
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
|
568
508
|
else:
|
569
509
|
for down_block in self.down_blocks:
|
570
510
|
hidden_states = down_block(hidden_states)
|
@@ -667,26 +607,10 @@ class HunyuanVideoDecoder3D(nn.Module):
|
|
667
607
|
hidden_states = self.conv_in(hidden_states)
|
668
608
|
|
669
609
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
670
|
-
|
671
|
-
def create_custom_forward(module, return_dict=None):
|
672
|
-
def custom_forward(*inputs):
|
673
|
-
if return_dict is not None:
|
674
|
-
return module(*inputs, return_dict=return_dict)
|
675
|
-
else:
|
676
|
-
return module(*inputs)
|
677
|
-
|
678
|
-
return custom_forward
|
679
|
-
|
680
|
-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
681
|
-
|
682
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
683
|
-
create_custom_forward(self.mid_block), hidden_states, **ckpt_kwargs
|
684
|
-
)
|
610
|
+
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
|
685
611
|
|
686
612
|
for up_block in self.up_blocks:
|
687
|
-
hidden_states =
|
688
|
-
create_custom_forward(up_block), hidden_states, **ckpt_kwargs
|
689
|
-
)
|
613
|
+
hidden_states = self._gradient_checkpointing_func(up_block, hidden_states)
|
690
614
|
else:
|
691
615
|
hidden_states = self.mid_block(hidden_states)
|
692
616
|
|
@@ -786,7 +710,7 @@ class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin):
|
|
786
710
|
self.use_tiling = False
|
787
711
|
|
788
712
|
# When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames
|
789
|
-
# at a fixed frame batch size (based on `self.
|
713
|
+
# at a fixed frame batch size (based on `self.tile_sample_min_num_frames`), the memory requirement can be lowered.
|
790
714
|
self.use_framewise_encoding = True
|
791
715
|
self.use_framewise_decoding = True
|
792
716
|
|
@@ -800,10 +724,6 @@ class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin):
|
|
800
724
|
self.tile_sample_stride_width = 192
|
801
725
|
self.tile_sample_stride_num_frames = 12
|
802
726
|
|
803
|
-
def _set_gradient_checkpointing(self, module, value=False):
|
804
|
-
if isinstance(module, (HunyuanVideoEncoder3D, HunyuanVideoDecoder3D)):
|
805
|
-
module.gradient_checkpointing = value
|
806
|
-
|
807
727
|
def enable_tiling(
|
808
728
|
self,
|
809
729
|
tile_sample_min_height: Optional[int] = None,
|
@@ -868,7 +788,7 @@ class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin):
|
|
868
788
|
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
869
789
|
batch_size, num_channels, num_frames, height, width = x.shape
|
870
790
|
|
871
|
-
if self.
|
791
|
+
if self.use_framewise_encoding and num_frames > self.tile_sample_min_num_frames:
|
872
792
|
return self._temporal_tiled_encode(x)
|
873
793
|
|
874
794
|
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
|