diffusers 0.33.1__py3-none-any.whl → 0.35.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +145 -1
- diffusers/callbacks.py +35 -0
- diffusers/commands/__init__.py +1 -1
- diffusers/commands/custom_blocks.py +134 -0
- diffusers/commands/diffusers_cli.py +3 -1
- diffusers/commands/env.py +1 -1
- diffusers/commands/fp16_safetensors.py +2 -2
- diffusers/configuration_utils.py +11 -2
- diffusers/dependency_versions_check.py +1 -1
- diffusers/dependency_versions_table.py +3 -3
- diffusers/experimental/rl/value_guided_sampling.py +1 -1
- diffusers/guiders/__init__.py +41 -0
- diffusers/guiders/adaptive_projected_guidance.py +188 -0
- diffusers/guiders/auto_guidance.py +190 -0
- diffusers/guiders/classifier_free_guidance.py +141 -0
- diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
- diffusers/guiders/frequency_decoupled_guidance.py +327 -0
- diffusers/guiders/guider_utils.py +309 -0
- diffusers/guiders/perturbed_attention_guidance.py +271 -0
- diffusers/guiders/skip_layer_guidance.py +262 -0
- diffusers/guiders/smoothed_energy_guidance.py +251 -0
- diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
- diffusers/hooks/__init__.py +17 -0
- diffusers/hooks/_common.py +56 -0
- diffusers/hooks/_helpers.py +293 -0
- diffusers/hooks/faster_cache.py +9 -8
- diffusers/hooks/first_block_cache.py +259 -0
- diffusers/hooks/group_offloading.py +332 -227
- diffusers/hooks/hooks.py +58 -3
- diffusers/hooks/layer_skip.py +263 -0
- diffusers/hooks/layerwise_casting.py +5 -10
- diffusers/hooks/pyramid_attention_broadcast.py +15 -12
- diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
- diffusers/hooks/utils.py +43 -0
- diffusers/image_processor.py +7 -2
- diffusers/loaders/__init__.py +10 -0
- diffusers/loaders/ip_adapter.py +260 -18
- diffusers/loaders/lora_base.py +261 -127
- diffusers/loaders/lora_conversion_utils.py +657 -35
- diffusers/loaders/lora_pipeline.py +2778 -1246
- diffusers/loaders/peft.py +78 -112
- diffusers/loaders/single_file.py +2 -2
- diffusers/loaders/single_file_model.py +64 -15
- diffusers/loaders/single_file_utils.py +395 -7
- diffusers/loaders/textual_inversion.py +3 -2
- diffusers/loaders/transformer_flux.py +10 -11
- diffusers/loaders/transformer_sd3.py +8 -3
- diffusers/loaders/unet.py +24 -21
- diffusers/loaders/unet_loader_utils.py +6 -3
- diffusers/loaders/utils.py +1 -1
- diffusers/models/__init__.py +23 -1
- diffusers/models/activations.py +5 -5
- diffusers/models/adapter.py +2 -3
- diffusers/models/attention.py +488 -7
- diffusers/models/attention_dispatch.py +1218 -0
- diffusers/models/attention_flax.py +10 -10
- diffusers/models/attention_processor.py +113 -667
- diffusers/models/auto_model.py +49 -12
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
- diffusers/models/autoencoders/autoencoder_dc.py +17 -4
- diffusers/models/autoencoders/autoencoder_kl.py +5 -5
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
- diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1110 -0
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_wan.py +626 -62
- diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
- diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
- diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
- diffusers/models/autoencoders/vae.py +13 -2
- diffusers/models/autoencoders/vq_model.py +2 -2
- diffusers/models/cache_utils.py +32 -10
- diffusers/models/controlnet.py +1 -1
- diffusers/models/controlnet_flux.py +1 -1
- diffusers/models/controlnet_sd3.py +1 -1
- diffusers/models/controlnet_sparsectrl.py +1 -1
- diffusers/models/controlnets/__init__.py +1 -0
- diffusers/models/controlnets/controlnet.py +3 -3
- diffusers/models/controlnets/controlnet_flax.py +1 -1
- diffusers/models/controlnets/controlnet_flux.py +21 -20
- diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
- diffusers/models/controlnets/controlnet_sana.py +290 -0
- diffusers/models/controlnets/controlnet_sd3.py +1 -1
- diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
- diffusers/models/controlnets/controlnet_union.py +5 -5
- diffusers/models/controlnets/controlnet_xs.py +7 -7
- diffusers/models/controlnets/multicontrolnet.py +4 -5
- diffusers/models/controlnets/multicontrolnet_union.py +5 -6
- diffusers/models/downsampling.py +2 -2
- diffusers/models/embeddings.py +36 -46
- diffusers/models/embeddings_flax.py +2 -2
- diffusers/models/lora.py +3 -3
- diffusers/models/model_loading_utils.py +233 -1
- diffusers/models/modeling_flax_utils.py +1 -2
- diffusers/models/modeling_utils.py +203 -108
- diffusers/models/normalization.py +4 -4
- diffusers/models/resnet.py +2 -2
- diffusers/models/resnet_flax.py +1 -1
- diffusers/models/transformers/__init__.py +7 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
- diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
- diffusers/models/transformers/consisid_transformer_3d.py +1 -1
- diffusers/models/transformers/dit_transformer_2d.py +2 -2
- diffusers/models/transformers/dual_transformer_2d.py +1 -1
- diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
- diffusers/models/transformers/latte_transformer_3d.py +4 -5
- diffusers/models/transformers/lumina_nextdit2d.py +2 -2
- diffusers/models/transformers/pixart_transformer_2d.py +3 -3
- diffusers/models/transformers/prior_transformer.py +1 -1
- diffusers/models/transformers/sana_transformer.py +8 -3
- diffusers/models/transformers/stable_audio_transformer.py +5 -9
- diffusers/models/transformers/t5_film_transformer.py +3 -3
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +1 -1
- diffusers/models/transformers/transformer_chroma.py +641 -0
- diffusers/models/transformers/transformer_cogview3plus.py +5 -10
- diffusers/models/transformers/transformer_cogview4.py +353 -27
- diffusers/models/transformers/transformer_cosmos.py +586 -0
- diffusers/models/transformers/transformer_flux.py +376 -138
- diffusers/models/transformers/transformer_hidream_image.py +942 -0
- diffusers/models/transformers/transformer_hunyuan_video.py +12 -8
- diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
- diffusers/models/transformers/transformer_ltx.py +105 -24
- diffusers/models/transformers/transformer_lumina2.py +1 -1
- diffusers/models/transformers/transformer_mochi.py +1 -1
- diffusers/models/transformers/transformer_omnigen.py +2 -2
- diffusers/models/transformers/transformer_qwenimage.py +645 -0
- diffusers/models/transformers/transformer_sd3.py +7 -7
- diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/transformers/transformer_wan.py +316 -87
- diffusers/models/transformers/transformer_wan_vace.py +387 -0
- diffusers/models/unets/unet_1d.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +1 -1
- diffusers/models/unets/unet_2d_blocks.py +1 -1
- diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
- diffusers/models/unets/unet_2d_condition.py +4 -3
- diffusers/models/unets/unet_2d_condition_flax.py +2 -2
- diffusers/models/unets/unet_3d_blocks.py +1 -1
- diffusers/models/unets/unet_3d_condition.py +3 -3
- diffusers/models/unets/unet_i2vgen_xl.py +3 -3
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +2 -2
- diffusers/models/unets/unet_stable_cascade.py +1 -1
- diffusers/models/upsampling.py +2 -2
- diffusers/models/vae_flax.py +2 -2
- diffusers/models/vq_model.py +1 -1
- diffusers/modular_pipelines/__init__.py +83 -0
- diffusers/modular_pipelines/components_manager.py +1068 -0
- diffusers/modular_pipelines/flux/__init__.py +66 -0
- diffusers/modular_pipelines/flux/before_denoise.py +689 -0
- diffusers/modular_pipelines/flux/decoders.py +109 -0
- diffusers/modular_pipelines/flux/denoise.py +227 -0
- diffusers/modular_pipelines/flux/encoders.py +412 -0
- diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
- diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
- diffusers/modular_pipelines/modular_pipeline.py +2446 -0
- diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
- diffusers/modular_pipelines/node_utils.py +665 -0
- diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
- diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
- diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
- diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
- diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
- diffusers/modular_pipelines/wan/__init__.py +66 -0
- diffusers/modular_pipelines/wan/before_denoise.py +365 -0
- diffusers/modular_pipelines/wan/decoders.py +105 -0
- diffusers/modular_pipelines/wan/denoise.py +261 -0
- diffusers/modular_pipelines/wan/encoders.py +242 -0
- diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
- diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
- diffusers/pipelines/__init__.py +68 -6
- diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
- diffusers/pipelines/amused/pipeline_amused.py +7 -6
- diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
- diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
- diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +22 -13
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
- diffusers/pipelines/auto_pipeline.py +23 -20
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
- diffusers/pipelines/chroma/__init__.py +49 -0
- diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
- diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
- diffusers/pipelines/chroma/pipeline_output.py +21 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +17 -16
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +17 -16
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +18 -17
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +17 -16
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
- diffusers/pipelines/cogview4/pipeline_cogview4.py +23 -22
- diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
- diffusers/pipelines/consisid/consisid_utils.py +2 -2
- diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +11 -10
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +226 -107
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +12 -8
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +207 -105
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
- diffusers/pipelines/cosmos/__init__.py +54 -0
- diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
- diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
- diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
- diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
- diffusers/pipelines/cosmos/pipeline_output.py +40 -0
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
- diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
- diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
- diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
- diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
- diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
- diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
- diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +8 -8
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
- diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +4 -2
- diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
- diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
- diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
- diffusers/pipelines/flux/__init__.py +4 -0
- diffusers/pipelines/flux/modeling_flux.py +1 -1
- diffusers/pipelines/flux/pipeline_flux.py +37 -36
- diffusers/pipelines/flux/pipeline_flux_control.py +9 -9
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +31 -23
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +3 -2
- diffusers/pipelines/flux/pipeline_flux_fill.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_img2img.py +40 -7
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +12 -7
- diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
- diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +2 -2
- diffusers/pipelines/flux/pipeline_output.py +6 -4
- diffusers/pipelines/free_init_utils.py +2 -2
- diffusers/pipelines/free_noise_utils.py +3 -3
- diffusers/pipelines/hidream_image/__init__.py +47 -0
- diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
- diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
- diffusers/pipelines/hunyuan_video/__init__.py +2 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +26 -25
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
- diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
- diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
- diffusers/pipelines/kolors/text_encoder.py +3 -3
- diffusers/pipelines/kolors/tokenizer.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
- diffusers/pipelines/latte/pipeline_latte.py +12 -12
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
- diffusers/pipelines/ltx/__init__.py +4 -0
- diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +64 -18
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +117 -38
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +63 -18
- diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
- diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
- diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
- diffusers/pipelines/mochi/pipeline_mochi.py +15 -14
- diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
- diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
- diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
- diffusers/pipelines/onnx_utils.py +15 -2
- diffusers/pipelines/pag/pag_utils.py +2 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
- diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
- diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
- diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
- diffusers/pipelines/pia/pipeline_pia.py +8 -6
- diffusers/pipelines/pipeline_flax_utils.py +5 -6
- diffusers/pipelines/pipeline_loading_utils.py +113 -15
- diffusers/pipelines/pipeline_utils.py +127 -48
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +14 -12
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +31 -11
- diffusers/pipelines/qwenimage/__init__.py +55 -0
- diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
- diffusers/pipelines/sana/__init__.py +4 -0
- diffusers/pipelines/sana/pipeline_sana.py +23 -21
- diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
- diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
- diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
- diffusers/pipelines/shap_e/camera.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
- diffusers/pipelines/shap_e/renderer.py +3 -3
- diffusers/pipelines/skyreels_v2/__init__.py +59 -0
- diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
- diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
- diffusers/pipelines/stable_diffusion/__init__.py +0 -7
- diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +12 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +11 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
- diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
- diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +13 -12
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
- diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
- diffusers/pipelines/unclip/text_proj.py +2 -2
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
- diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
- diffusers/pipelines/visualcloze/__init__.py +52 -0
- diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
- diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
- diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
- diffusers/pipelines/wan/__init__.py +2 -0
- diffusers/pipelines/wan/pipeline_wan.py +91 -30
- diffusers/pipelines/wan/pipeline_wan_i2v.py +145 -45
- diffusers/pipelines/wan/pipeline_wan_vace.py +975 -0
- diffusers/pipelines/wan/pipeline_wan_video2video.py +14 -16
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
- diffusers/quantizers/__init__.py +3 -1
- diffusers/quantizers/base.py +17 -1
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
- diffusers/quantizers/bitsandbytes/utils.py +10 -7
- diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
- diffusers/quantizers/gguf/utils.py +108 -16
- diffusers/quantizers/pipe_quant_config.py +202 -0
- diffusers/quantizers/quantization_config.py +18 -16
- diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +31 -1
- diffusers/schedulers/__init__.py +3 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
- diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
- diffusers/schedulers/scheduling_consistency_models.py +1 -1
- diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
- diffusers/schedulers/scheduling_ddim.py +8 -8
- diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
- diffusers/schedulers/scheduling_ddim_flax.py +6 -6
- diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
- diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
- diffusers/schedulers/scheduling_ddpm.py +9 -9
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
- diffusers/schedulers/scheduling_deis_multistep.py +16 -9
- diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +18 -12
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
- diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +19 -13
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
- diffusers/schedulers/scheduling_edm_euler.py +20 -11
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
- diffusers/schedulers/scheduling_euler_discrete.py +3 -3
- diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
- diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
- diffusers/schedulers/scheduling_heun_discrete.py +2 -2
- diffusers/schedulers/scheduling_ipndm.py +2 -2
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
- diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
- diffusers/schedulers/scheduling_lcm.py +3 -3
- diffusers/schedulers/scheduling_lms_discrete.py +2 -2
- diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
- diffusers/schedulers/scheduling_pndm.py +4 -4
- diffusers/schedulers/scheduling_pndm_flax.py +4 -4
- diffusers/schedulers/scheduling_repaint.py +9 -9
- diffusers/schedulers/scheduling_sasolver.py +15 -15
- diffusers/schedulers/scheduling_scm.py +1 -2
- diffusers/schedulers/scheduling_sde_ve.py +1 -1
- diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
- diffusers/schedulers/scheduling_tcd.py +3 -3
- diffusers/schedulers/scheduling_unclip.py +5 -5
- diffusers/schedulers/scheduling_unipc_multistep.py +21 -12
- diffusers/schedulers/scheduling_utils.py +3 -3
- diffusers/schedulers/scheduling_utils_flax.py +2 -2
- diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
- diffusers/training_utils.py +91 -5
- diffusers/utils/__init__.py +15 -0
- diffusers/utils/accelerate_utils.py +1 -1
- diffusers/utils/constants.py +4 -0
- diffusers/utils/doc_utils.py +1 -1
- diffusers/utils/dummy_pt_objects.py +432 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +480 -0
- diffusers/utils/dynamic_modules_utils.py +85 -8
- diffusers/utils/export_utils.py +1 -1
- diffusers/utils/hub_utils.py +33 -17
- diffusers/utils/import_utils.py +151 -18
- diffusers/utils/logging.py +1 -1
- diffusers/utils/outputs.py +2 -1
- diffusers/utils/peft_utils.py +96 -10
- diffusers/utils/state_dict_utils.py +20 -3
- diffusers/utils/testing_utils.py +195 -17
- diffusers/utils/torch_utils.py +43 -5
- diffusers/video_processor.py +2 -2
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/METADATA +72 -57
- diffusers-0.35.0.dist-info/RECORD +703 -0
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/WHEEL +1 -1
- diffusers-0.33.1.dist-info/RECORD +0 -608
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,56 @@
|
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Optional
|
16
|
+
|
17
|
+
import torch
|
18
|
+
|
19
|
+
from ..models.attention import AttentionModuleMixin, FeedForward, LuminaFeedForward
|
20
|
+
from ..models.attention_processor import Attention, MochiAttention
|
21
|
+
|
22
|
+
|
23
|
+
_ATTENTION_CLASSES = (Attention, MochiAttention, AttentionModuleMixin)
|
24
|
+
_FEEDFORWARD_CLASSES = (FeedForward, LuminaFeedForward)
|
25
|
+
|
26
|
+
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
|
27
|
+
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
|
28
|
+
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers")
|
29
|
+
|
30
|
+
_ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple(
|
31
|
+
{
|
32
|
+
*_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
|
33
|
+
*_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
|
34
|
+
*_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
|
35
|
+
}
|
36
|
+
)
|
37
|
+
|
38
|
+
# Layers supported for group offloading and layerwise casting
|
39
|
+
_GO_LC_SUPPORTED_PYTORCH_LAYERS = (
|
40
|
+
torch.nn.Conv1d,
|
41
|
+
torch.nn.Conv2d,
|
42
|
+
torch.nn.Conv3d,
|
43
|
+
torch.nn.ConvTranspose1d,
|
44
|
+
torch.nn.ConvTranspose2d,
|
45
|
+
torch.nn.ConvTranspose3d,
|
46
|
+
torch.nn.Linear,
|
47
|
+
# TODO(aryan): look into torch.nn.LayerNorm, torch.nn.GroupNorm later, seems to be causing some issues with CogVideoX
|
48
|
+
# because of double invocation of the same norm layer in CogVideoXLayerNorm
|
49
|
+
)
|
50
|
+
|
51
|
+
|
52
|
+
def _get_submodule_from_fqn(module: torch.nn.Module, fqn: str) -> Optional[torch.nn.Module]:
|
53
|
+
for submodule_name, submodule in module.named_modules():
|
54
|
+
if submodule_name == fqn:
|
55
|
+
return submodule
|
56
|
+
return None
|
@@ -0,0 +1,293 @@
|
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import inspect
|
16
|
+
from dataclasses import dataclass
|
17
|
+
from typing import Any, Callable, Dict, Type
|
18
|
+
|
19
|
+
|
20
|
+
@dataclass
|
21
|
+
class AttentionProcessorMetadata:
|
22
|
+
skip_processor_output_fn: Callable[[Any], Any]
|
23
|
+
|
24
|
+
|
25
|
+
@dataclass
|
26
|
+
class TransformerBlockMetadata:
|
27
|
+
return_hidden_states_index: int = None
|
28
|
+
return_encoder_hidden_states_index: int = None
|
29
|
+
|
30
|
+
_cls: Type = None
|
31
|
+
_cached_parameter_indices: Dict[str, int] = None
|
32
|
+
|
33
|
+
def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None):
|
34
|
+
kwargs = kwargs or {}
|
35
|
+
if identifier in kwargs:
|
36
|
+
return kwargs[identifier]
|
37
|
+
if self._cached_parameter_indices is not None:
|
38
|
+
return args[self._cached_parameter_indices[identifier]]
|
39
|
+
if self._cls is None:
|
40
|
+
raise ValueError("Model class is not set for metadata.")
|
41
|
+
parameters = list(inspect.signature(self._cls.forward).parameters.keys())
|
42
|
+
parameters = parameters[1:] # skip `self`
|
43
|
+
self._cached_parameter_indices = {param: i for i, param in enumerate(parameters)}
|
44
|
+
if identifier not in self._cached_parameter_indices:
|
45
|
+
raise ValueError(f"Parameter '{identifier}' not found in function signature but was requested.")
|
46
|
+
index = self._cached_parameter_indices[identifier]
|
47
|
+
if index >= len(args):
|
48
|
+
raise ValueError(f"Expected {index} arguments but got {len(args)}.")
|
49
|
+
return args[index]
|
50
|
+
|
51
|
+
|
52
|
+
class AttentionProcessorRegistry:
|
53
|
+
_registry = {}
|
54
|
+
# TODO(aryan): this is only required for the time being because we need to do the registrations
|
55
|
+
# for classes. If we do it eagerly, i.e. call the functions in global scope, we will get circular
|
56
|
+
# import errors because of the models imported in this file.
|
57
|
+
_is_registered = False
|
58
|
+
|
59
|
+
@classmethod
|
60
|
+
def register(cls, model_class: Type, metadata: AttentionProcessorMetadata):
|
61
|
+
cls._register()
|
62
|
+
cls._registry[model_class] = metadata
|
63
|
+
|
64
|
+
@classmethod
|
65
|
+
def get(cls, model_class: Type) -> AttentionProcessorMetadata:
|
66
|
+
cls._register()
|
67
|
+
if model_class not in cls._registry:
|
68
|
+
raise ValueError(f"Model class {model_class} not registered.")
|
69
|
+
return cls._registry[model_class]
|
70
|
+
|
71
|
+
@classmethod
|
72
|
+
def _register(cls):
|
73
|
+
if cls._is_registered:
|
74
|
+
return
|
75
|
+
cls._is_registered = True
|
76
|
+
_register_attention_processors_metadata()
|
77
|
+
|
78
|
+
|
79
|
+
class TransformerBlockRegistry:
|
80
|
+
_registry = {}
|
81
|
+
# TODO(aryan): this is only required for the time being because we need to do the registrations
|
82
|
+
# for classes. If we do it eagerly, i.e. call the functions in global scope, we will get circular
|
83
|
+
# import errors because of the models imported in this file.
|
84
|
+
_is_registered = False
|
85
|
+
|
86
|
+
@classmethod
|
87
|
+
def register(cls, model_class: Type, metadata: TransformerBlockMetadata):
|
88
|
+
cls._register()
|
89
|
+
metadata._cls = model_class
|
90
|
+
cls._registry[model_class] = metadata
|
91
|
+
|
92
|
+
@classmethod
|
93
|
+
def get(cls, model_class: Type) -> TransformerBlockMetadata:
|
94
|
+
cls._register()
|
95
|
+
if model_class not in cls._registry:
|
96
|
+
raise ValueError(f"Model class {model_class} not registered.")
|
97
|
+
return cls._registry[model_class]
|
98
|
+
|
99
|
+
@classmethod
|
100
|
+
def _register(cls):
|
101
|
+
if cls._is_registered:
|
102
|
+
return
|
103
|
+
cls._is_registered = True
|
104
|
+
_register_transformer_blocks_metadata()
|
105
|
+
|
106
|
+
|
107
|
+
def _register_attention_processors_metadata():
|
108
|
+
from ..models.attention_processor import AttnProcessor2_0
|
109
|
+
from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor
|
110
|
+
from ..models.transformers.transformer_flux import FluxAttnProcessor
|
111
|
+
from ..models.transformers.transformer_wan import WanAttnProcessor2_0
|
112
|
+
|
113
|
+
# AttnProcessor2_0
|
114
|
+
AttentionProcessorRegistry.register(
|
115
|
+
model_class=AttnProcessor2_0,
|
116
|
+
metadata=AttentionProcessorMetadata(
|
117
|
+
skip_processor_output_fn=_skip_proc_output_fn_Attention_AttnProcessor2_0,
|
118
|
+
),
|
119
|
+
)
|
120
|
+
|
121
|
+
# CogView4AttnProcessor
|
122
|
+
AttentionProcessorRegistry.register(
|
123
|
+
model_class=CogView4AttnProcessor,
|
124
|
+
metadata=AttentionProcessorMetadata(
|
125
|
+
skip_processor_output_fn=_skip_proc_output_fn_Attention_CogView4AttnProcessor,
|
126
|
+
),
|
127
|
+
)
|
128
|
+
|
129
|
+
# WanAttnProcessor2_0
|
130
|
+
AttentionProcessorRegistry.register(
|
131
|
+
model_class=WanAttnProcessor2_0,
|
132
|
+
metadata=AttentionProcessorMetadata(
|
133
|
+
skip_processor_output_fn=_skip_proc_output_fn_Attention_WanAttnProcessor2_0,
|
134
|
+
),
|
135
|
+
)
|
136
|
+
|
137
|
+
# FluxAttnProcessor
|
138
|
+
AttentionProcessorRegistry.register(
|
139
|
+
model_class=FluxAttnProcessor,
|
140
|
+
metadata=AttentionProcessorMetadata(skip_processor_output_fn=_skip_proc_output_fn_Attention_FluxAttnProcessor),
|
141
|
+
)
|
142
|
+
|
143
|
+
|
144
|
+
def _register_transformer_blocks_metadata():
|
145
|
+
from ..models.attention import BasicTransformerBlock
|
146
|
+
from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock
|
147
|
+
from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock
|
148
|
+
from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
|
149
|
+
from ..models.transformers.transformer_hunyuan_video import (
|
150
|
+
HunyuanVideoSingleTransformerBlock,
|
151
|
+
HunyuanVideoTokenReplaceSingleTransformerBlock,
|
152
|
+
HunyuanVideoTokenReplaceTransformerBlock,
|
153
|
+
HunyuanVideoTransformerBlock,
|
154
|
+
)
|
155
|
+
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
|
156
|
+
from ..models.transformers.transformer_mochi import MochiTransformerBlock
|
157
|
+
from ..models.transformers.transformer_qwenimage import QwenImageTransformerBlock
|
158
|
+
from ..models.transformers.transformer_wan import WanTransformerBlock
|
159
|
+
|
160
|
+
# BasicTransformerBlock
|
161
|
+
TransformerBlockRegistry.register(
|
162
|
+
model_class=BasicTransformerBlock,
|
163
|
+
metadata=TransformerBlockMetadata(
|
164
|
+
return_hidden_states_index=0,
|
165
|
+
return_encoder_hidden_states_index=None,
|
166
|
+
),
|
167
|
+
)
|
168
|
+
|
169
|
+
# CogVideoX
|
170
|
+
TransformerBlockRegistry.register(
|
171
|
+
model_class=CogVideoXBlock,
|
172
|
+
metadata=TransformerBlockMetadata(
|
173
|
+
return_hidden_states_index=0,
|
174
|
+
return_encoder_hidden_states_index=1,
|
175
|
+
),
|
176
|
+
)
|
177
|
+
|
178
|
+
# CogView4
|
179
|
+
TransformerBlockRegistry.register(
|
180
|
+
model_class=CogView4TransformerBlock,
|
181
|
+
metadata=TransformerBlockMetadata(
|
182
|
+
return_hidden_states_index=0,
|
183
|
+
return_encoder_hidden_states_index=1,
|
184
|
+
),
|
185
|
+
)
|
186
|
+
|
187
|
+
# Flux
|
188
|
+
TransformerBlockRegistry.register(
|
189
|
+
model_class=FluxTransformerBlock,
|
190
|
+
metadata=TransformerBlockMetadata(
|
191
|
+
return_hidden_states_index=1,
|
192
|
+
return_encoder_hidden_states_index=0,
|
193
|
+
),
|
194
|
+
)
|
195
|
+
TransformerBlockRegistry.register(
|
196
|
+
model_class=FluxSingleTransformerBlock,
|
197
|
+
metadata=TransformerBlockMetadata(
|
198
|
+
return_hidden_states_index=1,
|
199
|
+
return_encoder_hidden_states_index=0,
|
200
|
+
),
|
201
|
+
)
|
202
|
+
|
203
|
+
# HunyuanVideo
|
204
|
+
TransformerBlockRegistry.register(
|
205
|
+
model_class=HunyuanVideoTransformerBlock,
|
206
|
+
metadata=TransformerBlockMetadata(
|
207
|
+
return_hidden_states_index=0,
|
208
|
+
return_encoder_hidden_states_index=1,
|
209
|
+
),
|
210
|
+
)
|
211
|
+
TransformerBlockRegistry.register(
|
212
|
+
model_class=HunyuanVideoSingleTransformerBlock,
|
213
|
+
metadata=TransformerBlockMetadata(
|
214
|
+
return_hidden_states_index=0,
|
215
|
+
return_encoder_hidden_states_index=1,
|
216
|
+
),
|
217
|
+
)
|
218
|
+
TransformerBlockRegistry.register(
|
219
|
+
model_class=HunyuanVideoTokenReplaceTransformerBlock,
|
220
|
+
metadata=TransformerBlockMetadata(
|
221
|
+
return_hidden_states_index=0,
|
222
|
+
return_encoder_hidden_states_index=1,
|
223
|
+
),
|
224
|
+
)
|
225
|
+
TransformerBlockRegistry.register(
|
226
|
+
model_class=HunyuanVideoTokenReplaceSingleTransformerBlock,
|
227
|
+
metadata=TransformerBlockMetadata(
|
228
|
+
return_hidden_states_index=0,
|
229
|
+
return_encoder_hidden_states_index=1,
|
230
|
+
),
|
231
|
+
)
|
232
|
+
|
233
|
+
# LTXVideo
|
234
|
+
TransformerBlockRegistry.register(
|
235
|
+
model_class=LTXVideoTransformerBlock,
|
236
|
+
metadata=TransformerBlockMetadata(
|
237
|
+
return_hidden_states_index=0,
|
238
|
+
return_encoder_hidden_states_index=None,
|
239
|
+
),
|
240
|
+
)
|
241
|
+
|
242
|
+
# Mochi
|
243
|
+
TransformerBlockRegistry.register(
|
244
|
+
model_class=MochiTransformerBlock,
|
245
|
+
metadata=TransformerBlockMetadata(
|
246
|
+
return_hidden_states_index=0,
|
247
|
+
return_encoder_hidden_states_index=1,
|
248
|
+
),
|
249
|
+
)
|
250
|
+
|
251
|
+
# Wan
|
252
|
+
TransformerBlockRegistry.register(
|
253
|
+
model_class=WanTransformerBlock,
|
254
|
+
metadata=TransformerBlockMetadata(
|
255
|
+
return_hidden_states_index=0,
|
256
|
+
return_encoder_hidden_states_index=None,
|
257
|
+
),
|
258
|
+
)
|
259
|
+
|
260
|
+
# QwenImage
|
261
|
+
TransformerBlockRegistry.register(
|
262
|
+
model_class=QwenImageTransformerBlock,
|
263
|
+
metadata=TransformerBlockMetadata(
|
264
|
+
return_hidden_states_index=1,
|
265
|
+
return_encoder_hidden_states_index=0,
|
266
|
+
),
|
267
|
+
)
|
268
|
+
|
269
|
+
|
270
|
+
# fmt: off
|
271
|
+
def _skip_attention___ret___hidden_states(self, *args, **kwargs):
|
272
|
+
hidden_states = kwargs.get("hidden_states", None)
|
273
|
+
if hidden_states is None and len(args) > 0:
|
274
|
+
hidden_states = args[0]
|
275
|
+
return hidden_states
|
276
|
+
|
277
|
+
|
278
|
+
def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs):
|
279
|
+
hidden_states = kwargs.get("hidden_states", None)
|
280
|
+
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
|
281
|
+
if hidden_states is None and len(args) > 0:
|
282
|
+
hidden_states = args[0]
|
283
|
+
if encoder_hidden_states is None and len(args) > 1:
|
284
|
+
encoder_hidden_states = args[1]
|
285
|
+
return hidden_states, encoder_hidden_states
|
286
|
+
|
287
|
+
|
288
|
+
_skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states
|
289
|
+
_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states
|
290
|
+
_skip_proc_output_fn_Attention_WanAttnProcessor2_0 = _skip_attention___ret___hidden_states
|
291
|
+
# not sure what this is yet.
|
292
|
+
_skip_proc_output_fn_Attention_FluxAttnProcessor = _skip_attention___ret___hidden_states
|
293
|
+
# fmt: on
|
diffusers/hooks/faster_cache.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -18,9 +18,10 @@ from typing import Any, Callable, List, Optional, Tuple
|
|
18
18
|
|
19
19
|
import torch
|
20
20
|
|
21
|
-
from ..models.
|
21
|
+
from ..models.attention import AttentionModuleMixin
|
22
22
|
from ..models.modeling_outputs import Transformer2DModelOutput
|
23
23
|
from ..utils import logging
|
24
|
+
from ._common import _ATTENTION_CLASSES
|
24
25
|
from .hooks import HookRegistry, ModelHook
|
25
26
|
|
26
27
|
|
@@ -29,7 +30,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|
29
30
|
|
30
31
|
_FASTER_CACHE_DENOISER_HOOK = "faster_cache_denoiser"
|
31
32
|
_FASTER_CACHE_BLOCK_HOOK = "faster_cache_block"
|
32
|
-
_ATTENTION_CLASSES = (Attention, MochiAttention)
|
33
33
|
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = (
|
34
34
|
"^blocks.*attn",
|
35
35
|
"^transformer_blocks.*attn",
|
@@ -146,7 +146,7 @@ class FasterCacheConfig:
|
|
146
146
|
alpha_low_frequency: float = 1.1
|
147
147
|
alpha_high_frequency: float = 1.1
|
148
148
|
|
149
|
-
# n as described in CFG-Cache explanation in the paper -
|
149
|
+
# n as described in CFG-Cache explanation in the paper - dependent on the model
|
150
150
|
unconditional_batch_skip_range: int = 5
|
151
151
|
unconditional_batch_timestep_skip_range: Tuple[int, int] = (-1, 641)
|
152
152
|
|
@@ -488,9 +488,10 @@ def apply_faster_cache(module: torch.nn.Module, config: FasterCacheConfig) -> No
|
|
488
488
|
Applies [FasterCache](https://huggingface.co/papers/2410.19355) to a given pipeline.
|
489
489
|
|
490
490
|
Args:
|
491
|
-
|
492
|
-
The
|
493
|
-
|
491
|
+
module (`torch.nn.Module`):
|
492
|
+
The pytorch module to apply FasterCache to. Typically, this should be a transformer architecture supported
|
493
|
+
in Diffusers, such as `CogVideoXTransformer3DModel`, but external implementations may also work.
|
494
|
+
config (`FasterCacheConfig`):
|
494
495
|
The configuration to use for FasterCache.
|
495
496
|
|
496
497
|
Example:
|
@@ -588,7 +589,7 @@ def _apply_faster_cache_on_denoiser(module: torch.nn.Module, config: FasterCache
|
|
588
589
|
registry.register_hook(hook, _FASTER_CACHE_DENOISER_HOOK)
|
589
590
|
|
590
591
|
|
591
|
-
def _apply_faster_cache_on_attention_class(name: str, module:
|
592
|
+
def _apply_faster_cache_on_attention_class(name: str, module: AttentionModuleMixin, config: FasterCacheConfig) -> None:
|
592
593
|
is_spatial_self_attention = (
|
593
594
|
any(re.search(identifier, name) is not None for identifier in config.spatial_attention_block_identifiers)
|
594
595
|
and config.spatial_attention_block_skip_range is not None
|
@@ -0,0 +1,259 @@
|
|
1
|
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from dataclasses import dataclass
|
16
|
+
from typing import Tuple, Union
|
17
|
+
|
18
|
+
import torch
|
19
|
+
|
20
|
+
from ..utils import get_logger
|
21
|
+
from ..utils.torch_utils import unwrap_module
|
22
|
+
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
|
23
|
+
from ._helpers import TransformerBlockRegistry
|
24
|
+
from .hooks import BaseState, HookRegistry, ModelHook, StateManager
|
25
|
+
|
26
|
+
|
27
|
+
logger = get_logger(__name__) # pylint: disable=invalid-name
|
28
|
+
|
29
|
+
_FBC_LEADER_BLOCK_HOOK = "fbc_leader_block_hook"
|
30
|
+
_FBC_BLOCK_HOOK = "fbc_block_hook"
|
31
|
+
|
32
|
+
|
33
|
+
@dataclass
|
34
|
+
class FirstBlockCacheConfig:
|
35
|
+
r"""
|
36
|
+
Configuration for [First Block
|
37
|
+
Cache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching).
|
38
|
+
|
39
|
+
Args:
|
40
|
+
threshold (`float`, defaults to `0.05`):
|
41
|
+
The threshold to determine whether or not a forward pass through all layers of the model is required. A
|
42
|
+
higher threshold usually results in a forward pass through a lower number of layers and faster inference,
|
43
|
+
but might lead to poorer generation quality. A lower threshold may not result in significant generation
|
44
|
+
speedup. The threshold is compared against the absmean difference of the residuals between the current and
|
45
|
+
cached outputs from the first transformer block. If the difference is below the threshold, the forward pass
|
46
|
+
is skipped.
|
47
|
+
"""
|
48
|
+
|
49
|
+
threshold: float = 0.05
|
50
|
+
|
51
|
+
|
52
|
+
class FBCSharedBlockState(BaseState):
|
53
|
+
def __init__(self) -> None:
|
54
|
+
super().__init__()
|
55
|
+
|
56
|
+
self.head_block_output: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
|
57
|
+
self.head_block_residual: torch.Tensor = None
|
58
|
+
self.tail_block_residuals: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
|
59
|
+
self.should_compute: bool = True
|
60
|
+
|
61
|
+
def reset(self):
|
62
|
+
self.tail_block_residuals = None
|
63
|
+
self.should_compute = True
|
64
|
+
|
65
|
+
|
66
|
+
class FBCHeadBlockHook(ModelHook):
|
67
|
+
_is_stateful = True
|
68
|
+
|
69
|
+
def __init__(self, state_manager: StateManager, threshold: float):
|
70
|
+
self.state_manager = state_manager
|
71
|
+
self.threshold = threshold
|
72
|
+
self._metadata = None
|
73
|
+
|
74
|
+
def initialize_hook(self, module):
|
75
|
+
unwrapped_module = unwrap_module(module)
|
76
|
+
self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__)
|
77
|
+
return module
|
78
|
+
|
79
|
+
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
80
|
+
original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
|
81
|
+
|
82
|
+
output = self.fn_ref.original_forward(*args, **kwargs)
|
83
|
+
is_output_tuple = isinstance(output, tuple)
|
84
|
+
|
85
|
+
if is_output_tuple:
|
86
|
+
hidden_states_residual = output[self._metadata.return_hidden_states_index] - original_hidden_states
|
87
|
+
else:
|
88
|
+
hidden_states_residual = output - original_hidden_states
|
89
|
+
|
90
|
+
shared_state: FBCSharedBlockState = self.state_manager.get_state()
|
91
|
+
hidden_states = encoder_hidden_states = None
|
92
|
+
should_compute = self._should_compute_remaining_blocks(hidden_states_residual)
|
93
|
+
shared_state.should_compute = should_compute
|
94
|
+
|
95
|
+
if not should_compute:
|
96
|
+
# Apply caching
|
97
|
+
if is_output_tuple:
|
98
|
+
hidden_states = (
|
99
|
+
shared_state.tail_block_residuals[0] + output[self._metadata.return_hidden_states_index]
|
100
|
+
)
|
101
|
+
else:
|
102
|
+
hidden_states = shared_state.tail_block_residuals[0] + output
|
103
|
+
|
104
|
+
if self._metadata.return_encoder_hidden_states_index is not None:
|
105
|
+
assert is_output_tuple
|
106
|
+
encoder_hidden_states = (
|
107
|
+
shared_state.tail_block_residuals[1] + output[self._metadata.return_encoder_hidden_states_index]
|
108
|
+
)
|
109
|
+
|
110
|
+
if is_output_tuple:
|
111
|
+
return_output = [None] * len(output)
|
112
|
+
return_output[self._metadata.return_hidden_states_index] = hidden_states
|
113
|
+
return_output[self._metadata.return_encoder_hidden_states_index] = encoder_hidden_states
|
114
|
+
return_output = tuple(return_output)
|
115
|
+
else:
|
116
|
+
return_output = hidden_states
|
117
|
+
output = return_output
|
118
|
+
else:
|
119
|
+
if is_output_tuple:
|
120
|
+
head_block_output = [None] * len(output)
|
121
|
+
head_block_output[0] = output[self._metadata.return_hidden_states_index]
|
122
|
+
head_block_output[1] = output[self._metadata.return_encoder_hidden_states_index]
|
123
|
+
else:
|
124
|
+
head_block_output = output
|
125
|
+
shared_state.head_block_output = head_block_output
|
126
|
+
shared_state.head_block_residual = hidden_states_residual
|
127
|
+
|
128
|
+
return output
|
129
|
+
|
130
|
+
def reset_state(self, module):
|
131
|
+
self.state_manager.reset()
|
132
|
+
return module
|
133
|
+
|
134
|
+
@torch.compiler.disable
|
135
|
+
def _should_compute_remaining_blocks(self, hidden_states_residual: torch.Tensor) -> bool:
|
136
|
+
shared_state = self.state_manager.get_state()
|
137
|
+
if shared_state.head_block_residual is None:
|
138
|
+
return True
|
139
|
+
prev_hidden_states_residual = shared_state.head_block_residual
|
140
|
+
absmean = (hidden_states_residual - prev_hidden_states_residual).abs().mean()
|
141
|
+
prev_hidden_states_absmean = prev_hidden_states_residual.abs().mean()
|
142
|
+
diff = (absmean / prev_hidden_states_absmean).item()
|
143
|
+
return diff > self.threshold
|
144
|
+
|
145
|
+
|
146
|
+
class FBCBlockHook(ModelHook):
|
147
|
+
def __init__(self, state_manager: StateManager, is_tail: bool = False):
|
148
|
+
super().__init__()
|
149
|
+
self.state_manager = state_manager
|
150
|
+
self.is_tail = is_tail
|
151
|
+
self._metadata = None
|
152
|
+
|
153
|
+
def initialize_hook(self, module):
|
154
|
+
unwrapped_module = unwrap_module(module)
|
155
|
+
self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__)
|
156
|
+
return module
|
157
|
+
|
158
|
+
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
159
|
+
original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
|
160
|
+
original_encoder_hidden_states = None
|
161
|
+
if self._metadata.return_encoder_hidden_states_index is not None:
|
162
|
+
original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs(
|
163
|
+
"encoder_hidden_states", args, kwargs
|
164
|
+
)
|
165
|
+
|
166
|
+
shared_state = self.state_manager.get_state()
|
167
|
+
|
168
|
+
if shared_state.should_compute:
|
169
|
+
output = self.fn_ref.original_forward(*args, **kwargs)
|
170
|
+
if self.is_tail:
|
171
|
+
hidden_states_residual = encoder_hidden_states_residual = None
|
172
|
+
if isinstance(output, tuple):
|
173
|
+
hidden_states_residual = (
|
174
|
+
output[self._metadata.return_hidden_states_index] - shared_state.head_block_output[0]
|
175
|
+
)
|
176
|
+
encoder_hidden_states_residual = (
|
177
|
+
output[self._metadata.return_encoder_hidden_states_index] - shared_state.head_block_output[1]
|
178
|
+
)
|
179
|
+
else:
|
180
|
+
hidden_states_residual = output - shared_state.head_block_output
|
181
|
+
shared_state.tail_block_residuals = (hidden_states_residual, encoder_hidden_states_residual)
|
182
|
+
return output
|
183
|
+
|
184
|
+
if original_encoder_hidden_states is None:
|
185
|
+
return_output = original_hidden_states
|
186
|
+
else:
|
187
|
+
return_output = [None, None]
|
188
|
+
return_output[self._metadata.return_hidden_states_index] = original_hidden_states
|
189
|
+
return_output[self._metadata.return_encoder_hidden_states_index] = original_encoder_hidden_states
|
190
|
+
return_output = tuple(return_output)
|
191
|
+
return return_output
|
192
|
+
|
193
|
+
|
194
|
+
def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConfig) -> None:
|
195
|
+
"""
|
196
|
+
Applies [First Block
|
197
|
+
Cache](https://github.com/chengzeyi/ParaAttention/blob/4de137c5b96416489f06e43e19f2c14a772e28fd/README.md#first-block-cache-our-dynamic-caching)
|
198
|
+
to a given module.
|
199
|
+
|
200
|
+
First Block Cache builds on the ideas of [TeaCache](https://huggingface.co/papers/2411.19108). It is much simpler
|
201
|
+
to implement generically for a wide range of models and has been integrated first for experimental purposes.
|
202
|
+
|
203
|
+
Args:
|
204
|
+
module (`torch.nn.Module`):
|
205
|
+
The pytorch module to apply FBCache to. Typically, this should be a transformer architecture supported in
|
206
|
+
Diffusers, such as `CogVideoXTransformer3DModel`, but external implementations may also work.
|
207
|
+
config (`FirstBlockCacheConfig`):
|
208
|
+
The configuration to use for applying the FBCache method.
|
209
|
+
|
210
|
+
Example:
|
211
|
+
```python
|
212
|
+
>>> import torch
|
213
|
+
>>> from diffusers import CogView4Pipeline
|
214
|
+
>>> from diffusers.hooks import apply_first_block_cache, FirstBlockCacheConfig
|
215
|
+
|
216
|
+
>>> pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16)
|
217
|
+
>>> pipe.to("cuda")
|
218
|
+
|
219
|
+
>>> apply_first_block_cache(pipe.transformer, FirstBlockCacheConfig(threshold=0.2))
|
220
|
+
|
221
|
+
>>> prompt = "A photo of an astronaut riding a horse on mars"
|
222
|
+
>>> image = pipe(prompt, generator=torch.Generator().manual_seed(42)).images[0]
|
223
|
+
>>> image.save("output.png")
|
224
|
+
```
|
225
|
+
"""
|
226
|
+
|
227
|
+
state_manager = StateManager(FBCSharedBlockState, (), {})
|
228
|
+
remaining_blocks = []
|
229
|
+
|
230
|
+
for name, submodule in module.named_children():
|
231
|
+
if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList):
|
232
|
+
continue
|
233
|
+
for index, block in enumerate(submodule):
|
234
|
+
remaining_blocks.append((f"{name}.{index}", block))
|
235
|
+
|
236
|
+
head_block_name, head_block = remaining_blocks.pop(0)
|
237
|
+
tail_block_name, tail_block = remaining_blocks.pop(-1)
|
238
|
+
|
239
|
+
logger.debug(f"Applying FBCHeadBlockHook to '{head_block_name}'")
|
240
|
+
_apply_fbc_head_block_hook(head_block, state_manager, config.threshold)
|
241
|
+
|
242
|
+
for name, block in remaining_blocks:
|
243
|
+
logger.debug(f"Applying FBCBlockHook to '{name}'")
|
244
|
+
_apply_fbc_block_hook(block, state_manager)
|
245
|
+
|
246
|
+
logger.debug(f"Applying FBCBlockHook to tail block '{tail_block_name}'")
|
247
|
+
_apply_fbc_block_hook(tail_block, state_manager, is_tail=True)
|
248
|
+
|
249
|
+
|
250
|
+
def _apply_fbc_head_block_hook(block: torch.nn.Module, state_manager: StateManager, threshold: float) -> None:
|
251
|
+
registry = HookRegistry.check_if_exists_or_initialize(block)
|
252
|
+
hook = FBCHeadBlockHook(state_manager, threshold)
|
253
|
+
registry.register_hook(hook, _FBC_LEADER_BLOCK_HOOK)
|
254
|
+
|
255
|
+
|
256
|
+
def _apply_fbc_block_hook(block: torch.nn.Module, state_manager: StateManager, is_tail: bool = False) -> None:
|
257
|
+
registry = HookRegistry.check_if_exists_or_initialize(block)
|
258
|
+
hook = FBCBlockHook(state_manager, is_tail)
|
259
|
+
registry.register_hook(hook, _FBC_BLOCK_HOOK)
|