diffusers 0.33.1__py3-none-any.whl → 0.35.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +145 -1
- diffusers/callbacks.py +35 -0
- diffusers/commands/__init__.py +1 -1
- diffusers/commands/custom_blocks.py +134 -0
- diffusers/commands/diffusers_cli.py +3 -1
- diffusers/commands/env.py +1 -1
- diffusers/commands/fp16_safetensors.py +2 -2
- diffusers/configuration_utils.py +11 -2
- diffusers/dependency_versions_check.py +1 -1
- diffusers/dependency_versions_table.py +3 -3
- diffusers/experimental/rl/value_guided_sampling.py +1 -1
- diffusers/guiders/__init__.py +41 -0
- diffusers/guiders/adaptive_projected_guidance.py +188 -0
- diffusers/guiders/auto_guidance.py +190 -0
- diffusers/guiders/classifier_free_guidance.py +141 -0
- diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
- diffusers/guiders/frequency_decoupled_guidance.py +327 -0
- diffusers/guiders/guider_utils.py +309 -0
- diffusers/guiders/perturbed_attention_guidance.py +271 -0
- diffusers/guiders/skip_layer_guidance.py +262 -0
- diffusers/guiders/smoothed_energy_guidance.py +251 -0
- diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
- diffusers/hooks/__init__.py +17 -0
- diffusers/hooks/_common.py +56 -0
- diffusers/hooks/_helpers.py +293 -0
- diffusers/hooks/faster_cache.py +9 -8
- diffusers/hooks/first_block_cache.py +259 -0
- diffusers/hooks/group_offloading.py +332 -227
- diffusers/hooks/hooks.py +58 -3
- diffusers/hooks/layer_skip.py +263 -0
- diffusers/hooks/layerwise_casting.py +5 -10
- diffusers/hooks/pyramid_attention_broadcast.py +15 -12
- diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
- diffusers/hooks/utils.py +43 -0
- diffusers/image_processor.py +7 -2
- diffusers/loaders/__init__.py +10 -0
- diffusers/loaders/ip_adapter.py +260 -18
- diffusers/loaders/lora_base.py +261 -127
- diffusers/loaders/lora_conversion_utils.py +657 -35
- diffusers/loaders/lora_pipeline.py +2778 -1246
- diffusers/loaders/peft.py +78 -112
- diffusers/loaders/single_file.py +2 -2
- diffusers/loaders/single_file_model.py +64 -15
- diffusers/loaders/single_file_utils.py +395 -7
- diffusers/loaders/textual_inversion.py +3 -2
- diffusers/loaders/transformer_flux.py +10 -11
- diffusers/loaders/transformer_sd3.py +8 -3
- diffusers/loaders/unet.py +24 -21
- diffusers/loaders/unet_loader_utils.py +6 -3
- diffusers/loaders/utils.py +1 -1
- diffusers/models/__init__.py +23 -1
- diffusers/models/activations.py +5 -5
- diffusers/models/adapter.py +2 -3
- diffusers/models/attention.py +488 -7
- diffusers/models/attention_dispatch.py +1218 -0
- diffusers/models/attention_flax.py +10 -10
- diffusers/models/attention_processor.py +113 -667
- diffusers/models/auto_model.py +49 -12
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
- diffusers/models/autoencoders/autoencoder_dc.py +17 -4
- diffusers/models/autoencoders/autoencoder_kl.py +5 -5
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
- diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1110 -0
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_wan.py +626 -62
- diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
- diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
- diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
- diffusers/models/autoencoders/vae.py +13 -2
- diffusers/models/autoencoders/vq_model.py +2 -2
- diffusers/models/cache_utils.py +32 -10
- diffusers/models/controlnet.py +1 -1
- diffusers/models/controlnet_flux.py +1 -1
- diffusers/models/controlnet_sd3.py +1 -1
- diffusers/models/controlnet_sparsectrl.py +1 -1
- diffusers/models/controlnets/__init__.py +1 -0
- diffusers/models/controlnets/controlnet.py +3 -3
- diffusers/models/controlnets/controlnet_flax.py +1 -1
- diffusers/models/controlnets/controlnet_flux.py +21 -20
- diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
- diffusers/models/controlnets/controlnet_sana.py +290 -0
- diffusers/models/controlnets/controlnet_sd3.py +1 -1
- diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
- diffusers/models/controlnets/controlnet_union.py +5 -5
- diffusers/models/controlnets/controlnet_xs.py +7 -7
- diffusers/models/controlnets/multicontrolnet.py +4 -5
- diffusers/models/controlnets/multicontrolnet_union.py +5 -6
- diffusers/models/downsampling.py +2 -2
- diffusers/models/embeddings.py +36 -46
- diffusers/models/embeddings_flax.py +2 -2
- diffusers/models/lora.py +3 -3
- diffusers/models/model_loading_utils.py +233 -1
- diffusers/models/modeling_flax_utils.py +1 -2
- diffusers/models/modeling_utils.py +203 -108
- diffusers/models/normalization.py +4 -4
- diffusers/models/resnet.py +2 -2
- diffusers/models/resnet_flax.py +1 -1
- diffusers/models/transformers/__init__.py +7 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
- diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
- diffusers/models/transformers/consisid_transformer_3d.py +1 -1
- diffusers/models/transformers/dit_transformer_2d.py +2 -2
- diffusers/models/transformers/dual_transformer_2d.py +1 -1
- diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
- diffusers/models/transformers/latte_transformer_3d.py +4 -5
- diffusers/models/transformers/lumina_nextdit2d.py +2 -2
- diffusers/models/transformers/pixart_transformer_2d.py +3 -3
- diffusers/models/transformers/prior_transformer.py +1 -1
- diffusers/models/transformers/sana_transformer.py +8 -3
- diffusers/models/transformers/stable_audio_transformer.py +5 -9
- diffusers/models/transformers/t5_film_transformer.py +3 -3
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +1 -1
- diffusers/models/transformers/transformer_chroma.py +641 -0
- diffusers/models/transformers/transformer_cogview3plus.py +5 -10
- diffusers/models/transformers/transformer_cogview4.py +353 -27
- diffusers/models/transformers/transformer_cosmos.py +586 -0
- diffusers/models/transformers/transformer_flux.py +376 -138
- diffusers/models/transformers/transformer_hidream_image.py +942 -0
- diffusers/models/transformers/transformer_hunyuan_video.py +12 -8
- diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
- diffusers/models/transformers/transformer_ltx.py +105 -24
- diffusers/models/transformers/transformer_lumina2.py +1 -1
- diffusers/models/transformers/transformer_mochi.py +1 -1
- diffusers/models/transformers/transformer_omnigen.py +2 -2
- diffusers/models/transformers/transformer_qwenimage.py +645 -0
- diffusers/models/transformers/transformer_sd3.py +7 -7
- diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/transformers/transformer_wan.py +316 -87
- diffusers/models/transformers/transformer_wan_vace.py +387 -0
- diffusers/models/unets/unet_1d.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +1 -1
- diffusers/models/unets/unet_2d_blocks.py +1 -1
- diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
- diffusers/models/unets/unet_2d_condition.py +4 -3
- diffusers/models/unets/unet_2d_condition_flax.py +2 -2
- diffusers/models/unets/unet_3d_blocks.py +1 -1
- diffusers/models/unets/unet_3d_condition.py +3 -3
- diffusers/models/unets/unet_i2vgen_xl.py +3 -3
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +2 -2
- diffusers/models/unets/unet_stable_cascade.py +1 -1
- diffusers/models/upsampling.py +2 -2
- diffusers/models/vae_flax.py +2 -2
- diffusers/models/vq_model.py +1 -1
- diffusers/modular_pipelines/__init__.py +83 -0
- diffusers/modular_pipelines/components_manager.py +1068 -0
- diffusers/modular_pipelines/flux/__init__.py +66 -0
- diffusers/modular_pipelines/flux/before_denoise.py +689 -0
- diffusers/modular_pipelines/flux/decoders.py +109 -0
- diffusers/modular_pipelines/flux/denoise.py +227 -0
- diffusers/modular_pipelines/flux/encoders.py +412 -0
- diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
- diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
- diffusers/modular_pipelines/modular_pipeline.py +2446 -0
- diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
- diffusers/modular_pipelines/node_utils.py +665 -0
- diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
- diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
- diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
- diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
- diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
- diffusers/modular_pipelines/wan/__init__.py +66 -0
- diffusers/modular_pipelines/wan/before_denoise.py +365 -0
- diffusers/modular_pipelines/wan/decoders.py +105 -0
- diffusers/modular_pipelines/wan/denoise.py +261 -0
- diffusers/modular_pipelines/wan/encoders.py +242 -0
- diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
- diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
- diffusers/pipelines/__init__.py +68 -6
- diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
- diffusers/pipelines/amused/pipeline_amused.py +7 -6
- diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
- diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
- diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +22 -13
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
- diffusers/pipelines/auto_pipeline.py +23 -20
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
- diffusers/pipelines/chroma/__init__.py +49 -0
- diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
- diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
- diffusers/pipelines/chroma/pipeline_output.py +21 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +17 -16
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +17 -16
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +18 -17
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +17 -16
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
- diffusers/pipelines/cogview4/pipeline_cogview4.py +23 -22
- diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
- diffusers/pipelines/consisid/consisid_utils.py +2 -2
- diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +11 -10
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +226 -107
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +12 -8
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +207 -105
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
- diffusers/pipelines/cosmos/__init__.py +54 -0
- diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
- diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
- diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
- diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
- diffusers/pipelines/cosmos/pipeline_output.py +40 -0
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
- diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
- diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
- diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
- diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
- diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
- diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
- diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +8 -8
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
- diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +4 -2
- diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
- diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
- diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
- diffusers/pipelines/flux/__init__.py +4 -0
- diffusers/pipelines/flux/modeling_flux.py +1 -1
- diffusers/pipelines/flux/pipeline_flux.py +37 -36
- diffusers/pipelines/flux/pipeline_flux_control.py +9 -9
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +31 -23
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +3 -2
- diffusers/pipelines/flux/pipeline_flux_fill.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_img2img.py +40 -7
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +12 -7
- diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
- diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +2 -2
- diffusers/pipelines/flux/pipeline_output.py +6 -4
- diffusers/pipelines/free_init_utils.py +2 -2
- diffusers/pipelines/free_noise_utils.py +3 -3
- diffusers/pipelines/hidream_image/__init__.py +47 -0
- diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
- diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
- diffusers/pipelines/hunyuan_video/__init__.py +2 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +26 -25
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
- diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
- diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
- diffusers/pipelines/kolors/text_encoder.py +3 -3
- diffusers/pipelines/kolors/tokenizer.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
- diffusers/pipelines/latte/pipeline_latte.py +12 -12
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
- diffusers/pipelines/ltx/__init__.py +4 -0
- diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +64 -18
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +117 -38
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +63 -18
- diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
- diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
- diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
- diffusers/pipelines/mochi/pipeline_mochi.py +15 -14
- diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
- diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
- diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
- diffusers/pipelines/onnx_utils.py +15 -2
- diffusers/pipelines/pag/pag_utils.py +2 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
- diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
- diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
- diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
- diffusers/pipelines/pia/pipeline_pia.py +8 -6
- diffusers/pipelines/pipeline_flax_utils.py +5 -6
- diffusers/pipelines/pipeline_loading_utils.py +113 -15
- diffusers/pipelines/pipeline_utils.py +127 -48
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +14 -12
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +31 -11
- diffusers/pipelines/qwenimage/__init__.py +55 -0
- diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
- diffusers/pipelines/sana/__init__.py +4 -0
- diffusers/pipelines/sana/pipeline_sana.py +23 -21
- diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
- diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
- diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
- diffusers/pipelines/shap_e/camera.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
- diffusers/pipelines/shap_e/renderer.py +3 -3
- diffusers/pipelines/skyreels_v2/__init__.py +59 -0
- diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
- diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
- diffusers/pipelines/stable_diffusion/__init__.py +0 -7
- diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +12 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +11 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
- diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
- diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +13 -12
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
- diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
- diffusers/pipelines/unclip/text_proj.py +2 -2
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
- diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
- diffusers/pipelines/visualcloze/__init__.py +52 -0
- diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
- diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
- diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
- diffusers/pipelines/wan/__init__.py +2 -0
- diffusers/pipelines/wan/pipeline_wan.py +91 -30
- diffusers/pipelines/wan/pipeline_wan_i2v.py +145 -45
- diffusers/pipelines/wan/pipeline_wan_vace.py +975 -0
- diffusers/pipelines/wan/pipeline_wan_video2video.py +14 -16
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
- diffusers/quantizers/__init__.py +3 -1
- diffusers/quantizers/base.py +17 -1
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
- diffusers/quantizers/bitsandbytes/utils.py +10 -7
- diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
- diffusers/quantizers/gguf/utils.py +108 -16
- diffusers/quantizers/pipe_quant_config.py +202 -0
- diffusers/quantizers/quantization_config.py +18 -16
- diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +31 -1
- diffusers/schedulers/__init__.py +3 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
- diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
- diffusers/schedulers/scheduling_consistency_models.py +1 -1
- diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
- diffusers/schedulers/scheduling_ddim.py +8 -8
- diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
- diffusers/schedulers/scheduling_ddim_flax.py +6 -6
- diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
- diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
- diffusers/schedulers/scheduling_ddpm.py +9 -9
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
- diffusers/schedulers/scheduling_deis_multistep.py +16 -9
- diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +18 -12
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
- diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +19 -13
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
- diffusers/schedulers/scheduling_edm_euler.py +20 -11
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
- diffusers/schedulers/scheduling_euler_discrete.py +3 -3
- diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
- diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
- diffusers/schedulers/scheduling_heun_discrete.py +2 -2
- diffusers/schedulers/scheduling_ipndm.py +2 -2
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
- diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
- diffusers/schedulers/scheduling_lcm.py +3 -3
- diffusers/schedulers/scheduling_lms_discrete.py +2 -2
- diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
- diffusers/schedulers/scheduling_pndm.py +4 -4
- diffusers/schedulers/scheduling_pndm_flax.py +4 -4
- diffusers/schedulers/scheduling_repaint.py +9 -9
- diffusers/schedulers/scheduling_sasolver.py +15 -15
- diffusers/schedulers/scheduling_scm.py +1 -2
- diffusers/schedulers/scheduling_sde_ve.py +1 -1
- diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
- diffusers/schedulers/scheduling_tcd.py +3 -3
- diffusers/schedulers/scheduling_unclip.py +5 -5
- diffusers/schedulers/scheduling_unipc_multistep.py +21 -12
- diffusers/schedulers/scheduling_utils.py +3 -3
- diffusers/schedulers/scheduling_utils_flax.py +2 -2
- diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
- diffusers/training_utils.py +91 -5
- diffusers/utils/__init__.py +15 -0
- diffusers/utils/accelerate_utils.py +1 -1
- diffusers/utils/constants.py +4 -0
- diffusers/utils/doc_utils.py +1 -1
- diffusers/utils/dummy_pt_objects.py +432 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +480 -0
- diffusers/utils/dynamic_modules_utils.py +85 -8
- diffusers/utils/export_utils.py +1 -1
- diffusers/utils/hub_utils.py +33 -17
- diffusers/utils/import_utils.py +151 -18
- diffusers/utils/logging.py +1 -1
- diffusers/utils/outputs.py +2 -1
- diffusers/utils/peft_utils.py +96 -10
- diffusers/utils/state_dict_utils.py +20 -3
- diffusers/utils/testing_utils.py +195 -17
- diffusers/utils/torch_utils.py +43 -5
- diffusers/video_processor.py +2 -2
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/METADATA +72 -57
- diffusers-0.35.0.dist-info/RECORD +703 -0
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/WHEEL +1 -1
- diffusers-0.33.1.dist-info/RECORD +0 -608
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
diffusers/models/attention.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -11,23 +11,504 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
|
14
|
+
|
15
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
15
16
|
|
16
17
|
import torch
|
18
|
+
import torch.nn as nn
|
17
19
|
import torch.nn.functional as F
|
18
|
-
from torch import nn
|
19
20
|
|
20
21
|
from ..utils import deprecate, logging
|
22
|
+
from ..utils.import_utils import is_torch_npu_available, is_torch_xla_available, is_xformers_available
|
21
23
|
from ..utils.torch_utils import maybe_allow_in_graph
|
22
24
|
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU
|
23
|
-
from .attention_processor import Attention, JointAttnProcessor2_0
|
25
|
+
from .attention_processor import Attention, AttentionProcessor, JointAttnProcessor2_0
|
24
26
|
from .embeddings import SinusoidalPositionalEmbedding
|
25
27
|
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
|
26
28
|
|
27
29
|
|
30
|
+
if is_xformers_available():
|
31
|
+
import xformers as xops
|
32
|
+
else:
|
33
|
+
xops = None
|
34
|
+
|
35
|
+
|
28
36
|
logger = logging.get_logger(__name__)
|
29
37
|
|
30
38
|
|
39
|
+
class AttentionMixin:
|
40
|
+
@property
|
41
|
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
42
|
+
r"""
|
43
|
+
Returns:
|
44
|
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
45
|
+
indexed by its weight name.
|
46
|
+
"""
|
47
|
+
# set recursively
|
48
|
+
processors = {}
|
49
|
+
|
50
|
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
51
|
+
if hasattr(module, "get_processor"):
|
52
|
+
processors[f"{name}.processor"] = module.get_processor()
|
53
|
+
|
54
|
+
for sub_name, child in module.named_children():
|
55
|
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
56
|
+
|
57
|
+
return processors
|
58
|
+
|
59
|
+
for name, module in self.named_children():
|
60
|
+
fn_recursive_add_processors(name, module, processors)
|
61
|
+
|
62
|
+
return processors
|
63
|
+
|
64
|
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
65
|
+
r"""
|
66
|
+
Sets the attention processor to use to compute attention.
|
67
|
+
|
68
|
+
Parameters:
|
69
|
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
70
|
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
71
|
+
for **all** `Attention` layers.
|
72
|
+
|
73
|
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
74
|
+
processor. This is strongly recommended when setting trainable attention processors.
|
75
|
+
|
76
|
+
"""
|
77
|
+
count = len(self.attn_processors.keys())
|
78
|
+
|
79
|
+
if isinstance(processor, dict) and len(processor) != count:
|
80
|
+
raise ValueError(
|
81
|
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
82
|
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
83
|
+
)
|
84
|
+
|
85
|
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
86
|
+
if hasattr(module, "set_processor"):
|
87
|
+
if not isinstance(processor, dict):
|
88
|
+
module.set_processor(processor)
|
89
|
+
else:
|
90
|
+
module.set_processor(processor.pop(f"{name}.processor"))
|
91
|
+
|
92
|
+
for sub_name, child in module.named_children():
|
93
|
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
94
|
+
|
95
|
+
for name, module in self.named_children():
|
96
|
+
fn_recursive_attn_processor(name, module, processor)
|
97
|
+
|
98
|
+
def fuse_qkv_projections(self):
|
99
|
+
"""
|
100
|
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
101
|
+
are fused. For cross-attention modules, key and value projection matrices are fused.
|
102
|
+
"""
|
103
|
+
for _, attn_processor in self.attn_processors.items():
|
104
|
+
if "Added" in str(attn_processor.__class__.__name__):
|
105
|
+
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
106
|
+
|
107
|
+
for module in self.modules():
|
108
|
+
if isinstance(module, AttentionModuleMixin):
|
109
|
+
module.fuse_projections()
|
110
|
+
|
111
|
+
def unfuse_qkv_projections(self):
|
112
|
+
"""Disables the fused QKV projection if enabled.
|
113
|
+
|
114
|
+
<Tip warning={true}>
|
115
|
+
|
116
|
+
This API is 🧪 experimental.
|
117
|
+
|
118
|
+
</Tip>
|
119
|
+
"""
|
120
|
+
for module in self.modules():
|
121
|
+
if isinstance(module, AttentionModuleMixin):
|
122
|
+
module.unfuse_projections()
|
123
|
+
|
124
|
+
|
125
|
+
class AttentionModuleMixin:
|
126
|
+
_default_processor_cls = None
|
127
|
+
_available_processors = []
|
128
|
+
fused_projections = False
|
129
|
+
|
130
|
+
def set_processor(self, processor: AttentionProcessor) -> None:
|
131
|
+
"""
|
132
|
+
Set the attention processor to use.
|
133
|
+
|
134
|
+
Args:
|
135
|
+
processor (`AttnProcessor`):
|
136
|
+
The attention processor to use.
|
137
|
+
"""
|
138
|
+
# if current processor is in `self._modules` and if passed `processor` is not, we need to
|
139
|
+
# pop `processor` from `self._modules`
|
140
|
+
if (
|
141
|
+
hasattr(self, "processor")
|
142
|
+
and isinstance(self.processor, torch.nn.Module)
|
143
|
+
and not isinstance(processor, torch.nn.Module)
|
144
|
+
):
|
145
|
+
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
|
146
|
+
self._modules.pop("processor")
|
147
|
+
|
148
|
+
self.processor = processor
|
149
|
+
|
150
|
+
def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
|
151
|
+
"""
|
152
|
+
Get the attention processor in use.
|
153
|
+
|
154
|
+
Args:
|
155
|
+
return_deprecated_lora (`bool`, *optional*, defaults to `False`):
|
156
|
+
Set to `True` to return the deprecated LoRA attention processor.
|
157
|
+
|
158
|
+
Returns:
|
159
|
+
"AttentionProcessor": The attention processor in use.
|
160
|
+
"""
|
161
|
+
if not return_deprecated_lora:
|
162
|
+
return self.processor
|
163
|
+
|
164
|
+
def set_attention_backend(self, backend: str):
|
165
|
+
from .attention_dispatch import AttentionBackendName
|
166
|
+
|
167
|
+
available_backends = {x.value for x in AttentionBackendName.__members__.values()}
|
168
|
+
if backend not in available_backends:
|
169
|
+
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
|
170
|
+
|
171
|
+
backend = AttentionBackendName(backend.lower())
|
172
|
+
self.processor._attention_backend = backend
|
173
|
+
|
174
|
+
def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
|
175
|
+
"""
|
176
|
+
Set whether to use NPU flash attention from `torch_npu` or not.
|
177
|
+
|
178
|
+
Args:
|
179
|
+
use_npu_flash_attention (`bool`): Whether to use NPU flash attention or not.
|
180
|
+
"""
|
181
|
+
|
182
|
+
if use_npu_flash_attention:
|
183
|
+
if not is_torch_npu_available():
|
184
|
+
raise ImportError("torch_npu is not available")
|
185
|
+
|
186
|
+
self.set_attention_backend("_native_npu")
|
187
|
+
|
188
|
+
def set_use_xla_flash_attention(
|
189
|
+
self,
|
190
|
+
use_xla_flash_attention: bool,
|
191
|
+
partition_spec: Optional[Tuple[Optional[str], ...]] = None,
|
192
|
+
is_flux=False,
|
193
|
+
) -> None:
|
194
|
+
"""
|
195
|
+
Set whether to use XLA flash attention from `torch_xla` or not.
|
196
|
+
|
197
|
+
Args:
|
198
|
+
use_xla_flash_attention (`bool`):
|
199
|
+
Whether to use pallas flash attention kernel from `torch_xla` or not.
|
200
|
+
partition_spec (`Tuple[]`, *optional*):
|
201
|
+
Specify the partition specification if using SPMD. Otherwise None.
|
202
|
+
is_flux (`bool`, *optional*, defaults to `False`):
|
203
|
+
Whether the model is a Flux model.
|
204
|
+
"""
|
205
|
+
if use_xla_flash_attention:
|
206
|
+
if not is_torch_xla_available():
|
207
|
+
raise ImportError("torch_xla is not available")
|
208
|
+
|
209
|
+
self.set_attention_backend("_native_xla")
|
210
|
+
|
211
|
+
def set_use_memory_efficient_attention_xformers(
|
212
|
+
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
|
213
|
+
) -> None:
|
214
|
+
"""
|
215
|
+
Set whether to use memory efficient attention from `xformers` or not.
|
216
|
+
|
217
|
+
Args:
|
218
|
+
use_memory_efficient_attention_xformers (`bool`):
|
219
|
+
Whether to use memory efficient attention from `xformers` or not.
|
220
|
+
attention_op (`Callable`, *optional*):
|
221
|
+
The attention operation to use. Defaults to `None` which uses the default attention operation from
|
222
|
+
`xformers`.
|
223
|
+
"""
|
224
|
+
if use_memory_efficient_attention_xformers:
|
225
|
+
if not is_xformers_available():
|
226
|
+
raise ModuleNotFoundError(
|
227
|
+
"Refer to https://github.com/facebookresearch/xformers for more information on how to install xformers",
|
228
|
+
name="xformers",
|
229
|
+
)
|
230
|
+
elif not torch.cuda.is_available():
|
231
|
+
raise ValueError(
|
232
|
+
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
|
233
|
+
" only available for GPU "
|
234
|
+
)
|
235
|
+
else:
|
236
|
+
try:
|
237
|
+
# Make sure we can run the memory efficient attention
|
238
|
+
if is_xformers_available():
|
239
|
+
dtype = None
|
240
|
+
if attention_op is not None:
|
241
|
+
op_fw, op_bw = attention_op
|
242
|
+
dtype, *_ = op_fw.SUPPORTED_DTYPES
|
243
|
+
q = torch.randn((1, 2, 40), device="cuda", dtype=dtype)
|
244
|
+
_ = xops.memory_efficient_attention(q, q, q)
|
245
|
+
except Exception as e:
|
246
|
+
raise e
|
247
|
+
|
248
|
+
self.set_attention_backend("xformers")
|
249
|
+
|
250
|
+
@torch.no_grad()
|
251
|
+
def fuse_projections(self):
|
252
|
+
"""
|
253
|
+
Fuse the query, key, and value projections into a single projection for efficiency.
|
254
|
+
"""
|
255
|
+
# Skip if already fused
|
256
|
+
if getattr(self, "fused_projections", False):
|
257
|
+
return
|
258
|
+
|
259
|
+
device = self.to_q.weight.data.device
|
260
|
+
dtype = self.to_q.weight.data.dtype
|
261
|
+
|
262
|
+
if hasattr(self, "is_cross_attention") and self.is_cross_attention:
|
263
|
+
# Fuse cross-attention key-value projections
|
264
|
+
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
|
265
|
+
in_features = concatenated_weights.shape[1]
|
266
|
+
out_features = concatenated_weights.shape[0]
|
267
|
+
|
268
|
+
self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
|
269
|
+
self.to_kv.weight.copy_(concatenated_weights)
|
270
|
+
if hasattr(self, "use_bias") and self.use_bias:
|
271
|
+
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
|
272
|
+
self.to_kv.bias.copy_(concatenated_bias)
|
273
|
+
else:
|
274
|
+
# Fuse self-attention projections
|
275
|
+
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
|
276
|
+
in_features = concatenated_weights.shape[1]
|
277
|
+
out_features = concatenated_weights.shape[0]
|
278
|
+
|
279
|
+
self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
|
280
|
+
self.to_qkv.weight.copy_(concatenated_weights)
|
281
|
+
if hasattr(self, "use_bias") and self.use_bias:
|
282
|
+
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
|
283
|
+
self.to_qkv.bias.copy_(concatenated_bias)
|
284
|
+
|
285
|
+
# Handle added projections for models like SD3, Flux, etc.
|
286
|
+
if (
|
287
|
+
getattr(self, "add_q_proj", None) is not None
|
288
|
+
and getattr(self, "add_k_proj", None) is not None
|
289
|
+
and getattr(self, "add_v_proj", None) is not None
|
290
|
+
):
|
291
|
+
concatenated_weights = torch.cat(
|
292
|
+
[self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
|
293
|
+
)
|
294
|
+
in_features = concatenated_weights.shape[1]
|
295
|
+
out_features = concatenated_weights.shape[0]
|
296
|
+
|
297
|
+
self.to_added_qkv = nn.Linear(
|
298
|
+
in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
|
299
|
+
)
|
300
|
+
self.to_added_qkv.weight.copy_(concatenated_weights)
|
301
|
+
if self.added_proj_bias:
|
302
|
+
concatenated_bias = torch.cat(
|
303
|
+
[self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
|
304
|
+
)
|
305
|
+
self.to_added_qkv.bias.copy_(concatenated_bias)
|
306
|
+
|
307
|
+
self.fused_projections = True
|
308
|
+
|
309
|
+
@torch.no_grad()
|
310
|
+
def unfuse_projections(self):
|
311
|
+
"""
|
312
|
+
Unfuse the query, key, and value projections back to separate projections.
|
313
|
+
"""
|
314
|
+
# Skip if not fused
|
315
|
+
if not getattr(self, "fused_projections", False):
|
316
|
+
return
|
317
|
+
|
318
|
+
# Remove fused projection layers
|
319
|
+
if hasattr(self, "to_qkv"):
|
320
|
+
delattr(self, "to_qkv")
|
321
|
+
|
322
|
+
if hasattr(self, "to_kv"):
|
323
|
+
delattr(self, "to_kv")
|
324
|
+
|
325
|
+
if hasattr(self, "to_added_qkv"):
|
326
|
+
delattr(self, "to_added_qkv")
|
327
|
+
|
328
|
+
self.fused_projections = False
|
329
|
+
|
330
|
+
def set_attention_slice(self, slice_size: int) -> None:
|
331
|
+
"""
|
332
|
+
Set the slice size for attention computation.
|
333
|
+
|
334
|
+
Args:
|
335
|
+
slice_size (`int`):
|
336
|
+
The slice size for attention computation.
|
337
|
+
"""
|
338
|
+
if hasattr(self, "sliceable_head_dim") and slice_size is not None and slice_size > self.sliceable_head_dim:
|
339
|
+
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
|
340
|
+
|
341
|
+
processor = None
|
342
|
+
|
343
|
+
# Try to get a compatible processor for sliced attention
|
344
|
+
if slice_size is not None:
|
345
|
+
processor = self._get_compatible_processor("sliced")
|
346
|
+
|
347
|
+
# If no processor was found or slice_size is None, use default processor
|
348
|
+
if processor is None:
|
349
|
+
processor = self.default_processor_cls()
|
350
|
+
|
351
|
+
self.set_processor(processor)
|
352
|
+
|
353
|
+
def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
|
354
|
+
"""
|
355
|
+
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`.
|
356
|
+
|
357
|
+
Args:
|
358
|
+
tensor (`torch.Tensor`): The tensor to reshape.
|
359
|
+
|
360
|
+
Returns:
|
361
|
+
`torch.Tensor`: The reshaped tensor.
|
362
|
+
"""
|
363
|
+
head_size = self.heads
|
364
|
+
batch_size, seq_len, dim = tensor.shape
|
365
|
+
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
366
|
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
367
|
+
return tensor
|
368
|
+
|
369
|
+
def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
|
370
|
+
"""
|
371
|
+
Reshape the tensor for multi-head attention processing.
|
372
|
+
|
373
|
+
Args:
|
374
|
+
tensor (`torch.Tensor`): The tensor to reshape.
|
375
|
+
out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor.
|
376
|
+
|
377
|
+
Returns:
|
378
|
+
`torch.Tensor`: The reshaped tensor.
|
379
|
+
"""
|
380
|
+
head_size = self.heads
|
381
|
+
if tensor.ndim == 3:
|
382
|
+
batch_size, seq_len, dim = tensor.shape
|
383
|
+
extra_dim = 1
|
384
|
+
else:
|
385
|
+
batch_size, extra_dim, seq_len, dim = tensor.shape
|
386
|
+
tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size)
|
387
|
+
tensor = tensor.permute(0, 2, 1, 3)
|
388
|
+
|
389
|
+
if out_dim == 3:
|
390
|
+
tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size)
|
391
|
+
|
392
|
+
return tensor
|
393
|
+
|
394
|
+
def get_attention_scores(
|
395
|
+
self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
|
396
|
+
) -> torch.Tensor:
|
397
|
+
"""
|
398
|
+
Compute the attention scores.
|
399
|
+
|
400
|
+
Args:
|
401
|
+
query (`torch.Tensor`): The query tensor.
|
402
|
+
key (`torch.Tensor`): The key tensor.
|
403
|
+
attention_mask (`torch.Tensor`, *optional*): The attention mask to use.
|
404
|
+
|
405
|
+
Returns:
|
406
|
+
`torch.Tensor`: The attention probabilities/scores.
|
407
|
+
"""
|
408
|
+
dtype = query.dtype
|
409
|
+
if self.upcast_attention:
|
410
|
+
query = query.float()
|
411
|
+
key = key.float()
|
412
|
+
|
413
|
+
if attention_mask is None:
|
414
|
+
baddbmm_input = torch.empty(
|
415
|
+
query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
|
416
|
+
)
|
417
|
+
beta = 0
|
418
|
+
else:
|
419
|
+
baddbmm_input = attention_mask
|
420
|
+
beta = 1
|
421
|
+
|
422
|
+
attention_scores = torch.baddbmm(
|
423
|
+
baddbmm_input,
|
424
|
+
query,
|
425
|
+
key.transpose(-1, -2),
|
426
|
+
beta=beta,
|
427
|
+
alpha=self.scale,
|
428
|
+
)
|
429
|
+
del baddbmm_input
|
430
|
+
|
431
|
+
if self.upcast_softmax:
|
432
|
+
attention_scores = attention_scores.float()
|
433
|
+
|
434
|
+
attention_probs = attention_scores.softmax(dim=-1)
|
435
|
+
del attention_scores
|
436
|
+
|
437
|
+
attention_probs = attention_probs.to(dtype)
|
438
|
+
|
439
|
+
return attention_probs
|
440
|
+
|
441
|
+
def prepare_attention_mask(
|
442
|
+
self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
|
443
|
+
) -> torch.Tensor:
|
444
|
+
"""
|
445
|
+
Prepare the attention mask for the attention computation.
|
446
|
+
|
447
|
+
Args:
|
448
|
+
attention_mask (`torch.Tensor`): The attention mask to prepare.
|
449
|
+
target_length (`int`): The target length of the attention mask.
|
450
|
+
batch_size (`int`): The batch size for repeating the attention mask.
|
451
|
+
out_dim (`int`, *optional*, defaults to `3`): Output dimension.
|
452
|
+
|
453
|
+
Returns:
|
454
|
+
`torch.Tensor`: The prepared attention mask.
|
455
|
+
"""
|
456
|
+
head_size = self.heads
|
457
|
+
if attention_mask is None:
|
458
|
+
return attention_mask
|
459
|
+
|
460
|
+
current_length: int = attention_mask.shape[-1]
|
461
|
+
if current_length != target_length:
|
462
|
+
if attention_mask.device.type == "mps":
|
463
|
+
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
|
464
|
+
# Instead, we can manually construct the padding tensor.
|
465
|
+
padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
|
466
|
+
padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
|
467
|
+
attention_mask = torch.cat([attention_mask, padding], dim=2)
|
468
|
+
else:
|
469
|
+
# TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
|
470
|
+
# we want to instead pad by (0, remaining_length), where remaining_length is:
|
471
|
+
# remaining_length: int = target_length - current_length
|
472
|
+
# TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
|
473
|
+
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
474
|
+
|
475
|
+
if out_dim == 3:
|
476
|
+
if attention_mask.shape[0] < batch_size * head_size:
|
477
|
+
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
|
478
|
+
elif out_dim == 4:
|
479
|
+
attention_mask = attention_mask.unsqueeze(1)
|
480
|
+
attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
|
481
|
+
|
482
|
+
return attention_mask
|
483
|
+
|
484
|
+
def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
|
485
|
+
"""
|
486
|
+
Normalize the encoder hidden states.
|
487
|
+
|
488
|
+
Args:
|
489
|
+
encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
|
490
|
+
|
491
|
+
Returns:
|
492
|
+
`torch.Tensor`: The normalized encoder hidden states.
|
493
|
+
"""
|
494
|
+
assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
|
495
|
+
if isinstance(self.norm_cross, nn.LayerNorm):
|
496
|
+
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
497
|
+
elif isinstance(self.norm_cross, nn.GroupNorm):
|
498
|
+
# Group norm norms along the channels dimension and expects
|
499
|
+
# input to be in the shape of (N, C, *). In this case, we want
|
500
|
+
# to norm along the hidden dimension, so we need to move
|
501
|
+
# (batch_size, sequence_length, hidden_size) ->
|
502
|
+
# (batch_size, hidden_size, sequence_length)
|
503
|
+
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
504
|
+
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
505
|
+
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
506
|
+
else:
|
507
|
+
assert False
|
508
|
+
|
509
|
+
return encoder_hidden_states
|
510
|
+
|
511
|
+
|
31
512
|
def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
|
32
513
|
# "feed_forward_chunk_size" can be used to save memory
|
33
514
|
if hidden_states.shape[chunk_dim] % chunk_size != 0:
|
@@ -90,7 +571,7 @@ class JointTransformerBlock(nn.Module):
|
|
90
571
|
r"""
|
91
572
|
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
|
92
573
|
|
93
|
-
Reference: https://
|
574
|
+
Reference: https://huggingface.co/papers/2403.03206
|
94
575
|
|
95
576
|
Parameters:
|
96
577
|
dim (`int`): The number of channels in the input and output.
|
@@ -892,8 +1373,8 @@ class FreeNoiseTransformerBlock(nn.Module):
|
|
892
1373
|
The number of frames to be skipped before starting to process a new batch of `context_length` frames.
|
893
1374
|
weighting_scheme (`str`, defaults to `"pyramid"`):
|
894
1375
|
The weighting scheme to use for weighting averaging of processed latent frames. As described in the
|
895
|
-
Equation 9. of the [FreeNoise](https://
|
896
|
-
used.
|
1376
|
+
Equation 9. of the [FreeNoise](https://huggingface.co/papers/2310.15169) paper, "pyramid" is the default
|
1377
|
+
setting used.
|
897
1378
|
"""
|
898
1379
|
|
899
1380
|
def __init__(
|