diffusers 0.33.1__py3-none-any.whl → 0.35.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +145 -1
- diffusers/callbacks.py +35 -0
- diffusers/commands/__init__.py +1 -1
- diffusers/commands/custom_blocks.py +134 -0
- diffusers/commands/diffusers_cli.py +3 -1
- diffusers/commands/env.py +1 -1
- diffusers/commands/fp16_safetensors.py +2 -2
- diffusers/configuration_utils.py +11 -2
- diffusers/dependency_versions_check.py +1 -1
- diffusers/dependency_versions_table.py +3 -3
- diffusers/experimental/rl/value_guided_sampling.py +1 -1
- diffusers/guiders/__init__.py +41 -0
- diffusers/guiders/adaptive_projected_guidance.py +188 -0
- diffusers/guiders/auto_guidance.py +190 -0
- diffusers/guiders/classifier_free_guidance.py +141 -0
- diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
- diffusers/guiders/frequency_decoupled_guidance.py +327 -0
- diffusers/guiders/guider_utils.py +309 -0
- diffusers/guiders/perturbed_attention_guidance.py +271 -0
- diffusers/guiders/skip_layer_guidance.py +262 -0
- diffusers/guiders/smoothed_energy_guidance.py +251 -0
- diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
- diffusers/hooks/__init__.py +17 -0
- diffusers/hooks/_common.py +56 -0
- diffusers/hooks/_helpers.py +293 -0
- diffusers/hooks/faster_cache.py +9 -8
- diffusers/hooks/first_block_cache.py +259 -0
- diffusers/hooks/group_offloading.py +332 -227
- diffusers/hooks/hooks.py +58 -3
- diffusers/hooks/layer_skip.py +263 -0
- diffusers/hooks/layerwise_casting.py +5 -10
- diffusers/hooks/pyramid_attention_broadcast.py +15 -12
- diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
- diffusers/hooks/utils.py +43 -0
- diffusers/image_processor.py +7 -2
- diffusers/loaders/__init__.py +10 -0
- diffusers/loaders/ip_adapter.py +260 -18
- diffusers/loaders/lora_base.py +261 -127
- diffusers/loaders/lora_conversion_utils.py +657 -35
- diffusers/loaders/lora_pipeline.py +2778 -1246
- diffusers/loaders/peft.py +78 -112
- diffusers/loaders/single_file.py +2 -2
- diffusers/loaders/single_file_model.py +64 -15
- diffusers/loaders/single_file_utils.py +395 -7
- diffusers/loaders/textual_inversion.py +3 -2
- diffusers/loaders/transformer_flux.py +10 -11
- diffusers/loaders/transformer_sd3.py +8 -3
- diffusers/loaders/unet.py +24 -21
- diffusers/loaders/unet_loader_utils.py +6 -3
- diffusers/loaders/utils.py +1 -1
- diffusers/models/__init__.py +23 -1
- diffusers/models/activations.py +5 -5
- diffusers/models/adapter.py +2 -3
- diffusers/models/attention.py +488 -7
- diffusers/models/attention_dispatch.py +1218 -0
- diffusers/models/attention_flax.py +10 -10
- diffusers/models/attention_processor.py +113 -667
- diffusers/models/auto_model.py +49 -12
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
- diffusers/models/autoencoders/autoencoder_dc.py +17 -4
- diffusers/models/autoencoders/autoencoder_kl.py +5 -5
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
- diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1110 -0
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_wan.py +626 -62
- diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
- diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
- diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
- diffusers/models/autoencoders/vae.py +13 -2
- diffusers/models/autoencoders/vq_model.py +2 -2
- diffusers/models/cache_utils.py +32 -10
- diffusers/models/controlnet.py +1 -1
- diffusers/models/controlnet_flux.py +1 -1
- diffusers/models/controlnet_sd3.py +1 -1
- diffusers/models/controlnet_sparsectrl.py +1 -1
- diffusers/models/controlnets/__init__.py +1 -0
- diffusers/models/controlnets/controlnet.py +3 -3
- diffusers/models/controlnets/controlnet_flax.py +1 -1
- diffusers/models/controlnets/controlnet_flux.py +21 -20
- diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
- diffusers/models/controlnets/controlnet_sana.py +290 -0
- diffusers/models/controlnets/controlnet_sd3.py +1 -1
- diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
- diffusers/models/controlnets/controlnet_union.py +5 -5
- diffusers/models/controlnets/controlnet_xs.py +7 -7
- diffusers/models/controlnets/multicontrolnet.py +4 -5
- diffusers/models/controlnets/multicontrolnet_union.py +5 -6
- diffusers/models/downsampling.py +2 -2
- diffusers/models/embeddings.py +36 -46
- diffusers/models/embeddings_flax.py +2 -2
- diffusers/models/lora.py +3 -3
- diffusers/models/model_loading_utils.py +233 -1
- diffusers/models/modeling_flax_utils.py +1 -2
- diffusers/models/modeling_utils.py +203 -108
- diffusers/models/normalization.py +4 -4
- diffusers/models/resnet.py +2 -2
- diffusers/models/resnet_flax.py +1 -1
- diffusers/models/transformers/__init__.py +7 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
- diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
- diffusers/models/transformers/consisid_transformer_3d.py +1 -1
- diffusers/models/transformers/dit_transformer_2d.py +2 -2
- diffusers/models/transformers/dual_transformer_2d.py +1 -1
- diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
- diffusers/models/transformers/latte_transformer_3d.py +4 -5
- diffusers/models/transformers/lumina_nextdit2d.py +2 -2
- diffusers/models/transformers/pixart_transformer_2d.py +3 -3
- diffusers/models/transformers/prior_transformer.py +1 -1
- diffusers/models/transformers/sana_transformer.py +8 -3
- diffusers/models/transformers/stable_audio_transformer.py +5 -9
- diffusers/models/transformers/t5_film_transformer.py +3 -3
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +1 -1
- diffusers/models/transformers/transformer_chroma.py +641 -0
- diffusers/models/transformers/transformer_cogview3plus.py +5 -10
- diffusers/models/transformers/transformer_cogview4.py +353 -27
- diffusers/models/transformers/transformer_cosmos.py +586 -0
- diffusers/models/transformers/transformer_flux.py +376 -138
- diffusers/models/transformers/transformer_hidream_image.py +942 -0
- diffusers/models/transformers/transformer_hunyuan_video.py +12 -8
- diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
- diffusers/models/transformers/transformer_ltx.py +105 -24
- diffusers/models/transformers/transformer_lumina2.py +1 -1
- diffusers/models/transformers/transformer_mochi.py +1 -1
- diffusers/models/transformers/transformer_omnigen.py +2 -2
- diffusers/models/transformers/transformer_qwenimage.py +645 -0
- diffusers/models/transformers/transformer_sd3.py +7 -7
- diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/transformers/transformer_wan.py +316 -87
- diffusers/models/transformers/transformer_wan_vace.py +387 -0
- diffusers/models/unets/unet_1d.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +1 -1
- diffusers/models/unets/unet_2d_blocks.py +1 -1
- diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
- diffusers/models/unets/unet_2d_condition.py +4 -3
- diffusers/models/unets/unet_2d_condition_flax.py +2 -2
- diffusers/models/unets/unet_3d_blocks.py +1 -1
- diffusers/models/unets/unet_3d_condition.py +3 -3
- diffusers/models/unets/unet_i2vgen_xl.py +3 -3
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +2 -2
- diffusers/models/unets/unet_stable_cascade.py +1 -1
- diffusers/models/upsampling.py +2 -2
- diffusers/models/vae_flax.py +2 -2
- diffusers/models/vq_model.py +1 -1
- diffusers/modular_pipelines/__init__.py +83 -0
- diffusers/modular_pipelines/components_manager.py +1068 -0
- diffusers/modular_pipelines/flux/__init__.py +66 -0
- diffusers/modular_pipelines/flux/before_denoise.py +689 -0
- diffusers/modular_pipelines/flux/decoders.py +109 -0
- diffusers/modular_pipelines/flux/denoise.py +227 -0
- diffusers/modular_pipelines/flux/encoders.py +412 -0
- diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
- diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
- diffusers/modular_pipelines/modular_pipeline.py +2446 -0
- diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
- diffusers/modular_pipelines/node_utils.py +665 -0
- diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
- diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
- diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
- diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
- diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
- diffusers/modular_pipelines/wan/__init__.py +66 -0
- diffusers/modular_pipelines/wan/before_denoise.py +365 -0
- diffusers/modular_pipelines/wan/decoders.py +105 -0
- diffusers/modular_pipelines/wan/denoise.py +261 -0
- diffusers/modular_pipelines/wan/encoders.py +242 -0
- diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
- diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
- diffusers/pipelines/__init__.py +68 -6
- diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
- diffusers/pipelines/amused/pipeline_amused.py +7 -6
- diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
- diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
- diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +22 -13
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
- diffusers/pipelines/auto_pipeline.py +23 -20
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
- diffusers/pipelines/chroma/__init__.py +49 -0
- diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
- diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
- diffusers/pipelines/chroma/pipeline_output.py +21 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +17 -16
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +17 -16
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +18 -17
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +17 -16
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
- diffusers/pipelines/cogview4/pipeline_cogview4.py +23 -22
- diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
- diffusers/pipelines/consisid/consisid_utils.py +2 -2
- diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +11 -10
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +226 -107
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +12 -8
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +207 -105
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
- diffusers/pipelines/cosmos/__init__.py +54 -0
- diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
- diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
- diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
- diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
- diffusers/pipelines/cosmos/pipeline_output.py +40 -0
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
- diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
- diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
- diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
- diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
- diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
- diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
- diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +8 -8
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
- diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +4 -2
- diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
- diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
- diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
- diffusers/pipelines/flux/__init__.py +4 -0
- diffusers/pipelines/flux/modeling_flux.py +1 -1
- diffusers/pipelines/flux/pipeline_flux.py +37 -36
- diffusers/pipelines/flux/pipeline_flux_control.py +9 -9
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +31 -23
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +3 -2
- diffusers/pipelines/flux/pipeline_flux_fill.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_img2img.py +40 -7
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +12 -7
- diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
- diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +2 -2
- diffusers/pipelines/flux/pipeline_output.py +6 -4
- diffusers/pipelines/free_init_utils.py +2 -2
- diffusers/pipelines/free_noise_utils.py +3 -3
- diffusers/pipelines/hidream_image/__init__.py +47 -0
- diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
- diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
- diffusers/pipelines/hunyuan_video/__init__.py +2 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +26 -25
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
- diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
- diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
- diffusers/pipelines/kolors/text_encoder.py +3 -3
- diffusers/pipelines/kolors/tokenizer.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
- diffusers/pipelines/latte/pipeline_latte.py +12 -12
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
- diffusers/pipelines/ltx/__init__.py +4 -0
- diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +64 -18
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +117 -38
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +63 -18
- diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
- diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
- diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
- diffusers/pipelines/mochi/pipeline_mochi.py +15 -14
- diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
- diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
- diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
- diffusers/pipelines/onnx_utils.py +15 -2
- diffusers/pipelines/pag/pag_utils.py +2 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
- diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
- diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
- diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
- diffusers/pipelines/pia/pipeline_pia.py +8 -6
- diffusers/pipelines/pipeline_flax_utils.py +5 -6
- diffusers/pipelines/pipeline_loading_utils.py +113 -15
- diffusers/pipelines/pipeline_utils.py +127 -48
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +14 -12
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +31 -11
- diffusers/pipelines/qwenimage/__init__.py +55 -0
- diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
- diffusers/pipelines/sana/__init__.py +4 -0
- diffusers/pipelines/sana/pipeline_sana.py +23 -21
- diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
- diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
- diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
- diffusers/pipelines/shap_e/camera.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
- diffusers/pipelines/shap_e/renderer.py +3 -3
- diffusers/pipelines/skyreels_v2/__init__.py +59 -0
- diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
- diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
- diffusers/pipelines/stable_diffusion/__init__.py +0 -7
- diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +12 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +11 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
- diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
- diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +13 -12
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
- diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
- diffusers/pipelines/unclip/text_proj.py +2 -2
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
- diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
- diffusers/pipelines/visualcloze/__init__.py +52 -0
- diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
- diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
- diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
- diffusers/pipelines/wan/__init__.py +2 -0
- diffusers/pipelines/wan/pipeline_wan.py +91 -30
- diffusers/pipelines/wan/pipeline_wan_i2v.py +145 -45
- diffusers/pipelines/wan/pipeline_wan_vace.py +975 -0
- diffusers/pipelines/wan/pipeline_wan_video2video.py +14 -16
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
- diffusers/quantizers/__init__.py +3 -1
- diffusers/quantizers/base.py +17 -1
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
- diffusers/quantizers/bitsandbytes/utils.py +10 -7
- diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
- diffusers/quantizers/gguf/utils.py +108 -16
- diffusers/quantizers/pipe_quant_config.py +202 -0
- diffusers/quantizers/quantization_config.py +18 -16
- diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +31 -1
- diffusers/schedulers/__init__.py +3 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
- diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
- diffusers/schedulers/scheduling_consistency_models.py +1 -1
- diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
- diffusers/schedulers/scheduling_ddim.py +8 -8
- diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
- diffusers/schedulers/scheduling_ddim_flax.py +6 -6
- diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
- diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
- diffusers/schedulers/scheduling_ddpm.py +9 -9
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
- diffusers/schedulers/scheduling_deis_multistep.py +16 -9
- diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +18 -12
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
- diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +19 -13
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
- diffusers/schedulers/scheduling_edm_euler.py +20 -11
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
- diffusers/schedulers/scheduling_euler_discrete.py +3 -3
- diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
- diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
- diffusers/schedulers/scheduling_heun_discrete.py +2 -2
- diffusers/schedulers/scheduling_ipndm.py +2 -2
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
- diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
- diffusers/schedulers/scheduling_lcm.py +3 -3
- diffusers/schedulers/scheduling_lms_discrete.py +2 -2
- diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
- diffusers/schedulers/scheduling_pndm.py +4 -4
- diffusers/schedulers/scheduling_pndm_flax.py +4 -4
- diffusers/schedulers/scheduling_repaint.py +9 -9
- diffusers/schedulers/scheduling_sasolver.py +15 -15
- diffusers/schedulers/scheduling_scm.py +1 -2
- diffusers/schedulers/scheduling_sde_ve.py +1 -1
- diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
- diffusers/schedulers/scheduling_tcd.py +3 -3
- diffusers/schedulers/scheduling_unclip.py +5 -5
- diffusers/schedulers/scheduling_unipc_multistep.py +21 -12
- diffusers/schedulers/scheduling_utils.py +3 -3
- diffusers/schedulers/scheduling_utils_flax.py +2 -2
- diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
- diffusers/training_utils.py +91 -5
- diffusers/utils/__init__.py +15 -0
- diffusers/utils/accelerate_utils.py +1 -1
- diffusers/utils/constants.py +4 -0
- diffusers/utils/doc_utils.py +1 -1
- diffusers/utils/dummy_pt_objects.py +432 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +480 -0
- diffusers/utils/dynamic_modules_utils.py +85 -8
- diffusers/utils/export_utils.py +1 -1
- diffusers/utils/hub_utils.py +33 -17
- diffusers/utils/import_utils.py +151 -18
- diffusers/utils/logging.py +1 -1
- diffusers/utils/outputs.py +2 -1
- diffusers/utils/peft_utils.py +96 -10
- diffusers/utils/state_dict_utils.py +20 -3
- diffusers/utils/testing_utils.py +195 -17
- diffusers/utils/torch_utils.py +43 -5
- diffusers/video_processor.py +2 -2
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/METADATA +72 -57
- diffusers-0.35.0.dist-info/RECORD +703 -0
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/WHEEL +1 -1
- diffusers-0.33.1.dist-info/RECORD +0 -608
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -34,6 +34,103 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|
34
34
|
CACHE_T = 2
|
35
35
|
|
36
36
|
|
37
|
+
class AvgDown3D(nn.Module):
|
38
|
+
def __init__(
|
39
|
+
self,
|
40
|
+
in_channels,
|
41
|
+
out_channels,
|
42
|
+
factor_t,
|
43
|
+
factor_s=1,
|
44
|
+
):
|
45
|
+
super().__init__()
|
46
|
+
self.in_channels = in_channels
|
47
|
+
self.out_channels = out_channels
|
48
|
+
self.factor_t = factor_t
|
49
|
+
self.factor_s = factor_s
|
50
|
+
self.factor = self.factor_t * self.factor_s * self.factor_s
|
51
|
+
|
52
|
+
assert in_channels * self.factor % out_channels == 0
|
53
|
+
self.group_size = in_channels * self.factor // out_channels
|
54
|
+
|
55
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
56
|
+
pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
|
57
|
+
pad = (0, 0, 0, 0, pad_t, 0)
|
58
|
+
x = F.pad(x, pad)
|
59
|
+
B, C, T, H, W = x.shape
|
60
|
+
x = x.view(
|
61
|
+
B,
|
62
|
+
C,
|
63
|
+
T // self.factor_t,
|
64
|
+
self.factor_t,
|
65
|
+
H // self.factor_s,
|
66
|
+
self.factor_s,
|
67
|
+
W // self.factor_s,
|
68
|
+
self.factor_s,
|
69
|
+
)
|
70
|
+
x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
|
71
|
+
x = x.view(
|
72
|
+
B,
|
73
|
+
C * self.factor,
|
74
|
+
T // self.factor_t,
|
75
|
+
H // self.factor_s,
|
76
|
+
W // self.factor_s,
|
77
|
+
)
|
78
|
+
x = x.view(
|
79
|
+
B,
|
80
|
+
self.out_channels,
|
81
|
+
self.group_size,
|
82
|
+
T // self.factor_t,
|
83
|
+
H // self.factor_s,
|
84
|
+
W // self.factor_s,
|
85
|
+
)
|
86
|
+
x = x.mean(dim=2)
|
87
|
+
return x
|
88
|
+
|
89
|
+
|
90
|
+
class DupUp3D(nn.Module):
|
91
|
+
def __init__(
|
92
|
+
self,
|
93
|
+
in_channels: int,
|
94
|
+
out_channels: int,
|
95
|
+
factor_t,
|
96
|
+
factor_s=1,
|
97
|
+
):
|
98
|
+
super().__init__()
|
99
|
+
self.in_channels = in_channels
|
100
|
+
self.out_channels = out_channels
|
101
|
+
|
102
|
+
self.factor_t = factor_t
|
103
|
+
self.factor_s = factor_s
|
104
|
+
self.factor = self.factor_t * self.factor_s * self.factor_s
|
105
|
+
|
106
|
+
assert out_channels * self.factor % in_channels == 0
|
107
|
+
self.repeats = out_channels * self.factor // in_channels
|
108
|
+
|
109
|
+
def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
|
110
|
+
x = x.repeat_interleave(self.repeats, dim=1)
|
111
|
+
x = x.view(
|
112
|
+
x.size(0),
|
113
|
+
self.out_channels,
|
114
|
+
self.factor_t,
|
115
|
+
self.factor_s,
|
116
|
+
self.factor_s,
|
117
|
+
x.size(2),
|
118
|
+
x.size(3),
|
119
|
+
x.size(4),
|
120
|
+
)
|
121
|
+
x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
|
122
|
+
x = x.view(
|
123
|
+
x.size(0),
|
124
|
+
self.out_channels,
|
125
|
+
x.size(2) * self.factor_t,
|
126
|
+
x.size(4) * self.factor_s,
|
127
|
+
x.size(6) * self.factor_s,
|
128
|
+
)
|
129
|
+
if first_chunk:
|
130
|
+
x = x[:, :, self.factor_t - 1 :, :, :]
|
131
|
+
return x
|
132
|
+
|
133
|
+
|
37
134
|
class WanCausalConv3d(nn.Conv3d):
|
38
135
|
r"""
|
39
136
|
A custom 3D causal convolution layer with feature caching support.
|
@@ -134,19 +231,25 @@ class WanResample(nn.Module):
|
|
134
231
|
- 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.
|
135
232
|
"""
|
136
233
|
|
137
|
-
def __init__(self, dim: int, mode: str) -> None:
|
234
|
+
def __init__(self, dim: int, mode: str, upsample_out_dim: int = None) -> None:
|
138
235
|
super().__init__()
|
139
236
|
self.dim = dim
|
140
237
|
self.mode = mode
|
141
238
|
|
239
|
+
# default to dim //2
|
240
|
+
if upsample_out_dim is None:
|
241
|
+
upsample_out_dim = dim // 2
|
242
|
+
|
142
243
|
# layers
|
143
244
|
if mode == "upsample2d":
|
144
245
|
self.resample = nn.Sequential(
|
145
|
-
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
|
246
|
+
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
|
247
|
+
nn.Conv2d(dim, upsample_out_dim, 3, padding=1),
|
146
248
|
)
|
147
249
|
elif mode == "upsample3d":
|
148
250
|
self.resample = nn.Sequential(
|
149
|
-
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
|
251
|
+
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
|
252
|
+
nn.Conv2d(dim, upsample_out_dim, 3, padding=1),
|
150
253
|
)
|
151
254
|
self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
152
255
|
|
@@ -363,6 +466,42 @@ class WanMidBlock(nn.Module):
|
|
363
466
|
return x
|
364
467
|
|
365
468
|
|
469
|
+
class WanResidualDownBlock(nn.Module):
|
470
|
+
def __init__(self, in_dim, out_dim, dropout, num_res_blocks, temperal_downsample=False, down_flag=False):
|
471
|
+
super().__init__()
|
472
|
+
|
473
|
+
# Shortcut path with downsample
|
474
|
+
self.avg_shortcut = AvgDown3D(
|
475
|
+
in_dim,
|
476
|
+
out_dim,
|
477
|
+
factor_t=2 if temperal_downsample else 1,
|
478
|
+
factor_s=2 if down_flag else 1,
|
479
|
+
)
|
480
|
+
|
481
|
+
# Main path with residual blocks and downsample
|
482
|
+
resnets = []
|
483
|
+
for _ in range(num_res_blocks):
|
484
|
+
resnets.append(WanResidualBlock(in_dim, out_dim, dropout))
|
485
|
+
in_dim = out_dim
|
486
|
+
self.resnets = nn.ModuleList(resnets)
|
487
|
+
|
488
|
+
# Add the final downsample block
|
489
|
+
if down_flag:
|
490
|
+
mode = "downsample3d" if temperal_downsample else "downsample2d"
|
491
|
+
self.downsampler = WanResample(out_dim, mode=mode)
|
492
|
+
else:
|
493
|
+
self.downsampler = None
|
494
|
+
|
495
|
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
496
|
+
x_copy = x.clone()
|
497
|
+
for resnet in self.resnets:
|
498
|
+
x = resnet(x, feat_cache, feat_idx)
|
499
|
+
if self.downsampler is not None:
|
500
|
+
x = self.downsampler(x, feat_cache, feat_idx)
|
501
|
+
|
502
|
+
return x + self.avg_shortcut(x_copy)
|
503
|
+
|
504
|
+
|
366
505
|
class WanEncoder3d(nn.Module):
|
367
506
|
r"""
|
368
507
|
A 3D encoder module.
|
@@ -380,6 +519,7 @@ class WanEncoder3d(nn.Module):
|
|
380
519
|
|
381
520
|
def __init__(
|
382
521
|
self,
|
522
|
+
in_channels: int = 3,
|
383
523
|
dim=128,
|
384
524
|
z_dim=4,
|
385
525
|
dim_mult=[1, 2, 4, 4],
|
@@ -388,6 +528,7 @@ class WanEncoder3d(nn.Module):
|
|
388
528
|
temperal_downsample=[True, True, False],
|
389
529
|
dropout=0.0,
|
390
530
|
non_linearity: str = "silu",
|
531
|
+
is_residual: bool = False, # wan 2.2 vae use a residual downblock
|
391
532
|
):
|
392
533
|
super().__init__()
|
393
534
|
self.dim = dim
|
@@ -403,23 +544,35 @@ class WanEncoder3d(nn.Module):
|
|
403
544
|
scale = 1.0
|
404
545
|
|
405
546
|
# init block
|
406
|
-
self.conv_in = WanCausalConv3d(
|
547
|
+
self.conv_in = WanCausalConv3d(in_channels, dims[0], 3, padding=1)
|
407
548
|
|
408
549
|
# downsample blocks
|
409
550
|
self.down_blocks = nn.ModuleList([])
|
410
551
|
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
411
552
|
# residual (+attention) blocks
|
412
|
-
|
413
|
-
self.down_blocks.append(
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
553
|
+
if is_residual:
|
554
|
+
self.down_blocks.append(
|
555
|
+
WanResidualDownBlock(
|
556
|
+
in_dim,
|
557
|
+
out_dim,
|
558
|
+
dropout,
|
559
|
+
num_res_blocks,
|
560
|
+
temperal_downsample=temperal_downsample[i] if i != len(dim_mult) - 1 else False,
|
561
|
+
down_flag=i != len(dim_mult) - 1,
|
562
|
+
)
|
563
|
+
)
|
564
|
+
else:
|
565
|
+
for _ in range(num_res_blocks):
|
566
|
+
self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout))
|
567
|
+
if scale in attn_scales:
|
568
|
+
self.down_blocks.append(WanAttentionBlock(out_dim))
|
569
|
+
in_dim = out_dim
|
570
|
+
|
571
|
+
# downsample block
|
572
|
+
if i != len(dim_mult) - 1:
|
573
|
+
mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
|
574
|
+
self.down_blocks.append(WanResample(out_dim, mode=mode))
|
575
|
+
scale /= 2.0
|
423
576
|
|
424
577
|
# middle blocks
|
425
578
|
self.mid_block = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1)
|
@@ -470,6 +623,94 @@ class WanEncoder3d(nn.Module):
|
|
470
623
|
return x
|
471
624
|
|
472
625
|
|
626
|
+
class WanResidualUpBlock(nn.Module):
|
627
|
+
"""
|
628
|
+
A block that handles upsampling for the WanVAE decoder.
|
629
|
+
|
630
|
+
Args:
|
631
|
+
in_dim (int): Input dimension
|
632
|
+
out_dim (int): Output dimension
|
633
|
+
num_res_blocks (int): Number of residual blocks
|
634
|
+
dropout (float): Dropout rate
|
635
|
+
temperal_upsample (bool): Whether to upsample on temporal dimension
|
636
|
+
up_flag (bool): Whether to upsample or not
|
637
|
+
non_linearity (str): Type of non-linearity to use
|
638
|
+
"""
|
639
|
+
|
640
|
+
def __init__(
|
641
|
+
self,
|
642
|
+
in_dim: int,
|
643
|
+
out_dim: int,
|
644
|
+
num_res_blocks: int,
|
645
|
+
dropout: float = 0.0,
|
646
|
+
temperal_upsample: bool = False,
|
647
|
+
up_flag: bool = False,
|
648
|
+
non_linearity: str = "silu",
|
649
|
+
):
|
650
|
+
super().__init__()
|
651
|
+
self.in_dim = in_dim
|
652
|
+
self.out_dim = out_dim
|
653
|
+
|
654
|
+
if up_flag:
|
655
|
+
self.avg_shortcut = DupUp3D(
|
656
|
+
in_dim,
|
657
|
+
out_dim,
|
658
|
+
factor_t=2 if temperal_upsample else 1,
|
659
|
+
factor_s=2,
|
660
|
+
)
|
661
|
+
else:
|
662
|
+
self.avg_shortcut = None
|
663
|
+
|
664
|
+
# create residual blocks
|
665
|
+
resnets = []
|
666
|
+
current_dim = in_dim
|
667
|
+
for _ in range(num_res_blocks + 1):
|
668
|
+
resnets.append(WanResidualBlock(current_dim, out_dim, dropout, non_linearity))
|
669
|
+
current_dim = out_dim
|
670
|
+
|
671
|
+
self.resnets = nn.ModuleList(resnets)
|
672
|
+
|
673
|
+
# Add upsampling layer if needed
|
674
|
+
if up_flag:
|
675
|
+
upsample_mode = "upsample3d" if temperal_upsample else "upsample2d"
|
676
|
+
self.upsampler = WanResample(out_dim, mode=upsample_mode, upsample_out_dim=out_dim)
|
677
|
+
else:
|
678
|
+
self.upsampler = None
|
679
|
+
|
680
|
+
self.gradient_checkpointing = False
|
681
|
+
|
682
|
+
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
|
683
|
+
"""
|
684
|
+
Forward pass through the upsampling block.
|
685
|
+
|
686
|
+
Args:
|
687
|
+
x (torch.Tensor): Input tensor
|
688
|
+
feat_cache (list, optional): Feature cache for causal convolutions
|
689
|
+
feat_idx (list, optional): Feature index for cache management
|
690
|
+
|
691
|
+
Returns:
|
692
|
+
torch.Tensor: Output tensor
|
693
|
+
"""
|
694
|
+
x_copy = x.clone()
|
695
|
+
|
696
|
+
for resnet in self.resnets:
|
697
|
+
if feat_cache is not None:
|
698
|
+
x = resnet(x, feat_cache, feat_idx)
|
699
|
+
else:
|
700
|
+
x = resnet(x)
|
701
|
+
|
702
|
+
if self.upsampler is not None:
|
703
|
+
if feat_cache is not None:
|
704
|
+
x = self.upsampler(x, feat_cache, feat_idx)
|
705
|
+
else:
|
706
|
+
x = self.upsampler(x)
|
707
|
+
|
708
|
+
if self.avg_shortcut is not None:
|
709
|
+
x = x + self.avg_shortcut(x_copy, first_chunk=first_chunk)
|
710
|
+
|
711
|
+
return x
|
712
|
+
|
713
|
+
|
473
714
|
class WanUpBlock(nn.Module):
|
474
715
|
"""
|
475
716
|
A block that handles upsampling for the WanVAE decoder.
|
@@ -513,7 +754,7 @@ class WanUpBlock(nn.Module):
|
|
513
754
|
|
514
755
|
self.gradient_checkpointing = False
|
515
756
|
|
516
|
-
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
757
|
+
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=None):
|
517
758
|
"""
|
518
759
|
Forward pass through the upsampling block.
|
519
760
|
|
@@ -564,6 +805,8 @@ class WanDecoder3d(nn.Module):
|
|
564
805
|
temperal_upsample=[False, True, True],
|
565
806
|
dropout=0.0,
|
566
807
|
non_linearity: str = "silu",
|
808
|
+
out_channels: int = 3,
|
809
|
+
is_residual: bool = False,
|
567
810
|
):
|
568
811
|
super().__init__()
|
569
812
|
self.dim = dim
|
@@ -577,7 +820,6 @@ class WanDecoder3d(nn.Module):
|
|
577
820
|
|
578
821
|
# dimensions
|
579
822
|
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
580
|
-
scale = 1.0 / 2 ** (len(dim_mult) - 2)
|
581
823
|
|
582
824
|
# init block
|
583
825
|
self.conv_in = WanCausalConv3d(z_dim, dims[0], 3, padding=1)
|
@@ -589,36 +831,47 @@ class WanDecoder3d(nn.Module):
|
|
589
831
|
self.up_blocks = nn.ModuleList([])
|
590
832
|
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
591
833
|
# residual (+attention) blocks
|
592
|
-
if i > 0:
|
834
|
+
if i > 0 and not is_residual:
|
835
|
+
# wan vae 2.1
|
593
836
|
in_dim = in_dim // 2
|
594
837
|
|
595
|
-
#
|
838
|
+
# determine if we need upsampling
|
839
|
+
up_flag = i != len(dim_mult) - 1
|
840
|
+
# determine upsampling mode, if not upsampling, set to None
|
596
841
|
upsample_mode = None
|
597
|
-
if
|
598
|
-
upsample_mode = "upsample3d"
|
599
|
-
|
842
|
+
if up_flag and temperal_upsample[i]:
|
843
|
+
upsample_mode = "upsample3d"
|
844
|
+
elif up_flag:
|
845
|
+
upsample_mode = "upsample2d"
|
600
846
|
# Create and add the upsampling block
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
847
|
+
if is_residual:
|
848
|
+
up_block = WanResidualUpBlock(
|
849
|
+
in_dim=in_dim,
|
850
|
+
out_dim=out_dim,
|
851
|
+
num_res_blocks=num_res_blocks,
|
852
|
+
dropout=dropout,
|
853
|
+
temperal_upsample=temperal_upsample[i] if up_flag else False,
|
854
|
+
up_flag=up_flag,
|
855
|
+
non_linearity=non_linearity,
|
856
|
+
)
|
857
|
+
else:
|
858
|
+
up_block = WanUpBlock(
|
859
|
+
in_dim=in_dim,
|
860
|
+
out_dim=out_dim,
|
861
|
+
num_res_blocks=num_res_blocks,
|
862
|
+
dropout=dropout,
|
863
|
+
upsample_mode=upsample_mode,
|
864
|
+
non_linearity=non_linearity,
|
865
|
+
)
|
609
866
|
self.up_blocks.append(up_block)
|
610
867
|
|
611
|
-
# Update scale for next iteration
|
612
|
-
if upsample_mode is not None:
|
613
|
-
scale *= 2.0
|
614
|
-
|
615
868
|
# output blocks
|
616
869
|
self.norm_out = WanRMS_norm(out_dim, images=False)
|
617
|
-
self.conv_out = WanCausalConv3d(out_dim,
|
870
|
+
self.conv_out = WanCausalConv3d(out_dim, out_channels, 3, padding=1)
|
618
871
|
|
619
872
|
self.gradient_checkpointing = False
|
620
873
|
|
621
|
-
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
874
|
+
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
|
622
875
|
## conv1
|
623
876
|
if feat_cache is not None:
|
624
877
|
idx = feat_idx[0]
|
@@ -637,7 +890,7 @@ class WanDecoder3d(nn.Module):
|
|
637
890
|
|
638
891
|
## upsamples
|
639
892
|
for up_block in self.up_blocks:
|
640
|
-
x = up_block(x, feat_cache, feat_idx)
|
893
|
+
x = up_block(x, feat_cache, feat_idx, first_chunk=first_chunk)
|
641
894
|
|
642
895
|
## head
|
643
896
|
x = self.norm_out(x)
|
@@ -656,6 +909,49 @@ class WanDecoder3d(nn.Module):
|
|
656
909
|
return x
|
657
910
|
|
658
911
|
|
912
|
+
def patchify(x, patch_size):
|
913
|
+
if patch_size == 1:
|
914
|
+
return x
|
915
|
+
|
916
|
+
if x.dim() != 5:
|
917
|
+
raise ValueError(f"Invalid input shape: {x.shape}")
|
918
|
+
# x shape: [batch_size, channels, frames, height, width]
|
919
|
+
batch_size, channels, frames, height, width = x.shape
|
920
|
+
|
921
|
+
# Ensure height and width are divisible by patch_size
|
922
|
+
if height % patch_size != 0 or width % patch_size != 0:
|
923
|
+
raise ValueError(f"Height ({height}) and width ({width}) must be divisible by patch_size ({patch_size})")
|
924
|
+
|
925
|
+
# Reshape to [batch_size, channels, frames, height//patch_size, patch_size, width//patch_size, patch_size]
|
926
|
+
x = x.view(batch_size, channels, frames, height // patch_size, patch_size, width // patch_size, patch_size)
|
927
|
+
|
928
|
+
# Rearrange to [batch_size, channels * patch_size * patch_size, frames, height//patch_size, width//patch_size]
|
929
|
+
x = x.permute(0, 1, 6, 4, 2, 3, 5).contiguous()
|
930
|
+
x = x.view(batch_size, channels * patch_size * patch_size, frames, height // patch_size, width // patch_size)
|
931
|
+
|
932
|
+
return x
|
933
|
+
|
934
|
+
|
935
|
+
def unpatchify(x, patch_size):
|
936
|
+
if patch_size == 1:
|
937
|
+
return x
|
938
|
+
|
939
|
+
if x.dim() != 5:
|
940
|
+
raise ValueError(f"Invalid input shape: {x.shape}")
|
941
|
+
# x shape: [batch_size, (channels * patch_size * patch_size), frame, height, width]
|
942
|
+
batch_size, c_patches, frames, height, width = x.shape
|
943
|
+
channels = c_patches // (patch_size * patch_size)
|
944
|
+
|
945
|
+
# Reshape to [b, c, patch_size, patch_size, f, h, w]
|
946
|
+
x = x.view(batch_size, channels, patch_size, patch_size, frames, height, width)
|
947
|
+
|
948
|
+
# Rearrange to [b, c, f, h * patch_size, w * patch_size]
|
949
|
+
x = x.permute(0, 1, 4, 5, 3, 6, 2).contiguous()
|
950
|
+
x = x.view(batch_size, channels, frames, height * patch_size, width * patch_size)
|
951
|
+
|
952
|
+
return x
|
953
|
+
|
954
|
+
|
659
955
|
class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
660
956
|
r"""
|
661
957
|
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
|
@@ -671,6 +967,7 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
671
967
|
def __init__(
|
672
968
|
self,
|
673
969
|
base_dim: int = 96,
|
970
|
+
decoder_base_dim: Optional[int] = None,
|
674
971
|
z_dim: int = 16,
|
675
972
|
dim_mult: Tuple[int] = [1, 2, 4, 4],
|
676
973
|
num_res_blocks: int = 2,
|
@@ -713,6 +1010,12 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
713
1010
|
2.8251,
|
714
1011
|
1.9160,
|
715
1012
|
],
|
1013
|
+
is_residual: bool = False,
|
1014
|
+
in_channels: int = 3,
|
1015
|
+
out_channels: int = 3,
|
1016
|
+
patch_size: Optional[int] = None,
|
1017
|
+
scale_factor_temporal: Optional[int] = 4,
|
1018
|
+
scale_factor_spatial: Optional[int] = 8,
|
716
1019
|
) -> None:
|
717
1020
|
super().__init__()
|
718
1021
|
|
@@ -720,37 +1023,135 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
720
1023
|
self.temperal_downsample = temperal_downsample
|
721
1024
|
self.temperal_upsample = temperal_downsample[::-1]
|
722
1025
|
|
1026
|
+
if decoder_base_dim is None:
|
1027
|
+
decoder_base_dim = base_dim
|
1028
|
+
|
723
1029
|
self.encoder = WanEncoder3d(
|
724
|
-
|
1030
|
+
in_channels=in_channels,
|
1031
|
+
dim=base_dim,
|
1032
|
+
z_dim=z_dim * 2,
|
1033
|
+
dim_mult=dim_mult,
|
1034
|
+
num_res_blocks=num_res_blocks,
|
1035
|
+
attn_scales=attn_scales,
|
1036
|
+
temperal_downsample=temperal_downsample,
|
1037
|
+
dropout=dropout,
|
1038
|
+
is_residual=is_residual,
|
725
1039
|
)
|
726
1040
|
self.quant_conv = WanCausalConv3d(z_dim * 2, z_dim * 2, 1)
|
727
1041
|
self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1)
|
728
1042
|
|
729
1043
|
self.decoder = WanDecoder3d(
|
730
|
-
|
1044
|
+
dim=decoder_base_dim,
|
1045
|
+
z_dim=z_dim,
|
1046
|
+
dim_mult=dim_mult,
|
1047
|
+
num_res_blocks=num_res_blocks,
|
1048
|
+
attn_scales=attn_scales,
|
1049
|
+
temperal_upsample=self.temperal_upsample,
|
1050
|
+
dropout=dropout,
|
1051
|
+
out_channels=out_channels,
|
1052
|
+
is_residual=is_residual,
|
731
1053
|
)
|
732
1054
|
|
1055
|
+
self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)
|
1056
|
+
|
1057
|
+
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
|
1058
|
+
# to perform decoding of a single video latent at a time.
|
1059
|
+
self.use_slicing = False
|
1060
|
+
|
1061
|
+
# When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
|
1062
|
+
# frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
|
1063
|
+
# intermediate tiles together, the memory requirement can be lowered.
|
1064
|
+
self.use_tiling = False
|
1065
|
+
|
1066
|
+
# The minimal tile height and width for spatial tiling to be used
|
1067
|
+
self.tile_sample_min_height = 256
|
1068
|
+
self.tile_sample_min_width = 256
|
1069
|
+
|
1070
|
+
# The minimal distance between two spatial tiles
|
1071
|
+
self.tile_sample_stride_height = 192
|
1072
|
+
self.tile_sample_stride_width = 192
|
1073
|
+
|
1074
|
+
# Precompute and cache conv counts for encoder and decoder for clear_cache speedup
|
1075
|
+
self._cached_conv_counts = {
|
1076
|
+
"decoder": sum(isinstance(m, WanCausalConv3d) for m in self.decoder.modules())
|
1077
|
+
if self.decoder is not None
|
1078
|
+
else 0,
|
1079
|
+
"encoder": sum(isinstance(m, WanCausalConv3d) for m in self.encoder.modules())
|
1080
|
+
if self.encoder is not None
|
1081
|
+
else 0,
|
1082
|
+
}
|
1083
|
+
|
1084
|
+
def enable_tiling(
|
1085
|
+
self,
|
1086
|
+
tile_sample_min_height: Optional[int] = None,
|
1087
|
+
tile_sample_min_width: Optional[int] = None,
|
1088
|
+
tile_sample_stride_height: Optional[float] = None,
|
1089
|
+
tile_sample_stride_width: Optional[float] = None,
|
1090
|
+
) -> None:
|
1091
|
+
r"""
|
1092
|
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
1093
|
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
1094
|
+
processing larger images.
|
1095
|
+
|
1096
|
+
Args:
|
1097
|
+
tile_sample_min_height (`int`, *optional*):
|
1098
|
+
The minimum height required for a sample to be separated into tiles across the height dimension.
|
1099
|
+
tile_sample_min_width (`int`, *optional*):
|
1100
|
+
The minimum width required for a sample to be separated into tiles across the width dimension.
|
1101
|
+
tile_sample_stride_height (`int`, *optional*):
|
1102
|
+
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
|
1103
|
+
no tiling artifacts produced across the height dimension.
|
1104
|
+
tile_sample_stride_width (`int`, *optional*):
|
1105
|
+
The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
|
1106
|
+
artifacts produced across the width dimension.
|
1107
|
+
"""
|
1108
|
+
self.use_tiling = True
|
1109
|
+
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
|
1110
|
+
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
|
1111
|
+
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
|
1112
|
+
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
|
1113
|
+
|
1114
|
+
def disable_tiling(self) -> None:
|
1115
|
+
r"""
|
1116
|
+
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
1117
|
+
decoding in one step.
|
1118
|
+
"""
|
1119
|
+
self.use_tiling = False
|
1120
|
+
|
1121
|
+
def enable_slicing(self) -> None:
|
1122
|
+
r"""
|
1123
|
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
1124
|
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
1125
|
+
"""
|
1126
|
+
self.use_slicing = True
|
1127
|
+
|
1128
|
+
def disable_slicing(self) -> None:
|
1129
|
+
r"""
|
1130
|
+
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
1131
|
+
decoding in one step.
|
1132
|
+
"""
|
1133
|
+
self.use_slicing = False
|
1134
|
+
|
733
1135
|
def clear_cache(self):
|
734
|
-
|
735
|
-
|
736
|
-
for m in model.modules():
|
737
|
-
if isinstance(m, WanCausalConv3d):
|
738
|
-
count += 1
|
739
|
-
return count
|
740
|
-
|
741
|
-
self._conv_num = _count_conv3d(self.decoder)
|
1136
|
+
# Use cached conv counts for decoder and encoder to avoid re-iterating modules each call
|
1137
|
+
self._conv_num = self._cached_conv_counts["decoder"]
|
742
1138
|
self._conv_idx = [0]
|
743
1139
|
self._feat_map = [None] * self._conv_num
|
744
1140
|
# cache encode
|
745
|
-
self._enc_conv_num =
|
1141
|
+
self._enc_conv_num = self._cached_conv_counts["encoder"]
|
746
1142
|
self._enc_conv_idx = [0]
|
747
1143
|
self._enc_feat_map = [None] * self._enc_conv_num
|
748
1144
|
|
749
|
-
def _encode(self, x: torch.Tensor)
|
1145
|
+
def _encode(self, x: torch.Tensor):
|
1146
|
+
_, _, num_frame, height, width = x.shape
|
1147
|
+
|
1148
|
+
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
|
1149
|
+
return self.tiled_encode(x)
|
1150
|
+
|
750
1151
|
self.clear_cache()
|
751
|
-
|
752
|
-
|
753
|
-
iter_ = 1 + (
|
1152
|
+
if self.config.patch_size is not None:
|
1153
|
+
x = patchify(x, patch_size=self.config.patch_size)
|
1154
|
+
iter_ = 1 + (num_frame - 1) // 4
|
754
1155
|
for i in range(iter_):
|
755
1156
|
self._enc_conv_idx = [0]
|
756
1157
|
if i == 0:
|
@@ -764,8 +1165,6 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
764
1165
|
out = torch.cat([out, out_], 2)
|
765
1166
|
|
766
1167
|
enc = self.quant_conv(out)
|
767
|
-
mu, logvar = enc[:, : self.z_dim, :, :, :], enc[:, self.z_dim :, :, :, :]
|
768
|
-
enc = torch.cat([mu, logvar], dim=1)
|
769
1168
|
self.clear_cache()
|
770
1169
|
return enc
|
771
1170
|
|
@@ -785,26 +1184,42 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
785
1184
|
The latent representations of the encoded videos. If `return_dict` is True, a
|
786
1185
|
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
787
1186
|
"""
|
788
|
-
|
1187
|
+
if self.use_slicing and x.shape[0] > 1:
|
1188
|
+
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
|
1189
|
+
h = torch.cat(encoded_slices)
|
1190
|
+
else:
|
1191
|
+
h = self._encode(x)
|
789
1192
|
posterior = DiagonalGaussianDistribution(h)
|
1193
|
+
|
790
1194
|
if not return_dict:
|
791
1195
|
return (posterior,)
|
792
1196
|
return AutoencoderKLOutput(latent_dist=posterior)
|
793
1197
|
|
794
|
-
def _decode(self, z: torch.Tensor, return_dict: bool = True)
|
795
|
-
|
1198
|
+
def _decode(self, z: torch.Tensor, return_dict: bool = True):
|
1199
|
+
_, _, num_frame, height, width = z.shape
|
1200
|
+
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
1201
|
+
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
796
1202
|
|
797
|
-
|
1203
|
+
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
|
1204
|
+
return self.tiled_decode(z, return_dict=return_dict)
|
1205
|
+
|
1206
|
+
self.clear_cache()
|
798
1207
|
x = self.post_quant_conv(z)
|
799
|
-
for i in range(
|
1208
|
+
for i in range(num_frame):
|
800
1209
|
self._conv_idx = [0]
|
801
1210
|
if i == 0:
|
802
|
-
out = self.decoder(
|
1211
|
+
out = self.decoder(
|
1212
|
+
x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=True
|
1213
|
+
)
|
803
1214
|
else:
|
804
1215
|
out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
|
805
1216
|
out = torch.cat([out, out_], 2)
|
806
1217
|
|
1218
|
+
if self.config.patch_size is not None:
|
1219
|
+
out = unpatchify(out, patch_size=self.config.patch_size)
|
1220
|
+
|
807
1221
|
out = torch.clamp(out, min=-1.0, max=1.0)
|
1222
|
+
|
808
1223
|
self.clear_cache()
|
809
1224
|
if not return_dict:
|
810
1225
|
return (out,)
|
@@ -826,12 +1241,161 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
826
1241
|
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
827
1242
|
returned.
|
828
1243
|
"""
|
829
|
-
|
1244
|
+
if self.use_slicing and z.shape[0] > 1:
|
1245
|
+
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
1246
|
+
decoded = torch.cat(decoded_slices)
|
1247
|
+
else:
|
1248
|
+
decoded = self._decode(z).sample
|
1249
|
+
|
830
1250
|
if not return_dict:
|
831
1251
|
return (decoded,)
|
832
|
-
|
833
1252
|
return DecoderOutput(sample=decoded)
|
834
1253
|
|
1254
|
+
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
1255
|
+
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
|
1256
|
+
for y in range(blend_extent):
|
1257
|
+
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
|
1258
|
+
y / blend_extent
|
1259
|
+
)
|
1260
|
+
return b
|
1261
|
+
|
1262
|
+
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
1263
|
+
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
|
1264
|
+
for x in range(blend_extent):
|
1265
|
+
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
|
1266
|
+
x / blend_extent
|
1267
|
+
)
|
1268
|
+
return b
|
1269
|
+
|
1270
|
+
def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
|
1271
|
+
r"""Encode a batch of images using a tiled encoder.
|
1272
|
+
|
1273
|
+
Args:
|
1274
|
+
x (`torch.Tensor`): Input batch of videos.
|
1275
|
+
|
1276
|
+
Returns:
|
1277
|
+
`torch.Tensor`:
|
1278
|
+
The latent representation of the encoded videos.
|
1279
|
+
"""
|
1280
|
+
_, _, num_frames, height, width = x.shape
|
1281
|
+
latent_height = height // self.spatial_compression_ratio
|
1282
|
+
latent_width = width // self.spatial_compression_ratio
|
1283
|
+
|
1284
|
+
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
1285
|
+
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
1286
|
+
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
|
1287
|
+
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
|
1288
|
+
|
1289
|
+
blend_height = tile_latent_min_height - tile_latent_stride_height
|
1290
|
+
blend_width = tile_latent_min_width - tile_latent_stride_width
|
1291
|
+
|
1292
|
+
# Split x into overlapping tiles and encode them separately.
|
1293
|
+
# The tiles have an overlap to avoid seams between tiles.
|
1294
|
+
rows = []
|
1295
|
+
for i in range(0, height, self.tile_sample_stride_height):
|
1296
|
+
row = []
|
1297
|
+
for j in range(0, width, self.tile_sample_stride_width):
|
1298
|
+
self.clear_cache()
|
1299
|
+
time = []
|
1300
|
+
frame_range = 1 + (num_frames - 1) // 4
|
1301
|
+
for k in range(frame_range):
|
1302
|
+
self._enc_conv_idx = [0]
|
1303
|
+
if k == 0:
|
1304
|
+
tile = x[:, :, :1, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
|
1305
|
+
else:
|
1306
|
+
tile = x[
|
1307
|
+
:,
|
1308
|
+
:,
|
1309
|
+
1 + 4 * (k - 1) : 1 + 4 * k,
|
1310
|
+
i : i + self.tile_sample_min_height,
|
1311
|
+
j : j + self.tile_sample_min_width,
|
1312
|
+
]
|
1313
|
+
tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
|
1314
|
+
tile = self.quant_conv(tile)
|
1315
|
+
time.append(tile)
|
1316
|
+
row.append(torch.cat(time, dim=2))
|
1317
|
+
rows.append(row)
|
1318
|
+
self.clear_cache()
|
1319
|
+
|
1320
|
+
result_rows = []
|
1321
|
+
for i, row in enumerate(rows):
|
1322
|
+
result_row = []
|
1323
|
+
for j, tile in enumerate(row):
|
1324
|
+
# blend the above tile and the left tile
|
1325
|
+
# to the current tile and add the current tile to the result row
|
1326
|
+
if i > 0:
|
1327
|
+
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
|
1328
|
+
if j > 0:
|
1329
|
+
tile = self.blend_h(row[j - 1], tile, blend_width)
|
1330
|
+
result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
|
1331
|
+
result_rows.append(torch.cat(result_row, dim=-1))
|
1332
|
+
|
1333
|
+
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
|
1334
|
+
return enc
|
1335
|
+
|
1336
|
+
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
1337
|
+
r"""
|
1338
|
+
Decode a batch of images using a tiled decoder.
|
1339
|
+
|
1340
|
+
Args:
|
1341
|
+
z (`torch.Tensor`): Input batch of latent vectors.
|
1342
|
+
return_dict (`bool`, *optional*, defaults to `True`):
|
1343
|
+
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
1344
|
+
|
1345
|
+
Returns:
|
1346
|
+
[`~models.vae.DecoderOutput`] or `tuple`:
|
1347
|
+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
1348
|
+
returned.
|
1349
|
+
"""
|
1350
|
+
_, _, num_frames, height, width = z.shape
|
1351
|
+
sample_height = height * self.spatial_compression_ratio
|
1352
|
+
sample_width = width * self.spatial_compression_ratio
|
1353
|
+
|
1354
|
+
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
1355
|
+
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
1356
|
+
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
|
1357
|
+
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
|
1358
|
+
|
1359
|
+
blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
|
1360
|
+
blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
|
1361
|
+
|
1362
|
+
# Split z into overlapping tiles and decode them separately.
|
1363
|
+
# The tiles have an overlap to avoid seams between tiles.
|
1364
|
+
rows = []
|
1365
|
+
for i in range(0, height, tile_latent_stride_height):
|
1366
|
+
row = []
|
1367
|
+
for j in range(0, width, tile_latent_stride_width):
|
1368
|
+
self.clear_cache()
|
1369
|
+
time = []
|
1370
|
+
for k in range(num_frames):
|
1371
|
+
self._conv_idx = [0]
|
1372
|
+
tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
|
1373
|
+
tile = self.post_quant_conv(tile)
|
1374
|
+
decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx)
|
1375
|
+
time.append(decoded)
|
1376
|
+
row.append(torch.cat(time, dim=2))
|
1377
|
+
rows.append(row)
|
1378
|
+
self.clear_cache()
|
1379
|
+
|
1380
|
+
result_rows = []
|
1381
|
+
for i, row in enumerate(rows):
|
1382
|
+
result_row = []
|
1383
|
+
for j, tile in enumerate(row):
|
1384
|
+
# blend the above tile and the left tile
|
1385
|
+
# to the current tile and add the current tile to the result row
|
1386
|
+
if i > 0:
|
1387
|
+
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
|
1388
|
+
if j > 0:
|
1389
|
+
tile = self.blend_h(row[j - 1], tile, blend_width)
|
1390
|
+
result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
|
1391
|
+
result_rows.append(torch.cat(result_row, dim=-1))
|
1392
|
+
|
1393
|
+
dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
|
1394
|
+
|
1395
|
+
if not return_dict:
|
1396
|
+
return (dec,)
|
1397
|
+
return DecoderOutput(sample=dec)
|
1398
|
+
|
835
1399
|
def forward(
|
836
1400
|
self,
|
837
1401
|
sample: torch.Tensor,
|