diffusers 0.33.1__py3-none-any.whl → 0.35.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +145 -1
- diffusers/callbacks.py +35 -0
- diffusers/commands/__init__.py +1 -1
- diffusers/commands/custom_blocks.py +134 -0
- diffusers/commands/diffusers_cli.py +3 -1
- diffusers/commands/env.py +1 -1
- diffusers/commands/fp16_safetensors.py +2 -2
- diffusers/configuration_utils.py +11 -2
- diffusers/dependency_versions_check.py +1 -1
- diffusers/dependency_versions_table.py +3 -3
- diffusers/experimental/rl/value_guided_sampling.py +1 -1
- diffusers/guiders/__init__.py +41 -0
- diffusers/guiders/adaptive_projected_guidance.py +188 -0
- diffusers/guiders/auto_guidance.py +190 -0
- diffusers/guiders/classifier_free_guidance.py +141 -0
- diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
- diffusers/guiders/frequency_decoupled_guidance.py +327 -0
- diffusers/guiders/guider_utils.py +309 -0
- diffusers/guiders/perturbed_attention_guidance.py +271 -0
- diffusers/guiders/skip_layer_guidance.py +262 -0
- diffusers/guiders/smoothed_energy_guidance.py +251 -0
- diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
- diffusers/hooks/__init__.py +17 -0
- diffusers/hooks/_common.py +56 -0
- diffusers/hooks/_helpers.py +293 -0
- diffusers/hooks/faster_cache.py +9 -8
- diffusers/hooks/first_block_cache.py +259 -0
- diffusers/hooks/group_offloading.py +332 -227
- diffusers/hooks/hooks.py +58 -3
- diffusers/hooks/layer_skip.py +263 -0
- diffusers/hooks/layerwise_casting.py +5 -10
- diffusers/hooks/pyramid_attention_broadcast.py +15 -12
- diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
- diffusers/hooks/utils.py +43 -0
- diffusers/image_processor.py +7 -2
- diffusers/loaders/__init__.py +10 -0
- diffusers/loaders/ip_adapter.py +260 -18
- diffusers/loaders/lora_base.py +261 -127
- diffusers/loaders/lora_conversion_utils.py +657 -35
- diffusers/loaders/lora_pipeline.py +2778 -1246
- diffusers/loaders/peft.py +78 -112
- diffusers/loaders/single_file.py +2 -2
- diffusers/loaders/single_file_model.py +64 -15
- diffusers/loaders/single_file_utils.py +395 -7
- diffusers/loaders/textual_inversion.py +3 -2
- diffusers/loaders/transformer_flux.py +10 -11
- diffusers/loaders/transformer_sd3.py +8 -3
- diffusers/loaders/unet.py +24 -21
- diffusers/loaders/unet_loader_utils.py +6 -3
- diffusers/loaders/utils.py +1 -1
- diffusers/models/__init__.py +23 -1
- diffusers/models/activations.py +5 -5
- diffusers/models/adapter.py +2 -3
- diffusers/models/attention.py +488 -7
- diffusers/models/attention_dispatch.py +1218 -0
- diffusers/models/attention_flax.py +10 -10
- diffusers/models/attention_processor.py +113 -667
- diffusers/models/auto_model.py +49 -12
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
- diffusers/models/autoencoders/autoencoder_dc.py +17 -4
- diffusers/models/autoencoders/autoencoder_kl.py +5 -5
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
- diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1110 -0
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_wan.py +626 -62
- diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
- diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
- diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
- diffusers/models/autoencoders/vae.py +13 -2
- diffusers/models/autoencoders/vq_model.py +2 -2
- diffusers/models/cache_utils.py +32 -10
- diffusers/models/controlnet.py +1 -1
- diffusers/models/controlnet_flux.py +1 -1
- diffusers/models/controlnet_sd3.py +1 -1
- diffusers/models/controlnet_sparsectrl.py +1 -1
- diffusers/models/controlnets/__init__.py +1 -0
- diffusers/models/controlnets/controlnet.py +3 -3
- diffusers/models/controlnets/controlnet_flax.py +1 -1
- diffusers/models/controlnets/controlnet_flux.py +21 -20
- diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
- diffusers/models/controlnets/controlnet_sana.py +290 -0
- diffusers/models/controlnets/controlnet_sd3.py +1 -1
- diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
- diffusers/models/controlnets/controlnet_union.py +5 -5
- diffusers/models/controlnets/controlnet_xs.py +7 -7
- diffusers/models/controlnets/multicontrolnet.py +4 -5
- diffusers/models/controlnets/multicontrolnet_union.py +5 -6
- diffusers/models/downsampling.py +2 -2
- diffusers/models/embeddings.py +36 -46
- diffusers/models/embeddings_flax.py +2 -2
- diffusers/models/lora.py +3 -3
- diffusers/models/model_loading_utils.py +233 -1
- diffusers/models/modeling_flax_utils.py +1 -2
- diffusers/models/modeling_utils.py +203 -108
- diffusers/models/normalization.py +4 -4
- diffusers/models/resnet.py +2 -2
- diffusers/models/resnet_flax.py +1 -1
- diffusers/models/transformers/__init__.py +7 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
- diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
- diffusers/models/transformers/consisid_transformer_3d.py +1 -1
- diffusers/models/transformers/dit_transformer_2d.py +2 -2
- diffusers/models/transformers/dual_transformer_2d.py +1 -1
- diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
- diffusers/models/transformers/latte_transformer_3d.py +4 -5
- diffusers/models/transformers/lumina_nextdit2d.py +2 -2
- diffusers/models/transformers/pixart_transformer_2d.py +3 -3
- diffusers/models/transformers/prior_transformer.py +1 -1
- diffusers/models/transformers/sana_transformer.py +8 -3
- diffusers/models/transformers/stable_audio_transformer.py +5 -9
- diffusers/models/transformers/t5_film_transformer.py +3 -3
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +1 -1
- diffusers/models/transformers/transformer_chroma.py +641 -0
- diffusers/models/transformers/transformer_cogview3plus.py +5 -10
- diffusers/models/transformers/transformer_cogview4.py +353 -27
- diffusers/models/transformers/transformer_cosmos.py +586 -0
- diffusers/models/transformers/transformer_flux.py +376 -138
- diffusers/models/transformers/transformer_hidream_image.py +942 -0
- diffusers/models/transformers/transformer_hunyuan_video.py +12 -8
- diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
- diffusers/models/transformers/transformer_ltx.py +105 -24
- diffusers/models/transformers/transformer_lumina2.py +1 -1
- diffusers/models/transformers/transformer_mochi.py +1 -1
- diffusers/models/transformers/transformer_omnigen.py +2 -2
- diffusers/models/transformers/transformer_qwenimage.py +645 -0
- diffusers/models/transformers/transformer_sd3.py +7 -7
- diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/transformers/transformer_wan.py +316 -87
- diffusers/models/transformers/transformer_wan_vace.py +387 -0
- diffusers/models/unets/unet_1d.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +1 -1
- diffusers/models/unets/unet_2d_blocks.py +1 -1
- diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
- diffusers/models/unets/unet_2d_condition.py +4 -3
- diffusers/models/unets/unet_2d_condition_flax.py +2 -2
- diffusers/models/unets/unet_3d_blocks.py +1 -1
- diffusers/models/unets/unet_3d_condition.py +3 -3
- diffusers/models/unets/unet_i2vgen_xl.py +3 -3
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +2 -2
- diffusers/models/unets/unet_stable_cascade.py +1 -1
- diffusers/models/upsampling.py +2 -2
- diffusers/models/vae_flax.py +2 -2
- diffusers/models/vq_model.py +1 -1
- diffusers/modular_pipelines/__init__.py +83 -0
- diffusers/modular_pipelines/components_manager.py +1068 -0
- diffusers/modular_pipelines/flux/__init__.py +66 -0
- diffusers/modular_pipelines/flux/before_denoise.py +689 -0
- diffusers/modular_pipelines/flux/decoders.py +109 -0
- diffusers/modular_pipelines/flux/denoise.py +227 -0
- diffusers/modular_pipelines/flux/encoders.py +412 -0
- diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
- diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
- diffusers/modular_pipelines/modular_pipeline.py +2446 -0
- diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
- diffusers/modular_pipelines/node_utils.py +665 -0
- diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
- diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
- diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
- diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
- diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
- diffusers/modular_pipelines/wan/__init__.py +66 -0
- diffusers/modular_pipelines/wan/before_denoise.py +365 -0
- diffusers/modular_pipelines/wan/decoders.py +105 -0
- diffusers/modular_pipelines/wan/denoise.py +261 -0
- diffusers/modular_pipelines/wan/encoders.py +242 -0
- diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
- diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
- diffusers/pipelines/__init__.py +68 -6
- diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
- diffusers/pipelines/amused/pipeline_amused.py +7 -6
- diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
- diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
- diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +22 -13
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
- diffusers/pipelines/auto_pipeline.py +23 -20
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
- diffusers/pipelines/chroma/__init__.py +49 -0
- diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
- diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
- diffusers/pipelines/chroma/pipeline_output.py +21 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +17 -16
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +17 -16
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +18 -17
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +17 -16
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
- diffusers/pipelines/cogview4/pipeline_cogview4.py +23 -22
- diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
- diffusers/pipelines/consisid/consisid_utils.py +2 -2
- diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +11 -10
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +226 -107
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +12 -8
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +207 -105
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
- diffusers/pipelines/cosmos/__init__.py +54 -0
- diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
- diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
- diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
- diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
- diffusers/pipelines/cosmos/pipeline_output.py +40 -0
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
- diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
- diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
- diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
- diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
- diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
- diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
- diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +8 -8
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
- diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +4 -2
- diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
- diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
- diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
- diffusers/pipelines/flux/__init__.py +4 -0
- diffusers/pipelines/flux/modeling_flux.py +1 -1
- diffusers/pipelines/flux/pipeline_flux.py +37 -36
- diffusers/pipelines/flux/pipeline_flux_control.py +9 -9
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +31 -23
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +3 -2
- diffusers/pipelines/flux/pipeline_flux_fill.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_img2img.py +40 -7
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +12 -7
- diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
- diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +2 -2
- diffusers/pipelines/flux/pipeline_output.py +6 -4
- diffusers/pipelines/free_init_utils.py +2 -2
- diffusers/pipelines/free_noise_utils.py +3 -3
- diffusers/pipelines/hidream_image/__init__.py +47 -0
- diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
- diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
- diffusers/pipelines/hunyuan_video/__init__.py +2 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +26 -25
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
- diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
- diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
- diffusers/pipelines/kolors/text_encoder.py +3 -3
- diffusers/pipelines/kolors/tokenizer.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
- diffusers/pipelines/latte/pipeline_latte.py +12 -12
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
- diffusers/pipelines/ltx/__init__.py +4 -0
- diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +64 -18
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +117 -38
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +63 -18
- diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
- diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
- diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
- diffusers/pipelines/mochi/pipeline_mochi.py +15 -14
- diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
- diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
- diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
- diffusers/pipelines/onnx_utils.py +15 -2
- diffusers/pipelines/pag/pag_utils.py +2 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
- diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
- diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
- diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
- diffusers/pipelines/pia/pipeline_pia.py +8 -6
- diffusers/pipelines/pipeline_flax_utils.py +5 -6
- diffusers/pipelines/pipeline_loading_utils.py +113 -15
- diffusers/pipelines/pipeline_utils.py +127 -48
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +14 -12
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +31 -11
- diffusers/pipelines/qwenimage/__init__.py +55 -0
- diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
- diffusers/pipelines/sana/__init__.py +4 -0
- diffusers/pipelines/sana/pipeline_sana.py +23 -21
- diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
- diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
- diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
- diffusers/pipelines/shap_e/camera.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
- diffusers/pipelines/shap_e/renderer.py +3 -3
- diffusers/pipelines/skyreels_v2/__init__.py +59 -0
- diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
- diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
- diffusers/pipelines/stable_diffusion/__init__.py +0 -7
- diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +12 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +11 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
- diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
- diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +13 -12
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
- diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
- diffusers/pipelines/unclip/text_proj.py +2 -2
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
- diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
- diffusers/pipelines/visualcloze/__init__.py +52 -0
- diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
- diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
- diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
- diffusers/pipelines/wan/__init__.py +2 -0
- diffusers/pipelines/wan/pipeline_wan.py +91 -30
- diffusers/pipelines/wan/pipeline_wan_i2v.py +145 -45
- diffusers/pipelines/wan/pipeline_wan_vace.py +975 -0
- diffusers/pipelines/wan/pipeline_wan_video2video.py +14 -16
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
- diffusers/quantizers/__init__.py +3 -1
- diffusers/quantizers/base.py +17 -1
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
- diffusers/quantizers/bitsandbytes/utils.py +10 -7
- diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
- diffusers/quantizers/gguf/utils.py +108 -16
- diffusers/quantizers/pipe_quant_config.py +202 -0
- diffusers/quantizers/quantization_config.py +18 -16
- diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +31 -1
- diffusers/schedulers/__init__.py +3 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
- diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
- diffusers/schedulers/scheduling_consistency_models.py +1 -1
- diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
- diffusers/schedulers/scheduling_ddim.py +8 -8
- diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
- diffusers/schedulers/scheduling_ddim_flax.py +6 -6
- diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
- diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
- diffusers/schedulers/scheduling_ddpm.py +9 -9
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
- diffusers/schedulers/scheduling_deis_multistep.py +16 -9
- diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +18 -12
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
- diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +19 -13
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
- diffusers/schedulers/scheduling_edm_euler.py +20 -11
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
- diffusers/schedulers/scheduling_euler_discrete.py +3 -3
- diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
- diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
- diffusers/schedulers/scheduling_heun_discrete.py +2 -2
- diffusers/schedulers/scheduling_ipndm.py +2 -2
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
- diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
- diffusers/schedulers/scheduling_lcm.py +3 -3
- diffusers/schedulers/scheduling_lms_discrete.py +2 -2
- diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
- diffusers/schedulers/scheduling_pndm.py +4 -4
- diffusers/schedulers/scheduling_pndm_flax.py +4 -4
- diffusers/schedulers/scheduling_repaint.py +9 -9
- diffusers/schedulers/scheduling_sasolver.py +15 -15
- diffusers/schedulers/scheduling_scm.py +1 -2
- diffusers/schedulers/scheduling_sde_ve.py +1 -1
- diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
- diffusers/schedulers/scheduling_tcd.py +3 -3
- diffusers/schedulers/scheduling_unclip.py +5 -5
- diffusers/schedulers/scheduling_unipc_multistep.py +21 -12
- diffusers/schedulers/scheduling_utils.py +3 -3
- diffusers/schedulers/scheduling_utils_flax.py +2 -2
- diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
- diffusers/training_utils.py +91 -5
- diffusers/utils/__init__.py +15 -0
- diffusers/utils/accelerate_utils.py +1 -1
- diffusers/utils/constants.py +4 -0
- diffusers/utils/doc_utils.py +1 -1
- diffusers/utils/dummy_pt_objects.py +432 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +480 -0
- diffusers/utils/dynamic_modules_utils.py +85 -8
- diffusers/utils/export_utils.py +1 -1
- diffusers/utils/hub_utils.py +33 -17
- diffusers/utils/import_utils.py +151 -18
- diffusers/utils/logging.py +1 -1
- diffusers/utils/outputs.py +2 -1
- diffusers/utils/peft_utils.py +96 -10
- diffusers/utils/state_dict_utils.py +20 -3
- diffusers/utils/testing_utils.py +195 -17
- diffusers/utils/torch_utils.py +43 -5
- diffusers/video_processor.py +2 -2
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/METADATA +72 -57
- diffusers-0.35.0.dist-info/RECORD +703 -0
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/WHEEL +1 -1
- diffusers-0.33.1.dist-info/RECORD +0 -608
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,978 @@
|
|
1
|
+
# Copyright 2025 The SkyReels-V2 Team, The Wan Team and The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import html
|
16
|
+
import math
|
17
|
+
import re
|
18
|
+
from copy import deepcopy
|
19
|
+
from typing import Any, Callable, Dict, List, Optional, Union
|
20
|
+
|
21
|
+
import ftfy
|
22
|
+
import torch
|
23
|
+
from transformers import AutoTokenizer, UMT5EncoderModel
|
24
|
+
|
25
|
+
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
26
|
+
from ...loaders import SkyReelsV2LoraLoaderMixin
|
27
|
+
from ...models import AutoencoderKLWan, SkyReelsV2Transformer3DModel
|
28
|
+
from ...schedulers import UniPCMultistepScheduler
|
29
|
+
from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
|
30
|
+
from ...utils.torch_utils import randn_tensor
|
31
|
+
from ...video_processor import VideoProcessor
|
32
|
+
from ..pipeline_utils import DiffusionPipeline
|
33
|
+
from .pipeline_output import SkyReelsV2PipelineOutput
|
34
|
+
|
35
|
+
|
36
|
+
if is_torch_xla_available():
|
37
|
+
import torch_xla.core.xla_model as xm
|
38
|
+
|
39
|
+
XLA_AVAILABLE = True
|
40
|
+
else:
|
41
|
+
XLA_AVAILABLE = False
|
42
|
+
|
43
|
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
44
|
+
|
45
|
+
if is_ftfy_available():
|
46
|
+
import ftfy
|
47
|
+
|
48
|
+
|
49
|
+
EXAMPLE_DOC_STRING = """\
|
50
|
+
Examples:
|
51
|
+
```py
|
52
|
+
>>> import torch
|
53
|
+
>>> from diffusers import (
|
54
|
+
... SkyReelsV2DiffusionForcingPipeline,
|
55
|
+
... UniPCMultistepScheduler,
|
56
|
+
... AutoencoderKLWan,
|
57
|
+
... )
|
58
|
+
>>> from diffusers.utils import export_to_video
|
59
|
+
|
60
|
+
>>> # Load the pipeline
|
61
|
+
>>> # Available models:
|
62
|
+
>>> # - Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers
|
63
|
+
>>> # - Skywork/SkyReels-V2-DF-14B-540P-Diffusers
|
64
|
+
>>> # - Skywork/SkyReels-V2-DF-14B-720P-Diffusers
|
65
|
+
>>> vae = AutoencoderKLWan.from_pretrained(
|
66
|
+
... "Skywork/SkyReels-V2-DF-14B-720P-Diffusers",
|
67
|
+
... subfolder="vae",
|
68
|
+
... torch_dtype=torch.float32,
|
69
|
+
... )
|
70
|
+
>>> pipe = SkyReelsV2DiffusionForcingPipeline.from_pretrained(
|
71
|
+
... "Skywork/SkyReels-V2-DF-14B-720P-Diffusers",
|
72
|
+
... vae=vae,
|
73
|
+
... torch_dtype=torch.bfloat16,
|
74
|
+
... )
|
75
|
+
>>> flow_shift = 8.0 # 8.0 for T2V, 5.0 for I2V
|
76
|
+
>>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
|
77
|
+
>>> pipe = pipe.to("cuda")
|
78
|
+
|
79
|
+
>>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
|
80
|
+
|
81
|
+
>>> output = pipe(
|
82
|
+
... prompt=prompt,
|
83
|
+
... num_inference_steps=30,
|
84
|
+
... height=544,
|
85
|
+
... width=960,
|
86
|
+
... guidance_scale=6.0, # 6.0 for T2V, 5.0 for I2V
|
87
|
+
... num_frames=97,
|
88
|
+
... ar_step=5, # Controls asynchronous inference (0 for synchronous mode)
|
89
|
+
... causal_block_size=5, # Number of frames processed together in a causal block
|
90
|
+
... overlap_history=None, # Number of frames to overlap for smooth transitions in long videos
|
91
|
+
... addnoise_condition=20, # Improves consistency in long video generation
|
92
|
+
... ).frames[0]
|
93
|
+
>>> export_to_video(output, "video.mp4", fps=24, quality=8)
|
94
|
+
```
|
95
|
+
"""
|
96
|
+
|
97
|
+
|
98
|
+
def basic_clean(text):
|
99
|
+
text = ftfy.fix_text(text)
|
100
|
+
text = html.unescape(html.unescape(text))
|
101
|
+
return text.strip()
|
102
|
+
|
103
|
+
|
104
|
+
def whitespace_clean(text):
|
105
|
+
text = re.sub(r"\s+", " ", text)
|
106
|
+
text = text.strip()
|
107
|
+
return text
|
108
|
+
|
109
|
+
|
110
|
+
def prompt_clean(text):
|
111
|
+
text = whitespace_clean(basic_clean(text))
|
112
|
+
return text
|
113
|
+
|
114
|
+
|
115
|
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
116
|
+
def retrieve_latents(
|
117
|
+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
118
|
+
):
|
119
|
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
120
|
+
return encoder_output.latent_dist.sample(generator)
|
121
|
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
122
|
+
return encoder_output.latent_dist.mode()
|
123
|
+
elif hasattr(encoder_output, "latents"):
|
124
|
+
return encoder_output.latents
|
125
|
+
else:
|
126
|
+
raise AttributeError("Could not access latents of provided encoder_output")
|
127
|
+
|
128
|
+
|
129
|
+
class SkyReelsV2DiffusionForcingPipeline(DiffusionPipeline, SkyReelsV2LoraLoaderMixin):
|
130
|
+
"""
|
131
|
+
Pipeline for Text-to-Video (t2v) generation using SkyReels-V2 with diffusion forcing.
|
132
|
+
|
133
|
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
134
|
+
implemented for all pipelines (downloading, saving, running on a specific device, etc.).
|
135
|
+
|
136
|
+
Args:
|
137
|
+
tokenizer ([`AutoTokenizer`]):
|
138
|
+
Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
|
139
|
+
specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
|
140
|
+
text_encoder ([`UMT5EncoderModel`]):
|
141
|
+
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
142
|
+
the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
|
143
|
+
transformer ([`SkyReelsV2Transformer3DModel`]):
|
144
|
+
Conditional Transformer to denoise the encoded image latents.
|
145
|
+
scheduler ([`UniPCMultistepScheduler`]):
|
146
|
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
147
|
+
vae ([`AutoencoderKLWan`]):
|
148
|
+
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
149
|
+
"""
|
150
|
+
|
151
|
+
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
152
|
+
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
153
|
+
|
154
|
+
def __init__(
|
155
|
+
self,
|
156
|
+
tokenizer: AutoTokenizer,
|
157
|
+
text_encoder: UMT5EncoderModel,
|
158
|
+
transformer: SkyReelsV2Transformer3DModel,
|
159
|
+
vae: AutoencoderKLWan,
|
160
|
+
scheduler: UniPCMultistepScheduler,
|
161
|
+
):
|
162
|
+
super().__init__()
|
163
|
+
|
164
|
+
self.register_modules(
|
165
|
+
vae=vae,
|
166
|
+
text_encoder=text_encoder,
|
167
|
+
tokenizer=tokenizer,
|
168
|
+
transformer=transformer,
|
169
|
+
scheduler=scheduler,
|
170
|
+
)
|
171
|
+
|
172
|
+
self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
|
173
|
+
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
|
174
|
+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
175
|
+
|
176
|
+
# Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline._get_t5_prompt_embeds
|
177
|
+
def _get_t5_prompt_embeds(
|
178
|
+
self,
|
179
|
+
prompt: Union[str, List[str]] = None,
|
180
|
+
num_videos_per_prompt: int = 1,
|
181
|
+
max_sequence_length: int = 226,
|
182
|
+
device: Optional[torch.device] = None,
|
183
|
+
dtype: Optional[torch.dtype] = None,
|
184
|
+
):
|
185
|
+
device = device or self._execution_device
|
186
|
+
dtype = dtype or self.text_encoder.dtype
|
187
|
+
|
188
|
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
189
|
+
prompt = [prompt_clean(u) for u in prompt]
|
190
|
+
batch_size = len(prompt)
|
191
|
+
|
192
|
+
text_inputs = self.tokenizer(
|
193
|
+
prompt,
|
194
|
+
padding="max_length",
|
195
|
+
max_length=max_sequence_length,
|
196
|
+
truncation=True,
|
197
|
+
add_special_tokens=True,
|
198
|
+
return_attention_mask=True,
|
199
|
+
return_tensors="pt",
|
200
|
+
)
|
201
|
+
text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
|
202
|
+
seq_lens = mask.gt(0).sum(dim=1).long()
|
203
|
+
|
204
|
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
|
205
|
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
206
|
+
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
|
207
|
+
prompt_embeds = torch.stack(
|
208
|
+
[torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
|
209
|
+
)
|
210
|
+
|
211
|
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
212
|
+
_, seq_len, _ = prompt_embeds.shape
|
213
|
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
214
|
+
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
215
|
+
|
216
|
+
return prompt_embeds
|
217
|
+
|
218
|
+
# Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
|
219
|
+
def encode_prompt(
|
220
|
+
self,
|
221
|
+
prompt: Union[str, List[str]],
|
222
|
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
223
|
+
do_classifier_free_guidance: bool = True,
|
224
|
+
num_videos_per_prompt: int = 1,
|
225
|
+
prompt_embeds: Optional[torch.Tensor] = None,
|
226
|
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
227
|
+
max_sequence_length: int = 226,
|
228
|
+
device: Optional[torch.device] = None,
|
229
|
+
dtype: Optional[torch.dtype] = None,
|
230
|
+
):
|
231
|
+
r"""
|
232
|
+
Encodes the prompt into text encoder hidden states.
|
233
|
+
|
234
|
+
Args:
|
235
|
+
prompt (`str` or `List[str]`, *optional*):
|
236
|
+
prompt to be encoded
|
237
|
+
negative_prompt (`str` or `List[str]`, *optional*):
|
238
|
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
239
|
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
240
|
+
less than `1`).
|
241
|
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
242
|
+
Whether to use classifier free guidance or not.
|
243
|
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
244
|
+
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
245
|
+
prompt_embeds (`torch.Tensor`, *optional*):
|
246
|
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
247
|
+
provided, text embeddings will be generated from `prompt` input argument.
|
248
|
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
249
|
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
250
|
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
251
|
+
argument.
|
252
|
+
device: (`torch.device`, *optional*):
|
253
|
+
torch device
|
254
|
+
dtype: (`torch.dtype`, *optional*):
|
255
|
+
torch dtype
|
256
|
+
"""
|
257
|
+
device = device or self._execution_device
|
258
|
+
|
259
|
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
260
|
+
if prompt is not None:
|
261
|
+
batch_size = len(prompt)
|
262
|
+
else:
|
263
|
+
batch_size = prompt_embeds.shape[0]
|
264
|
+
|
265
|
+
if prompt_embeds is None:
|
266
|
+
prompt_embeds = self._get_t5_prompt_embeds(
|
267
|
+
prompt=prompt,
|
268
|
+
num_videos_per_prompt=num_videos_per_prompt,
|
269
|
+
max_sequence_length=max_sequence_length,
|
270
|
+
device=device,
|
271
|
+
dtype=dtype,
|
272
|
+
)
|
273
|
+
|
274
|
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
275
|
+
negative_prompt = negative_prompt or ""
|
276
|
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
277
|
+
|
278
|
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
279
|
+
raise TypeError(
|
280
|
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
281
|
+
f" {type(prompt)}."
|
282
|
+
)
|
283
|
+
elif batch_size != len(negative_prompt):
|
284
|
+
raise ValueError(
|
285
|
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
286
|
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
287
|
+
" the batch size of `prompt`."
|
288
|
+
)
|
289
|
+
|
290
|
+
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
291
|
+
prompt=negative_prompt,
|
292
|
+
num_videos_per_prompt=num_videos_per_prompt,
|
293
|
+
max_sequence_length=max_sequence_length,
|
294
|
+
device=device,
|
295
|
+
dtype=dtype,
|
296
|
+
)
|
297
|
+
|
298
|
+
return prompt_embeds, negative_prompt_embeds
|
299
|
+
|
300
|
+
def check_inputs(
|
301
|
+
self,
|
302
|
+
prompt,
|
303
|
+
negative_prompt,
|
304
|
+
height,
|
305
|
+
width,
|
306
|
+
prompt_embeds=None,
|
307
|
+
negative_prompt_embeds=None,
|
308
|
+
callback_on_step_end_tensor_inputs=None,
|
309
|
+
overlap_history=None,
|
310
|
+
num_frames=None,
|
311
|
+
base_num_frames=None,
|
312
|
+
):
|
313
|
+
if height % 16 != 0 or width % 16 != 0:
|
314
|
+
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
315
|
+
|
316
|
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
317
|
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
318
|
+
):
|
319
|
+
raise ValueError(
|
320
|
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
321
|
+
)
|
322
|
+
|
323
|
+
if prompt is not None and prompt_embeds is not None:
|
324
|
+
raise ValueError(
|
325
|
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
326
|
+
" only forward one of the two."
|
327
|
+
)
|
328
|
+
elif negative_prompt is not None and negative_prompt_embeds is not None:
|
329
|
+
raise ValueError(
|
330
|
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
|
331
|
+
" only forward one of the two."
|
332
|
+
)
|
333
|
+
elif prompt is None and prompt_embeds is None:
|
334
|
+
raise ValueError(
|
335
|
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
336
|
+
)
|
337
|
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
338
|
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
339
|
+
elif negative_prompt is not None and (
|
340
|
+
not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
|
341
|
+
):
|
342
|
+
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
|
343
|
+
|
344
|
+
if num_frames > base_num_frames and overlap_history is None:
|
345
|
+
raise ValueError(
|
346
|
+
"`overlap_history` is required when `num_frames` exceeds `base_num_frames` to ensure smooth transitions in long video generation. "
|
347
|
+
"Please specify a value for `overlap_history`. Recommended values are 17 or 37."
|
348
|
+
)
|
349
|
+
|
350
|
+
def prepare_latents(
|
351
|
+
self,
|
352
|
+
batch_size: int,
|
353
|
+
num_channels_latents: int = 16,
|
354
|
+
height: int = 480,
|
355
|
+
width: int = 832,
|
356
|
+
num_frames: int = 97,
|
357
|
+
dtype: Optional[torch.dtype] = None,
|
358
|
+
device: Optional[torch.device] = None,
|
359
|
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
360
|
+
latents: Optional[torch.Tensor] = None,
|
361
|
+
base_latent_num_frames: Optional[int] = None,
|
362
|
+
video_latents: Optional[torch.Tensor] = None,
|
363
|
+
causal_block_size: Optional[int] = None,
|
364
|
+
overlap_history_latent_frames: Optional[int] = None,
|
365
|
+
long_video_iter: Optional[int] = None,
|
366
|
+
) -> torch.Tensor:
|
367
|
+
if latents is not None:
|
368
|
+
return latents.to(device=device, dtype=dtype)
|
369
|
+
|
370
|
+
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
371
|
+
latent_height = height // self.vae_scale_factor_spatial
|
372
|
+
latent_width = width // self.vae_scale_factor_spatial
|
373
|
+
|
374
|
+
prefix_video_latents = None
|
375
|
+
prefix_video_latents_frames = 0
|
376
|
+
|
377
|
+
if video_latents is not None: # long video generation at the iterations other than the first one
|
378
|
+
prefix_video_latents = video_latents[:, :, -overlap_history_latent_frames:]
|
379
|
+
|
380
|
+
if prefix_video_latents.shape[2] % causal_block_size != 0:
|
381
|
+
truncate_len_latents = prefix_video_latents.shape[2] % causal_block_size
|
382
|
+
logger.warning(
|
383
|
+
f"The length of prefix video latents is truncated by {truncate_len_latents} frames for the causal block size alignment. "
|
384
|
+
f"This truncation ensures compatibility with the causal block size, which is required for proper processing. "
|
385
|
+
f"However, it may slightly affect the continuity of the generated video at the truncation boundary."
|
386
|
+
)
|
387
|
+
prefix_video_latents = prefix_video_latents[:, :, :-truncate_len_latents]
|
388
|
+
prefix_video_latents_frames = prefix_video_latents.shape[2]
|
389
|
+
|
390
|
+
finished_frame_num = (
|
391
|
+
long_video_iter * (base_latent_num_frames - overlap_history_latent_frames)
|
392
|
+
+ overlap_history_latent_frames
|
393
|
+
)
|
394
|
+
left_frame_num = num_latent_frames - finished_frame_num
|
395
|
+
num_latent_frames = min(left_frame_num + overlap_history_latent_frames, base_latent_num_frames)
|
396
|
+
elif base_latent_num_frames is not None: # long video generation at the first iteration
|
397
|
+
num_latent_frames = base_latent_num_frames
|
398
|
+
else: # short video generation
|
399
|
+
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
400
|
+
|
401
|
+
shape = (
|
402
|
+
batch_size,
|
403
|
+
num_channels_latents,
|
404
|
+
num_latent_frames,
|
405
|
+
latent_height,
|
406
|
+
latent_width,
|
407
|
+
)
|
408
|
+
if isinstance(generator, list) and len(generator) != batch_size:
|
409
|
+
raise ValueError(
|
410
|
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
411
|
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
412
|
+
)
|
413
|
+
|
414
|
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
415
|
+
|
416
|
+
return latents, num_latent_frames, prefix_video_latents, prefix_video_latents_frames
|
417
|
+
|
418
|
+
def generate_timestep_matrix(
|
419
|
+
self,
|
420
|
+
num_latent_frames: int,
|
421
|
+
step_template: torch.Tensor,
|
422
|
+
base_num_latent_frames: int,
|
423
|
+
ar_step: int = 5,
|
424
|
+
num_pre_ready: int = 0,
|
425
|
+
causal_block_size: int = 1,
|
426
|
+
shrink_interval_with_mask: bool = False,
|
427
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[tuple]]:
|
428
|
+
"""
|
429
|
+
This function implements the core diffusion forcing algorithm that creates a coordinated denoising schedule
|
430
|
+
across temporal frames. It supports both synchronous and asynchronous generation modes:
|
431
|
+
|
432
|
+
**Synchronous Mode** (ar_step=0, causal_block_size=1):
|
433
|
+
- All frames are denoised simultaneously at each timestep
|
434
|
+
- Each frame follows the same denoising trajectory: [1000, 800, 600, ..., 0]
|
435
|
+
- Simpler but may have less temporal consistency for long videos
|
436
|
+
|
437
|
+
**Asynchronous Mode** (ar_step>0, causal_block_size>1):
|
438
|
+
- Frames are grouped into causal blocks and processed block/chunk-wise
|
439
|
+
- Each block is denoised in a staggered pattern creating a "denoising wave"
|
440
|
+
- Earlier blocks are more denoised, later blocks lag behind by ar_step timesteps
|
441
|
+
- Creates stronger temporal dependencies and better consistency
|
442
|
+
|
443
|
+
Args:
|
444
|
+
num_latent_frames (int): Total number of latent frames to generate
|
445
|
+
step_template (torch.Tensor): Base timestep schedule (e.g., [1000, 800, 600, ..., 0])
|
446
|
+
base_num_latent_frames (int): Maximum frames the model can process in one forward pass
|
447
|
+
ar_step (int, optional): Autoregressive step size for temporal lag.
|
448
|
+
0 = synchronous, >0 = asynchronous. Defaults to 5.
|
449
|
+
num_pre_ready (int, optional):
|
450
|
+
Number of frames already denoised (e.g., from prefix in a video2video task).
|
451
|
+
Defaults to 0.
|
452
|
+
causal_block_size (int, optional): Number of frames processed as a causal block.
|
453
|
+
Defaults to 1.
|
454
|
+
shrink_interval_with_mask (bool, optional): Whether to optimize processing intervals.
|
455
|
+
Defaults to False.
|
456
|
+
|
457
|
+
Returns:
|
458
|
+
tuple containing:
|
459
|
+
- step_matrix (torch.Tensor): Matrix of timesteps for each frame at each iteration Shape:
|
460
|
+
[num_iterations, num_latent_frames]
|
461
|
+
- step_index (torch.Tensor): Index matrix for timestep lookup Shape: [num_iterations,
|
462
|
+
num_latent_frames]
|
463
|
+
- step_update_mask (torch.Tensor): Boolean mask indicating which frames to update Shape:
|
464
|
+
[num_iterations, num_latent_frames]
|
465
|
+
- valid_interval (list[tuple]): List of (start, end) intervals for each iteration
|
466
|
+
|
467
|
+
Raises:
|
468
|
+
ValueError: If ar_step is too small for the given configuration
|
469
|
+
"""
|
470
|
+
# Initialize lists to store the scheduling matrices and metadata
|
471
|
+
step_matrix, step_index = [], [] # Will store timestep values and indices for each iteration
|
472
|
+
update_mask, valid_interval = [], [] # Will store update masks and processing intervals
|
473
|
+
|
474
|
+
# Calculate total number of denoising iterations (add 1 for initial noise state)
|
475
|
+
num_iterations = len(step_template) + 1
|
476
|
+
|
477
|
+
# Convert frame counts to block counts for causal processing
|
478
|
+
# Each block contains causal_block_size frames that are processed together
|
479
|
+
# E.g.: 25 frames ÷ 5 = 5 blocks total
|
480
|
+
num_blocks = num_latent_frames // causal_block_size
|
481
|
+
base_num_blocks = base_num_latent_frames // causal_block_size
|
482
|
+
|
483
|
+
# Validate ar_step is sufficient for the given configuration
|
484
|
+
# In asynchronous mode, we need enough timesteps to create the staggered pattern
|
485
|
+
if base_num_blocks < num_blocks:
|
486
|
+
min_ar_step = len(step_template) / base_num_blocks
|
487
|
+
if ar_step < min_ar_step:
|
488
|
+
raise ValueError(f"`ar_step` should be at least {math.ceil(min_ar_step)} in your setting")
|
489
|
+
|
490
|
+
# Extend step_template with boundary values for easier indexing
|
491
|
+
# 999: dummy value for counter starting from 1
|
492
|
+
# 0: final timestep (completely denoised)
|
493
|
+
step_template = torch.cat(
|
494
|
+
[
|
495
|
+
torch.tensor([999], dtype=torch.int64, device=step_template.device),
|
496
|
+
step_template.long(),
|
497
|
+
torch.tensor([0], dtype=torch.int64, device=step_template.device),
|
498
|
+
]
|
499
|
+
)
|
500
|
+
|
501
|
+
# Initialize the previous row state (tracks denoising progress for each block)
|
502
|
+
# 0 means not started, num_iterations means fully denoised
|
503
|
+
pre_row = torch.zeros(num_blocks, dtype=torch.long)
|
504
|
+
|
505
|
+
# Mark pre-ready frames (e.g., from prefix video for a video2video task) as already at final denoising state
|
506
|
+
if num_pre_ready > 0:
|
507
|
+
pre_row[: num_pre_ready // causal_block_size] = num_iterations
|
508
|
+
|
509
|
+
# Main loop: Generate denoising schedule until all frames are fully denoised
|
510
|
+
while not torch.all(pre_row >= (num_iterations - 1)):
|
511
|
+
# Create new row representing the next denoising step
|
512
|
+
new_row = torch.zeros(num_blocks, dtype=torch.long)
|
513
|
+
|
514
|
+
# Apply diffusion forcing logic for each block
|
515
|
+
for i in range(num_blocks):
|
516
|
+
if i == 0 or pre_row[i - 1] >= (
|
517
|
+
num_iterations - 1
|
518
|
+
): # the first frame or the last frame is completely denoised
|
519
|
+
new_row[i] = pre_row[i] + 1
|
520
|
+
else:
|
521
|
+
# Asynchronous mode: lag behind previous block by ar_step timesteps
|
522
|
+
# This creates the "diffusion forcing" staggered pattern
|
523
|
+
new_row[i] = new_row[i - 1] - ar_step
|
524
|
+
|
525
|
+
# Clamp values to valid range [0, num_iterations]
|
526
|
+
new_row = new_row.clamp(0, num_iterations)
|
527
|
+
|
528
|
+
# Create update mask: True for blocks that need denoising update at this iteration
|
529
|
+
# Exclude blocks that haven't started (new_row != pre_row) or are finished (new_row != num_iterations)
|
530
|
+
# Final state example: [False, ..., False, True, True, True, True, True]
|
531
|
+
# where first 20 frames are done (False) and last 5 frames still need updates (True)
|
532
|
+
update_mask.append((new_row != pre_row) & (new_row != num_iterations))
|
533
|
+
|
534
|
+
# Store the iteration state
|
535
|
+
step_index.append(new_row) # Index into step_template
|
536
|
+
step_matrix.append(step_template[new_row]) # Actual timestep values
|
537
|
+
pre_row = new_row # Update for next iteration
|
538
|
+
|
539
|
+
# For videos longer than model capacity, we process in sliding windows
|
540
|
+
terminal_flag = base_num_blocks
|
541
|
+
|
542
|
+
# Optional optimization: shrink interval based on first update mask
|
543
|
+
if shrink_interval_with_mask:
|
544
|
+
idx_sequence = torch.arange(num_blocks, dtype=torch.int64)
|
545
|
+
update_mask = update_mask[0]
|
546
|
+
update_mask_idx = idx_sequence[update_mask]
|
547
|
+
last_update_idx = update_mask_idx[-1].item()
|
548
|
+
terminal_flag = last_update_idx + 1
|
549
|
+
|
550
|
+
# Each interval defines which frames to process in the current forward pass
|
551
|
+
for curr_mask in update_mask:
|
552
|
+
# Extend terminal flag if current mask has updates beyond current terminal
|
553
|
+
if terminal_flag < num_blocks and curr_mask[terminal_flag]:
|
554
|
+
terminal_flag += 1
|
555
|
+
# Create interval: [start, end) where start ensures we don't exceed model capacity
|
556
|
+
valid_interval.append((max(terminal_flag - base_num_blocks, 0), terminal_flag))
|
557
|
+
|
558
|
+
# Convert lists to tensors for efficient processing
|
559
|
+
step_update_mask = torch.stack(update_mask, dim=0)
|
560
|
+
step_index = torch.stack(step_index, dim=0)
|
561
|
+
step_matrix = torch.stack(step_matrix, dim=0)
|
562
|
+
|
563
|
+
# Each block's schedule is replicated to all frames within that block
|
564
|
+
if causal_block_size > 1:
|
565
|
+
# Expand each block to causal_block_size frames
|
566
|
+
step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous()
|
567
|
+
step_index = step_index.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous()
|
568
|
+
step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous()
|
569
|
+
# Scale intervals from block-level to frame-level
|
570
|
+
valid_interval = [(s * causal_block_size, e * causal_block_size) for s, e in valid_interval]
|
571
|
+
|
572
|
+
return step_matrix, step_index, step_update_mask, valid_interval
|
573
|
+
|
574
|
+
@property
|
575
|
+
def guidance_scale(self):
|
576
|
+
return self._guidance_scale
|
577
|
+
|
578
|
+
@property
|
579
|
+
def do_classifier_free_guidance(self):
|
580
|
+
return self._guidance_scale > 1.0
|
581
|
+
|
582
|
+
@property
|
583
|
+
def num_timesteps(self):
|
584
|
+
return self._num_timesteps
|
585
|
+
|
586
|
+
@property
|
587
|
+
def current_timestep(self):
|
588
|
+
return self._current_timestep
|
589
|
+
|
590
|
+
@property
|
591
|
+
def interrupt(self):
|
592
|
+
return self._interrupt
|
593
|
+
|
594
|
+
@property
|
595
|
+
def attention_kwargs(self):
|
596
|
+
return self._attention_kwargs
|
597
|
+
|
598
|
+
@torch.no_grad()
|
599
|
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
600
|
+
def __call__(
|
601
|
+
self,
|
602
|
+
prompt: Union[str, List[str]],
|
603
|
+
negative_prompt: Union[str, List[str]] = None,
|
604
|
+
height: int = 544,
|
605
|
+
width: int = 960,
|
606
|
+
num_frames: int = 97,
|
607
|
+
num_inference_steps: int = 50,
|
608
|
+
guidance_scale: float = 6.0,
|
609
|
+
num_videos_per_prompt: Optional[int] = 1,
|
610
|
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
611
|
+
latents: Optional[torch.Tensor] = None,
|
612
|
+
prompt_embeds: Optional[torch.Tensor] = None,
|
613
|
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
614
|
+
output_type: Optional[str] = "np",
|
615
|
+
return_dict: bool = True,
|
616
|
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
617
|
+
callback_on_step_end: Optional[
|
618
|
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
619
|
+
] = None,
|
620
|
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
621
|
+
max_sequence_length: int = 512,
|
622
|
+
overlap_history: Optional[int] = None,
|
623
|
+
addnoise_condition: float = 0,
|
624
|
+
base_num_frames: int = 97,
|
625
|
+
ar_step: int = 0,
|
626
|
+
causal_block_size: Optional[int] = None,
|
627
|
+
fps: int = 24,
|
628
|
+
):
|
629
|
+
r"""
|
630
|
+
The call function to the pipeline for generation.
|
631
|
+
|
632
|
+
Args:
|
633
|
+
prompt (`str` or `List[str]`, *optional*):
|
634
|
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
635
|
+
instead.
|
636
|
+
negative_prompt (`str` or `List[str]`, *optional*):
|
637
|
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
638
|
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
639
|
+
less than `1`).
|
640
|
+
height (`int`, defaults to `544`):
|
641
|
+
The height of the generated video.
|
642
|
+
width (`int`, defaults to `960`):
|
643
|
+
The width of the generated video.
|
644
|
+
num_frames (`int`, defaults to `97`):
|
645
|
+
The number of frames in the generated video.
|
646
|
+
num_inference_steps (`int`, defaults to `50`):
|
647
|
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
648
|
+
expense of slower inference.
|
649
|
+
guidance_scale (`float`, defaults to `6.0`):
|
650
|
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
651
|
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
652
|
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
653
|
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
654
|
+
usually at the expense of lower image quality. (**6.0 for T2V**, **5.0 for I2V**)
|
655
|
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
656
|
+
The number of images to generate per prompt.
|
657
|
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
658
|
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
659
|
+
generation deterministic.
|
660
|
+
latents (`torch.Tensor`, *optional*):
|
661
|
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
662
|
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
663
|
+
tensor is generated by sampling using the supplied random `generator`.
|
664
|
+
prompt_embeds (`torch.Tensor`, *optional*):
|
665
|
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
666
|
+
provided, text embeddings are generated from the `prompt` input argument.
|
667
|
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
668
|
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
669
|
+
provided, text embeddings are generated from the `negative_prompt` input argument.
|
670
|
+
output_type (`str`, *optional*, defaults to `"np"`):
|
671
|
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
672
|
+
return_dict (`bool`, *optional*, defaults to `True`):
|
673
|
+
Whether or not to return a [`SkyReelsV2PipelineOutput`] instead of a plain tuple.
|
674
|
+
attention_kwargs (`dict`, *optional*):
|
675
|
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
676
|
+
`self.processor` in
|
677
|
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
678
|
+
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
679
|
+
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
680
|
+
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
681
|
+
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
682
|
+
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
683
|
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
684
|
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
685
|
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
686
|
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
687
|
+
max_sequence_length (`int`, *optional*, defaults to `512`):
|
688
|
+
The maximum sequence length of the prompt.
|
689
|
+
overlap_history (`int`, *optional*, defaults to `None`):
|
690
|
+
Number of frames to overlap for smooth transitions in long videos. If `None`, the pipeline assumes
|
691
|
+
short video generation mode, and no overlap is applied. 17 and 37 are recommended to set.
|
692
|
+
addnoise_condition (`float`, *optional*, defaults to `0`):
|
693
|
+
This is used to help smooth the long video generation by adding some noise to the clean condition. Too
|
694
|
+
large noise can cause the inconsistency as well. 20 is a recommended value, and you may try larger
|
695
|
+
ones, but it is recommended to not exceed 50.
|
696
|
+
base_num_frames (`int`, *optional*, defaults to `97`):
|
697
|
+
97 or 121 | Base frame count (**97 for 540P**, **121 for 720P**)
|
698
|
+
ar_step (`int`, *optional*, defaults to `0`):
|
699
|
+
Controls asynchronous inference (0 for synchronous mode) You can set `ar_step=5` to enable asynchronous
|
700
|
+
inference. When asynchronous inference, `causal_block_size=5` is recommended while it is not supposed
|
701
|
+
to be set for synchronous generation. Asynchronous inference will take more steps to diffuse the whole
|
702
|
+
sequence which means it will be SLOWER than synchronous mode. In our experiments, asynchronous
|
703
|
+
inference may improve the instruction following and visual consistent performance.
|
704
|
+
causal_block_size (`int`, *optional*, defaults to `None`):
|
705
|
+
The number of frames in each block/chunk. Recommended when using asynchronous inference (when ar_step >
|
706
|
+
0)
|
707
|
+
fps (`int`, *optional*, defaults to `24`):
|
708
|
+
Frame rate of the generated video
|
709
|
+
|
710
|
+
Examples:
|
711
|
+
|
712
|
+
Returns:
|
713
|
+
[`~SkyReelsV2PipelineOutput`] or `tuple`:
|
714
|
+
If `return_dict` is `True`, [`SkyReelsV2PipelineOutput`] is returned, otherwise a `tuple` is returned
|
715
|
+
where the first element is a list with the generated images and the second element is a list of `bool`s
|
716
|
+
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
|
717
|
+
"""
|
718
|
+
|
719
|
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
720
|
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
721
|
+
|
722
|
+
# 1. Check inputs. Raise error if not correct
|
723
|
+
self.check_inputs(
|
724
|
+
prompt,
|
725
|
+
negative_prompt,
|
726
|
+
height,
|
727
|
+
width,
|
728
|
+
prompt_embeds,
|
729
|
+
negative_prompt_embeds,
|
730
|
+
callback_on_step_end_tensor_inputs,
|
731
|
+
overlap_history,
|
732
|
+
num_frames,
|
733
|
+
base_num_frames,
|
734
|
+
)
|
735
|
+
|
736
|
+
if addnoise_condition > 60:
|
737
|
+
logger.warning(
|
738
|
+
f"The value of 'addnoise_condition' is too large ({addnoise_condition}) and may cause inconsistencies in long video generation. A value of 20 is recommended."
|
739
|
+
)
|
740
|
+
|
741
|
+
if num_frames % self.vae_scale_factor_temporal != 1:
|
742
|
+
logger.warning(
|
743
|
+
f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
|
744
|
+
)
|
745
|
+
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
|
746
|
+
num_frames = max(num_frames, 1)
|
747
|
+
|
748
|
+
self._guidance_scale = guidance_scale
|
749
|
+
self._attention_kwargs = attention_kwargs
|
750
|
+
self._current_timestep = None
|
751
|
+
self._interrupt = False
|
752
|
+
|
753
|
+
device = self._execution_device
|
754
|
+
|
755
|
+
# 2. Define call parameters
|
756
|
+
if prompt is not None and isinstance(prompt, str):
|
757
|
+
batch_size = 1
|
758
|
+
elif prompt is not None and isinstance(prompt, list):
|
759
|
+
batch_size = len(prompt)
|
760
|
+
else:
|
761
|
+
batch_size = prompt_embeds.shape[0]
|
762
|
+
|
763
|
+
# 3. Encode input prompt
|
764
|
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
765
|
+
prompt=prompt,
|
766
|
+
negative_prompt=negative_prompt,
|
767
|
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
768
|
+
num_videos_per_prompt=num_videos_per_prompt,
|
769
|
+
prompt_embeds=prompt_embeds,
|
770
|
+
negative_prompt_embeds=negative_prompt_embeds,
|
771
|
+
max_sequence_length=max_sequence_length,
|
772
|
+
device=device,
|
773
|
+
)
|
774
|
+
|
775
|
+
transformer_dtype = self.transformer.dtype
|
776
|
+
prompt_embeds = prompt_embeds.to(transformer_dtype)
|
777
|
+
if negative_prompt_embeds is not None:
|
778
|
+
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
|
779
|
+
|
780
|
+
# 4. Prepare timesteps
|
781
|
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
782
|
+
timesteps = self.scheduler.timesteps
|
783
|
+
|
784
|
+
if causal_block_size is None:
|
785
|
+
causal_block_size = self.transformer.config.num_frame_per_block
|
786
|
+
else:
|
787
|
+
self.transformer._set_ar_attention(causal_block_size)
|
788
|
+
|
789
|
+
fps_embeds = [fps] * prompt_embeds.shape[0]
|
790
|
+
fps_embeds = [0 if i == 16 else 1 for i in fps_embeds]
|
791
|
+
|
792
|
+
# Determine if we're doing long video generation
|
793
|
+
is_long_video = overlap_history is not None and base_num_frames is not None and num_frames > base_num_frames
|
794
|
+
# Initialize accumulated_latents to store all latents in one tensor
|
795
|
+
accumulated_latents = None
|
796
|
+
if is_long_video:
|
797
|
+
# Long video generation setup
|
798
|
+
overlap_history_latent_frames = (overlap_history - 1) // self.vae_scale_factor_temporal + 1
|
799
|
+
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
800
|
+
base_latent_num_frames = (
|
801
|
+
(base_num_frames - 1) // self.vae_scale_factor_temporal + 1
|
802
|
+
if base_num_frames is not None
|
803
|
+
else num_latent_frames
|
804
|
+
)
|
805
|
+
n_iter = (
|
806
|
+
1
|
807
|
+
+ (num_latent_frames - base_latent_num_frames - 1)
|
808
|
+
// (base_latent_num_frames - overlap_history_latent_frames)
|
809
|
+
+ 1
|
810
|
+
)
|
811
|
+
else:
|
812
|
+
# Short video generation setup
|
813
|
+
n_iter = 1
|
814
|
+
base_latent_num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
815
|
+
|
816
|
+
# Loop through iterations (multiple iterations only for long videos)
|
817
|
+
for iter_idx in range(n_iter):
|
818
|
+
if is_long_video:
|
819
|
+
logger.debug(f"Processing iteration {iter_idx + 1}/{n_iter} for long video generation...")
|
820
|
+
|
821
|
+
# 5. Prepare latent variables
|
822
|
+
num_channels_latents = self.transformer.config.in_channels
|
823
|
+
latents, current_num_latent_frames, prefix_video_latents, prefix_video_latents_frames = (
|
824
|
+
self.prepare_latents(
|
825
|
+
batch_size * num_videos_per_prompt,
|
826
|
+
num_channels_latents,
|
827
|
+
height,
|
828
|
+
width,
|
829
|
+
num_frames,
|
830
|
+
torch.float32,
|
831
|
+
device,
|
832
|
+
generator,
|
833
|
+
latents if iter_idx == 0 else None,
|
834
|
+
video_latents=accumulated_latents, # Pass latents directly instead of decoded video
|
835
|
+
base_latent_num_frames=base_latent_num_frames if is_long_video else None,
|
836
|
+
causal_block_size=causal_block_size,
|
837
|
+
overlap_history_latent_frames=overlap_history_latent_frames if is_long_video else None,
|
838
|
+
long_video_iter=iter_idx if is_long_video else None,
|
839
|
+
)
|
840
|
+
)
|
841
|
+
|
842
|
+
if prefix_video_latents_frames > 0:
|
843
|
+
latents[:, :, :prefix_video_latents_frames, :, :] = prefix_video_latents.to(transformer_dtype)
|
844
|
+
|
845
|
+
# 6. Prepare sample schedulers and timestep matrix
|
846
|
+
sample_schedulers = []
|
847
|
+
for _ in range(current_num_latent_frames):
|
848
|
+
sample_scheduler = deepcopy(self.scheduler)
|
849
|
+
sample_scheduler.set_timesteps(num_inference_steps, device=device)
|
850
|
+
sample_schedulers.append(sample_scheduler)
|
851
|
+
|
852
|
+
# Different matrix generation for short vs long video
|
853
|
+
step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix(
|
854
|
+
current_num_latent_frames,
|
855
|
+
timesteps,
|
856
|
+
current_num_latent_frames if is_long_video else base_latent_num_frames,
|
857
|
+
ar_step,
|
858
|
+
prefix_video_latents_frames,
|
859
|
+
causal_block_size,
|
860
|
+
)
|
861
|
+
|
862
|
+
# 7. Denoising loop
|
863
|
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
864
|
+
self._num_timesteps = len(step_matrix)
|
865
|
+
|
866
|
+
with self.progress_bar(total=len(step_matrix)) as progress_bar:
|
867
|
+
for i, t in enumerate(step_matrix):
|
868
|
+
if self.interrupt:
|
869
|
+
continue
|
870
|
+
|
871
|
+
self._current_timestep = t
|
872
|
+
valid_interval_start, valid_interval_end = valid_interval[i]
|
873
|
+
latent_model_input = (
|
874
|
+
latents[:, :, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone()
|
875
|
+
)
|
876
|
+
timestep = t.expand(latents.shape[0], -1)[:, valid_interval_start:valid_interval_end].clone()
|
877
|
+
|
878
|
+
if addnoise_condition > 0 and valid_interval_start < prefix_video_latents_frames:
|
879
|
+
noise_factor = 0.001 * addnoise_condition
|
880
|
+
latent_model_input[:, :, valid_interval_start:prefix_video_latents_frames, :, :] = (
|
881
|
+
latent_model_input[:, :, valid_interval_start:prefix_video_latents_frames, :, :]
|
882
|
+
* (1.0 - noise_factor)
|
883
|
+
+ torch.randn_like(
|
884
|
+
latent_model_input[:, :, valid_interval_start:prefix_video_latents_frames, :, :]
|
885
|
+
)
|
886
|
+
* noise_factor
|
887
|
+
)
|
888
|
+
timestep[:, valid_interval_start:prefix_video_latents_frames] = addnoise_condition
|
889
|
+
|
890
|
+
noise_pred = self.transformer(
|
891
|
+
hidden_states=latent_model_input,
|
892
|
+
timestep=timestep,
|
893
|
+
encoder_hidden_states=prompt_embeds,
|
894
|
+
enable_diffusion_forcing=True,
|
895
|
+
fps=fps_embeds,
|
896
|
+
attention_kwargs=attention_kwargs,
|
897
|
+
return_dict=False,
|
898
|
+
)[0]
|
899
|
+
if self.do_classifier_free_guidance:
|
900
|
+
noise_uncond = self.transformer(
|
901
|
+
hidden_states=latent_model_input,
|
902
|
+
timestep=timestep,
|
903
|
+
encoder_hidden_states=negative_prompt_embeds,
|
904
|
+
enable_diffusion_forcing=True,
|
905
|
+
fps=fps_embeds,
|
906
|
+
attention_kwargs=attention_kwargs,
|
907
|
+
return_dict=False,
|
908
|
+
)[0]
|
909
|
+
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
|
910
|
+
|
911
|
+
update_mask_i = step_update_mask[i]
|
912
|
+
for idx in range(valid_interval_start, valid_interval_end):
|
913
|
+
if update_mask_i[idx].item():
|
914
|
+
latents[:, :, idx, :, :] = sample_schedulers[idx].step(
|
915
|
+
noise_pred[:, :, idx - valid_interval_start, :, :],
|
916
|
+
t[idx],
|
917
|
+
latents[:, :, idx, :, :],
|
918
|
+
return_dict=False,
|
919
|
+
)[0]
|
920
|
+
|
921
|
+
if callback_on_step_end is not None:
|
922
|
+
callback_kwargs = {}
|
923
|
+
for k in callback_on_step_end_tensor_inputs:
|
924
|
+
callback_kwargs[k] = locals()[k]
|
925
|
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
926
|
+
|
927
|
+
latents = callback_outputs.pop("latents", latents)
|
928
|
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
929
|
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
930
|
+
|
931
|
+
# call the callback, if provided
|
932
|
+
if i == len(step_matrix) - 1 or (
|
933
|
+
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
934
|
+
):
|
935
|
+
progress_bar.update()
|
936
|
+
|
937
|
+
if XLA_AVAILABLE:
|
938
|
+
xm.mark_step()
|
939
|
+
|
940
|
+
# Handle latent accumulation for long videos or use the current latents for short videos
|
941
|
+
if is_long_video:
|
942
|
+
if accumulated_latents is None:
|
943
|
+
accumulated_latents = latents
|
944
|
+
else:
|
945
|
+
# Keep overlap frames for conditioning but don't include them in final output
|
946
|
+
accumulated_latents = torch.cat(
|
947
|
+
[accumulated_latents, latents[:, :, overlap_history_latent_frames:]], dim=2
|
948
|
+
)
|
949
|
+
|
950
|
+
if is_long_video:
|
951
|
+
latents = accumulated_latents
|
952
|
+
|
953
|
+
self._current_timestep = None
|
954
|
+
|
955
|
+
# Final decoding step - convert latents to pixels
|
956
|
+
if not output_type == "latent":
|
957
|
+
latents = latents.to(self.vae.dtype)
|
958
|
+
latents_mean = (
|
959
|
+
torch.tensor(self.vae.config.latents_mean)
|
960
|
+
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
961
|
+
.to(latents.device, latents.dtype)
|
962
|
+
)
|
963
|
+
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
964
|
+
latents.device, latents.dtype
|
965
|
+
)
|
966
|
+
latents = latents / latents_std + latents_mean
|
967
|
+
video = self.vae.decode(latents, return_dict=False)[0]
|
968
|
+
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
969
|
+
else:
|
970
|
+
video = latents
|
971
|
+
|
972
|
+
# Offload all models
|
973
|
+
self.maybe_free_model_hooks()
|
974
|
+
|
975
|
+
if not return_dict:
|
976
|
+
return (video,)
|
977
|
+
|
978
|
+
return SkyReelsV2PipelineOutput(frames=video)
|