diffusers 0.33.1__py3-none-any.whl → 0.35.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +145 -1
- diffusers/callbacks.py +35 -0
- diffusers/commands/__init__.py +1 -1
- diffusers/commands/custom_blocks.py +134 -0
- diffusers/commands/diffusers_cli.py +3 -1
- diffusers/commands/env.py +1 -1
- diffusers/commands/fp16_safetensors.py +2 -2
- diffusers/configuration_utils.py +11 -2
- diffusers/dependency_versions_check.py +1 -1
- diffusers/dependency_versions_table.py +3 -3
- diffusers/experimental/rl/value_guided_sampling.py +1 -1
- diffusers/guiders/__init__.py +41 -0
- diffusers/guiders/adaptive_projected_guidance.py +188 -0
- diffusers/guiders/auto_guidance.py +190 -0
- diffusers/guiders/classifier_free_guidance.py +141 -0
- diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
- diffusers/guiders/frequency_decoupled_guidance.py +327 -0
- diffusers/guiders/guider_utils.py +309 -0
- diffusers/guiders/perturbed_attention_guidance.py +271 -0
- diffusers/guiders/skip_layer_guidance.py +262 -0
- diffusers/guiders/smoothed_energy_guidance.py +251 -0
- diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
- diffusers/hooks/__init__.py +17 -0
- diffusers/hooks/_common.py +56 -0
- diffusers/hooks/_helpers.py +293 -0
- diffusers/hooks/faster_cache.py +9 -8
- diffusers/hooks/first_block_cache.py +259 -0
- diffusers/hooks/group_offloading.py +332 -227
- diffusers/hooks/hooks.py +58 -3
- diffusers/hooks/layer_skip.py +263 -0
- diffusers/hooks/layerwise_casting.py +5 -10
- diffusers/hooks/pyramid_attention_broadcast.py +15 -12
- diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
- diffusers/hooks/utils.py +43 -0
- diffusers/image_processor.py +7 -2
- diffusers/loaders/__init__.py +10 -0
- diffusers/loaders/ip_adapter.py +260 -18
- diffusers/loaders/lora_base.py +261 -127
- diffusers/loaders/lora_conversion_utils.py +657 -35
- diffusers/loaders/lora_pipeline.py +2778 -1246
- diffusers/loaders/peft.py +78 -112
- diffusers/loaders/single_file.py +2 -2
- diffusers/loaders/single_file_model.py +64 -15
- diffusers/loaders/single_file_utils.py +395 -7
- diffusers/loaders/textual_inversion.py +3 -2
- diffusers/loaders/transformer_flux.py +10 -11
- diffusers/loaders/transformer_sd3.py +8 -3
- diffusers/loaders/unet.py +24 -21
- diffusers/loaders/unet_loader_utils.py +6 -3
- diffusers/loaders/utils.py +1 -1
- diffusers/models/__init__.py +23 -1
- diffusers/models/activations.py +5 -5
- diffusers/models/adapter.py +2 -3
- diffusers/models/attention.py +488 -7
- diffusers/models/attention_dispatch.py +1218 -0
- diffusers/models/attention_flax.py +10 -10
- diffusers/models/attention_processor.py +113 -667
- diffusers/models/auto_model.py +49 -12
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
- diffusers/models/autoencoders/autoencoder_dc.py +17 -4
- diffusers/models/autoencoders/autoencoder_kl.py +5 -5
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
- diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1110 -0
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_wan.py +626 -62
- diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
- diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
- diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
- diffusers/models/autoencoders/vae.py +13 -2
- diffusers/models/autoencoders/vq_model.py +2 -2
- diffusers/models/cache_utils.py +32 -10
- diffusers/models/controlnet.py +1 -1
- diffusers/models/controlnet_flux.py +1 -1
- diffusers/models/controlnet_sd3.py +1 -1
- diffusers/models/controlnet_sparsectrl.py +1 -1
- diffusers/models/controlnets/__init__.py +1 -0
- diffusers/models/controlnets/controlnet.py +3 -3
- diffusers/models/controlnets/controlnet_flax.py +1 -1
- diffusers/models/controlnets/controlnet_flux.py +21 -20
- diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
- diffusers/models/controlnets/controlnet_sana.py +290 -0
- diffusers/models/controlnets/controlnet_sd3.py +1 -1
- diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
- diffusers/models/controlnets/controlnet_union.py +5 -5
- diffusers/models/controlnets/controlnet_xs.py +7 -7
- diffusers/models/controlnets/multicontrolnet.py +4 -5
- diffusers/models/controlnets/multicontrolnet_union.py +5 -6
- diffusers/models/downsampling.py +2 -2
- diffusers/models/embeddings.py +36 -46
- diffusers/models/embeddings_flax.py +2 -2
- diffusers/models/lora.py +3 -3
- diffusers/models/model_loading_utils.py +233 -1
- diffusers/models/modeling_flax_utils.py +1 -2
- diffusers/models/modeling_utils.py +203 -108
- diffusers/models/normalization.py +4 -4
- diffusers/models/resnet.py +2 -2
- diffusers/models/resnet_flax.py +1 -1
- diffusers/models/transformers/__init__.py +7 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
- diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
- diffusers/models/transformers/consisid_transformer_3d.py +1 -1
- diffusers/models/transformers/dit_transformer_2d.py +2 -2
- diffusers/models/transformers/dual_transformer_2d.py +1 -1
- diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
- diffusers/models/transformers/latte_transformer_3d.py +4 -5
- diffusers/models/transformers/lumina_nextdit2d.py +2 -2
- diffusers/models/transformers/pixart_transformer_2d.py +3 -3
- diffusers/models/transformers/prior_transformer.py +1 -1
- diffusers/models/transformers/sana_transformer.py +8 -3
- diffusers/models/transformers/stable_audio_transformer.py +5 -9
- diffusers/models/transformers/t5_film_transformer.py +3 -3
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +1 -1
- diffusers/models/transformers/transformer_chroma.py +641 -0
- diffusers/models/transformers/transformer_cogview3plus.py +5 -10
- diffusers/models/transformers/transformer_cogview4.py +353 -27
- diffusers/models/transformers/transformer_cosmos.py +586 -0
- diffusers/models/transformers/transformer_flux.py +376 -138
- diffusers/models/transformers/transformer_hidream_image.py +942 -0
- diffusers/models/transformers/transformer_hunyuan_video.py +12 -8
- diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
- diffusers/models/transformers/transformer_ltx.py +105 -24
- diffusers/models/transformers/transformer_lumina2.py +1 -1
- diffusers/models/transformers/transformer_mochi.py +1 -1
- diffusers/models/transformers/transformer_omnigen.py +2 -2
- diffusers/models/transformers/transformer_qwenimage.py +645 -0
- diffusers/models/transformers/transformer_sd3.py +7 -7
- diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/transformers/transformer_wan.py +316 -87
- diffusers/models/transformers/transformer_wan_vace.py +387 -0
- diffusers/models/unets/unet_1d.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +1 -1
- diffusers/models/unets/unet_2d_blocks.py +1 -1
- diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
- diffusers/models/unets/unet_2d_condition.py +4 -3
- diffusers/models/unets/unet_2d_condition_flax.py +2 -2
- diffusers/models/unets/unet_3d_blocks.py +1 -1
- diffusers/models/unets/unet_3d_condition.py +3 -3
- diffusers/models/unets/unet_i2vgen_xl.py +3 -3
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +2 -2
- diffusers/models/unets/unet_stable_cascade.py +1 -1
- diffusers/models/upsampling.py +2 -2
- diffusers/models/vae_flax.py +2 -2
- diffusers/models/vq_model.py +1 -1
- diffusers/modular_pipelines/__init__.py +83 -0
- diffusers/modular_pipelines/components_manager.py +1068 -0
- diffusers/modular_pipelines/flux/__init__.py +66 -0
- diffusers/modular_pipelines/flux/before_denoise.py +689 -0
- diffusers/modular_pipelines/flux/decoders.py +109 -0
- diffusers/modular_pipelines/flux/denoise.py +227 -0
- diffusers/modular_pipelines/flux/encoders.py +412 -0
- diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
- diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
- diffusers/modular_pipelines/modular_pipeline.py +2446 -0
- diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
- diffusers/modular_pipelines/node_utils.py +665 -0
- diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
- diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
- diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
- diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
- diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
- diffusers/modular_pipelines/wan/__init__.py +66 -0
- diffusers/modular_pipelines/wan/before_denoise.py +365 -0
- diffusers/modular_pipelines/wan/decoders.py +105 -0
- diffusers/modular_pipelines/wan/denoise.py +261 -0
- diffusers/modular_pipelines/wan/encoders.py +242 -0
- diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
- diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
- diffusers/pipelines/__init__.py +68 -6
- diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
- diffusers/pipelines/amused/pipeline_amused.py +7 -6
- diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
- diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
- diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +22 -13
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
- diffusers/pipelines/auto_pipeline.py +23 -20
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
- diffusers/pipelines/chroma/__init__.py +49 -0
- diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
- diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
- diffusers/pipelines/chroma/pipeline_output.py +21 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +17 -16
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +17 -16
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +18 -17
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +17 -16
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
- diffusers/pipelines/cogview4/pipeline_cogview4.py +23 -22
- diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
- diffusers/pipelines/consisid/consisid_utils.py +2 -2
- diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +11 -10
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +226 -107
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +12 -8
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +207 -105
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
- diffusers/pipelines/cosmos/__init__.py +54 -0
- diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
- diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
- diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
- diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
- diffusers/pipelines/cosmos/pipeline_output.py +40 -0
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
- diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
- diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
- diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
- diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
- diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
- diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
- diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +8 -8
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
- diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +4 -2
- diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
- diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
- diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
- diffusers/pipelines/flux/__init__.py +4 -0
- diffusers/pipelines/flux/modeling_flux.py +1 -1
- diffusers/pipelines/flux/pipeline_flux.py +37 -36
- diffusers/pipelines/flux/pipeline_flux_control.py +9 -9
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +31 -23
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +3 -2
- diffusers/pipelines/flux/pipeline_flux_fill.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_img2img.py +40 -7
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +12 -7
- diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
- diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +2 -2
- diffusers/pipelines/flux/pipeline_output.py +6 -4
- diffusers/pipelines/free_init_utils.py +2 -2
- diffusers/pipelines/free_noise_utils.py +3 -3
- diffusers/pipelines/hidream_image/__init__.py +47 -0
- diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
- diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
- diffusers/pipelines/hunyuan_video/__init__.py +2 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +26 -25
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
- diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
- diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
- diffusers/pipelines/kolors/text_encoder.py +3 -3
- diffusers/pipelines/kolors/tokenizer.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
- diffusers/pipelines/latte/pipeline_latte.py +12 -12
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
- diffusers/pipelines/ltx/__init__.py +4 -0
- diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +64 -18
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +117 -38
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +63 -18
- diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
- diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
- diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
- diffusers/pipelines/mochi/pipeline_mochi.py +15 -14
- diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
- diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
- diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
- diffusers/pipelines/onnx_utils.py +15 -2
- diffusers/pipelines/pag/pag_utils.py +2 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
- diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
- diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
- diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
- diffusers/pipelines/pia/pipeline_pia.py +8 -6
- diffusers/pipelines/pipeline_flax_utils.py +5 -6
- diffusers/pipelines/pipeline_loading_utils.py +113 -15
- diffusers/pipelines/pipeline_utils.py +127 -48
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +14 -12
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +31 -11
- diffusers/pipelines/qwenimage/__init__.py +55 -0
- diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
- diffusers/pipelines/sana/__init__.py +4 -0
- diffusers/pipelines/sana/pipeline_sana.py +23 -21
- diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
- diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
- diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
- diffusers/pipelines/shap_e/camera.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
- diffusers/pipelines/shap_e/renderer.py +3 -3
- diffusers/pipelines/skyreels_v2/__init__.py +59 -0
- diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
- diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
- diffusers/pipelines/stable_diffusion/__init__.py +0 -7
- diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +12 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +11 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
- diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
- diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +13 -12
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
- diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
- diffusers/pipelines/unclip/text_proj.py +2 -2
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
- diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
- diffusers/pipelines/visualcloze/__init__.py +52 -0
- diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
- diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
- diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
- diffusers/pipelines/wan/__init__.py +2 -0
- diffusers/pipelines/wan/pipeline_wan.py +91 -30
- diffusers/pipelines/wan/pipeline_wan_i2v.py +145 -45
- diffusers/pipelines/wan/pipeline_wan_vace.py +975 -0
- diffusers/pipelines/wan/pipeline_wan_video2video.py +14 -16
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
- diffusers/quantizers/__init__.py +3 -1
- diffusers/quantizers/base.py +17 -1
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
- diffusers/quantizers/bitsandbytes/utils.py +10 -7
- diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
- diffusers/quantizers/gguf/utils.py +108 -16
- diffusers/quantizers/pipe_quant_config.py +202 -0
- diffusers/quantizers/quantization_config.py +18 -16
- diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +31 -1
- diffusers/schedulers/__init__.py +3 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
- diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
- diffusers/schedulers/scheduling_consistency_models.py +1 -1
- diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
- diffusers/schedulers/scheduling_ddim.py +8 -8
- diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
- diffusers/schedulers/scheduling_ddim_flax.py +6 -6
- diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
- diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
- diffusers/schedulers/scheduling_ddpm.py +9 -9
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
- diffusers/schedulers/scheduling_deis_multistep.py +16 -9
- diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +18 -12
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
- diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +19 -13
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
- diffusers/schedulers/scheduling_edm_euler.py +20 -11
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
- diffusers/schedulers/scheduling_euler_discrete.py +3 -3
- diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
- diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
- diffusers/schedulers/scheduling_heun_discrete.py +2 -2
- diffusers/schedulers/scheduling_ipndm.py +2 -2
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
- diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
- diffusers/schedulers/scheduling_lcm.py +3 -3
- diffusers/schedulers/scheduling_lms_discrete.py +2 -2
- diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
- diffusers/schedulers/scheduling_pndm.py +4 -4
- diffusers/schedulers/scheduling_pndm_flax.py +4 -4
- diffusers/schedulers/scheduling_repaint.py +9 -9
- diffusers/schedulers/scheduling_sasolver.py +15 -15
- diffusers/schedulers/scheduling_scm.py +1 -2
- diffusers/schedulers/scheduling_sde_ve.py +1 -1
- diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
- diffusers/schedulers/scheduling_tcd.py +3 -3
- diffusers/schedulers/scheduling_unclip.py +5 -5
- diffusers/schedulers/scheduling_unipc_multistep.py +21 -12
- diffusers/schedulers/scheduling_utils.py +3 -3
- diffusers/schedulers/scheduling_utils_flax.py +2 -2
- diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
- diffusers/training_utils.py +91 -5
- diffusers/utils/__init__.py +15 -0
- diffusers/utils/accelerate_utils.py +1 -1
- diffusers/utils/constants.py +4 -0
- diffusers/utils/doc_utils.py +1 -1
- diffusers/utils/dummy_pt_objects.py +432 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +480 -0
- diffusers/utils/dynamic_modules_utils.py +85 -8
- diffusers/utils/export_utils.py +1 -1
- diffusers/utils/hub_utils.py +33 -17
- diffusers/utils/import_utils.py +151 -18
- diffusers/utils/logging.py +1 -1
- diffusers/utils/outputs.py +2 -1
- diffusers/utils/peft_utils.py +96 -10
- diffusers/utils/state_dict_utils.py +20 -3
- diffusers/utils/testing_utils.py +195 -17
- diffusers/utils/torch_utils.py +43 -5
- diffusers/video_processor.py +2 -2
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/METADATA +72 -57
- diffusers-0.35.0.dist-info/RECORD +703 -0
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/WHEEL +1 -1
- diffusers-0.33.1.dist-info/RECORD +0 -608
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1218 @@
|
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import contextlib
|
16
|
+
import functools
|
17
|
+
import inspect
|
18
|
+
import math
|
19
|
+
from enum import Enum
|
20
|
+
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
21
|
+
|
22
|
+
import torch
|
23
|
+
|
24
|
+
from ..utils import (
|
25
|
+
get_logger,
|
26
|
+
is_flash_attn_3_available,
|
27
|
+
is_flash_attn_available,
|
28
|
+
is_flash_attn_version,
|
29
|
+
is_sageattention_available,
|
30
|
+
is_sageattention_version,
|
31
|
+
is_torch_npu_available,
|
32
|
+
is_torch_version,
|
33
|
+
is_torch_xla_available,
|
34
|
+
is_torch_xla_version,
|
35
|
+
is_xformers_available,
|
36
|
+
is_xformers_version,
|
37
|
+
)
|
38
|
+
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
|
39
|
+
|
40
|
+
|
41
|
+
_REQUIRED_FLASH_VERSION = "2.6.3"
|
42
|
+
_REQUIRED_SAGE_VERSION = "2.1.1"
|
43
|
+
_REQUIRED_FLEX_VERSION = "2.5.0"
|
44
|
+
_REQUIRED_XLA_VERSION = "2.2"
|
45
|
+
_REQUIRED_XFORMERS_VERSION = "0.0.29"
|
46
|
+
|
47
|
+
_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION)
|
48
|
+
_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available()
|
49
|
+
_CAN_USE_SAGE_ATTN = is_sageattention_available() and is_sageattention_version(">=", _REQUIRED_SAGE_VERSION)
|
50
|
+
_CAN_USE_FLEX_ATTN = is_torch_version(">=", _REQUIRED_FLEX_VERSION)
|
51
|
+
_CAN_USE_NPU_ATTN = is_torch_npu_available()
|
52
|
+
_CAN_USE_XLA_ATTN = is_torch_xla_available() and is_torch_xla_version(">=", _REQUIRED_XLA_VERSION)
|
53
|
+
_CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _REQUIRED_XFORMERS_VERSION)
|
54
|
+
|
55
|
+
|
56
|
+
if _CAN_USE_FLASH_ATTN:
|
57
|
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
58
|
+
else:
|
59
|
+
flash_attn_func = None
|
60
|
+
flash_attn_varlen_func = None
|
61
|
+
|
62
|
+
|
63
|
+
if _CAN_USE_FLASH_ATTN_3:
|
64
|
+
from flash_attn_interface import flash_attn_func as flash_attn_3_func
|
65
|
+
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
|
66
|
+
else:
|
67
|
+
flash_attn_3_func = None
|
68
|
+
flash_attn_3_varlen_func = None
|
69
|
+
|
70
|
+
|
71
|
+
if _CAN_USE_SAGE_ATTN:
|
72
|
+
from sageattention import (
|
73
|
+
sageattn,
|
74
|
+
sageattn_qk_int8_pv_fp8_cuda,
|
75
|
+
sageattn_qk_int8_pv_fp8_cuda_sm90,
|
76
|
+
sageattn_qk_int8_pv_fp16_cuda,
|
77
|
+
sageattn_qk_int8_pv_fp16_triton,
|
78
|
+
sageattn_varlen,
|
79
|
+
)
|
80
|
+
else:
|
81
|
+
sageattn = None
|
82
|
+
sageattn_qk_int8_pv_fp16_cuda = None
|
83
|
+
sageattn_qk_int8_pv_fp16_triton = None
|
84
|
+
sageattn_qk_int8_pv_fp8_cuda = None
|
85
|
+
sageattn_qk_int8_pv_fp8_cuda_sm90 = None
|
86
|
+
sageattn_varlen = None
|
87
|
+
|
88
|
+
|
89
|
+
if _CAN_USE_FLEX_ATTN:
|
90
|
+
# We cannot import the flex_attention function from the package directly because it is expected (from the
|
91
|
+
# pytorch documentation) that the user may compile it. If we import directly, we will not have access to the
|
92
|
+
# compiled function.
|
93
|
+
import torch.nn.attention.flex_attention as flex_attention
|
94
|
+
|
95
|
+
|
96
|
+
if _CAN_USE_NPU_ATTN:
|
97
|
+
from torch_npu import npu_fusion_attention
|
98
|
+
else:
|
99
|
+
npu_fusion_attention = None
|
100
|
+
|
101
|
+
|
102
|
+
if _CAN_USE_XLA_ATTN:
|
103
|
+
from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention
|
104
|
+
else:
|
105
|
+
xla_flash_attention = None
|
106
|
+
|
107
|
+
|
108
|
+
if _CAN_USE_XFORMERS_ATTN:
|
109
|
+
import xformers.ops as xops
|
110
|
+
else:
|
111
|
+
xops = None
|
112
|
+
|
113
|
+
|
114
|
+
logger = get_logger(__name__) # pylint: disable=invalid-name
|
115
|
+
|
116
|
+
# TODO(aryan): Add support for the following:
|
117
|
+
# - Sage Attention++
|
118
|
+
# - block sparse, radial and other attention methods
|
119
|
+
# - CP with sage attention, flex, xformers, other missing backends
|
120
|
+
# - Add support for normal and CP training with backends that don't support it yet
|
121
|
+
|
122
|
+
_SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"]
|
123
|
+
_SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"]
|
124
|
+
_SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"]
|
125
|
+
|
126
|
+
|
127
|
+
class AttentionBackendName(str, Enum):
|
128
|
+
# EAGER = "eager"
|
129
|
+
|
130
|
+
# `flash-attn`
|
131
|
+
FLASH = "flash"
|
132
|
+
FLASH_VARLEN = "flash_varlen"
|
133
|
+
_FLASH_3 = "_flash_3"
|
134
|
+
_FLASH_VARLEN_3 = "_flash_varlen_3"
|
135
|
+
|
136
|
+
# PyTorch native
|
137
|
+
FLEX = "flex"
|
138
|
+
NATIVE = "native"
|
139
|
+
_NATIVE_CUDNN = "_native_cudnn"
|
140
|
+
_NATIVE_EFFICIENT = "_native_efficient"
|
141
|
+
_NATIVE_FLASH = "_native_flash"
|
142
|
+
_NATIVE_MATH = "_native_math"
|
143
|
+
_NATIVE_NPU = "_native_npu"
|
144
|
+
_NATIVE_XLA = "_native_xla"
|
145
|
+
|
146
|
+
# `sageattention`
|
147
|
+
SAGE = "sage"
|
148
|
+
SAGE_VARLEN = "sage_varlen"
|
149
|
+
_SAGE_QK_INT8_PV_FP8_CUDA = "_sage_qk_int8_pv_fp8_cuda"
|
150
|
+
_SAGE_QK_INT8_PV_FP8_CUDA_SM90 = "_sage_qk_int8_pv_fp8_cuda_sm90"
|
151
|
+
_SAGE_QK_INT8_PV_FP16_CUDA = "_sage_qk_int8_pv_fp16_cuda"
|
152
|
+
_SAGE_QK_INT8_PV_FP16_TRITON = "_sage_qk_int8_pv_fp16_triton"
|
153
|
+
# TODO: let's not add support for Sparge Attention now because it requires tuning per model
|
154
|
+
# We can look into supporting something "autotune"-ing in the future
|
155
|
+
# SPARGE = "sparge"
|
156
|
+
|
157
|
+
# `xformers`
|
158
|
+
XFORMERS = "xformers"
|
159
|
+
|
160
|
+
|
161
|
+
class _AttentionBackendRegistry:
|
162
|
+
_backends = {}
|
163
|
+
_constraints = {}
|
164
|
+
_supported_arg_names = {}
|
165
|
+
_active_backend = AttentionBackendName(DIFFUSERS_ATTN_BACKEND)
|
166
|
+
_checks_enabled = DIFFUSERS_ATTN_CHECKS
|
167
|
+
|
168
|
+
@classmethod
|
169
|
+
def register(cls, backend: AttentionBackendName, constraints: Optional[List[Callable]] = None):
|
170
|
+
logger.debug(f"Registering attention backend: {backend} with constraints: {constraints}")
|
171
|
+
|
172
|
+
def decorator(func):
|
173
|
+
cls._backends[backend] = func
|
174
|
+
cls._constraints[backend] = constraints or []
|
175
|
+
cls._supported_arg_names[backend] = set(inspect.signature(func).parameters.keys())
|
176
|
+
return func
|
177
|
+
|
178
|
+
return decorator
|
179
|
+
|
180
|
+
@classmethod
|
181
|
+
def get_active_backend(cls):
|
182
|
+
return cls._active_backend, cls._backends[cls._active_backend]
|
183
|
+
|
184
|
+
@classmethod
|
185
|
+
def list_backends(cls):
|
186
|
+
return list(cls._backends.keys())
|
187
|
+
|
188
|
+
|
189
|
+
@contextlib.contextmanager
|
190
|
+
def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBackendName.NATIVE):
|
191
|
+
"""
|
192
|
+
Context manager to set the active attention backend.
|
193
|
+
"""
|
194
|
+
if backend not in _AttentionBackendRegistry._backends:
|
195
|
+
raise ValueError(f"Backend {backend} is not registered.")
|
196
|
+
|
197
|
+
backend = AttentionBackendName(backend)
|
198
|
+
_check_attention_backend_requirements(backend)
|
199
|
+
|
200
|
+
old_backend = _AttentionBackendRegistry._active_backend
|
201
|
+
_AttentionBackendRegistry._active_backend = backend
|
202
|
+
|
203
|
+
try:
|
204
|
+
yield
|
205
|
+
finally:
|
206
|
+
_AttentionBackendRegistry._active_backend = old_backend
|
207
|
+
|
208
|
+
|
209
|
+
def dispatch_attention_fn(
|
210
|
+
query: torch.Tensor,
|
211
|
+
key: torch.Tensor,
|
212
|
+
value: torch.Tensor,
|
213
|
+
attn_mask: Optional[torch.Tensor] = None,
|
214
|
+
dropout_p: float = 0.0,
|
215
|
+
is_causal: bool = False,
|
216
|
+
scale: Optional[float] = None,
|
217
|
+
enable_gqa: bool = False,
|
218
|
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
219
|
+
*,
|
220
|
+
backend: Optional[AttentionBackendName] = None,
|
221
|
+
) -> torch.Tensor:
|
222
|
+
attention_kwargs = attention_kwargs or {}
|
223
|
+
|
224
|
+
if backend is None:
|
225
|
+
# If no backend is specified, we either use the default backend (set via the DIFFUSERS_ATTN_BACKEND environment
|
226
|
+
# variable), or we use a custom backend based on whether user is using the `attention_backend` context manager
|
227
|
+
backend_name, backend_fn = _AttentionBackendRegistry.get_active_backend()
|
228
|
+
else:
|
229
|
+
backend_name = AttentionBackendName(backend)
|
230
|
+
backend_fn = _AttentionBackendRegistry._backends.get(backend_name)
|
231
|
+
|
232
|
+
kwargs = {
|
233
|
+
"query": query,
|
234
|
+
"key": key,
|
235
|
+
"value": value,
|
236
|
+
"attn_mask": attn_mask,
|
237
|
+
"dropout_p": dropout_p,
|
238
|
+
"is_causal": is_causal,
|
239
|
+
"scale": scale,
|
240
|
+
**attention_kwargs,
|
241
|
+
}
|
242
|
+
if is_torch_version(">=", "2.5.0"):
|
243
|
+
kwargs["enable_gqa"] = enable_gqa
|
244
|
+
|
245
|
+
if _AttentionBackendRegistry._checks_enabled:
|
246
|
+
removed_kwargs = set(kwargs) - set(_AttentionBackendRegistry._supported_arg_names[backend_name])
|
247
|
+
if removed_kwargs:
|
248
|
+
logger.warning(f"Removing unsupported arguments for attention backend {backend_name}: {removed_kwargs}.")
|
249
|
+
for check in _AttentionBackendRegistry._constraints.get(backend_name):
|
250
|
+
check(**kwargs)
|
251
|
+
|
252
|
+
kwargs = {k: v for k, v in kwargs.items() if k in _AttentionBackendRegistry._supported_arg_names[backend_name]}
|
253
|
+
return backend_fn(**kwargs)
|
254
|
+
|
255
|
+
|
256
|
+
# ===== Checks =====
|
257
|
+
# A list of very simple functions to catch common errors quickly when debugging.
|
258
|
+
|
259
|
+
|
260
|
+
def _check_attn_mask_or_causal(attn_mask: Optional[torch.Tensor], is_causal: bool, **kwargs) -> None:
|
261
|
+
if attn_mask is not None and is_causal:
|
262
|
+
raise ValueError("`is_causal` cannot be True when `attn_mask` is not None.")
|
263
|
+
|
264
|
+
|
265
|
+
def _check_device(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
|
266
|
+
if query.device != key.device or query.device != value.device:
|
267
|
+
raise ValueError("Query, key, and value must be on the same device.")
|
268
|
+
if query.dtype != key.dtype or query.dtype != value.dtype:
|
269
|
+
raise ValueError("Query, key, and value must have the same dtype.")
|
270
|
+
|
271
|
+
|
272
|
+
def _check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
|
273
|
+
_check_device(query, key, value)
|
274
|
+
if query.device.type != "cuda":
|
275
|
+
raise ValueError("Query, key, and value must be on a CUDA device.")
|
276
|
+
|
277
|
+
|
278
|
+
def _check_device_cuda_atleast_smXY(major: int, minor: int) -> Callable:
|
279
|
+
def check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
|
280
|
+
_check_device_cuda(query, key, value)
|
281
|
+
if torch.cuda.get_device_capability(query.device) < (major, minor):
|
282
|
+
raise ValueError(
|
283
|
+
f"Query, key, and value must be on a CUDA device with compute capability >= {major}.{minor}."
|
284
|
+
)
|
285
|
+
|
286
|
+
return check_device_cuda
|
287
|
+
|
288
|
+
|
289
|
+
def _check_qkv_dtype_match(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
|
290
|
+
if query.dtype != key.dtype:
|
291
|
+
raise ValueError("Query and key must have the same dtype.")
|
292
|
+
if query.dtype != value.dtype:
|
293
|
+
raise ValueError("Query and value must have the same dtype.")
|
294
|
+
|
295
|
+
|
296
|
+
def _check_qkv_dtype_bf16_or_fp16(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
|
297
|
+
_check_qkv_dtype_match(query, key, value)
|
298
|
+
if query.dtype not in (torch.bfloat16, torch.float16):
|
299
|
+
raise ValueError("Query, key, and value must be either bfloat16 or float16.")
|
300
|
+
|
301
|
+
|
302
|
+
def _check_shape(
|
303
|
+
query: torch.Tensor,
|
304
|
+
key: torch.Tensor,
|
305
|
+
value: torch.Tensor,
|
306
|
+
attn_mask: Optional[torch.Tensor] = None,
|
307
|
+
**kwargs,
|
308
|
+
) -> None:
|
309
|
+
if query.shape[-1] != key.shape[-1]:
|
310
|
+
raise ValueError("Query and key must have the same last dimension.")
|
311
|
+
if query.shape[-2] != value.shape[-2]:
|
312
|
+
raise ValueError("Query and value must have the same second to last dimension.")
|
313
|
+
if attn_mask is not None and attn_mask.shape[-1] != key.shape[-2]:
|
314
|
+
raise ValueError("Attention mask must match the key's second to last dimension.")
|
315
|
+
|
316
|
+
|
317
|
+
# ===== Helper functions =====
|
318
|
+
|
319
|
+
|
320
|
+
def _check_attention_backend_requirements(backend: AttentionBackendName) -> None:
|
321
|
+
if backend in [AttentionBackendName.FLASH, AttentionBackendName.FLASH_VARLEN]:
|
322
|
+
if not _CAN_USE_FLASH_ATTN:
|
323
|
+
raise RuntimeError(
|
324
|
+
f"Flash Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `flash-attn>={_REQUIRED_FLASH_VERSION}`."
|
325
|
+
)
|
326
|
+
|
327
|
+
elif backend in [AttentionBackendName._FLASH_3, AttentionBackendName._FLASH_VARLEN_3]:
|
328
|
+
if not _CAN_USE_FLASH_ATTN_3:
|
329
|
+
raise RuntimeError(
|
330
|
+
f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source."
|
331
|
+
)
|
332
|
+
|
333
|
+
elif backend in [
|
334
|
+
AttentionBackendName.SAGE,
|
335
|
+
AttentionBackendName.SAGE_VARLEN,
|
336
|
+
AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA,
|
337
|
+
AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90,
|
338
|
+
AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA,
|
339
|
+
AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON,
|
340
|
+
]:
|
341
|
+
if not _CAN_USE_SAGE_ATTN:
|
342
|
+
raise RuntimeError(
|
343
|
+
f"Sage Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `sageattention>={_REQUIRED_SAGE_VERSION}`."
|
344
|
+
)
|
345
|
+
|
346
|
+
elif backend == AttentionBackendName.FLEX:
|
347
|
+
if not _CAN_USE_FLEX_ATTN:
|
348
|
+
raise RuntimeError(
|
349
|
+
f"Flex Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch>=2.5.0`."
|
350
|
+
)
|
351
|
+
|
352
|
+
elif backend == AttentionBackendName._NATIVE_NPU:
|
353
|
+
if not _CAN_USE_NPU_ATTN:
|
354
|
+
raise RuntimeError(
|
355
|
+
f"NPU Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch_npu`."
|
356
|
+
)
|
357
|
+
|
358
|
+
elif backend == AttentionBackendName._NATIVE_XLA:
|
359
|
+
if not _CAN_USE_XLA_ATTN:
|
360
|
+
raise RuntimeError(
|
361
|
+
f"XLA Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch_xla>={_REQUIRED_XLA_VERSION}`."
|
362
|
+
)
|
363
|
+
|
364
|
+
elif backend == AttentionBackendName.XFORMERS:
|
365
|
+
if not _CAN_USE_XFORMERS_ATTN:
|
366
|
+
raise RuntimeError(
|
367
|
+
f"Xformers Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `xformers>={_REQUIRED_XFORMERS_VERSION}`."
|
368
|
+
)
|
369
|
+
|
370
|
+
|
371
|
+
@functools.lru_cache(maxsize=128)
|
372
|
+
def _prepare_for_flash_attn_or_sage_varlen_without_mask(
|
373
|
+
batch_size: int,
|
374
|
+
seq_len_q: int,
|
375
|
+
seq_len_kv: int,
|
376
|
+
device: Optional[torch.device] = None,
|
377
|
+
):
|
378
|
+
seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device)
|
379
|
+
seqlens_k = torch.full((batch_size,), seq_len_kv, dtype=torch.int32, device=device)
|
380
|
+
cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
|
381
|
+
cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
|
382
|
+
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
|
383
|
+
cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0)
|
384
|
+
max_seqlen_q = seqlens_q.max().item()
|
385
|
+
max_seqlen_k = seqlens_k.max().item()
|
386
|
+
return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
|
387
|
+
|
388
|
+
|
389
|
+
def _prepare_for_flash_attn_or_sage_varlen_with_mask(
|
390
|
+
batch_size: int,
|
391
|
+
seq_len_q: int,
|
392
|
+
attn_mask: torch.Tensor,
|
393
|
+
device: Optional[torch.device] = None,
|
394
|
+
):
|
395
|
+
seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device)
|
396
|
+
seqlens_k = attn_mask.sum(dim=1, dtype=torch.int32)
|
397
|
+
cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
|
398
|
+
cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
|
399
|
+
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
|
400
|
+
cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0)
|
401
|
+
max_seqlen_q = seqlens_q.max().item()
|
402
|
+
max_seqlen_k = seqlens_k.max().item()
|
403
|
+
return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
|
404
|
+
|
405
|
+
|
406
|
+
def _prepare_for_flash_attn_or_sage_varlen(
|
407
|
+
batch_size: int,
|
408
|
+
seq_len_q: int,
|
409
|
+
seq_len_kv: int,
|
410
|
+
attn_mask: Optional[torch.Tensor] = None,
|
411
|
+
device: Optional[torch.device] = None,
|
412
|
+
) -> None:
|
413
|
+
if attn_mask is None:
|
414
|
+
return _prepare_for_flash_attn_or_sage_varlen_without_mask(batch_size, seq_len_q, seq_len_kv, device)
|
415
|
+
return _prepare_for_flash_attn_or_sage_varlen_with_mask(batch_size, seq_len_q, attn_mask, device)
|
416
|
+
|
417
|
+
|
418
|
+
def _normalize_attn_mask(attn_mask: torch.Tensor, batch_size: int, seq_len_k: int) -> torch.Tensor:
|
419
|
+
"""
|
420
|
+
Normalize an attention mask to shape [batch_size, seq_len_k] (bool) suitable for inferring seqlens_[q|k] in
|
421
|
+
FlashAttention/Sage varlen.
|
422
|
+
|
423
|
+
Supports 1D to 4D shapes and common broadcasting patterns.
|
424
|
+
"""
|
425
|
+
if attn_mask.dtype != torch.bool:
|
426
|
+
raise ValueError(f"Attention mask must be of type bool, got {attn_mask.dtype}.")
|
427
|
+
|
428
|
+
if attn_mask.ndim == 1:
|
429
|
+
# [seq_len_k] -> broadcast across batch
|
430
|
+
attn_mask = attn_mask.unsqueeze(0).expand(batch_size, seq_len_k)
|
431
|
+
|
432
|
+
elif attn_mask.ndim == 2:
|
433
|
+
# [batch_size, seq_len_k]. Maybe broadcast across batch
|
434
|
+
if attn_mask.size(0) not in [1, batch_size]:
|
435
|
+
raise ValueError(
|
436
|
+
f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 2D attention mask."
|
437
|
+
)
|
438
|
+
attn_mask = attn_mask.expand(batch_size, seq_len_k)
|
439
|
+
|
440
|
+
elif attn_mask.ndim == 3:
|
441
|
+
# [batch_size, seq_len_q, seq_len_k] -> reduce over query dimension
|
442
|
+
# We do this reduction because we know that arbitrary QK masks is not supported in Flash/Sage varlen.
|
443
|
+
if attn_mask.size(0) not in [1, batch_size]:
|
444
|
+
raise ValueError(
|
445
|
+
f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 3D attention mask."
|
446
|
+
)
|
447
|
+
attn_mask = attn_mask.any(dim=1)
|
448
|
+
attn_mask = attn_mask.expand(batch_size, seq_len_k)
|
449
|
+
|
450
|
+
elif attn_mask.ndim == 4:
|
451
|
+
# [batch_size, num_heads, seq_len_q, seq_len_k] or broadcastable versions
|
452
|
+
if attn_mask.size(0) not in [1, batch_size]:
|
453
|
+
raise ValueError(
|
454
|
+
f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 4D attention mask."
|
455
|
+
)
|
456
|
+
attn_mask = attn_mask.expand(batch_size, -1, -1, seq_len_k) # [B, H, Q, K]
|
457
|
+
attn_mask = attn_mask.any(dim=(1, 2)) # [B, K]
|
458
|
+
|
459
|
+
else:
|
460
|
+
raise ValueError(f"Unsupported attention mask shape: {attn_mask.shape}")
|
461
|
+
|
462
|
+
if attn_mask.shape != (batch_size, seq_len_k):
|
463
|
+
raise ValueError(
|
464
|
+
f"Normalized attention mask shape mismatch: got {attn_mask.shape}, expected ({batch_size}, {seq_len_k})"
|
465
|
+
)
|
466
|
+
|
467
|
+
return attn_mask
|
468
|
+
|
469
|
+
|
470
|
+
def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
|
471
|
+
return q_idx >= kv_idx
|
472
|
+
|
473
|
+
|
474
|
+
# ===== torch op registrations =====
|
475
|
+
# Registrations are required for fullgraph tracing compatibility
|
476
|
+
|
477
|
+
|
478
|
+
# TODO: library.custom_op and register_fake probably need version guards?
|
479
|
+
# TODO: this is only required because the beta release FA3 does not have it. There is a PR adding
|
480
|
+
# this but it was never merged: https://github.com/Dao-AILab/flash-attention/pull/1590
|
481
|
+
@torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
|
482
|
+
def _wrapped_flash_attn_3_original(
|
483
|
+
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
|
484
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
485
|
+
out, lse = flash_attn_3_func(query, key, value)
|
486
|
+
lse = lse.permute(0, 2, 1)
|
487
|
+
return out, lse
|
488
|
+
|
489
|
+
|
490
|
+
@torch.library.register_fake("flash_attn_3::_flash_attn_forward")
|
491
|
+
def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
492
|
+
batch_size, seq_len, num_heads, head_dim = query.shape
|
493
|
+
lse_shape = (batch_size, seq_len, num_heads)
|
494
|
+
return torch.empty_like(query), query.new_empty(lse_shape)
|
495
|
+
|
496
|
+
|
497
|
+
# ===== Attention backends =====
|
498
|
+
|
499
|
+
|
500
|
+
@_AttentionBackendRegistry.register(
|
501
|
+
AttentionBackendName.FLASH,
|
502
|
+
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
503
|
+
)
|
504
|
+
def _flash_attention(
|
505
|
+
query: torch.Tensor,
|
506
|
+
key: torch.Tensor,
|
507
|
+
value: torch.Tensor,
|
508
|
+
dropout_p: float = 0.0,
|
509
|
+
scale: Optional[float] = None,
|
510
|
+
is_causal: bool = False,
|
511
|
+
window_size: Tuple[int, int] = (-1, -1),
|
512
|
+
softcap: float = 0.0,
|
513
|
+
alibi_slopes: Optional[torch.Tensor] = None,
|
514
|
+
deterministic: bool = False,
|
515
|
+
return_attn_probs: bool = False,
|
516
|
+
) -> torch.Tensor:
|
517
|
+
out = flash_attn_func(
|
518
|
+
q=query,
|
519
|
+
k=key,
|
520
|
+
v=value,
|
521
|
+
dropout_p=dropout_p,
|
522
|
+
softmax_scale=scale,
|
523
|
+
causal=is_causal,
|
524
|
+
window_size=window_size,
|
525
|
+
softcap=softcap,
|
526
|
+
alibi_slopes=alibi_slopes,
|
527
|
+
deterministic=deterministic,
|
528
|
+
return_attn_probs=return_attn_probs,
|
529
|
+
)
|
530
|
+
return out
|
531
|
+
|
532
|
+
|
533
|
+
@_AttentionBackendRegistry.register(
|
534
|
+
AttentionBackendName.FLASH_VARLEN,
|
535
|
+
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
536
|
+
)
|
537
|
+
def _flash_varlen_attention(
|
538
|
+
query: torch.Tensor,
|
539
|
+
key: torch.Tensor,
|
540
|
+
value: torch.Tensor,
|
541
|
+
cu_seqlens_q: Optional[torch.Tensor] = None,
|
542
|
+
cu_seqlens_k: Optional[torch.Tensor] = None,
|
543
|
+
max_seqlen_q: Optional[int] = None,
|
544
|
+
max_seqlen_k: Optional[int] = None,
|
545
|
+
dropout_p: float = 0.0,
|
546
|
+
scale: Optional[float] = None,
|
547
|
+
is_causal: bool = False,
|
548
|
+
window_size: Tuple[int, int] = (-1, -1),
|
549
|
+
softcap: float = 0.0,
|
550
|
+
alibi_slopes: Optional[torch.Tensor] = None,
|
551
|
+
deterministic: bool = False,
|
552
|
+
return_attn_probs: bool = False,
|
553
|
+
attn_mask: Optional[torch.Tensor] = None,
|
554
|
+
) -> torch.Tensor:
|
555
|
+
batch_size, seq_len_q, _, _ = query.shape
|
556
|
+
_, seq_len_kv, _, _ = key.shape
|
557
|
+
|
558
|
+
if attn_mask is not None:
|
559
|
+
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
|
560
|
+
|
561
|
+
if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
|
562
|
+
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
|
563
|
+
_prepare_for_flash_attn_or_sage_varlen(
|
564
|
+
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
|
565
|
+
)
|
566
|
+
)
|
567
|
+
else:
|
568
|
+
seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
|
569
|
+
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
|
570
|
+
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
|
571
|
+
|
572
|
+
key_valid, value_valid = [], []
|
573
|
+
for b in range(batch_size):
|
574
|
+
valid_len = seqlens_k[b]
|
575
|
+
key_valid.append(key[b, :valid_len])
|
576
|
+
value_valid.append(value[b, :valid_len])
|
577
|
+
|
578
|
+
query_packed = query.flatten(0, 1)
|
579
|
+
key_packed = torch.cat(key_valid, dim=0)
|
580
|
+
value_packed = torch.cat(value_valid, dim=0)
|
581
|
+
|
582
|
+
out = flash_attn_varlen_func(
|
583
|
+
q=query_packed,
|
584
|
+
k=key_packed,
|
585
|
+
v=value_packed,
|
586
|
+
cu_seqlens_q=cu_seqlens_q,
|
587
|
+
cu_seqlens_k=cu_seqlens_k,
|
588
|
+
max_seqlen_q=max_seqlen_q,
|
589
|
+
max_seqlen_k=max_seqlen_k,
|
590
|
+
dropout_p=dropout_p,
|
591
|
+
softmax_scale=scale,
|
592
|
+
causal=is_causal,
|
593
|
+
window_size=window_size,
|
594
|
+
softcap=softcap,
|
595
|
+
alibi_slopes=alibi_slopes,
|
596
|
+
deterministic=deterministic,
|
597
|
+
return_attn_probs=return_attn_probs,
|
598
|
+
)
|
599
|
+
out = out.unflatten(0, (batch_size, -1))
|
600
|
+
|
601
|
+
return out
|
602
|
+
|
603
|
+
|
604
|
+
@_AttentionBackendRegistry.register(
|
605
|
+
AttentionBackendName._FLASH_3,
|
606
|
+
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
607
|
+
)
|
608
|
+
def _flash_attention_3(
|
609
|
+
query: torch.Tensor,
|
610
|
+
key: torch.Tensor,
|
611
|
+
value: torch.Tensor,
|
612
|
+
scale: Optional[float] = None,
|
613
|
+
is_causal: bool = False,
|
614
|
+
window_size: Tuple[int, int] = (-1, -1),
|
615
|
+
softcap: float = 0.0,
|
616
|
+
deterministic: bool = False,
|
617
|
+
return_attn_probs: bool = False,
|
618
|
+
) -> torch.Tensor:
|
619
|
+
out, lse, *_ = flash_attn_3_func(
|
620
|
+
q=query,
|
621
|
+
k=key,
|
622
|
+
v=value,
|
623
|
+
softmax_scale=scale,
|
624
|
+
causal=is_causal,
|
625
|
+
qv=None,
|
626
|
+
q_descale=None,
|
627
|
+
k_descale=None,
|
628
|
+
v_descale=None,
|
629
|
+
window_size=window_size,
|
630
|
+
attention_chunk=0,
|
631
|
+
softcap=softcap,
|
632
|
+
num_splits=1,
|
633
|
+
pack_gqa=None,
|
634
|
+
deterministic=deterministic,
|
635
|
+
sm_margin=0,
|
636
|
+
)
|
637
|
+
return (out, lse) if return_attn_probs else out
|
638
|
+
|
639
|
+
|
640
|
+
@_AttentionBackendRegistry.register(
|
641
|
+
AttentionBackendName._FLASH_VARLEN_3,
|
642
|
+
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
643
|
+
)
|
644
|
+
def _flash_varlen_attention_3(
|
645
|
+
query: torch.Tensor,
|
646
|
+
key: torch.Tensor,
|
647
|
+
value: torch.Tensor,
|
648
|
+
cu_seqlens_q: Optional[torch.Tensor] = None,
|
649
|
+
cu_seqlens_k: Optional[torch.Tensor] = None,
|
650
|
+
max_seqlen_q: Optional[int] = None,
|
651
|
+
max_seqlen_k: Optional[int] = None,
|
652
|
+
scale: Optional[float] = None,
|
653
|
+
is_causal: bool = False,
|
654
|
+
window_size: Tuple[int, int] = (-1, -1),
|
655
|
+
softcap: float = 0.0,
|
656
|
+
deterministic: bool = False,
|
657
|
+
return_attn_probs: bool = False,
|
658
|
+
attn_mask: Optional[torch.Tensor] = None,
|
659
|
+
) -> torch.Tensor:
|
660
|
+
batch_size, seq_len_q, _, _ = query.shape
|
661
|
+
_, seq_len_kv, _, _ = key.shape
|
662
|
+
|
663
|
+
if attn_mask is not None:
|
664
|
+
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
|
665
|
+
|
666
|
+
if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
|
667
|
+
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
|
668
|
+
_prepare_for_flash_attn_or_sage_varlen(
|
669
|
+
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
|
670
|
+
)
|
671
|
+
)
|
672
|
+
else:
|
673
|
+
seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
|
674
|
+
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
|
675
|
+
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
|
676
|
+
|
677
|
+
key_valid, value_valid = [], []
|
678
|
+
for b in range(batch_size):
|
679
|
+
valid_len = seqlens_k[b]
|
680
|
+
key_valid.append(key[b, :valid_len])
|
681
|
+
value_valid.append(value[b, :valid_len])
|
682
|
+
|
683
|
+
query_packed = query.flatten(0, 1)
|
684
|
+
key_packed = torch.cat(key_valid, dim=0)
|
685
|
+
value_packed = torch.cat(value_valid, dim=0)
|
686
|
+
|
687
|
+
out, lse, *_ = flash_attn_3_varlen_func(
|
688
|
+
q=query_packed,
|
689
|
+
k=key_packed,
|
690
|
+
v=value_packed,
|
691
|
+
cu_seqlens_q=cu_seqlens_q,
|
692
|
+
cu_seqlens_k=cu_seqlens_k,
|
693
|
+
max_seqlen_q=max_seqlen_q,
|
694
|
+
max_seqlen_k=max_seqlen_k,
|
695
|
+
seqused_q=None,
|
696
|
+
seqused_k=None,
|
697
|
+
softmax_scale=scale,
|
698
|
+
causal=is_causal,
|
699
|
+
qv=None,
|
700
|
+
q_descale=None,
|
701
|
+
k_descale=None,
|
702
|
+
v_descale=None,
|
703
|
+
window_size=window_size,
|
704
|
+
softcap=softcap,
|
705
|
+
num_splits=1,
|
706
|
+
pack_gqa=None,
|
707
|
+
deterministic=deterministic,
|
708
|
+
sm_margin=0,
|
709
|
+
)
|
710
|
+
out = out.unflatten(0, (batch_size, -1))
|
711
|
+
|
712
|
+
return (out, lse) if return_attn_probs else out
|
713
|
+
|
714
|
+
|
715
|
+
@_AttentionBackendRegistry.register(
|
716
|
+
AttentionBackendName.FLEX,
|
717
|
+
constraints=[_check_attn_mask_or_causal, _check_device, _check_shape],
|
718
|
+
)
|
719
|
+
def _native_flex_attention(
|
720
|
+
query: torch.Tensor,
|
721
|
+
key: torch.Tensor,
|
722
|
+
value: torch.Tensor,
|
723
|
+
attn_mask: Optional[Union[torch.Tensor, "flex_attention.BlockMask"]] = None,
|
724
|
+
is_causal: bool = False,
|
725
|
+
scale: Optional[float] = None,
|
726
|
+
enable_gqa: bool = False,
|
727
|
+
return_lse: bool = False,
|
728
|
+
kernel_options: Optional[Dict[str, Any]] = None,
|
729
|
+
) -> torch.Tensor:
|
730
|
+
# TODO: should we LRU cache the block mask creation?
|
731
|
+
score_mod = None
|
732
|
+
block_mask = None
|
733
|
+
batch_size, seq_len_q, num_heads, _ = query.shape
|
734
|
+
_, seq_len_kv, _, _ = key.shape
|
735
|
+
|
736
|
+
if attn_mask is None or isinstance(attn_mask, flex_attention.BlockMask):
|
737
|
+
block_mask = attn_mask
|
738
|
+
elif is_causal:
|
739
|
+
block_mask = flex_attention.create_block_mask(
|
740
|
+
_flex_attention_causal_mask_mod, batch_size, num_heads, seq_len_q, seq_len_kv, query.device
|
741
|
+
)
|
742
|
+
elif torch.is_tensor(attn_mask):
|
743
|
+
if attn_mask.ndim == 2:
|
744
|
+
attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1)
|
745
|
+
|
746
|
+
attn_mask = attn_mask.expand(batch_size, num_heads, seq_len_q, seq_len_kv)
|
747
|
+
|
748
|
+
if attn_mask.dtype == torch.bool:
|
749
|
+
# TODO: this probably does not work but verify!
|
750
|
+
def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
|
751
|
+
return attn_mask[batch_idx, head_idx, q_idx, kv_idx]
|
752
|
+
|
753
|
+
block_mask = flex_attention.create_block_mask(
|
754
|
+
mask_mod, batch_size, None, seq_len_q, seq_len_kv, query.device
|
755
|
+
)
|
756
|
+
else:
|
757
|
+
|
758
|
+
def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
|
759
|
+
return score + attn_mask[batch_idx, head_idx, q_idx, kv_idx]
|
760
|
+
else:
|
761
|
+
raise ValueError("Attention mask must be either None, a BlockMask, or a 2D/4D tensor.")
|
762
|
+
|
763
|
+
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
764
|
+
out = flex_attention.flex_attention(
|
765
|
+
query=query,
|
766
|
+
key=key,
|
767
|
+
value=value,
|
768
|
+
score_mod=score_mod,
|
769
|
+
block_mask=block_mask,
|
770
|
+
scale=scale,
|
771
|
+
enable_gqa=enable_gqa,
|
772
|
+
return_lse=return_lse,
|
773
|
+
kernel_options=kernel_options,
|
774
|
+
)
|
775
|
+
out = out.permute(0, 2, 1, 3)
|
776
|
+
return out
|
777
|
+
|
778
|
+
|
779
|
+
@_AttentionBackendRegistry.register(
|
780
|
+
AttentionBackendName.NATIVE,
|
781
|
+
constraints=[_check_device, _check_shape],
|
782
|
+
)
|
783
|
+
def _native_attention(
|
784
|
+
query: torch.Tensor,
|
785
|
+
key: torch.Tensor,
|
786
|
+
value: torch.Tensor,
|
787
|
+
attn_mask: Optional[torch.Tensor] = None,
|
788
|
+
dropout_p: float = 0.0,
|
789
|
+
is_causal: bool = False,
|
790
|
+
scale: Optional[float] = None,
|
791
|
+
enable_gqa: bool = False,
|
792
|
+
) -> torch.Tensor:
|
793
|
+
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
794
|
+
out = torch.nn.functional.scaled_dot_product_attention(
|
795
|
+
query=query,
|
796
|
+
key=key,
|
797
|
+
value=value,
|
798
|
+
attn_mask=attn_mask,
|
799
|
+
dropout_p=dropout_p,
|
800
|
+
is_causal=is_causal,
|
801
|
+
scale=scale,
|
802
|
+
enable_gqa=enable_gqa,
|
803
|
+
)
|
804
|
+
out = out.permute(0, 2, 1, 3)
|
805
|
+
return out
|
806
|
+
|
807
|
+
|
808
|
+
@_AttentionBackendRegistry.register(
|
809
|
+
AttentionBackendName._NATIVE_CUDNN,
|
810
|
+
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
811
|
+
)
|
812
|
+
def _native_cudnn_attention(
|
813
|
+
query: torch.Tensor,
|
814
|
+
key: torch.Tensor,
|
815
|
+
value: torch.Tensor,
|
816
|
+
attn_mask: Optional[torch.Tensor] = None,
|
817
|
+
dropout_p: float = 0.0,
|
818
|
+
is_causal: bool = False,
|
819
|
+
scale: Optional[float] = None,
|
820
|
+
enable_gqa: bool = False,
|
821
|
+
) -> torch.Tensor:
|
822
|
+
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
823
|
+
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.CUDNN_ATTENTION):
|
824
|
+
out = torch.nn.functional.scaled_dot_product_attention(
|
825
|
+
query=query,
|
826
|
+
key=key,
|
827
|
+
value=value,
|
828
|
+
attn_mask=attn_mask,
|
829
|
+
dropout_p=dropout_p,
|
830
|
+
is_causal=is_causal,
|
831
|
+
scale=scale,
|
832
|
+
enable_gqa=enable_gqa,
|
833
|
+
)
|
834
|
+
out = out.permute(0, 2, 1, 3)
|
835
|
+
return out
|
836
|
+
|
837
|
+
|
838
|
+
@_AttentionBackendRegistry.register(
|
839
|
+
AttentionBackendName._NATIVE_EFFICIENT,
|
840
|
+
constraints=[_check_device, _check_shape],
|
841
|
+
)
|
842
|
+
def _native_efficient_attention(
|
843
|
+
query: torch.Tensor,
|
844
|
+
key: torch.Tensor,
|
845
|
+
value: torch.Tensor,
|
846
|
+
attn_mask: Optional[torch.Tensor] = None,
|
847
|
+
dropout_p: float = 0.0,
|
848
|
+
is_causal: bool = False,
|
849
|
+
scale: Optional[float] = None,
|
850
|
+
enable_gqa: bool = False,
|
851
|
+
) -> torch.Tensor:
|
852
|
+
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
853
|
+
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION):
|
854
|
+
out = torch.nn.functional.scaled_dot_product_attention(
|
855
|
+
query=query,
|
856
|
+
key=key,
|
857
|
+
value=value,
|
858
|
+
attn_mask=attn_mask,
|
859
|
+
dropout_p=dropout_p,
|
860
|
+
is_causal=is_causal,
|
861
|
+
scale=scale,
|
862
|
+
enable_gqa=enable_gqa,
|
863
|
+
)
|
864
|
+
out = out.permute(0, 2, 1, 3)
|
865
|
+
return out
|
866
|
+
|
867
|
+
|
868
|
+
@_AttentionBackendRegistry.register(
|
869
|
+
AttentionBackendName._NATIVE_FLASH,
|
870
|
+
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
871
|
+
)
|
872
|
+
def _native_flash_attention(
|
873
|
+
query: torch.Tensor,
|
874
|
+
key: torch.Tensor,
|
875
|
+
value: torch.Tensor,
|
876
|
+
dropout_p: float = 0.0,
|
877
|
+
is_causal: bool = False,
|
878
|
+
scale: Optional[float] = None,
|
879
|
+
enable_gqa: bool = False,
|
880
|
+
) -> torch.Tensor:
|
881
|
+
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
882
|
+
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
|
883
|
+
out = torch.nn.functional.scaled_dot_product_attention(
|
884
|
+
query=query,
|
885
|
+
key=key,
|
886
|
+
value=value,
|
887
|
+
attn_mask=None, # not supported
|
888
|
+
dropout_p=dropout_p,
|
889
|
+
is_causal=is_causal,
|
890
|
+
scale=scale,
|
891
|
+
enable_gqa=enable_gqa,
|
892
|
+
)
|
893
|
+
out = out.permute(0, 2, 1, 3)
|
894
|
+
return out
|
895
|
+
|
896
|
+
|
897
|
+
@_AttentionBackendRegistry.register(
|
898
|
+
AttentionBackendName._NATIVE_MATH,
|
899
|
+
constraints=[_check_device, _check_shape],
|
900
|
+
)
|
901
|
+
def _native_math_attention(
|
902
|
+
query: torch.Tensor,
|
903
|
+
key: torch.Tensor,
|
904
|
+
value: torch.Tensor,
|
905
|
+
attn_mask: Optional[torch.Tensor] = None,
|
906
|
+
dropout_p: float = 0.0,
|
907
|
+
is_causal: bool = False,
|
908
|
+
scale: Optional[float] = None,
|
909
|
+
enable_gqa: bool = False,
|
910
|
+
) -> torch.Tensor:
|
911
|
+
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
912
|
+
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH):
|
913
|
+
out = torch.nn.functional.scaled_dot_product_attention(
|
914
|
+
query=query,
|
915
|
+
key=key,
|
916
|
+
value=value,
|
917
|
+
attn_mask=attn_mask,
|
918
|
+
dropout_p=dropout_p,
|
919
|
+
is_causal=is_causal,
|
920
|
+
scale=scale,
|
921
|
+
enable_gqa=enable_gqa,
|
922
|
+
)
|
923
|
+
out = out.permute(0, 2, 1, 3)
|
924
|
+
return out
|
925
|
+
|
926
|
+
|
927
|
+
@_AttentionBackendRegistry.register(
|
928
|
+
AttentionBackendName._NATIVE_NPU,
|
929
|
+
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
930
|
+
)
|
931
|
+
def _native_npu_attention(
|
932
|
+
query: torch.Tensor,
|
933
|
+
key: torch.Tensor,
|
934
|
+
value: torch.Tensor,
|
935
|
+
dropout_p: float = 0.0,
|
936
|
+
scale: Optional[float] = None,
|
937
|
+
) -> torch.Tensor:
|
938
|
+
return npu_fusion_attention(
|
939
|
+
query,
|
940
|
+
key,
|
941
|
+
value,
|
942
|
+
query.size(2), # num_heads
|
943
|
+
input_layout="BSND",
|
944
|
+
pse=None,
|
945
|
+
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
|
946
|
+
pre_tockens=65536,
|
947
|
+
next_tockens=65536,
|
948
|
+
keep_prob=1.0 - dropout_p,
|
949
|
+
sync=False,
|
950
|
+
inner_precise=0,
|
951
|
+
)[0]
|
952
|
+
|
953
|
+
|
954
|
+
# Reference: https://github.com/pytorch/xla/blob/06c5533de6588f6b90aa1655d9850bcf733b90b4/torch_xla/experimental/custom_kernel.py#L853
|
955
|
+
@_AttentionBackendRegistry.register(
|
956
|
+
AttentionBackendName._NATIVE_XLA,
|
957
|
+
constraints=[_check_device, _check_shape],
|
958
|
+
)
|
959
|
+
def _native_xla_attention(
|
960
|
+
query: torch.Tensor,
|
961
|
+
key: torch.Tensor,
|
962
|
+
value: torch.Tensor,
|
963
|
+
is_causal: bool = False,
|
964
|
+
) -> torch.Tensor:
|
965
|
+
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
966
|
+
query = query / math.sqrt(query.shape[-1])
|
967
|
+
out = xla_flash_attention(
|
968
|
+
q=query,
|
969
|
+
k=key,
|
970
|
+
v=value,
|
971
|
+
causal=is_causal,
|
972
|
+
)
|
973
|
+
out = out.permute(0, 2, 1, 3)
|
974
|
+
return out
|
975
|
+
|
976
|
+
|
977
|
+
@_AttentionBackendRegistry.register(
|
978
|
+
AttentionBackendName.SAGE,
|
979
|
+
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
980
|
+
)
|
981
|
+
def _sage_attention(
|
982
|
+
query: torch.Tensor,
|
983
|
+
key: torch.Tensor,
|
984
|
+
value: torch.Tensor,
|
985
|
+
is_causal: bool = False,
|
986
|
+
scale: Optional[float] = None,
|
987
|
+
return_lse: bool = False,
|
988
|
+
) -> torch.Tensor:
|
989
|
+
return sageattn(
|
990
|
+
q=query,
|
991
|
+
k=key,
|
992
|
+
v=value,
|
993
|
+
tensor_layout="NHD",
|
994
|
+
is_causal=is_causal,
|
995
|
+
sm_scale=scale,
|
996
|
+
return_lse=return_lse,
|
997
|
+
)
|
998
|
+
|
999
|
+
|
1000
|
+
@_AttentionBackendRegistry.register(
|
1001
|
+
AttentionBackendName.SAGE_VARLEN,
|
1002
|
+
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
1003
|
+
)
|
1004
|
+
def _sage_varlen_attention(
|
1005
|
+
query: torch.Tensor,
|
1006
|
+
key: torch.Tensor,
|
1007
|
+
value: torch.Tensor,
|
1008
|
+
cu_seqlens_q: Optional[torch.Tensor] = None,
|
1009
|
+
cu_seqlens_k: Optional[torch.Tensor] = None,
|
1010
|
+
max_seqlen_q: Optional[int] = None,
|
1011
|
+
max_seqlen_k: Optional[int] = None,
|
1012
|
+
is_causal: bool = False,
|
1013
|
+
scale: Optional[float] = None,
|
1014
|
+
smooth_k: bool = True,
|
1015
|
+
attn_mask: Optional[torch.Tensor] = None,
|
1016
|
+
) -> torch.Tensor:
|
1017
|
+
batch_size, seq_len_q, _, _ = query.shape
|
1018
|
+
_, seq_len_kv, _, _ = key.shape
|
1019
|
+
|
1020
|
+
if attn_mask is not None:
|
1021
|
+
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
|
1022
|
+
|
1023
|
+
if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
|
1024
|
+
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
|
1025
|
+
_prepare_for_flash_attn_or_sage_varlen(
|
1026
|
+
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
|
1027
|
+
)
|
1028
|
+
)
|
1029
|
+
else:
|
1030
|
+
seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
|
1031
|
+
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
|
1032
|
+
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
|
1033
|
+
|
1034
|
+
key_valid, value_valid = [], []
|
1035
|
+
for b in range(batch_size):
|
1036
|
+
valid_len = seqlens_k[b]
|
1037
|
+
key_valid.append(key[b, :valid_len])
|
1038
|
+
value_valid.append(value[b, :valid_len])
|
1039
|
+
|
1040
|
+
query_packed = query.flatten(0, 1)
|
1041
|
+
key_packed = torch.cat(key_valid, dim=0)
|
1042
|
+
value_packed = torch.cat(value_valid, dim=0)
|
1043
|
+
|
1044
|
+
out = sageattn_varlen(
|
1045
|
+
q=query_packed,
|
1046
|
+
k=key_packed,
|
1047
|
+
v=value_packed,
|
1048
|
+
cu_seqlens_q=cu_seqlens_q,
|
1049
|
+
cu_seqlens_k=cu_seqlens_k,
|
1050
|
+
max_seqlen_q=max_seqlen_q,
|
1051
|
+
max_seqlen_k=max_seqlen_k,
|
1052
|
+
is_causal=is_causal,
|
1053
|
+
sm_scale=scale,
|
1054
|
+
smooth_k=smooth_k,
|
1055
|
+
)
|
1056
|
+
out = out.unflatten(0, (batch_size, -1))
|
1057
|
+
|
1058
|
+
return out
|
1059
|
+
|
1060
|
+
|
1061
|
+
@_AttentionBackendRegistry.register(
|
1062
|
+
AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA,
|
1063
|
+
constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape],
|
1064
|
+
)
|
1065
|
+
def _sage_qk_int8_pv_fp8_cuda_attention(
|
1066
|
+
query: torch.Tensor,
|
1067
|
+
key: torch.Tensor,
|
1068
|
+
value: torch.Tensor,
|
1069
|
+
is_causal: bool = False,
|
1070
|
+
scale: Optional[float] = None,
|
1071
|
+
qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread",
|
1072
|
+
pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32",
|
1073
|
+
smooth_k: bool = True,
|
1074
|
+
smooth_v: bool = False,
|
1075
|
+
return_lse: bool = False,
|
1076
|
+
) -> torch.Tensor:
|
1077
|
+
return sageattn_qk_int8_pv_fp8_cuda(
|
1078
|
+
q=query,
|
1079
|
+
k=key,
|
1080
|
+
v=value,
|
1081
|
+
tensor_layout="NHD",
|
1082
|
+
is_causal=is_causal,
|
1083
|
+
qk_quant_gran=qk_quant_gran,
|
1084
|
+
sm_scale=scale,
|
1085
|
+
pv_accum_dtype=pv_accum_dtype,
|
1086
|
+
smooth_k=smooth_k,
|
1087
|
+
smooth_v=smooth_v,
|
1088
|
+
return_lse=return_lse,
|
1089
|
+
)
|
1090
|
+
|
1091
|
+
|
1092
|
+
@_AttentionBackendRegistry.register(
|
1093
|
+
AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90,
|
1094
|
+
constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape],
|
1095
|
+
)
|
1096
|
+
def _sage_qk_int8_pv_fp8_cuda_sm90_attention(
|
1097
|
+
query: torch.Tensor,
|
1098
|
+
key: torch.Tensor,
|
1099
|
+
value: torch.Tensor,
|
1100
|
+
is_causal: bool = False,
|
1101
|
+
scale: Optional[float] = None,
|
1102
|
+
qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread",
|
1103
|
+
pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32",
|
1104
|
+
smooth_k: bool = True,
|
1105
|
+
return_lse: bool = False,
|
1106
|
+
) -> torch.Tensor:
|
1107
|
+
return sageattn_qk_int8_pv_fp8_cuda_sm90(
|
1108
|
+
q=query,
|
1109
|
+
k=key,
|
1110
|
+
v=value,
|
1111
|
+
tensor_layout="NHD",
|
1112
|
+
is_causal=is_causal,
|
1113
|
+
qk_quant_gran=qk_quant_gran,
|
1114
|
+
sm_scale=scale,
|
1115
|
+
pv_accum_dtype=pv_accum_dtype,
|
1116
|
+
smooth_k=smooth_k,
|
1117
|
+
return_lse=return_lse,
|
1118
|
+
)
|
1119
|
+
|
1120
|
+
|
1121
|
+
@_AttentionBackendRegistry.register(
|
1122
|
+
AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA,
|
1123
|
+
constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape],
|
1124
|
+
)
|
1125
|
+
def _sage_qk_int8_pv_fp16_cuda_attention(
|
1126
|
+
query: torch.Tensor,
|
1127
|
+
key: torch.Tensor,
|
1128
|
+
value: torch.Tensor,
|
1129
|
+
is_causal: bool = False,
|
1130
|
+
scale: Optional[float] = None,
|
1131
|
+
qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread",
|
1132
|
+
pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32",
|
1133
|
+
smooth_k: bool = True,
|
1134
|
+
smooth_v: bool = False,
|
1135
|
+
return_lse: bool = False,
|
1136
|
+
) -> torch.Tensor:
|
1137
|
+
return sageattn_qk_int8_pv_fp16_cuda(
|
1138
|
+
q=query,
|
1139
|
+
k=key,
|
1140
|
+
v=value,
|
1141
|
+
tensor_layout="NHD",
|
1142
|
+
is_causal=is_causal,
|
1143
|
+
qk_quant_gran=qk_quant_gran,
|
1144
|
+
sm_scale=scale,
|
1145
|
+
pv_accum_dtype=pv_accum_dtype,
|
1146
|
+
smooth_k=smooth_k,
|
1147
|
+
smooth_v=smooth_v,
|
1148
|
+
return_lse=return_lse,
|
1149
|
+
)
|
1150
|
+
|
1151
|
+
|
1152
|
+
@_AttentionBackendRegistry.register(
|
1153
|
+
AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON,
|
1154
|
+
constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape],
|
1155
|
+
)
|
1156
|
+
def _sage_qk_int8_pv_fp16_triton_attention(
|
1157
|
+
query: torch.Tensor,
|
1158
|
+
key: torch.Tensor,
|
1159
|
+
value: torch.Tensor,
|
1160
|
+
is_causal: bool = False,
|
1161
|
+
scale: Optional[float] = None,
|
1162
|
+
quantization_backend: _SAGE_ATTENTION_QUANTIZATION_BACKEND = "triton",
|
1163
|
+
smooth_k: bool = True,
|
1164
|
+
return_lse: bool = False,
|
1165
|
+
) -> torch.Tensor:
|
1166
|
+
return sageattn_qk_int8_pv_fp16_triton(
|
1167
|
+
q=query,
|
1168
|
+
k=key,
|
1169
|
+
v=value,
|
1170
|
+
tensor_layout="NHD",
|
1171
|
+
quantization_backend=quantization_backend,
|
1172
|
+
is_causal=is_causal,
|
1173
|
+
sm_scale=scale,
|
1174
|
+
smooth_k=smooth_k,
|
1175
|
+
return_lse=return_lse,
|
1176
|
+
)
|
1177
|
+
|
1178
|
+
|
1179
|
+
@_AttentionBackendRegistry.register(
|
1180
|
+
AttentionBackendName.XFORMERS,
|
1181
|
+
constraints=[_check_attn_mask_or_causal, _check_device, _check_shape],
|
1182
|
+
)
|
1183
|
+
def _xformers_attention(
|
1184
|
+
query: torch.Tensor,
|
1185
|
+
key: torch.Tensor,
|
1186
|
+
value: torch.Tensor,
|
1187
|
+
attn_mask: Optional[torch.Tensor] = None,
|
1188
|
+
dropout_p: float = 0.0,
|
1189
|
+
is_causal: bool = False,
|
1190
|
+
scale: Optional[float] = None,
|
1191
|
+
enable_gqa: bool = False,
|
1192
|
+
) -> torch.Tensor:
|
1193
|
+
batch_size, seq_len_q, num_heads_q, _ = query.shape
|
1194
|
+
_, seq_len_kv, num_heads_kv, _ = key.shape
|
1195
|
+
|
1196
|
+
if is_causal:
|
1197
|
+
attn_mask = xops.LowerTriangularMask()
|
1198
|
+
elif attn_mask is not None:
|
1199
|
+
if attn_mask.ndim == 2:
|
1200
|
+
attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1)
|
1201
|
+
elif attn_mask.ndim != 4:
|
1202
|
+
raise ValueError("Only 2D and 4D attention masks are supported for xformers attention.")
|
1203
|
+
attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query)
|
1204
|
+
|
1205
|
+
if enable_gqa:
|
1206
|
+
if num_heads_q % num_heads_kv != 0:
|
1207
|
+
raise ValueError("Number of heads in query must be divisible by number of heads in key/value.")
|
1208
|
+
num_heads_per_group = num_heads_q // num_heads_kv
|
1209
|
+
query = query.unflatten(2, (num_heads_kv, -1))
|
1210
|
+
key = key.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1)
|
1211
|
+
value = value.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1)
|
1212
|
+
|
1213
|
+
out = xops.memory_efficient_attention(query, key, value, attn_mask, dropout_p, scale)
|
1214
|
+
|
1215
|
+
if enable_gqa:
|
1216
|
+
out = out.flatten(2, 3)
|
1217
|
+
|
1218
|
+
return out
|