diffusers 0.34.0__py3-none-any.whl → 0.35.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +98 -1
- diffusers/callbacks.py +35 -0
- diffusers/commands/custom_blocks.py +134 -0
- diffusers/commands/diffusers_cli.py +2 -0
- diffusers/commands/fp16_safetensors.py +1 -1
- diffusers/configuration_utils.py +11 -2
- diffusers/dependency_versions_table.py +3 -3
- diffusers/guiders/__init__.py +41 -0
- diffusers/guiders/adaptive_projected_guidance.py +188 -0
- diffusers/guiders/auto_guidance.py +190 -0
- diffusers/guiders/classifier_free_guidance.py +141 -0
- diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
- diffusers/guiders/frequency_decoupled_guidance.py +327 -0
- diffusers/guiders/guider_utils.py +309 -0
- diffusers/guiders/perturbed_attention_guidance.py +271 -0
- diffusers/guiders/skip_layer_guidance.py +262 -0
- diffusers/guiders/smoothed_energy_guidance.py +251 -0
- diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
- diffusers/hooks/__init__.py +17 -0
- diffusers/hooks/_common.py +56 -0
- diffusers/hooks/_helpers.py +293 -0
- diffusers/hooks/faster_cache.py +7 -6
- diffusers/hooks/first_block_cache.py +259 -0
- diffusers/hooks/group_offloading.py +292 -286
- diffusers/hooks/hooks.py +56 -1
- diffusers/hooks/layer_skip.py +263 -0
- diffusers/hooks/layerwise_casting.py +2 -7
- diffusers/hooks/pyramid_attention_broadcast.py +14 -11
- diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
- diffusers/hooks/utils.py +43 -0
- diffusers/loaders/__init__.py +6 -0
- diffusers/loaders/ip_adapter.py +255 -4
- diffusers/loaders/lora_base.py +63 -30
- diffusers/loaders/lora_conversion_utils.py +434 -53
- diffusers/loaders/lora_pipeline.py +834 -37
- diffusers/loaders/peft.py +28 -5
- diffusers/loaders/single_file_model.py +44 -11
- diffusers/loaders/single_file_utils.py +170 -2
- diffusers/loaders/transformer_flux.py +9 -10
- diffusers/loaders/transformer_sd3.py +6 -1
- diffusers/loaders/unet.py +22 -5
- diffusers/loaders/unet_loader_utils.py +5 -2
- diffusers/models/__init__.py +8 -0
- diffusers/models/attention.py +484 -3
- diffusers/models/attention_dispatch.py +1218 -0
- diffusers/models/attention_processor.py +105 -663
- diffusers/models/auto_model.py +2 -2
- diffusers/models/autoencoders/__init__.py +1 -0
- diffusers/models/autoencoders/autoencoder_dc.py +14 -1
- diffusers/models/autoencoders/autoencoder_kl.py +1 -1
- diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -1
- diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
- diffusers/models/autoencoders/autoencoder_kl_wan.py +370 -40
- diffusers/models/cache_utils.py +31 -9
- diffusers/models/controlnets/controlnet_flux.py +5 -5
- diffusers/models/controlnets/controlnet_union.py +4 -4
- diffusers/models/embeddings.py +26 -34
- diffusers/models/model_loading_utils.py +233 -1
- diffusers/models/modeling_flax_utils.py +1 -2
- diffusers/models/modeling_utils.py +159 -94
- diffusers/models/transformers/__init__.py +2 -0
- diffusers/models/transformers/transformer_chroma.py +16 -117
- diffusers/models/transformers/transformer_cogview4.py +36 -2
- diffusers/models/transformers/transformer_cosmos.py +11 -4
- diffusers/models/transformers/transformer_flux.py +372 -132
- diffusers/models/transformers/transformer_hunyuan_video.py +6 -0
- diffusers/models/transformers/transformer_ltx.py +104 -23
- diffusers/models/transformers/transformer_qwenimage.py +645 -0
- diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
- diffusers/models/transformers/transformer_wan.py +298 -85
- diffusers/models/transformers/transformer_wan_vace.py +15 -21
- diffusers/models/unets/unet_2d_condition.py +2 -1
- diffusers/modular_pipelines/__init__.py +83 -0
- diffusers/modular_pipelines/components_manager.py +1068 -0
- diffusers/modular_pipelines/flux/__init__.py +66 -0
- diffusers/modular_pipelines/flux/before_denoise.py +689 -0
- diffusers/modular_pipelines/flux/decoders.py +109 -0
- diffusers/modular_pipelines/flux/denoise.py +227 -0
- diffusers/modular_pipelines/flux/encoders.py +412 -0
- diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
- diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
- diffusers/modular_pipelines/modular_pipeline.py +2446 -0
- diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
- diffusers/modular_pipelines/node_utils.py +665 -0
- diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
- diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
- diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
- diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
- diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
- diffusers/modular_pipelines/wan/__init__.py +66 -0
- diffusers/modular_pipelines/wan/before_denoise.py +365 -0
- diffusers/modular_pipelines/wan/decoders.py +105 -0
- diffusers/modular_pipelines/wan/denoise.py +261 -0
- diffusers/modular_pipelines/wan/encoders.py +242 -0
- diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
- diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
- diffusers/pipelines/__init__.py +31 -0
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +2 -3
- diffusers/pipelines/auto_pipeline.py +17 -13
- diffusers/pipelines/chroma/pipeline_chroma.py +5 -5
- diffusers/pipelines/chroma/pipeline_chroma_img2img.py +5 -5
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +9 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +9 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +10 -9
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +9 -8
- diffusers/pipelines/cogview4/pipeline_cogview4.py +16 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +3 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +212 -93
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +7 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +194 -92
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +3 -1
- diffusers/pipelines/flux/__init__.py +4 -0
- diffusers/pipelines/flux/pipeline_flux.py +34 -26
- diffusers/pipelines/flux/pipeline_flux_control.py +8 -8
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_fill.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_img2img.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
- diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
- diffusers/pipelines/flux/pipeline_output.py +6 -4
- diffusers/pipelines/hidream_image/pipeline_hidream_image.py +5 -5
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +25 -24
- diffusers/pipelines/ltx/pipeline_ltx.py +13 -12
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +10 -9
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +13 -12
- diffusers/pipelines/mochi/pipeline_mochi.py +9 -8
- diffusers/pipelines/pipeline_flax_utils.py +2 -2
- diffusers/pipelines/pipeline_loading_utils.py +24 -2
- diffusers/pipelines/pipeline_utils.py +22 -15
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +3 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +20 -0
- diffusers/pipelines/qwenimage/__init__.py +55 -0
- diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
- diffusers/pipelines/sana/pipeline_sana_sprint.py +5 -5
- diffusers/pipelines/skyreels_v2/__init__.py +59 -0
- diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +6 -5
- diffusers/pipelines/wan/pipeline_wan.py +78 -20
- diffusers/pipelines/wan/pipeline_wan_i2v.py +112 -32
- diffusers/pipelines/wan/pipeline_wan_vace.py +1 -2
- diffusers/quantizers/__init__.py +1 -177
- diffusers/quantizers/base.py +11 -0
- diffusers/quantizers/gguf/utils.py +92 -3
- diffusers/quantizers/pipe_quant_config.py +202 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +26 -0
- diffusers/schedulers/scheduling_deis_multistep.py +8 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +6 -0
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +6 -0
- diffusers/schedulers/scheduling_scm.py +0 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +10 -1
- diffusers/schedulers/scheduling_utils.py +2 -2
- diffusers/schedulers/scheduling_utils_flax.py +1 -1
- diffusers/training_utils.py +78 -0
- diffusers/utils/__init__.py +10 -0
- diffusers/utils/constants.py +4 -0
- diffusers/utils/dummy_pt_objects.py +312 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +255 -0
- diffusers/utils/dynamic_modules_utils.py +84 -25
- diffusers/utils/hub_utils.py +33 -17
- diffusers/utils/import_utils.py +70 -0
- diffusers/utils/peft_utils.py +11 -8
- diffusers/utils/testing_utils.py +136 -10
- diffusers/utils/torch_utils.py +18 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/METADATA +6 -6
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/RECORD +191 -127
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/WHEEL +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1218 @@
|
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import contextlib
|
16
|
+
import functools
|
17
|
+
import inspect
|
18
|
+
import math
|
19
|
+
from enum import Enum
|
20
|
+
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
21
|
+
|
22
|
+
import torch
|
23
|
+
|
24
|
+
from ..utils import (
|
25
|
+
get_logger,
|
26
|
+
is_flash_attn_3_available,
|
27
|
+
is_flash_attn_available,
|
28
|
+
is_flash_attn_version,
|
29
|
+
is_sageattention_available,
|
30
|
+
is_sageattention_version,
|
31
|
+
is_torch_npu_available,
|
32
|
+
is_torch_version,
|
33
|
+
is_torch_xla_available,
|
34
|
+
is_torch_xla_version,
|
35
|
+
is_xformers_available,
|
36
|
+
is_xformers_version,
|
37
|
+
)
|
38
|
+
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
|
39
|
+
|
40
|
+
|
41
|
+
_REQUIRED_FLASH_VERSION = "2.6.3"
|
42
|
+
_REQUIRED_SAGE_VERSION = "2.1.1"
|
43
|
+
_REQUIRED_FLEX_VERSION = "2.5.0"
|
44
|
+
_REQUIRED_XLA_VERSION = "2.2"
|
45
|
+
_REQUIRED_XFORMERS_VERSION = "0.0.29"
|
46
|
+
|
47
|
+
_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION)
|
48
|
+
_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available()
|
49
|
+
_CAN_USE_SAGE_ATTN = is_sageattention_available() and is_sageattention_version(">=", _REQUIRED_SAGE_VERSION)
|
50
|
+
_CAN_USE_FLEX_ATTN = is_torch_version(">=", _REQUIRED_FLEX_VERSION)
|
51
|
+
_CAN_USE_NPU_ATTN = is_torch_npu_available()
|
52
|
+
_CAN_USE_XLA_ATTN = is_torch_xla_available() and is_torch_xla_version(">=", _REQUIRED_XLA_VERSION)
|
53
|
+
_CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _REQUIRED_XFORMERS_VERSION)
|
54
|
+
|
55
|
+
|
56
|
+
if _CAN_USE_FLASH_ATTN:
|
57
|
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
58
|
+
else:
|
59
|
+
flash_attn_func = None
|
60
|
+
flash_attn_varlen_func = None
|
61
|
+
|
62
|
+
|
63
|
+
if _CAN_USE_FLASH_ATTN_3:
|
64
|
+
from flash_attn_interface import flash_attn_func as flash_attn_3_func
|
65
|
+
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
|
66
|
+
else:
|
67
|
+
flash_attn_3_func = None
|
68
|
+
flash_attn_3_varlen_func = None
|
69
|
+
|
70
|
+
|
71
|
+
if _CAN_USE_SAGE_ATTN:
|
72
|
+
from sageattention import (
|
73
|
+
sageattn,
|
74
|
+
sageattn_qk_int8_pv_fp8_cuda,
|
75
|
+
sageattn_qk_int8_pv_fp8_cuda_sm90,
|
76
|
+
sageattn_qk_int8_pv_fp16_cuda,
|
77
|
+
sageattn_qk_int8_pv_fp16_triton,
|
78
|
+
sageattn_varlen,
|
79
|
+
)
|
80
|
+
else:
|
81
|
+
sageattn = None
|
82
|
+
sageattn_qk_int8_pv_fp16_cuda = None
|
83
|
+
sageattn_qk_int8_pv_fp16_triton = None
|
84
|
+
sageattn_qk_int8_pv_fp8_cuda = None
|
85
|
+
sageattn_qk_int8_pv_fp8_cuda_sm90 = None
|
86
|
+
sageattn_varlen = None
|
87
|
+
|
88
|
+
|
89
|
+
if _CAN_USE_FLEX_ATTN:
|
90
|
+
# We cannot import the flex_attention function from the package directly because it is expected (from the
|
91
|
+
# pytorch documentation) that the user may compile it. If we import directly, we will not have access to the
|
92
|
+
# compiled function.
|
93
|
+
import torch.nn.attention.flex_attention as flex_attention
|
94
|
+
|
95
|
+
|
96
|
+
if _CAN_USE_NPU_ATTN:
|
97
|
+
from torch_npu import npu_fusion_attention
|
98
|
+
else:
|
99
|
+
npu_fusion_attention = None
|
100
|
+
|
101
|
+
|
102
|
+
if _CAN_USE_XLA_ATTN:
|
103
|
+
from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention
|
104
|
+
else:
|
105
|
+
xla_flash_attention = None
|
106
|
+
|
107
|
+
|
108
|
+
if _CAN_USE_XFORMERS_ATTN:
|
109
|
+
import xformers.ops as xops
|
110
|
+
else:
|
111
|
+
xops = None
|
112
|
+
|
113
|
+
|
114
|
+
logger = get_logger(__name__) # pylint: disable=invalid-name
|
115
|
+
|
116
|
+
# TODO(aryan): Add support for the following:
|
117
|
+
# - Sage Attention++
|
118
|
+
# - block sparse, radial and other attention methods
|
119
|
+
# - CP with sage attention, flex, xformers, other missing backends
|
120
|
+
# - Add support for normal and CP training with backends that don't support it yet
|
121
|
+
|
122
|
+
_SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"]
|
123
|
+
_SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"]
|
124
|
+
_SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"]
|
125
|
+
|
126
|
+
|
127
|
+
class AttentionBackendName(str, Enum):
|
128
|
+
# EAGER = "eager"
|
129
|
+
|
130
|
+
# `flash-attn`
|
131
|
+
FLASH = "flash"
|
132
|
+
FLASH_VARLEN = "flash_varlen"
|
133
|
+
_FLASH_3 = "_flash_3"
|
134
|
+
_FLASH_VARLEN_3 = "_flash_varlen_3"
|
135
|
+
|
136
|
+
# PyTorch native
|
137
|
+
FLEX = "flex"
|
138
|
+
NATIVE = "native"
|
139
|
+
_NATIVE_CUDNN = "_native_cudnn"
|
140
|
+
_NATIVE_EFFICIENT = "_native_efficient"
|
141
|
+
_NATIVE_FLASH = "_native_flash"
|
142
|
+
_NATIVE_MATH = "_native_math"
|
143
|
+
_NATIVE_NPU = "_native_npu"
|
144
|
+
_NATIVE_XLA = "_native_xla"
|
145
|
+
|
146
|
+
# `sageattention`
|
147
|
+
SAGE = "sage"
|
148
|
+
SAGE_VARLEN = "sage_varlen"
|
149
|
+
_SAGE_QK_INT8_PV_FP8_CUDA = "_sage_qk_int8_pv_fp8_cuda"
|
150
|
+
_SAGE_QK_INT8_PV_FP8_CUDA_SM90 = "_sage_qk_int8_pv_fp8_cuda_sm90"
|
151
|
+
_SAGE_QK_INT8_PV_FP16_CUDA = "_sage_qk_int8_pv_fp16_cuda"
|
152
|
+
_SAGE_QK_INT8_PV_FP16_TRITON = "_sage_qk_int8_pv_fp16_triton"
|
153
|
+
# TODO: let's not add support for Sparge Attention now because it requires tuning per model
|
154
|
+
# We can look into supporting something "autotune"-ing in the future
|
155
|
+
# SPARGE = "sparge"
|
156
|
+
|
157
|
+
# `xformers`
|
158
|
+
XFORMERS = "xformers"
|
159
|
+
|
160
|
+
|
161
|
+
class _AttentionBackendRegistry:
|
162
|
+
_backends = {}
|
163
|
+
_constraints = {}
|
164
|
+
_supported_arg_names = {}
|
165
|
+
_active_backend = AttentionBackendName(DIFFUSERS_ATTN_BACKEND)
|
166
|
+
_checks_enabled = DIFFUSERS_ATTN_CHECKS
|
167
|
+
|
168
|
+
@classmethod
|
169
|
+
def register(cls, backend: AttentionBackendName, constraints: Optional[List[Callable]] = None):
|
170
|
+
logger.debug(f"Registering attention backend: {backend} with constraints: {constraints}")
|
171
|
+
|
172
|
+
def decorator(func):
|
173
|
+
cls._backends[backend] = func
|
174
|
+
cls._constraints[backend] = constraints or []
|
175
|
+
cls._supported_arg_names[backend] = set(inspect.signature(func).parameters.keys())
|
176
|
+
return func
|
177
|
+
|
178
|
+
return decorator
|
179
|
+
|
180
|
+
@classmethod
|
181
|
+
def get_active_backend(cls):
|
182
|
+
return cls._active_backend, cls._backends[cls._active_backend]
|
183
|
+
|
184
|
+
@classmethod
|
185
|
+
def list_backends(cls):
|
186
|
+
return list(cls._backends.keys())
|
187
|
+
|
188
|
+
|
189
|
+
@contextlib.contextmanager
|
190
|
+
def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBackendName.NATIVE):
|
191
|
+
"""
|
192
|
+
Context manager to set the active attention backend.
|
193
|
+
"""
|
194
|
+
if backend not in _AttentionBackendRegistry._backends:
|
195
|
+
raise ValueError(f"Backend {backend} is not registered.")
|
196
|
+
|
197
|
+
backend = AttentionBackendName(backend)
|
198
|
+
_check_attention_backend_requirements(backend)
|
199
|
+
|
200
|
+
old_backend = _AttentionBackendRegistry._active_backend
|
201
|
+
_AttentionBackendRegistry._active_backend = backend
|
202
|
+
|
203
|
+
try:
|
204
|
+
yield
|
205
|
+
finally:
|
206
|
+
_AttentionBackendRegistry._active_backend = old_backend
|
207
|
+
|
208
|
+
|
209
|
+
def dispatch_attention_fn(
|
210
|
+
query: torch.Tensor,
|
211
|
+
key: torch.Tensor,
|
212
|
+
value: torch.Tensor,
|
213
|
+
attn_mask: Optional[torch.Tensor] = None,
|
214
|
+
dropout_p: float = 0.0,
|
215
|
+
is_causal: bool = False,
|
216
|
+
scale: Optional[float] = None,
|
217
|
+
enable_gqa: bool = False,
|
218
|
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
219
|
+
*,
|
220
|
+
backend: Optional[AttentionBackendName] = None,
|
221
|
+
) -> torch.Tensor:
|
222
|
+
attention_kwargs = attention_kwargs or {}
|
223
|
+
|
224
|
+
if backend is None:
|
225
|
+
# If no backend is specified, we either use the default backend (set via the DIFFUSERS_ATTN_BACKEND environment
|
226
|
+
# variable), or we use a custom backend based on whether user is using the `attention_backend` context manager
|
227
|
+
backend_name, backend_fn = _AttentionBackendRegistry.get_active_backend()
|
228
|
+
else:
|
229
|
+
backend_name = AttentionBackendName(backend)
|
230
|
+
backend_fn = _AttentionBackendRegistry._backends.get(backend_name)
|
231
|
+
|
232
|
+
kwargs = {
|
233
|
+
"query": query,
|
234
|
+
"key": key,
|
235
|
+
"value": value,
|
236
|
+
"attn_mask": attn_mask,
|
237
|
+
"dropout_p": dropout_p,
|
238
|
+
"is_causal": is_causal,
|
239
|
+
"scale": scale,
|
240
|
+
**attention_kwargs,
|
241
|
+
}
|
242
|
+
if is_torch_version(">=", "2.5.0"):
|
243
|
+
kwargs["enable_gqa"] = enable_gqa
|
244
|
+
|
245
|
+
if _AttentionBackendRegistry._checks_enabled:
|
246
|
+
removed_kwargs = set(kwargs) - set(_AttentionBackendRegistry._supported_arg_names[backend_name])
|
247
|
+
if removed_kwargs:
|
248
|
+
logger.warning(f"Removing unsupported arguments for attention backend {backend_name}: {removed_kwargs}.")
|
249
|
+
for check in _AttentionBackendRegistry._constraints.get(backend_name):
|
250
|
+
check(**kwargs)
|
251
|
+
|
252
|
+
kwargs = {k: v for k, v in kwargs.items() if k in _AttentionBackendRegistry._supported_arg_names[backend_name]}
|
253
|
+
return backend_fn(**kwargs)
|
254
|
+
|
255
|
+
|
256
|
+
# ===== Checks =====
|
257
|
+
# A list of very simple functions to catch common errors quickly when debugging.
|
258
|
+
|
259
|
+
|
260
|
+
def _check_attn_mask_or_causal(attn_mask: Optional[torch.Tensor], is_causal: bool, **kwargs) -> None:
|
261
|
+
if attn_mask is not None and is_causal:
|
262
|
+
raise ValueError("`is_causal` cannot be True when `attn_mask` is not None.")
|
263
|
+
|
264
|
+
|
265
|
+
def _check_device(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
|
266
|
+
if query.device != key.device or query.device != value.device:
|
267
|
+
raise ValueError("Query, key, and value must be on the same device.")
|
268
|
+
if query.dtype != key.dtype or query.dtype != value.dtype:
|
269
|
+
raise ValueError("Query, key, and value must have the same dtype.")
|
270
|
+
|
271
|
+
|
272
|
+
def _check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
|
273
|
+
_check_device(query, key, value)
|
274
|
+
if query.device.type != "cuda":
|
275
|
+
raise ValueError("Query, key, and value must be on a CUDA device.")
|
276
|
+
|
277
|
+
|
278
|
+
def _check_device_cuda_atleast_smXY(major: int, minor: int) -> Callable:
|
279
|
+
def check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
|
280
|
+
_check_device_cuda(query, key, value)
|
281
|
+
if torch.cuda.get_device_capability(query.device) < (major, minor):
|
282
|
+
raise ValueError(
|
283
|
+
f"Query, key, and value must be on a CUDA device with compute capability >= {major}.{minor}."
|
284
|
+
)
|
285
|
+
|
286
|
+
return check_device_cuda
|
287
|
+
|
288
|
+
|
289
|
+
def _check_qkv_dtype_match(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
|
290
|
+
if query.dtype != key.dtype:
|
291
|
+
raise ValueError("Query and key must have the same dtype.")
|
292
|
+
if query.dtype != value.dtype:
|
293
|
+
raise ValueError("Query and value must have the same dtype.")
|
294
|
+
|
295
|
+
|
296
|
+
def _check_qkv_dtype_bf16_or_fp16(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
|
297
|
+
_check_qkv_dtype_match(query, key, value)
|
298
|
+
if query.dtype not in (torch.bfloat16, torch.float16):
|
299
|
+
raise ValueError("Query, key, and value must be either bfloat16 or float16.")
|
300
|
+
|
301
|
+
|
302
|
+
def _check_shape(
|
303
|
+
query: torch.Tensor,
|
304
|
+
key: torch.Tensor,
|
305
|
+
value: torch.Tensor,
|
306
|
+
attn_mask: Optional[torch.Tensor] = None,
|
307
|
+
**kwargs,
|
308
|
+
) -> None:
|
309
|
+
if query.shape[-1] != key.shape[-1]:
|
310
|
+
raise ValueError("Query and key must have the same last dimension.")
|
311
|
+
if query.shape[-2] != value.shape[-2]:
|
312
|
+
raise ValueError("Query and value must have the same second to last dimension.")
|
313
|
+
if attn_mask is not None and attn_mask.shape[-1] != key.shape[-2]:
|
314
|
+
raise ValueError("Attention mask must match the key's second to last dimension.")
|
315
|
+
|
316
|
+
|
317
|
+
# ===== Helper functions =====
|
318
|
+
|
319
|
+
|
320
|
+
def _check_attention_backend_requirements(backend: AttentionBackendName) -> None:
|
321
|
+
if backend in [AttentionBackendName.FLASH, AttentionBackendName.FLASH_VARLEN]:
|
322
|
+
if not _CAN_USE_FLASH_ATTN:
|
323
|
+
raise RuntimeError(
|
324
|
+
f"Flash Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `flash-attn>={_REQUIRED_FLASH_VERSION}`."
|
325
|
+
)
|
326
|
+
|
327
|
+
elif backend in [AttentionBackendName._FLASH_3, AttentionBackendName._FLASH_VARLEN_3]:
|
328
|
+
if not _CAN_USE_FLASH_ATTN_3:
|
329
|
+
raise RuntimeError(
|
330
|
+
f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source."
|
331
|
+
)
|
332
|
+
|
333
|
+
elif backend in [
|
334
|
+
AttentionBackendName.SAGE,
|
335
|
+
AttentionBackendName.SAGE_VARLEN,
|
336
|
+
AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA,
|
337
|
+
AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90,
|
338
|
+
AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA,
|
339
|
+
AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON,
|
340
|
+
]:
|
341
|
+
if not _CAN_USE_SAGE_ATTN:
|
342
|
+
raise RuntimeError(
|
343
|
+
f"Sage Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `sageattention>={_REQUIRED_SAGE_VERSION}`."
|
344
|
+
)
|
345
|
+
|
346
|
+
elif backend == AttentionBackendName.FLEX:
|
347
|
+
if not _CAN_USE_FLEX_ATTN:
|
348
|
+
raise RuntimeError(
|
349
|
+
f"Flex Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch>=2.5.0`."
|
350
|
+
)
|
351
|
+
|
352
|
+
elif backend == AttentionBackendName._NATIVE_NPU:
|
353
|
+
if not _CAN_USE_NPU_ATTN:
|
354
|
+
raise RuntimeError(
|
355
|
+
f"NPU Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch_npu`."
|
356
|
+
)
|
357
|
+
|
358
|
+
elif backend == AttentionBackendName._NATIVE_XLA:
|
359
|
+
if not _CAN_USE_XLA_ATTN:
|
360
|
+
raise RuntimeError(
|
361
|
+
f"XLA Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch_xla>={_REQUIRED_XLA_VERSION}`."
|
362
|
+
)
|
363
|
+
|
364
|
+
elif backend == AttentionBackendName.XFORMERS:
|
365
|
+
if not _CAN_USE_XFORMERS_ATTN:
|
366
|
+
raise RuntimeError(
|
367
|
+
f"Xformers Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `xformers>={_REQUIRED_XFORMERS_VERSION}`."
|
368
|
+
)
|
369
|
+
|
370
|
+
|
371
|
+
@functools.lru_cache(maxsize=128)
|
372
|
+
def _prepare_for_flash_attn_or_sage_varlen_without_mask(
|
373
|
+
batch_size: int,
|
374
|
+
seq_len_q: int,
|
375
|
+
seq_len_kv: int,
|
376
|
+
device: Optional[torch.device] = None,
|
377
|
+
):
|
378
|
+
seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device)
|
379
|
+
seqlens_k = torch.full((batch_size,), seq_len_kv, dtype=torch.int32, device=device)
|
380
|
+
cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
|
381
|
+
cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
|
382
|
+
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
|
383
|
+
cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0)
|
384
|
+
max_seqlen_q = seqlens_q.max().item()
|
385
|
+
max_seqlen_k = seqlens_k.max().item()
|
386
|
+
return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
|
387
|
+
|
388
|
+
|
389
|
+
def _prepare_for_flash_attn_or_sage_varlen_with_mask(
|
390
|
+
batch_size: int,
|
391
|
+
seq_len_q: int,
|
392
|
+
attn_mask: torch.Tensor,
|
393
|
+
device: Optional[torch.device] = None,
|
394
|
+
):
|
395
|
+
seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device)
|
396
|
+
seqlens_k = attn_mask.sum(dim=1, dtype=torch.int32)
|
397
|
+
cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
|
398
|
+
cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
|
399
|
+
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
|
400
|
+
cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0)
|
401
|
+
max_seqlen_q = seqlens_q.max().item()
|
402
|
+
max_seqlen_k = seqlens_k.max().item()
|
403
|
+
return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
|
404
|
+
|
405
|
+
|
406
|
+
def _prepare_for_flash_attn_or_sage_varlen(
|
407
|
+
batch_size: int,
|
408
|
+
seq_len_q: int,
|
409
|
+
seq_len_kv: int,
|
410
|
+
attn_mask: Optional[torch.Tensor] = None,
|
411
|
+
device: Optional[torch.device] = None,
|
412
|
+
) -> None:
|
413
|
+
if attn_mask is None:
|
414
|
+
return _prepare_for_flash_attn_or_sage_varlen_without_mask(batch_size, seq_len_q, seq_len_kv, device)
|
415
|
+
return _prepare_for_flash_attn_or_sage_varlen_with_mask(batch_size, seq_len_q, attn_mask, device)
|
416
|
+
|
417
|
+
|
418
|
+
def _normalize_attn_mask(attn_mask: torch.Tensor, batch_size: int, seq_len_k: int) -> torch.Tensor:
|
419
|
+
"""
|
420
|
+
Normalize an attention mask to shape [batch_size, seq_len_k] (bool) suitable for inferring seqlens_[q|k] in
|
421
|
+
FlashAttention/Sage varlen.
|
422
|
+
|
423
|
+
Supports 1D to 4D shapes and common broadcasting patterns.
|
424
|
+
"""
|
425
|
+
if attn_mask.dtype != torch.bool:
|
426
|
+
raise ValueError(f"Attention mask must be of type bool, got {attn_mask.dtype}.")
|
427
|
+
|
428
|
+
if attn_mask.ndim == 1:
|
429
|
+
# [seq_len_k] -> broadcast across batch
|
430
|
+
attn_mask = attn_mask.unsqueeze(0).expand(batch_size, seq_len_k)
|
431
|
+
|
432
|
+
elif attn_mask.ndim == 2:
|
433
|
+
# [batch_size, seq_len_k]. Maybe broadcast across batch
|
434
|
+
if attn_mask.size(0) not in [1, batch_size]:
|
435
|
+
raise ValueError(
|
436
|
+
f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 2D attention mask."
|
437
|
+
)
|
438
|
+
attn_mask = attn_mask.expand(batch_size, seq_len_k)
|
439
|
+
|
440
|
+
elif attn_mask.ndim == 3:
|
441
|
+
# [batch_size, seq_len_q, seq_len_k] -> reduce over query dimension
|
442
|
+
# We do this reduction because we know that arbitrary QK masks is not supported in Flash/Sage varlen.
|
443
|
+
if attn_mask.size(0) not in [1, batch_size]:
|
444
|
+
raise ValueError(
|
445
|
+
f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 3D attention mask."
|
446
|
+
)
|
447
|
+
attn_mask = attn_mask.any(dim=1)
|
448
|
+
attn_mask = attn_mask.expand(batch_size, seq_len_k)
|
449
|
+
|
450
|
+
elif attn_mask.ndim == 4:
|
451
|
+
# [batch_size, num_heads, seq_len_q, seq_len_k] or broadcastable versions
|
452
|
+
if attn_mask.size(0) not in [1, batch_size]:
|
453
|
+
raise ValueError(
|
454
|
+
f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 4D attention mask."
|
455
|
+
)
|
456
|
+
attn_mask = attn_mask.expand(batch_size, -1, -1, seq_len_k) # [B, H, Q, K]
|
457
|
+
attn_mask = attn_mask.any(dim=(1, 2)) # [B, K]
|
458
|
+
|
459
|
+
else:
|
460
|
+
raise ValueError(f"Unsupported attention mask shape: {attn_mask.shape}")
|
461
|
+
|
462
|
+
if attn_mask.shape != (batch_size, seq_len_k):
|
463
|
+
raise ValueError(
|
464
|
+
f"Normalized attention mask shape mismatch: got {attn_mask.shape}, expected ({batch_size}, {seq_len_k})"
|
465
|
+
)
|
466
|
+
|
467
|
+
return attn_mask
|
468
|
+
|
469
|
+
|
470
|
+
def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
|
471
|
+
return q_idx >= kv_idx
|
472
|
+
|
473
|
+
|
474
|
+
# ===== torch op registrations =====
|
475
|
+
# Registrations are required for fullgraph tracing compatibility
|
476
|
+
|
477
|
+
|
478
|
+
# TODO: library.custom_op and register_fake probably need version guards?
|
479
|
+
# TODO: this is only required because the beta release FA3 does not have it. There is a PR adding
|
480
|
+
# this but it was never merged: https://github.com/Dao-AILab/flash-attention/pull/1590
|
481
|
+
@torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
|
482
|
+
def _wrapped_flash_attn_3_original(
|
483
|
+
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
|
484
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
485
|
+
out, lse = flash_attn_3_func(query, key, value)
|
486
|
+
lse = lse.permute(0, 2, 1)
|
487
|
+
return out, lse
|
488
|
+
|
489
|
+
|
490
|
+
@torch.library.register_fake("flash_attn_3::_flash_attn_forward")
|
491
|
+
def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
492
|
+
batch_size, seq_len, num_heads, head_dim = query.shape
|
493
|
+
lse_shape = (batch_size, seq_len, num_heads)
|
494
|
+
return torch.empty_like(query), query.new_empty(lse_shape)
|
495
|
+
|
496
|
+
|
497
|
+
# ===== Attention backends =====
|
498
|
+
|
499
|
+
|
500
|
+
@_AttentionBackendRegistry.register(
|
501
|
+
AttentionBackendName.FLASH,
|
502
|
+
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
503
|
+
)
|
504
|
+
def _flash_attention(
|
505
|
+
query: torch.Tensor,
|
506
|
+
key: torch.Tensor,
|
507
|
+
value: torch.Tensor,
|
508
|
+
dropout_p: float = 0.0,
|
509
|
+
scale: Optional[float] = None,
|
510
|
+
is_causal: bool = False,
|
511
|
+
window_size: Tuple[int, int] = (-1, -1),
|
512
|
+
softcap: float = 0.0,
|
513
|
+
alibi_slopes: Optional[torch.Tensor] = None,
|
514
|
+
deterministic: bool = False,
|
515
|
+
return_attn_probs: bool = False,
|
516
|
+
) -> torch.Tensor:
|
517
|
+
out = flash_attn_func(
|
518
|
+
q=query,
|
519
|
+
k=key,
|
520
|
+
v=value,
|
521
|
+
dropout_p=dropout_p,
|
522
|
+
softmax_scale=scale,
|
523
|
+
causal=is_causal,
|
524
|
+
window_size=window_size,
|
525
|
+
softcap=softcap,
|
526
|
+
alibi_slopes=alibi_slopes,
|
527
|
+
deterministic=deterministic,
|
528
|
+
return_attn_probs=return_attn_probs,
|
529
|
+
)
|
530
|
+
return out
|
531
|
+
|
532
|
+
|
533
|
+
@_AttentionBackendRegistry.register(
|
534
|
+
AttentionBackendName.FLASH_VARLEN,
|
535
|
+
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
536
|
+
)
|
537
|
+
def _flash_varlen_attention(
|
538
|
+
query: torch.Tensor,
|
539
|
+
key: torch.Tensor,
|
540
|
+
value: torch.Tensor,
|
541
|
+
cu_seqlens_q: Optional[torch.Tensor] = None,
|
542
|
+
cu_seqlens_k: Optional[torch.Tensor] = None,
|
543
|
+
max_seqlen_q: Optional[int] = None,
|
544
|
+
max_seqlen_k: Optional[int] = None,
|
545
|
+
dropout_p: float = 0.0,
|
546
|
+
scale: Optional[float] = None,
|
547
|
+
is_causal: bool = False,
|
548
|
+
window_size: Tuple[int, int] = (-1, -1),
|
549
|
+
softcap: float = 0.0,
|
550
|
+
alibi_slopes: Optional[torch.Tensor] = None,
|
551
|
+
deterministic: bool = False,
|
552
|
+
return_attn_probs: bool = False,
|
553
|
+
attn_mask: Optional[torch.Tensor] = None,
|
554
|
+
) -> torch.Tensor:
|
555
|
+
batch_size, seq_len_q, _, _ = query.shape
|
556
|
+
_, seq_len_kv, _, _ = key.shape
|
557
|
+
|
558
|
+
if attn_mask is not None:
|
559
|
+
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
|
560
|
+
|
561
|
+
if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
|
562
|
+
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
|
563
|
+
_prepare_for_flash_attn_or_sage_varlen(
|
564
|
+
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
|
565
|
+
)
|
566
|
+
)
|
567
|
+
else:
|
568
|
+
seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
|
569
|
+
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
|
570
|
+
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
|
571
|
+
|
572
|
+
key_valid, value_valid = [], []
|
573
|
+
for b in range(batch_size):
|
574
|
+
valid_len = seqlens_k[b]
|
575
|
+
key_valid.append(key[b, :valid_len])
|
576
|
+
value_valid.append(value[b, :valid_len])
|
577
|
+
|
578
|
+
query_packed = query.flatten(0, 1)
|
579
|
+
key_packed = torch.cat(key_valid, dim=0)
|
580
|
+
value_packed = torch.cat(value_valid, dim=0)
|
581
|
+
|
582
|
+
out = flash_attn_varlen_func(
|
583
|
+
q=query_packed,
|
584
|
+
k=key_packed,
|
585
|
+
v=value_packed,
|
586
|
+
cu_seqlens_q=cu_seqlens_q,
|
587
|
+
cu_seqlens_k=cu_seqlens_k,
|
588
|
+
max_seqlen_q=max_seqlen_q,
|
589
|
+
max_seqlen_k=max_seqlen_k,
|
590
|
+
dropout_p=dropout_p,
|
591
|
+
softmax_scale=scale,
|
592
|
+
causal=is_causal,
|
593
|
+
window_size=window_size,
|
594
|
+
softcap=softcap,
|
595
|
+
alibi_slopes=alibi_slopes,
|
596
|
+
deterministic=deterministic,
|
597
|
+
return_attn_probs=return_attn_probs,
|
598
|
+
)
|
599
|
+
out = out.unflatten(0, (batch_size, -1))
|
600
|
+
|
601
|
+
return out
|
602
|
+
|
603
|
+
|
604
|
+
@_AttentionBackendRegistry.register(
|
605
|
+
AttentionBackendName._FLASH_3,
|
606
|
+
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
607
|
+
)
|
608
|
+
def _flash_attention_3(
|
609
|
+
query: torch.Tensor,
|
610
|
+
key: torch.Tensor,
|
611
|
+
value: torch.Tensor,
|
612
|
+
scale: Optional[float] = None,
|
613
|
+
is_causal: bool = False,
|
614
|
+
window_size: Tuple[int, int] = (-1, -1),
|
615
|
+
softcap: float = 0.0,
|
616
|
+
deterministic: bool = False,
|
617
|
+
return_attn_probs: bool = False,
|
618
|
+
) -> torch.Tensor:
|
619
|
+
out, lse, *_ = flash_attn_3_func(
|
620
|
+
q=query,
|
621
|
+
k=key,
|
622
|
+
v=value,
|
623
|
+
softmax_scale=scale,
|
624
|
+
causal=is_causal,
|
625
|
+
qv=None,
|
626
|
+
q_descale=None,
|
627
|
+
k_descale=None,
|
628
|
+
v_descale=None,
|
629
|
+
window_size=window_size,
|
630
|
+
attention_chunk=0,
|
631
|
+
softcap=softcap,
|
632
|
+
num_splits=1,
|
633
|
+
pack_gqa=None,
|
634
|
+
deterministic=deterministic,
|
635
|
+
sm_margin=0,
|
636
|
+
)
|
637
|
+
return (out, lse) if return_attn_probs else out
|
638
|
+
|
639
|
+
|
640
|
+
@_AttentionBackendRegistry.register(
|
641
|
+
AttentionBackendName._FLASH_VARLEN_3,
|
642
|
+
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
643
|
+
)
|
644
|
+
def _flash_varlen_attention_3(
|
645
|
+
query: torch.Tensor,
|
646
|
+
key: torch.Tensor,
|
647
|
+
value: torch.Tensor,
|
648
|
+
cu_seqlens_q: Optional[torch.Tensor] = None,
|
649
|
+
cu_seqlens_k: Optional[torch.Tensor] = None,
|
650
|
+
max_seqlen_q: Optional[int] = None,
|
651
|
+
max_seqlen_k: Optional[int] = None,
|
652
|
+
scale: Optional[float] = None,
|
653
|
+
is_causal: bool = False,
|
654
|
+
window_size: Tuple[int, int] = (-1, -1),
|
655
|
+
softcap: float = 0.0,
|
656
|
+
deterministic: bool = False,
|
657
|
+
return_attn_probs: bool = False,
|
658
|
+
attn_mask: Optional[torch.Tensor] = None,
|
659
|
+
) -> torch.Tensor:
|
660
|
+
batch_size, seq_len_q, _, _ = query.shape
|
661
|
+
_, seq_len_kv, _, _ = key.shape
|
662
|
+
|
663
|
+
if attn_mask is not None:
|
664
|
+
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
|
665
|
+
|
666
|
+
if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
|
667
|
+
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
|
668
|
+
_prepare_for_flash_attn_or_sage_varlen(
|
669
|
+
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
|
670
|
+
)
|
671
|
+
)
|
672
|
+
else:
|
673
|
+
seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
|
674
|
+
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
|
675
|
+
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
|
676
|
+
|
677
|
+
key_valid, value_valid = [], []
|
678
|
+
for b in range(batch_size):
|
679
|
+
valid_len = seqlens_k[b]
|
680
|
+
key_valid.append(key[b, :valid_len])
|
681
|
+
value_valid.append(value[b, :valid_len])
|
682
|
+
|
683
|
+
query_packed = query.flatten(0, 1)
|
684
|
+
key_packed = torch.cat(key_valid, dim=0)
|
685
|
+
value_packed = torch.cat(value_valid, dim=0)
|
686
|
+
|
687
|
+
out, lse, *_ = flash_attn_3_varlen_func(
|
688
|
+
q=query_packed,
|
689
|
+
k=key_packed,
|
690
|
+
v=value_packed,
|
691
|
+
cu_seqlens_q=cu_seqlens_q,
|
692
|
+
cu_seqlens_k=cu_seqlens_k,
|
693
|
+
max_seqlen_q=max_seqlen_q,
|
694
|
+
max_seqlen_k=max_seqlen_k,
|
695
|
+
seqused_q=None,
|
696
|
+
seqused_k=None,
|
697
|
+
softmax_scale=scale,
|
698
|
+
causal=is_causal,
|
699
|
+
qv=None,
|
700
|
+
q_descale=None,
|
701
|
+
k_descale=None,
|
702
|
+
v_descale=None,
|
703
|
+
window_size=window_size,
|
704
|
+
softcap=softcap,
|
705
|
+
num_splits=1,
|
706
|
+
pack_gqa=None,
|
707
|
+
deterministic=deterministic,
|
708
|
+
sm_margin=0,
|
709
|
+
)
|
710
|
+
out = out.unflatten(0, (batch_size, -1))
|
711
|
+
|
712
|
+
return (out, lse) if return_attn_probs else out
|
713
|
+
|
714
|
+
|
715
|
+
@_AttentionBackendRegistry.register(
|
716
|
+
AttentionBackendName.FLEX,
|
717
|
+
constraints=[_check_attn_mask_or_causal, _check_device, _check_shape],
|
718
|
+
)
|
719
|
+
def _native_flex_attention(
|
720
|
+
query: torch.Tensor,
|
721
|
+
key: torch.Tensor,
|
722
|
+
value: torch.Tensor,
|
723
|
+
attn_mask: Optional[Union[torch.Tensor, "flex_attention.BlockMask"]] = None,
|
724
|
+
is_causal: bool = False,
|
725
|
+
scale: Optional[float] = None,
|
726
|
+
enable_gqa: bool = False,
|
727
|
+
return_lse: bool = False,
|
728
|
+
kernel_options: Optional[Dict[str, Any]] = None,
|
729
|
+
) -> torch.Tensor:
|
730
|
+
# TODO: should we LRU cache the block mask creation?
|
731
|
+
score_mod = None
|
732
|
+
block_mask = None
|
733
|
+
batch_size, seq_len_q, num_heads, _ = query.shape
|
734
|
+
_, seq_len_kv, _, _ = key.shape
|
735
|
+
|
736
|
+
if attn_mask is None or isinstance(attn_mask, flex_attention.BlockMask):
|
737
|
+
block_mask = attn_mask
|
738
|
+
elif is_causal:
|
739
|
+
block_mask = flex_attention.create_block_mask(
|
740
|
+
_flex_attention_causal_mask_mod, batch_size, num_heads, seq_len_q, seq_len_kv, query.device
|
741
|
+
)
|
742
|
+
elif torch.is_tensor(attn_mask):
|
743
|
+
if attn_mask.ndim == 2:
|
744
|
+
attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1)
|
745
|
+
|
746
|
+
attn_mask = attn_mask.expand(batch_size, num_heads, seq_len_q, seq_len_kv)
|
747
|
+
|
748
|
+
if attn_mask.dtype == torch.bool:
|
749
|
+
# TODO: this probably does not work but verify!
|
750
|
+
def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
|
751
|
+
return attn_mask[batch_idx, head_idx, q_idx, kv_idx]
|
752
|
+
|
753
|
+
block_mask = flex_attention.create_block_mask(
|
754
|
+
mask_mod, batch_size, None, seq_len_q, seq_len_kv, query.device
|
755
|
+
)
|
756
|
+
else:
|
757
|
+
|
758
|
+
def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
|
759
|
+
return score + attn_mask[batch_idx, head_idx, q_idx, kv_idx]
|
760
|
+
else:
|
761
|
+
raise ValueError("Attention mask must be either None, a BlockMask, or a 2D/4D tensor.")
|
762
|
+
|
763
|
+
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
764
|
+
out = flex_attention.flex_attention(
|
765
|
+
query=query,
|
766
|
+
key=key,
|
767
|
+
value=value,
|
768
|
+
score_mod=score_mod,
|
769
|
+
block_mask=block_mask,
|
770
|
+
scale=scale,
|
771
|
+
enable_gqa=enable_gqa,
|
772
|
+
return_lse=return_lse,
|
773
|
+
kernel_options=kernel_options,
|
774
|
+
)
|
775
|
+
out = out.permute(0, 2, 1, 3)
|
776
|
+
return out
|
777
|
+
|
778
|
+
|
779
|
+
@_AttentionBackendRegistry.register(
|
780
|
+
AttentionBackendName.NATIVE,
|
781
|
+
constraints=[_check_device, _check_shape],
|
782
|
+
)
|
783
|
+
def _native_attention(
|
784
|
+
query: torch.Tensor,
|
785
|
+
key: torch.Tensor,
|
786
|
+
value: torch.Tensor,
|
787
|
+
attn_mask: Optional[torch.Tensor] = None,
|
788
|
+
dropout_p: float = 0.0,
|
789
|
+
is_causal: bool = False,
|
790
|
+
scale: Optional[float] = None,
|
791
|
+
enable_gqa: bool = False,
|
792
|
+
) -> torch.Tensor:
|
793
|
+
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
794
|
+
out = torch.nn.functional.scaled_dot_product_attention(
|
795
|
+
query=query,
|
796
|
+
key=key,
|
797
|
+
value=value,
|
798
|
+
attn_mask=attn_mask,
|
799
|
+
dropout_p=dropout_p,
|
800
|
+
is_causal=is_causal,
|
801
|
+
scale=scale,
|
802
|
+
enable_gqa=enable_gqa,
|
803
|
+
)
|
804
|
+
out = out.permute(0, 2, 1, 3)
|
805
|
+
return out
|
806
|
+
|
807
|
+
|
808
|
+
@_AttentionBackendRegistry.register(
|
809
|
+
AttentionBackendName._NATIVE_CUDNN,
|
810
|
+
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
811
|
+
)
|
812
|
+
def _native_cudnn_attention(
|
813
|
+
query: torch.Tensor,
|
814
|
+
key: torch.Tensor,
|
815
|
+
value: torch.Tensor,
|
816
|
+
attn_mask: Optional[torch.Tensor] = None,
|
817
|
+
dropout_p: float = 0.0,
|
818
|
+
is_causal: bool = False,
|
819
|
+
scale: Optional[float] = None,
|
820
|
+
enable_gqa: bool = False,
|
821
|
+
) -> torch.Tensor:
|
822
|
+
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
823
|
+
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.CUDNN_ATTENTION):
|
824
|
+
out = torch.nn.functional.scaled_dot_product_attention(
|
825
|
+
query=query,
|
826
|
+
key=key,
|
827
|
+
value=value,
|
828
|
+
attn_mask=attn_mask,
|
829
|
+
dropout_p=dropout_p,
|
830
|
+
is_causal=is_causal,
|
831
|
+
scale=scale,
|
832
|
+
enable_gqa=enable_gqa,
|
833
|
+
)
|
834
|
+
out = out.permute(0, 2, 1, 3)
|
835
|
+
return out
|
836
|
+
|
837
|
+
|
838
|
+
@_AttentionBackendRegistry.register(
|
839
|
+
AttentionBackendName._NATIVE_EFFICIENT,
|
840
|
+
constraints=[_check_device, _check_shape],
|
841
|
+
)
|
842
|
+
def _native_efficient_attention(
|
843
|
+
query: torch.Tensor,
|
844
|
+
key: torch.Tensor,
|
845
|
+
value: torch.Tensor,
|
846
|
+
attn_mask: Optional[torch.Tensor] = None,
|
847
|
+
dropout_p: float = 0.0,
|
848
|
+
is_causal: bool = False,
|
849
|
+
scale: Optional[float] = None,
|
850
|
+
enable_gqa: bool = False,
|
851
|
+
) -> torch.Tensor:
|
852
|
+
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
853
|
+
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION):
|
854
|
+
out = torch.nn.functional.scaled_dot_product_attention(
|
855
|
+
query=query,
|
856
|
+
key=key,
|
857
|
+
value=value,
|
858
|
+
attn_mask=attn_mask,
|
859
|
+
dropout_p=dropout_p,
|
860
|
+
is_causal=is_causal,
|
861
|
+
scale=scale,
|
862
|
+
enable_gqa=enable_gqa,
|
863
|
+
)
|
864
|
+
out = out.permute(0, 2, 1, 3)
|
865
|
+
return out
|
866
|
+
|
867
|
+
|
868
|
+
@_AttentionBackendRegistry.register(
|
869
|
+
AttentionBackendName._NATIVE_FLASH,
|
870
|
+
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
871
|
+
)
|
872
|
+
def _native_flash_attention(
|
873
|
+
query: torch.Tensor,
|
874
|
+
key: torch.Tensor,
|
875
|
+
value: torch.Tensor,
|
876
|
+
dropout_p: float = 0.0,
|
877
|
+
is_causal: bool = False,
|
878
|
+
scale: Optional[float] = None,
|
879
|
+
enable_gqa: bool = False,
|
880
|
+
) -> torch.Tensor:
|
881
|
+
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
882
|
+
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
|
883
|
+
out = torch.nn.functional.scaled_dot_product_attention(
|
884
|
+
query=query,
|
885
|
+
key=key,
|
886
|
+
value=value,
|
887
|
+
attn_mask=None, # not supported
|
888
|
+
dropout_p=dropout_p,
|
889
|
+
is_causal=is_causal,
|
890
|
+
scale=scale,
|
891
|
+
enable_gqa=enable_gqa,
|
892
|
+
)
|
893
|
+
out = out.permute(0, 2, 1, 3)
|
894
|
+
return out
|
895
|
+
|
896
|
+
|
897
|
+
@_AttentionBackendRegistry.register(
|
898
|
+
AttentionBackendName._NATIVE_MATH,
|
899
|
+
constraints=[_check_device, _check_shape],
|
900
|
+
)
|
901
|
+
def _native_math_attention(
|
902
|
+
query: torch.Tensor,
|
903
|
+
key: torch.Tensor,
|
904
|
+
value: torch.Tensor,
|
905
|
+
attn_mask: Optional[torch.Tensor] = None,
|
906
|
+
dropout_p: float = 0.0,
|
907
|
+
is_causal: bool = False,
|
908
|
+
scale: Optional[float] = None,
|
909
|
+
enable_gqa: bool = False,
|
910
|
+
) -> torch.Tensor:
|
911
|
+
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
912
|
+
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH):
|
913
|
+
out = torch.nn.functional.scaled_dot_product_attention(
|
914
|
+
query=query,
|
915
|
+
key=key,
|
916
|
+
value=value,
|
917
|
+
attn_mask=attn_mask,
|
918
|
+
dropout_p=dropout_p,
|
919
|
+
is_causal=is_causal,
|
920
|
+
scale=scale,
|
921
|
+
enable_gqa=enable_gqa,
|
922
|
+
)
|
923
|
+
out = out.permute(0, 2, 1, 3)
|
924
|
+
return out
|
925
|
+
|
926
|
+
|
927
|
+
@_AttentionBackendRegistry.register(
|
928
|
+
AttentionBackendName._NATIVE_NPU,
|
929
|
+
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
930
|
+
)
|
931
|
+
def _native_npu_attention(
|
932
|
+
query: torch.Tensor,
|
933
|
+
key: torch.Tensor,
|
934
|
+
value: torch.Tensor,
|
935
|
+
dropout_p: float = 0.0,
|
936
|
+
scale: Optional[float] = None,
|
937
|
+
) -> torch.Tensor:
|
938
|
+
return npu_fusion_attention(
|
939
|
+
query,
|
940
|
+
key,
|
941
|
+
value,
|
942
|
+
query.size(2), # num_heads
|
943
|
+
input_layout="BSND",
|
944
|
+
pse=None,
|
945
|
+
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
|
946
|
+
pre_tockens=65536,
|
947
|
+
next_tockens=65536,
|
948
|
+
keep_prob=1.0 - dropout_p,
|
949
|
+
sync=False,
|
950
|
+
inner_precise=0,
|
951
|
+
)[0]
|
952
|
+
|
953
|
+
|
954
|
+
# Reference: https://github.com/pytorch/xla/blob/06c5533de6588f6b90aa1655d9850bcf733b90b4/torch_xla/experimental/custom_kernel.py#L853
|
955
|
+
@_AttentionBackendRegistry.register(
|
956
|
+
AttentionBackendName._NATIVE_XLA,
|
957
|
+
constraints=[_check_device, _check_shape],
|
958
|
+
)
|
959
|
+
def _native_xla_attention(
|
960
|
+
query: torch.Tensor,
|
961
|
+
key: torch.Tensor,
|
962
|
+
value: torch.Tensor,
|
963
|
+
is_causal: bool = False,
|
964
|
+
) -> torch.Tensor:
|
965
|
+
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
966
|
+
query = query / math.sqrt(query.shape[-1])
|
967
|
+
out = xla_flash_attention(
|
968
|
+
q=query,
|
969
|
+
k=key,
|
970
|
+
v=value,
|
971
|
+
causal=is_causal,
|
972
|
+
)
|
973
|
+
out = out.permute(0, 2, 1, 3)
|
974
|
+
return out
|
975
|
+
|
976
|
+
|
977
|
+
@_AttentionBackendRegistry.register(
|
978
|
+
AttentionBackendName.SAGE,
|
979
|
+
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
980
|
+
)
|
981
|
+
def _sage_attention(
|
982
|
+
query: torch.Tensor,
|
983
|
+
key: torch.Tensor,
|
984
|
+
value: torch.Tensor,
|
985
|
+
is_causal: bool = False,
|
986
|
+
scale: Optional[float] = None,
|
987
|
+
return_lse: bool = False,
|
988
|
+
) -> torch.Tensor:
|
989
|
+
return sageattn(
|
990
|
+
q=query,
|
991
|
+
k=key,
|
992
|
+
v=value,
|
993
|
+
tensor_layout="NHD",
|
994
|
+
is_causal=is_causal,
|
995
|
+
sm_scale=scale,
|
996
|
+
return_lse=return_lse,
|
997
|
+
)
|
998
|
+
|
999
|
+
|
1000
|
+
@_AttentionBackendRegistry.register(
|
1001
|
+
AttentionBackendName.SAGE_VARLEN,
|
1002
|
+
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
1003
|
+
)
|
1004
|
+
def _sage_varlen_attention(
|
1005
|
+
query: torch.Tensor,
|
1006
|
+
key: torch.Tensor,
|
1007
|
+
value: torch.Tensor,
|
1008
|
+
cu_seqlens_q: Optional[torch.Tensor] = None,
|
1009
|
+
cu_seqlens_k: Optional[torch.Tensor] = None,
|
1010
|
+
max_seqlen_q: Optional[int] = None,
|
1011
|
+
max_seqlen_k: Optional[int] = None,
|
1012
|
+
is_causal: bool = False,
|
1013
|
+
scale: Optional[float] = None,
|
1014
|
+
smooth_k: bool = True,
|
1015
|
+
attn_mask: Optional[torch.Tensor] = None,
|
1016
|
+
) -> torch.Tensor:
|
1017
|
+
batch_size, seq_len_q, _, _ = query.shape
|
1018
|
+
_, seq_len_kv, _, _ = key.shape
|
1019
|
+
|
1020
|
+
if attn_mask is not None:
|
1021
|
+
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
|
1022
|
+
|
1023
|
+
if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
|
1024
|
+
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
|
1025
|
+
_prepare_for_flash_attn_or_sage_varlen(
|
1026
|
+
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
|
1027
|
+
)
|
1028
|
+
)
|
1029
|
+
else:
|
1030
|
+
seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
|
1031
|
+
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
|
1032
|
+
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
|
1033
|
+
|
1034
|
+
key_valid, value_valid = [], []
|
1035
|
+
for b in range(batch_size):
|
1036
|
+
valid_len = seqlens_k[b]
|
1037
|
+
key_valid.append(key[b, :valid_len])
|
1038
|
+
value_valid.append(value[b, :valid_len])
|
1039
|
+
|
1040
|
+
query_packed = query.flatten(0, 1)
|
1041
|
+
key_packed = torch.cat(key_valid, dim=0)
|
1042
|
+
value_packed = torch.cat(value_valid, dim=0)
|
1043
|
+
|
1044
|
+
out = sageattn_varlen(
|
1045
|
+
q=query_packed,
|
1046
|
+
k=key_packed,
|
1047
|
+
v=value_packed,
|
1048
|
+
cu_seqlens_q=cu_seqlens_q,
|
1049
|
+
cu_seqlens_k=cu_seqlens_k,
|
1050
|
+
max_seqlen_q=max_seqlen_q,
|
1051
|
+
max_seqlen_k=max_seqlen_k,
|
1052
|
+
is_causal=is_causal,
|
1053
|
+
sm_scale=scale,
|
1054
|
+
smooth_k=smooth_k,
|
1055
|
+
)
|
1056
|
+
out = out.unflatten(0, (batch_size, -1))
|
1057
|
+
|
1058
|
+
return out
|
1059
|
+
|
1060
|
+
|
1061
|
+
@_AttentionBackendRegistry.register(
|
1062
|
+
AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA,
|
1063
|
+
constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape],
|
1064
|
+
)
|
1065
|
+
def _sage_qk_int8_pv_fp8_cuda_attention(
|
1066
|
+
query: torch.Tensor,
|
1067
|
+
key: torch.Tensor,
|
1068
|
+
value: torch.Tensor,
|
1069
|
+
is_causal: bool = False,
|
1070
|
+
scale: Optional[float] = None,
|
1071
|
+
qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread",
|
1072
|
+
pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32",
|
1073
|
+
smooth_k: bool = True,
|
1074
|
+
smooth_v: bool = False,
|
1075
|
+
return_lse: bool = False,
|
1076
|
+
) -> torch.Tensor:
|
1077
|
+
return sageattn_qk_int8_pv_fp8_cuda(
|
1078
|
+
q=query,
|
1079
|
+
k=key,
|
1080
|
+
v=value,
|
1081
|
+
tensor_layout="NHD",
|
1082
|
+
is_causal=is_causal,
|
1083
|
+
qk_quant_gran=qk_quant_gran,
|
1084
|
+
sm_scale=scale,
|
1085
|
+
pv_accum_dtype=pv_accum_dtype,
|
1086
|
+
smooth_k=smooth_k,
|
1087
|
+
smooth_v=smooth_v,
|
1088
|
+
return_lse=return_lse,
|
1089
|
+
)
|
1090
|
+
|
1091
|
+
|
1092
|
+
@_AttentionBackendRegistry.register(
|
1093
|
+
AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90,
|
1094
|
+
constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape],
|
1095
|
+
)
|
1096
|
+
def _sage_qk_int8_pv_fp8_cuda_sm90_attention(
|
1097
|
+
query: torch.Tensor,
|
1098
|
+
key: torch.Tensor,
|
1099
|
+
value: torch.Tensor,
|
1100
|
+
is_causal: bool = False,
|
1101
|
+
scale: Optional[float] = None,
|
1102
|
+
qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread",
|
1103
|
+
pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32",
|
1104
|
+
smooth_k: bool = True,
|
1105
|
+
return_lse: bool = False,
|
1106
|
+
) -> torch.Tensor:
|
1107
|
+
return sageattn_qk_int8_pv_fp8_cuda_sm90(
|
1108
|
+
q=query,
|
1109
|
+
k=key,
|
1110
|
+
v=value,
|
1111
|
+
tensor_layout="NHD",
|
1112
|
+
is_causal=is_causal,
|
1113
|
+
qk_quant_gran=qk_quant_gran,
|
1114
|
+
sm_scale=scale,
|
1115
|
+
pv_accum_dtype=pv_accum_dtype,
|
1116
|
+
smooth_k=smooth_k,
|
1117
|
+
return_lse=return_lse,
|
1118
|
+
)
|
1119
|
+
|
1120
|
+
|
1121
|
+
@_AttentionBackendRegistry.register(
|
1122
|
+
AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA,
|
1123
|
+
constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape],
|
1124
|
+
)
|
1125
|
+
def _sage_qk_int8_pv_fp16_cuda_attention(
|
1126
|
+
query: torch.Tensor,
|
1127
|
+
key: torch.Tensor,
|
1128
|
+
value: torch.Tensor,
|
1129
|
+
is_causal: bool = False,
|
1130
|
+
scale: Optional[float] = None,
|
1131
|
+
qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread",
|
1132
|
+
pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32",
|
1133
|
+
smooth_k: bool = True,
|
1134
|
+
smooth_v: bool = False,
|
1135
|
+
return_lse: bool = False,
|
1136
|
+
) -> torch.Tensor:
|
1137
|
+
return sageattn_qk_int8_pv_fp16_cuda(
|
1138
|
+
q=query,
|
1139
|
+
k=key,
|
1140
|
+
v=value,
|
1141
|
+
tensor_layout="NHD",
|
1142
|
+
is_causal=is_causal,
|
1143
|
+
qk_quant_gran=qk_quant_gran,
|
1144
|
+
sm_scale=scale,
|
1145
|
+
pv_accum_dtype=pv_accum_dtype,
|
1146
|
+
smooth_k=smooth_k,
|
1147
|
+
smooth_v=smooth_v,
|
1148
|
+
return_lse=return_lse,
|
1149
|
+
)
|
1150
|
+
|
1151
|
+
|
1152
|
+
@_AttentionBackendRegistry.register(
|
1153
|
+
AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON,
|
1154
|
+
constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape],
|
1155
|
+
)
|
1156
|
+
def _sage_qk_int8_pv_fp16_triton_attention(
|
1157
|
+
query: torch.Tensor,
|
1158
|
+
key: torch.Tensor,
|
1159
|
+
value: torch.Tensor,
|
1160
|
+
is_causal: bool = False,
|
1161
|
+
scale: Optional[float] = None,
|
1162
|
+
quantization_backend: _SAGE_ATTENTION_QUANTIZATION_BACKEND = "triton",
|
1163
|
+
smooth_k: bool = True,
|
1164
|
+
return_lse: bool = False,
|
1165
|
+
) -> torch.Tensor:
|
1166
|
+
return sageattn_qk_int8_pv_fp16_triton(
|
1167
|
+
q=query,
|
1168
|
+
k=key,
|
1169
|
+
v=value,
|
1170
|
+
tensor_layout="NHD",
|
1171
|
+
quantization_backend=quantization_backend,
|
1172
|
+
is_causal=is_causal,
|
1173
|
+
sm_scale=scale,
|
1174
|
+
smooth_k=smooth_k,
|
1175
|
+
return_lse=return_lse,
|
1176
|
+
)
|
1177
|
+
|
1178
|
+
|
1179
|
+
@_AttentionBackendRegistry.register(
|
1180
|
+
AttentionBackendName.XFORMERS,
|
1181
|
+
constraints=[_check_attn_mask_or_causal, _check_device, _check_shape],
|
1182
|
+
)
|
1183
|
+
def _xformers_attention(
|
1184
|
+
query: torch.Tensor,
|
1185
|
+
key: torch.Tensor,
|
1186
|
+
value: torch.Tensor,
|
1187
|
+
attn_mask: Optional[torch.Tensor] = None,
|
1188
|
+
dropout_p: float = 0.0,
|
1189
|
+
is_causal: bool = False,
|
1190
|
+
scale: Optional[float] = None,
|
1191
|
+
enable_gqa: bool = False,
|
1192
|
+
) -> torch.Tensor:
|
1193
|
+
batch_size, seq_len_q, num_heads_q, _ = query.shape
|
1194
|
+
_, seq_len_kv, num_heads_kv, _ = key.shape
|
1195
|
+
|
1196
|
+
if is_causal:
|
1197
|
+
attn_mask = xops.LowerTriangularMask()
|
1198
|
+
elif attn_mask is not None:
|
1199
|
+
if attn_mask.ndim == 2:
|
1200
|
+
attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1)
|
1201
|
+
elif attn_mask.ndim != 4:
|
1202
|
+
raise ValueError("Only 2D and 4D attention masks are supported for xformers attention.")
|
1203
|
+
attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query)
|
1204
|
+
|
1205
|
+
if enable_gqa:
|
1206
|
+
if num_heads_q % num_heads_kv != 0:
|
1207
|
+
raise ValueError("Number of heads in query must be divisible by number of heads in key/value.")
|
1208
|
+
num_heads_per_group = num_heads_q // num_heads_kv
|
1209
|
+
query = query.unflatten(2, (num_heads_kv, -1))
|
1210
|
+
key = key.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1)
|
1211
|
+
value = value.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1)
|
1212
|
+
|
1213
|
+
out = xops.memory_efficient_attention(query, key, value, attn_mask, dropout_p, scale)
|
1214
|
+
|
1215
|
+
if enable_gqa:
|
1216
|
+
out = out.flatten(2, 3)
|
1217
|
+
|
1218
|
+
return out
|