diffusers 0.32.2__py3-none-any.whl → 0.33.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +186 -3
- diffusers/configuration_utils.py +40 -12
- diffusers/dependency_versions_table.py +9 -2
- diffusers/hooks/__init__.py +9 -0
- diffusers/hooks/faster_cache.py +653 -0
- diffusers/hooks/group_offloading.py +793 -0
- diffusers/hooks/hooks.py +236 -0
- diffusers/hooks/layerwise_casting.py +245 -0
- diffusers/hooks/pyramid_attention_broadcast.py +311 -0
- diffusers/loaders/__init__.py +6 -0
- diffusers/loaders/ip_adapter.py +38 -30
- diffusers/loaders/lora_base.py +121 -86
- diffusers/loaders/lora_conversion_utils.py +504 -44
- diffusers/loaders/lora_pipeline.py +1769 -181
- diffusers/loaders/peft.py +167 -57
- diffusers/loaders/single_file.py +17 -2
- diffusers/loaders/single_file_model.py +53 -5
- diffusers/loaders/single_file_utils.py +646 -72
- diffusers/loaders/textual_inversion.py +9 -9
- diffusers/loaders/transformer_flux.py +8 -9
- diffusers/loaders/transformer_sd3.py +120 -39
- diffusers/loaders/unet.py +20 -7
- diffusers/models/__init__.py +22 -0
- diffusers/models/activations.py +9 -9
- diffusers/models/attention.py +0 -1
- diffusers/models/attention_processor.py +163 -25
- diffusers/models/auto_model.py +169 -0
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
- diffusers/models/autoencoders/autoencoder_dc.py +106 -4
- diffusers/models/autoencoders/autoencoder_kl.py +0 -4
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
- diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
- diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
- diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
- diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
- diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
- diffusers/models/autoencoders/vae.py +31 -141
- diffusers/models/autoencoders/vq_model.py +3 -0
- diffusers/models/cache_utils.py +108 -0
- diffusers/models/controlnets/__init__.py +1 -0
- diffusers/models/controlnets/controlnet.py +3 -8
- diffusers/models/controlnets/controlnet_flux.py +14 -42
- diffusers/models/controlnets/controlnet_sd3.py +58 -34
- diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
- diffusers/models/controlnets/controlnet_union.py +27 -18
- diffusers/models/controlnets/controlnet_xs.py +7 -46
- diffusers/models/controlnets/multicontrolnet_union.py +196 -0
- diffusers/models/embeddings.py +18 -7
- diffusers/models/model_loading_utils.py +122 -80
- diffusers/models/modeling_flax_pytorch_utils.py +1 -1
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +617 -272
- diffusers/models/normalization.py +67 -14
- diffusers/models/resnet.py +1 -1
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
- diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
- diffusers/models/transformers/consisid_transformer_3d.py +789 -0
- diffusers/models/transformers/dit_transformer_2d.py +5 -19
- diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
- diffusers/models/transformers/latte_transformer_3d.py +20 -15
- diffusers/models/transformers/lumina_nextdit2d.py +3 -1
- diffusers/models/transformers/pixart_transformer_2d.py +4 -19
- diffusers/models/transformers/prior_transformer.py +5 -1
- diffusers/models/transformers/sana_transformer.py +144 -40
- diffusers/models/transformers/stable_audio_transformer.py +5 -20
- diffusers/models/transformers/transformer_2d.py +7 -22
- diffusers/models/transformers/transformer_allegro.py +9 -17
- diffusers/models/transformers/transformer_cogview3plus.py +6 -17
- diffusers/models/transformers/transformer_cogview4.py +462 -0
- diffusers/models/transformers/transformer_easyanimate.py +527 -0
- diffusers/models/transformers/transformer_flux.py +68 -110
- diffusers/models/transformers/transformer_hunyuan_video.py +404 -46
- diffusers/models/transformers/transformer_ltx.py +53 -35
- diffusers/models/transformers/transformer_lumina2.py +548 -0
- diffusers/models/transformers/transformer_mochi.py +6 -17
- diffusers/models/transformers/transformer_omnigen.py +469 -0
- diffusers/models/transformers/transformer_sd3.py +56 -86
- diffusers/models/transformers/transformer_temporal.py +5 -11
- diffusers/models/transformers/transformer_wan.py +469 -0
- diffusers/models/unets/unet_1d.py +3 -1
- diffusers/models/unets/unet_2d.py +21 -20
- diffusers/models/unets/unet_2d_blocks.py +19 -243
- diffusers/models/unets/unet_2d_condition.py +4 -6
- diffusers/models/unets/unet_3d_blocks.py +14 -127
- diffusers/models/unets/unet_3d_condition.py +8 -12
- diffusers/models/unets/unet_i2vgen_xl.py +5 -13
- diffusers/models/unets/unet_kandinsky3.py +0 -4
- diffusers/models/unets/unet_motion_model.py +20 -114
- diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
- diffusers/models/unets/unet_stable_cascade.py +8 -35
- diffusers/models/unets/uvit_2d.py +1 -4
- diffusers/optimization.py +2 -2
- diffusers/pipelines/__init__.py +57 -8
- diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
- diffusers/pipelines/amused/pipeline_amused.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
- diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
- diffusers/pipelines/auto_pipeline.py +35 -14
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
- diffusers/pipelines/cogview4/__init__.py +49 -0
- diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
- diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
- diffusers/pipelines/cogview4/pipeline_output.py +21 -0
- diffusers/pipelines/consisid/__init__.py +49 -0
- diffusers/pipelines/consisid/consisid_utils.py +357 -0
- diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
- diffusers/pipelines/consisid/pipeline_output.py +20 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
- diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
- diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +15 -2
- diffusers/pipelines/easyanimate/__init__.py +52 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
- diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
- diffusers/pipelines/flux/pipeline_flux.py +53 -21
- diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
- diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
- diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
- diffusers/pipelines/free_noise_utils.py +3 -3
- diffusers/pipelines/hunyuan_video/__init__.py +4 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
- diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
- diffusers/pipelines/kolors/text_encoder.py +7 -34
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
- diffusers/pipelines/latte/pipeline_latte.py +36 -7
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
- diffusers/pipelines/ltx/__init__.py +2 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
- diffusers/pipelines/lumina/__init__.py +2 -2
- diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
- diffusers/pipelines/lumina2/__init__.py +48 -0
- diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
- diffusers/pipelines/marigold/__init__.py +2 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
- diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
- diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
- diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
- diffusers/pipelines/omnigen/__init__.py +50 -0
- diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
- diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
- diffusers/pipelines/onnx_utils.py +5 -3
- diffusers/pipelines/pag/pag_utils.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
- diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
- diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
- diffusers/pipelines/pia/pipeline_pia.py +13 -1
- diffusers/pipelines/pipeline_flax_utils.py +7 -7
- diffusers/pipelines/pipeline_loading_utils.py +193 -83
- diffusers/pipelines/pipeline_utils.py +221 -106
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
- diffusers/pipelines/sana/__init__.py +2 -0
- diffusers/pipelines/sana/pipeline_sana.py +183 -58
- diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
- diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
- diffusers/pipelines/shap_e/renderer.py +6 -6
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
- diffusers/pipelines/transformers_loading_utils.py +121 -0
- diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
- diffusers/pipelines/wan/__init__.py +51 -0
- diffusers/pipelines/wan/pipeline_output.py +20 -0
- diffusers/pipelines/wan/pipeline_wan.py +593 -0
- diffusers/pipelines/wan/pipeline_wan_i2v.py +722 -0
- diffusers/pipelines/wan/pipeline_wan_video2video.py +725 -0
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
- diffusers/quantizers/auto.py +5 -1
- diffusers/quantizers/base.py +5 -9
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
- diffusers/quantizers/bitsandbytes/utils.py +30 -20
- diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
- diffusers/quantizers/gguf/utils.py +4 -2
- diffusers/quantizers/quantization_config.py +59 -4
- diffusers/quantizers/quanto/__init__.py +1 -0
- diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
- diffusers/quantizers/quanto/utils.py +60 -0
- diffusers/quantizers/torchao/__init__.py +1 -1
- diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
- diffusers/schedulers/__init__.py +2 -1
- diffusers/schedulers/scheduling_consistency_models.py +1 -2
- diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
- diffusers/schedulers/scheduling_ddpm.py +2 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
- diffusers/schedulers/scheduling_edm_euler.py +45 -10
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
- diffusers/schedulers/scheduling_heun_discrete.py +1 -1
- diffusers/schedulers/scheduling_lcm.py +1 -2
- diffusers/schedulers/scheduling_lms_discrete.py +1 -1
- diffusers/schedulers/scheduling_repaint.py +5 -1
- diffusers/schedulers/scheduling_scm.py +265 -0
- diffusers/schedulers/scheduling_tcd.py +1 -2
- diffusers/schedulers/scheduling_utils.py +2 -1
- diffusers/training_utils.py +14 -7
- diffusers/utils/__init__.py +9 -1
- diffusers/utils/constants.py +13 -1
- diffusers/utils/deprecation_utils.py +1 -1
- diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
- diffusers/utils/dummy_gguf_objects.py +17 -0
- diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
- diffusers/utils/dummy_pt_objects.py +233 -0
- diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
- diffusers/utils/dummy_torchao_objects.py +17 -0
- diffusers/utils/dynamic_modules_utils.py +1 -1
- diffusers/utils/export_utils.py +28 -3
- diffusers/utils/hub_utils.py +52 -102
- diffusers/utils/import_utils.py +121 -221
- diffusers/utils/loading_utils.py +2 -1
- diffusers/utils/logging.py +1 -2
- diffusers/utils/peft_utils.py +6 -14
- diffusers/utils/remote_utils.py +425 -0
- diffusers/utils/source_code_parsing_utils.py +52 -0
- diffusers/utils/state_dict_utils.py +15 -1
- diffusers/utils/testing_utils.py +243 -13
- diffusers/utils/torch_utils.py +10 -0
- diffusers/utils/typing_utils.py +91 -0
- diffusers/video_processor.py +1 -1
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/METADATA +76 -44
- diffusers-0.33.0.dist-info/RECORD +608 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/WHEEL +1 -1
- diffusers-0.32.2.dist-info/RECORD +0 -550
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/LICENSE +0 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/top_level.txt +0 -0
@@ -196,6 +196,55 @@ class LTXVideoResnetBlock3d(nn.Module):
|
|
196
196
|
return hidden_states
|
197
197
|
|
198
198
|
|
199
|
+
class LTXVideoDownsampler3d(nn.Module):
|
200
|
+
def __init__(
|
201
|
+
self,
|
202
|
+
in_channels: int,
|
203
|
+
out_channels: int,
|
204
|
+
stride: Union[int, Tuple[int, int, int]] = 1,
|
205
|
+
is_causal: bool = True,
|
206
|
+
padding_mode: str = "zeros",
|
207
|
+
) -> None:
|
208
|
+
super().__init__()
|
209
|
+
|
210
|
+
self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
|
211
|
+
self.group_size = (in_channels * stride[0] * stride[1] * stride[2]) // out_channels
|
212
|
+
|
213
|
+
out_channels = out_channels // (self.stride[0] * self.stride[1] * self.stride[2])
|
214
|
+
|
215
|
+
self.conv = LTXVideoCausalConv3d(
|
216
|
+
in_channels=in_channels,
|
217
|
+
out_channels=out_channels,
|
218
|
+
kernel_size=3,
|
219
|
+
stride=1,
|
220
|
+
is_causal=is_causal,
|
221
|
+
padding_mode=padding_mode,
|
222
|
+
)
|
223
|
+
|
224
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
225
|
+
hidden_states = torch.cat([hidden_states[:, :, : self.stride[0] - 1], hidden_states], dim=2)
|
226
|
+
|
227
|
+
residual = (
|
228
|
+
hidden_states.unflatten(4, (-1, self.stride[2]))
|
229
|
+
.unflatten(3, (-1, self.stride[1]))
|
230
|
+
.unflatten(2, (-1, self.stride[0]))
|
231
|
+
)
|
232
|
+
residual = residual.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4)
|
233
|
+
residual = residual.unflatten(1, (-1, self.group_size))
|
234
|
+
residual = residual.mean(dim=2)
|
235
|
+
|
236
|
+
hidden_states = self.conv(hidden_states)
|
237
|
+
hidden_states = (
|
238
|
+
hidden_states.unflatten(4, (-1, self.stride[2]))
|
239
|
+
.unflatten(3, (-1, self.stride[1]))
|
240
|
+
.unflatten(2, (-1, self.stride[0]))
|
241
|
+
)
|
242
|
+
hidden_states = hidden_states.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4)
|
243
|
+
hidden_states = hidden_states + residual
|
244
|
+
|
245
|
+
return hidden_states
|
246
|
+
|
247
|
+
|
199
248
|
class LTXVideoUpsampler3d(nn.Module):
|
200
249
|
def __init__(
|
201
250
|
self,
|
@@ -204,6 +253,7 @@ class LTXVideoUpsampler3d(nn.Module):
|
|
204
253
|
is_causal: bool = True,
|
205
254
|
residual: bool = False,
|
206
255
|
upscale_factor: int = 1,
|
256
|
+
padding_mode: str = "zeros",
|
207
257
|
) -> None:
|
208
258
|
super().__init__()
|
209
259
|
|
@@ -219,6 +269,7 @@ class LTXVideoUpsampler3d(nn.Module):
|
|
219
269
|
kernel_size=3,
|
220
270
|
stride=1,
|
221
271
|
is_causal=is_causal,
|
272
|
+
padding_mode=padding_mode,
|
222
273
|
)
|
223
274
|
|
224
275
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
@@ -338,16 +389,122 @@ class LTXVideoDownBlock3D(nn.Module):
|
|
338
389
|
|
339
390
|
for i, resnet in enumerate(self.resnets):
|
340
391
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
392
|
+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator)
|
393
|
+
else:
|
394
|
+
hidden_states = resnet(hidden_states, temb, generator)
|
341
395
|
|
342
|
-
|
343
|
-
|
344
|
-
|
396
|
+
if self.downsamplers is not None:
|
397
|
+
for downsampler in self.downsamplers:
|
398
|
+
hidden_states = downsampler(hidden_states)
|
345
399
|
|
346
|
-
|
400
|
+
if self.conv_out is not None:
|
401
|
+
hidden_states = self.conv_out(hidden_states, temb, generator)
|
402
|
+
|
403
|
+
return hidden_states
|
404
|
+
|
405
|
+
|
406
|
+
class LTXVideo095DownBlock3D(nn.Module):
|
407
|
+
r"""
|
408
|
+
Down block used in the LTXVideo model.
|
409
|
+
|
410
|
+
Args:
|
411
|
+
in_channels (`int`):
|
412
|
+
Number of input channels.
|
413
|
+
out_channels (`int`, *optional*):
|
414
|
+
Number of output channels. If None, defaults to `in_channels`.
|
415
|
+
num_layers (`int`, defaults to `1`):
|
416
|
+
Number of resnet layers.
|
417
|
+
dropout (`float`, defaults to `0.0`):
|
418
|
+
Dropout rate.
|
419
|
+
resnet_eps (`float`, defaults to `1e-6`):
|
420
|
+
Epsilon value for normalization layers.
|
421
|
+
resnet_act_fn (`str`, defaults to `"swish"`):
|
422
|
+
Activation function to use.
|
423
|
+
spatio_temporal_scale (`bool`, defaults to `True`):
|
424
|
+
Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
|
425
|
+
Whether or not to downsample across temporal dimension.
|
426
|
+
is_causal (`bool`, defaults to `True`):
|
427
|
+
Whether this layer behaves causally (future frames depend only on past frames) or not.
|
428
|
+
"""
|
429
|
+
|
430
|
+
_supports_gradient_checkpointing = True
|
431
|
+
|
432
|
+
def __init__(
|
433
|
+
self,
|
434
|
+
in_channels: int,
|
435
|
+
out_channels: Optional[int] = None,
|
436
|
+
num_layers: int = 1,
|
437
|
+
dropout: float = 0.0,
|
438
|
+
resnet_eps: float = 1e-6,
|
439
|
+
resnet_act_fn: str = "swish",
|
440
|
+
spatio_temporal_scale: bool = True,
|
441
|
+
is_causal: bool = True,
|
442
|
+
downsample_type: str = "conv",
|
443
|
+
):
|
444
|
+
super().__init__()
|
445
|
+
|
446
|
+
out_channels = out_channels or in_channels
|
447
|
+
|
448
|
+
resnets = []
|
449
|
+
for _ in range(num_layers):
|
450
|
+
resnets.append(
|
451
|
+
LTXVideoResnetBlock3d(
|
452
|
+
in_channels=in_channels,
|
453
|
+
out_channels=in_channels,
|
454
|
+
dropout=dropout,
|
455
|
+
eps=resnet_eps,
|
456
|
+
non_linearity=resnet_act_fn,
|
457
|
+
is_causal=is_causal,
|
458
|
+
)
|
459
|
+
)
|
460
|
+
self.resnets = nn.ModuleList(resnets)
|
461
|
+
|
462
|
+
self.downsamplers = None
|
463
|
+
if spatio_temporal_scale:
|
464
|
+
self.downsamplers = nn.ModuleList()
|
347
465
|
|
348
|
-
|
349
|
-
|
466
|
+
if downsample_type == "conv":
|
467
|
+
self.downsamplers.append(
|
468
|
+
LTXVideoCausalConv3d(
|
469
|
+
in_channels=in_channels,
|
470
|
+
out_channels=in_channels,
|
471
|
+
kernel_size=3,
|
472
|
+
stride=(2, 2, 2),
|
473
|
+
is_causal=is_causal,
|
474
|
+
)
|
475
|
+
)
|
476
|
+
elif downsample_type == "spatial":
|
477
|
+
self.downsamplers.append(
|
478
|
+
LTXVideoDownsampler3d(
|
479
|
+
in_channels=in_channels, out_channels=out_channels, stride=(1, 2, 2), is_causal=is_causal
|
480
|
+
)
|
481
|
+
)
|
482
|
+
elif downsample_type == "temporal":
|
483
|
+
self.downsamplers.append(
|
484
|
+
LTXVideoDownsampler3d(
|
485
|
+
in_channels=in_channels, out_channels=out_channels, stride=(2, 1, 1), is_causal=is_causal
|
486
|
+
)
|
487
|
+
)
|
488
|
+
elif downsample_type == "spatiotemporal":
|
489
|
+
self.downsamplers.append(
|
490
|
+
LTXVideoDownsampler3d(
|
491
|
+
in_channels=in_channels, out_channels=out_channels, stride=(2, 2, 2), is_causal=is_causal
|
492
|
+
)
|
350
493
|
)
|
494
|
+
|
495
|
+
self.gradient_checkpointing = False
|
496
|
+
|
497
|
+
def forward(
|
498
|
+
self,
|
499
|
+
hidden_states: torch.Tensor,
|
500
|
+
temb: Optional[torch.Tensor] = None,
|
501
|
+
generator: Optional[torch.Generator] = None,
|
502
|
+
) -> torch.Tensor:
|
503
|
+
r"""Forward method of the `LTXDownBlock3D` class."""
|
504
|
+
|
505
|
+
for i, resnet in enumerate(self.resnets):
|
506
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
507
|
+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator)
|
351
508
|
else:
|
352
509
|
hidden_states = resnet(hidden_states, temb, generator)
|
353
510
|
|
@@ -355,9 +512,6 @@ class LTXVideoDownBlock3D(nn.Module):
|
|
355
512
|
for downsampler in self.downsamplers:
|
356
513
|
hidden_states = downsampler(hidden_states)
|
357
514
|
|
358
|
-
if self.conv_out is not None:
|
359
|
-
hidden_states = self.conv_out(hidden_states, temb, generator)
|
360
|
-
|
361
515
|
return hidden_states
|
362
516
|
|
363
517
|
|
@@ -438,16 +592,7 @@ class LTXVideoMidBlock3d(nn.Module):
|
|
438
592
|
|
439
593
|
for i, resnet in enumerate(self.resnets):
|
440
594
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
441
|
-
|
442
|
-
def create_custom_forward(module):
|
443
|
-
def create_forward(*inputs):
|
444
|
-
return module(*inputs)
|
445
|
-
|
446
|
-
return create_forward
|
447
|
-
|
448
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
449
|
-
create_custom_forward(resnet), hidden_states, temb, generator
|
450
|
-
)
|
595
|
+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator)
|
451
596
|
else:
|
452
597
|
hidden_states = resnet(hidden_states, temb, generator)
|
453
598
|
|
@@ -573,16 +718,7 @@ class LTXVideoUpBlock3d(nn.Module):
|
|
573
718
|
|
574
719
|
for i, resnet in enumerate(self.resnets):
|
575
720
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
576
|
-
|
577
|
-
def create_custom_forward(module):
|
578
|
-
def create_forward(*inputs):
|
579
|
-
return module(*inputs)
|
580
|
-
|
581
|
-
return create_forward
|
582
|
-
|
583
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
584
|
-
create_custom_forward(resnet), hidden_states, temb, generator
|
585
|
-
)
|
721
|
+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator)
|
586
722
|
else:
|
587
723
|
hidden_states = resnet(hidden_states, temb, generator)
|
588
724
|
|
@@ -620,8 +756,15 @@ class LTXVideoEncoder3d(nn.Module):
|
|
620
756
|
in_channels: int = 3,
|
621
757
|
out_channels: int = 128,
|
622
758
|
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
759
|
+
down_block_types: Tuple[str, ...] = (
|
760
|
+
"LTXVideoDownBlock3D",
|
761
|
+
"LTXVideoDownBlock3D",
|
762
|
+
"LTXVideoDownBlock3D",
|
763
|
+
"LTXVideoDownBlock3D",
|
764
|
+
),
|
623
765
|
spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
|
624
766
|
layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
|
767
|
+
downsample_type: Tuple[str, ...] = ("conv", "conv", "conv", "conv"),
|
625
768
|
patch_size: int = 4,
|
626
769
|
patch_size_t: int = 1,
|
627
770
|
resnet_norm_eps: float = 1e-6,
|
@@ -644,20 +787,37 @@ class LTXVideoEncoder3d(nn.Module):
|
|
644
787
|
)
|
645
788
|
|
646
789
|
# down blocks
|
647
|
-
|
790
|
+
is_ltx_095 = down_block_types[-1] == "LTXVideo095DownBlock3D"
|
791
|
+
num_block_out_channels = len(block_out_channels) - (1 if is_ltx_095 else 0)
|
648
792
|
self.down_blocks = nn.ModuleList([])
|
649
793
|
for i in range(num_block_out_channels):
|
650
794
|
input_channel = output_channel
|
651
|
-
|
652
|
-
|
653
|
-
|
654
|
-
|
655
|
-
|
656
|
-
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
795
|
+
if not is_ltx_095:
|
796
|
+
output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i]
|
797
|
+
else:
|
798
|
+
output_channel = block_out_channels[i + 1]
|
799
|
+
|
800
|
+
if down_block_types[i] == "LTXVideoDownBlock3D":
|
801
|
+
down_block = LTXVideoDownBlock3D(
|
802
|
+
in_channels=input_channel,
|
803
|
+
out_channels=output_channel,
|
804
|
+
num_layers=layers_per_block[i],
|
805
|
+
resnet_eps=resnet_norm_eps,
|
806
|
+
spatio_temporal_scale=spatio_temporal_scaling[i],
|
807
|
+
is_causal=is_causal,
|
808
|
+
)
|
809
|
+
elif down_block_types[i] == "LTXVideo095DownBlock3D":
|
810
|
+
down_block = LTXVideo095DownBlock3D(
|
811
|
+
in_channels=input_channel,
|
812
|
+
out_channels=output_channel,
|
813
|
+
num_layers=layers_per_block[i],
|
814
|
+
resnet_eps=resnet_norm_eps,
|
815
|
+
spatio_temporal_scale=spatio_temporal_scaling[i],
|
816
|
+
is_causal=is_causal,
|
817
|
+
downsample_type=downsample_type[i],
|
818
|
+
)
|
819
|
+
else:
|
820
|
+
raise ValueError(f"Unknown down block type: {down_block_types[i]}")
|
661
821
|
|
662
822
|
self.down_blocks.append(down_block)
|
663
823
|
|
@@ -697,17 +857,10 @@ class LTXVideoEncoder3d(nn.Module):
|
|
697
857
|
hidden_states = self.conv_in(hidden_states)
|
698
858
|
|
699
859
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
700
|
-
|
701
|
-
def create_custom_forward(module):
|
702
|
-
def create_forward(*inputs):
|
703
|
-
return module(*inputs)
|
704
|
-
|
705
|
-
return create_forward
|
706
|
-
|
707
860
|
for down_block in self.down_blocks:
|
708
|
-
hidden_states =
|
861
|
+
hidden_states = self._gradient_checkpointing_func(down_block, hidden_states)
|
709
862
|
|
710
|
-
hidden_states =
|
863
|
+
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
|
711
864
|
else:
|
712
865
|
for down_block in self.down_blocks:
|
713
866
|
hidden_states = down_block(hidden_states)
|
@@ -828,7 +981,9 @@ class LTXVideoDecoder3d(nn.Module):
|
|
828
981
|
# timestep embedding
|
829
982
|
self.time_embedder = None
|
830
983
|
self.scale_shift_table = None
|
984
|
+
self.timestep_scale_multiplier = None
|
831
985
|
if timestep_conditioning:
|
986
|
+
self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0, dtype=torch.float32))
|
832
987
|
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0)
|
833
988
|
self.scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5)
|
834
989
|
|
@@ -837,20 +992,14 @@ class LTXVideoDecoder3d(nn.Module):
|
|
837
992
|
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
838
993
|
hidden_states = self.conv_in(hidden_states)
|
839
994
|
|
840
|
-
if
|
995
|
+
if self.timestep_scale_multiplier is not None:
|
996
|
+
temb = temb * self.timestep_scale_multiplier
|
841
997
|
|
842
|
-
|
843
|
-
|
844
|
-
return module(*inputs)
|
845
|
-
|
846
|
-
return create_forward
|
847
|
-
|
848
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
849
|
-
create_custom_forward(self.mid_block), hidden_states, temb
|
850
|
-
)
|
998
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
999
|
+
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, temb)
|
851
1000
|
|
852
1001
|
for up_block in self.up_blocks:
|
853
|
-
hidden_states =
|
1002
|
+
hidden_states = self._gradient_checkpointing_func(up_block, hidden_states, temb)
|
854
1003
|
else:
|
855
1004
|
hidden_states = self.mid_block(hidden_states, temb)
|
856
1005
|
|
@@ -934,12 +1083,19 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
934
1083
|
out_channels: int = 3,
|
935
1084
|
latent_channels: int = 128,
|
936
1085
|
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
1086
|
+
down_block_types: Tuple[str, ...] = (
|
1087
|
+
"LTXVideoDownBlock3D",
|
1088
|
+
"LTXVideoDownBlock3D",
|
1089
|
+
"LTXVideoDownBlock3D",
|
1090
|
+
"LTXVideoDownBlock3D",
|
1091
|
+
),
|
937
1092
|
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
938
1093
|
layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
|
939
1094
|
decoder_layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
|
940
1095
|
spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
|
941
1096
|
decoder_spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
|
942
1097
|
decoder_inject_noise: Tuple[bool, ...] = (False, False, False, False, False),
|
1098
|
+
downsample_type: Tuple[str, ...] = ("conv", "conv", "conv", "conv"),
|
943
1099
|
upsample_residual: Tuple[bool, ...] = (False, False, False, False),
|
944
1100
|
upsample_factor: Tuple[int, ...] = (1, 1, 1, 1),
|
945
1101
|
timestep_conditioning: bool = False,
|
@@ -949,6 +1105,8 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
949
1105
|
scaling_factor: float = 1.0,
|
950
1106
|
encoder_causal: bool = True,
|
951
1107
|
decoder_causal: bool = False,
|
1108
|
+
spatial_compression_ratio: int = None,
|
1109
|
+
temporal_compression_ratio: int = None,
|
952
1110
|
) -> None:
|
953
1111
|
super().__init__()
|
954
1112
|
|
@@ -956,8 +1114,10 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
956
1114
|
in_channels=in_channels,
|
957
1115
|
out_channels=latent_channels,
|
958
1116
|
block_out_channels=block_out_channels,
|
1117
|
+
down_block_types=down_block_types,
|
959
1118
|
spatio_temporal_scaling=spatio_temporal_scaling,
|
960
1119
|
layers_per_block=layers_per_block,
|
1120
|
+
downsample_type=downsample_type,
|
961
1121
|
patch_size=patch_size,
|
962
1122
|
patch_size_t=patch_size_t,
|
963
1123
|
resnet_norm_eps=resnet_norm_eps,
|
@@ -984,8 +1144,16 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
984
1144
|
self.register_buffer("latents_mean", latents_mean, persistent=True)
|
985
1145
|
self.register_buffer("latents_std", latents_std, persistent=True)
|
986
1146
|
|
987
|
-
self.spatial_compression_ratio =
|
988
|
-
|
1147
|
+
self.spatial_compression_ratio = (
|
1148
|
+
patch_size * 2 ** sum(spatio_temporal_scaling)
|
1149
|
+
if spatial_compression_ratio is None
|
1150
|
+
else spatial_compression_ratio
|
1151
|
+
)
|
1152
|
+
self.temporal_compression_ratio = (
|
1153
|
+
patch_size_t * 2 ** sum(spatio_temporal_scaling)
|
1154
|
+
if temporal_compression_ratio is None
|
1155
|
+
else temporal_compression_ratio
|
1156
|
+
)
|
989
1157
|
|
990
1158
|
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
|
991
1159
|
# to perform decoding of a single video latent at a time.
|
@@ -1010,21 +1178,21 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
1010
1178
|
# The minimal tile height and width for spatial tiling to be used
|
1011
1179
|
self.tile_sample_min_height = 512
|
1012
1180
|
self.tile_sample_min_width = 512
|
1181
|
+
self.tile_sample_min_num_frames = 16
|
1013
1182
|
|
1014
1183
|
# The minimal distance between two spatial tiles
|
1015
1184
|
self.tile_sample_stride_height = 448
|
1016
1185
|
self.tile_sample_stride_width = 448
|
1017
|
-
|
1018
|
-
def _set_gradient_checkpointing(self, module, value=False):
|
1019
|
-
if isinstance(module, (LTXVideoEncoder3d, LTXVideoDecoder3d)):
|
1020
|
-
module.gradient_checkpointing = value
|
1186
|
+
self.tile_sample_stride_num_frames = 8
|
1021
1187
|
|
1022
1188
|
def enable_tiling(
|
1023
1189
|
self,
|
1024
1190
|
tile_sample_min_height: Optional[int] = None,
|
1025
1191
|
tile_sample_min_width: Optional[int] = None,
|
1192
|
+
tile_sample_min_num_frames: Optional[int] = None,
|
1026
1193
|
tile_sample_stride_height: Optional[float] = None,
|
1027
1194
|
tile_sample_stride_width: Optional[float] = None,
|
1195
|
+
tile_sample_stride_num_frames: Optional[float] = None,
|
1028
1196
|
) -> None:
|
1029
1197
|
r"""
|
1030
1198
|
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
@@ -1046,8 +1214,10 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
1046
1214
|
self.use_tiling = True
|
1047
1215
|
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
|
1048
1216
|
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
|
1217
|
+
self.tile_sample_min_num_frames = tile_sample_min_num_frames or self.tile_sample_min_num_frames
|
1049
1218
|
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
|
1050
1219
|
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
|
1220
|
+
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
|
1051
1221
|
|
1052
1222
|
def disable_tiling(self) -> None:
|
1053
1223
|
r"""
|
@@ -1073,18 +1243,13 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
1073
1243
|
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
1074
1244
|
batch_size, num_channels, num_frames, height, width = x.shape
|
1075
1245
|
|
1246
|
+
if self.use_framewise_decoding and num_frames > self.tile_sample_min_num_frames:
|
1247
|
+
return self._temporal_tiled_encode(x)
|
1248
|
+
|
1076
1249
|
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
|
1077
1250
|
return self.tiled_encode(x)
|
1078
1251
|
|
1079
|
-
|
1080
|
-
# TODO(aryan): requires investigation
|
1081
|
-
raise NotImplementedError(
|
1082
|
-
"Frame-wise encoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
|
1083
|
-
"quality issues caused by splitting inference across frame dimension. If you believe this "
|
1084
|
-
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
|
1085
|
-
)
|
1086
|
-
else:
|
1087
|
-
enc = self.encoder(x)
|
1252
|
+
enc = self.encoder(x)
|
1088
1253
|
|
1089
1254
|
return enc
|
1090
1255
|
|
@@ -1121,19 +1286,15 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
1121
1286
|
batch_size, num_channels, num_frames, height, width = z.shape
|
1122
1287
|
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
1123
1288
|
tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio
|
1289
|
+
tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
|
1290
|
+
|
1291
|
+
if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames:
|
1292
|
+
return self._temporal_tiled_decode(z, temb, return_dict=return_dict)
|
1124
1293
|
|
1125
1294
|
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
|
1126
1295
|
return self.tiled_decode(z, temb, return_dict=return_dict)
|
1127
1296
|
|
1128
|
-
|
1129
|
-
# TODO(aryan): requires investigation
|
1130
|
-
raise NotImplementedError(
|
1131
|
-
"Frame-wise decoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
|
1132
|
-
"quality issues caused by splitting inference across frame dimension. If you believe this "
|
1133
|
-
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
|
1134
|
-
)
|
1135
|
-
else:
|
1136
|
-
dec = self.decoder(z, temb)
|
1297
|
+
dec = self.decoder(z, temb)
|
1137
1298
|
|
1138
1299
|
if not return_dict:
|
1139
1300
|
return (dec,)
|
@@ -1189,6 +1350,14 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
1189
1350
|
)
|
1190
1351
|
return b
|
1191
1352
|
|
1353
|
+
def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
1354
|
+
blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
|
1355
|
+
for x in range(blend_extent):
|
1356
|
+
b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (
|
1357
|
+
x / blend_extent
|
1358
|
+
)
|
1359
|
+
return b
|
1360
|
+
|
1192
1361
|
def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
|
1193
1362
|
r"""Encode a batch of images using a tiled encoder.
|
1194
1363
|
|
@@ -1217,17 +1386,9 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
1217
1386
|
for i in range(0, height, self.tile_sample_stride_height):
|
1218
1387
|
row = []
|
1219
1388
|
for j in range(0, width, self.tile_sample_stride_width):
|
1220
|
-
|
1221
|
-
|
1222
|
-
|
1223
|
-
"Frame-wise encoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
|
1224
|
-
"quality issues caused by splitting inference across frame dimension. If you believe this "
|
1225
|
-
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
|
1226
|
-
)
|
1227
|
-
else:
|
1228
|
-
time = self.encoder(
|
1229
|
-
x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
|
1230
|
-
)
|
1389
|
+
time = self.encoder(
|
1390
|
+
x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
|
1391
|
+
)
|
1231
1392
|
|
1232
1393
|
row.append(time)
|
1233
1394
|
rows.append(row)
|
@@ -1283,17 +1444,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
1283
1444
|
for i in range(0, height, tile_latent_stride_height):
|
1284
1445
|
row = []
|
1285
1446
|
for j in range(0, width, tile_latent_stride_width):
|
1286
|
-
|
1287
|
-
# TODO(aryan): requires investigation
|
1288
|
-
raise NotImplementedError(
|
1289
|
-
"Frame-wise decoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
|
1290
|
-
"quality issues caused by splitting inference across frame dimension. If you believe this "
|
1291
|
-
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
|
1292
|
-
)
|
1293
|
-
else:
|
1294
|
-
time = self.decoder(
|
1295
|
-
z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb
|
1296
|
-
)
|
1447
|
+
time = self.decoder(z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb)
|
1297
1448
|
|
1298
1449
|
row.append(time)
|
1299
1450
|
rows.append(row)
|
@@ -1318,6 +1469,74 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
1318
1469
|
|
1319
1470
|
return DecoderOutput(sample=dec)
|
1320
1471
|
|
1472
|
+
def _temporal_tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
|
1473
|
+
batch_size, num_channels, num_frames, height, width = x.shape
|
1474
|
+
latent_num_frames = (num_frames - 1) // self.temporal_compression_ratio + 1
|
1475
|
+
|
1476
|
+
tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
|
1477
|
+
tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio
|
1478
|
+
blend_num_frames = tile_latent_min_num_frames - tile_latent_stride_num_frames
|
1479
|
+
|
1480
|
+
row = []
|
1481
|
+
for i in range(0, num_frames, self.tile_sample_stride_num_frames):
|
1482
|
+
tile = x[:, :, i : i + self.tile_sample_min_num_frames + 1, :, :]
|
1483
|
+
if self.use_tiling and (height > self.tile_sample_min_height or width > self.tile_sample_min_width):
|
1484
|
+
tile = self.tiled_encode(tile)
|
1485
|
+
else:
|
1486
|
+
tile = self.encoder(tile)
|
1487
|
+
if i > 0:
|
1488
|
+
tile = tile[:, :, 1:, :, :]
|
1489
|
+
row.append(tile)
|
1490
|
+
|
1491
|
+
result_row = []
|
1492
|
+
for i, tile in enumerate(row):
|
1493
|
+
if i > 0:
|
1494
|
+
tile = self.blend_t(row[i - 1], tile, blend_num_frames)
|
1495
|
+
result_row.append(tile[:, :, :tile_latent_stride_num_frames, :, :])
|
1496
|
+
else:
|
1497
|
+
result_row.append(tile[:, :, : tile_latent_stride_num_frames + 1, :, :])
|
1498
|
+
|
1499
|
+
enc = torch.cat(result_row, dim=2)[:, :, :latent_num_frames]
|
1500
|
+
return enc
|
1501
|
+
|
1502
|
+
def _temporal_tiled_decode(
|
1503
|
+
self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True
|
1504
|
+
) -> Union[DecoderOutput, torch.Tensor]:
|
1505
|
+
batch_size, num_channels, num_frames, height, width = z.shape
|
1506
|
+
num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
|
1507
|
+
|
1508
|
+
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
1509
|
+
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
1510
|
+
tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
|
1511
|
+
tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio
|
1512
|
+
blend_num_frames = self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames
|
1513
|
+
|
1514
|
+
row = []
|
1515
|
+
for i in range(0, num_frames, tile_latent_stride_num_frames):
|
1516
|
+
tile = z[:, :, i : i + tile_latent_min_num_frames + 1, :, :]
|
1517
|
+
if self.use_tiling and (tile.shape[-1] > tile_latent_min_width or tile.shape[-2] > tile_latent_min_height):
|
1518
|
+
decoded = self.tiled_decode(tile, temb, return_dict=True).sample
|
1519
|
+
else:
|
1520
|
+
decoded = self.decoder(tile, temb)
|
1521
|
+
if i > 0:
|
1522
|
+
decoded = decoded[:, :, :-1, :, :]
|
1523
|
+
row.append(decoded)
|
1524
|
+
|
1525
|
+
result_row = []
|
1526
|
+
for i, tile in enumerate(row):
|
1527
|
+
if i > 0:
|
1528
|
+
tile = self.blend_t(row[i - 1], tile, blend_num_frames)
|
1529
|
+
tile = tile[:, :, : self.tile_sample_stride_num_frames, :, :]
|
1530
|
+
result_row.append(tile)
|
1531
|
+
else:
|
1532
|
+
result_row.append(tile[:, :, : self.tile_sample_stride_num_frames + 1, :, :])
|
1533
|
+
|
1534
|
+
dec = torch.cat(result_row, dim=2)[:, :, :num_sample_frames]
|
1535
|
+
|
1536
|
+
if not return_dict:
|
1537
|
+
return (dec,)
|
1538
|
+
return DecoderOutput(sample=dec)
|
1539
|
+
|
1321
1540
|
def forward(
|
1322
1541
|
self,
|
1323
1542
|
sample: torch.Tensor,
|
@@ -1334,5 +1553,5 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
1334
1553
|
z = posterior.mode()
|
1335
1554
|
dec = self.decode(z, temb)
|
1336
1555
|
if not return_dict:
|
1337
|
-
return (dec,)
|
1556
|
+
return (dec.sample,)
|
1338
1557
|
return dec
|