vllm-npu 0.4.2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- vllm/__init__.py +23 -0
- vllm/_custom_ops.py +251 -0
- vllm/attention/__init__.py +13 -0
- vllm/attention/backends/__init__.py +0 -0
- vllm/attention/backends/abstract.py +127 -0
- vllm/attention/backends/flash_attn.py +271 -0
- vllm/attention/backends/flashinfer.py +220 -0
- vllm/attention/backends/rocm_flash_attn.py +374 -0
- vllm/attention/backends/torch_sdpa.py +250 -0
- vllm/attention/backends/xformers.py +393 -0
- vllm/attention/layer.py +56 -0
- vllm/attention/ops/__init__.py +0 -0
- vllm/attention/ops/paged_attn.py +216 -0
- vllm/attention/ops/prefix_prefill.py +792 -0
- vllm/attention/ops/triton_flash_attention.py +810 -0
- vllm/attention/selector.py +91 -0
- vllm/block.py +84 -0
- vllm/config.py +1225 -0
- vllm/core/__init__.py +0 -0
- vllm/core/block/__init__.py +0 -0
- vllm/core/block/block_table.py +295 -0
- vllm/core/block/common.py +199 -0
- vllm/core/block/cpu_gpu_block_allocator.py +228 -0
- vllm/core/block/interfaces.py +205 -0
- vllm/core/block/naive_block.py +318 -0
- vllm/core/block/prefix_caching_block.py +606 -0
- vllm/core/block_manager_v1.py +625 -0
- vllm/core/block_manager_v2.py +258 -0
- vllm/core/evictor_v1.py +105 -0
- vllm/core/evictor_v2.py +127 -0
- vllm/core/interfaces.py +113 -0
- vllm/core/policy.py +45 -0
- vllm/core/scheduler.py +1163 -0
- vllm/distributed/__init__.py +3 -0
- vllm/distributed/communication_op.py +237 -0
- vllm/distributed/device_communicators/__init__.py +0 -0
- vllm/distributed/device_communicators/custom_all_reduce.py +274 -0
- vllm/distributed/device_communicators/pynccl.py +287 -0
- vllm/distributed/device_communicators/pynccl_utils.py +66 -0
- vllm/distributed/parallel_state.py +339 -0
- vllm/distributed/utils.py +136 -0
- vllm/engine/__init__.py +0 -0
- vllm/engine/arg_utils.py +649 -0
- vllm/engine/async_llm_engine.py +737 -0
- vllm/engine/llm_engine.py +784 -0
- vllm/engine/metrics.py +368 -0
- vllm/engine/output_processor/__init__.py +0 -0
- vllm/engine/output_processor/interfaces.py +76 -0
- vllm/engine/output_processor/multi_step.py +142 -0
- vllm/engine/output_processor/single_step.py +284 -0
- vllm/engine/output_processor/stop_checker.py +101 -0
- vllm/engine/output_processor/util.py +19 -0
- vllm/entrypoints/__init__.py +0 -0
- vllm/entrypoints/api_server.py +119 -0
- vllm/entrypoints/llm.py +259 -0
- vllm/entrypoints/openai/__init__.py +0 -0
- vllm/entrypoints/openai/api_server.py +186 -0
- vllm/entrypoints/openai/cli_args.py +115 -0
- vllm/entrypoints/openai/protocol.py +460 -0
- vllm/entrypoints/openai/serving_chat.py +392 -0
- vllm/entrypoints/openai/serving_completion.py +347 -0
- vllm/entrypoints/openai/serving_engine.py +234 -0
- vllm/envs.py +217 -0
- vllm/executor/__init__.py +0 -0
- vllm/executor/cpu_executor.py +152 -0
- vllm/executor/distributed_gpu_executor.py +115 -0
- vllm/executor/executor_base.py +115 -0
- vllm/executor/gpu_executor.py +150 -0
- vllm/executor/multiproc_worker_utils.py +263 -0
- vllm/executor/neuron_executor.py +91 -0
- vllm/executor/ray_gpu_executor.py +327 -0
- vllm/executor/ray_utils.py +119 -0
- vllm/logger.py +153 -0
- vllm/logging/__init__.py +5 -0
- vllm/logging/formatter.py +15 -0
- vllm/lora/__init__.py +0 -0
- vllm/lora/fully_sharded_layers.py +262 -0
- vllm/lora/layers.py +1181 -0
- vllm/lora/lora.py +167 -0
- vllm/lora/models.py +645 -0
- vllm/lora/punica.py +213 -0
- vllm/lora/request.py +32 -0
- vllm/lora/utils.py +98 -0
- vllm/lora/worker_manager.py +251 -0
- vllm/model_executor/__init__.py +7 -0
- vllm/model_executor/guided_decoding/__init__.py +25 -0
- vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +70 -0
- vllm/model_executor/guided_decoding/outlines_decoding.py +130 -0
- vllm/model_executor/guided_decoding/outlines_logits_processors.py +184 -0
- vllm/model_executor/layers/__init__.py +0 -0
- vllm/model_executor/layers/activation.py +173 -0
- vllm/model_executor/layers/fused_moe/__init__.py +7 -0
- vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +140 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/fused_moe.py +479 -0
- vllm/model_executor/layers/layernorm.py +71 -0
- vllm/model_executor/layers/linear.py +709 -0
- vllm/model_executor/layers/logits_processor.py +115 -0
- vllm/model_executor/layers/ops/__init__.py +0 -0
- vllm/model_executor/layers/ops/rand.py +157 -0
- vllm/model_executor/layers/ops/sample.py +406 -0
- vllm/model_executor/layers/quantization/__init__.py +35 -0
- vllm/model_executor/layers/quantization/aqlm.py +376 -0
- vllm/model_executor/layers/quantization/awq.py +175 -0
- vllm/model_executor/layers/quantization/base_config.py +97 -0
- vllm/model_executor/layers/quantization/fp8.py +265 -0
- vllm/model_executor/layers/quantization/gptq.py +224 -0
- vllm/model_executor/layers/quantization/gptq_marlin.py +438 -0
- vllm/model_executor/layers/quantization/marlin.py +227 -0
- vllm/model_executor/layers/quantization/schema.py +84 -0
- vllm/model_executor/layers/quantization/squeezellm.py +137 -0
- vllm/model_executor/layers/rejection_sampler.py +405 -0
- vllm/model_executor/layers/rotary_embedding.py +525 -0
- vllm/model_executor/layers/sampler.py +1051 -0
- vllm/model_executor/layers/vocab_parallel_embedding.py +155 -0
- vllm/model_executor/model_loader/__init__.py +30 -0
- vllm/model_executor/model_loader/loader.py +362 -0
- vllm/model_executor/model_loader/neuron.py +136 -0
- vllm/model_executor/model_loader/tensorizer.py +368 -0
- vllm/model_executor/model_loader/utils.py +41 -0
- vllm/model_executor/model_loader/weight_utils.py +372 -0
- vllm/model_executor/models/__init__.py +119 -0
- vllm/model_executor/models/baichuan.py +410 -0
- vllm/model_executor/models/bloom.py +327 -0
- vllm/model_executor/models/chatglm.py +386 -0
- vllm/model_executor/models/commandr.py +373 -0
- vllm/model_executor/models/dbrx.py +413 -0
- vllm/model_executor/models/decilm.py +122 -0
- vllm/model_executor/models/deepseek.py +438 -0
- vllm/model_executor/models/falcon.py +444 -0
- vllm/model_executor/models/gemma.py +393 -0
- vllm/model_executor/models/gpt2.py +266 -0
- vllm/model_executor/models/gpt_bigcode.py +274 -0
- vllm/model_executor/models/gpt_j.py +281 -0
- vllm/model_executor/models/gpt_neox.py +295 -0
- vllm/model_executor/models/internlm2.py +323 -0
- vllm/model_executor/models/jais.py +333 -0
- vllm/model_executor/models/llama.py +442 -0
- vllm/model_executor/models/llava.py +239 -0
- vllm/model_executor/models/minicpm.py +531 -0
- vllm/model_executor/models/mixtral.py +583 -0
- vllm/model_executor/models/mixtral_quant.py +404 -0
- vllm/model_executor/models/mpt.py +295 -0
- vllm/model_executor/models/olmo.py +356 -0
- vllm/model_executor/models/opt.py +349 -0
- vllm/model_executor/models/orion.py +319 -0
- vllm/model_executor/models/phi.py +300 -0
- vllm/model_executor/models/qwen.py +284 -0
- vllm/model_executor/models/qwen2.py +367 -0
- vllm/model_executor/models/qwen2_moe.py +447 -0
- vllm/model_executor/models/stablelm.py +301 -0
- vllm/model_executor/models/starcoder2.py +302 -0
- vllm/model_executor/models/xverse.py +366 -0
- vllm/model_executor/sampling_metadata.py +588 -0
- vllm/model_executor/utils.py +35 -0
- vllm/outputs.py +150 -0
- vllm/py.typed +2 -0
- vllm/sampling_params.py +340 -0
- vllm/sequence.py +766 -0
- vllm/spec_decode/__init__.py +0 -0
- vllm/spec_decode/batch_expansion.py +397 -0
- vllm/spec_decode/interfaces.py +73 -0
- vllm/spec_decode/metrics.py +191 -0
- vllm/spec_decode/multi_step_worker.py +203 -0
- vllm/spec_decode/ngram_worker.py +176 -0
- vllm/spec_decode/spec_decode_worker.py +472 -0
- vllm/spec_decode/top1_proposer.py +200 -0
- vllm/spec_decode/util.py +228 -0
- vllm/test_utils.py +41 -0
- vllm/transformers_utils/__init__.py +0 -0
- vllm/transformers_utils/config.py +58 -0
- vllm/transformers_utils/configs/__init__.py +16 -0
- vllm/transformers_utils/configs/chatglm.py +68 -0
- vllm/transformers_utils/configs/dbrx.py +278 -0
- vllm/transformers_utils/configs/falcon.py +87 -0
- vllm/transformers_utils/configs/jais.py +236 -0
- vllm/transformers_utils/configs/mpt.py +178 -0
- vllm/transformers_utils/detokenizer.py +313 -0
- vllm/transformers_utils/tokenizer.py +149 -0
- vllm/transformers_utils/tokenizer_group/__init__.py +33 -0
- vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py +55 -0
- vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +169 -0
- vllm/transformers_utils/tokenizer_group/tokenizer_group.py +78 -0
- vllm/transformers_utils/tokenizers/__init__.py +5 -0
- vllm/transformers_utils/tokenizers/baichuan.py +255 -0
- vllm/usage/__init__.py +0 -0
- vllm/usage/usage_lib.py +209 -0
- vllm/utils.py +677 -0
- vllm/worker/__init__.py +0 -0
- vllm/worker/cache_engine.py +105 -0
- vllm/worker/cpu_model_runner.py +346 -0
- vllm/worker/cpu_worker.py +321 -0
- vllm/worker/model_runner.py +1168 -0
- vllm/worker/neuron_model_runner.py +196 -0
- vllm/worker/neuron_worker.py +98 -0
- vllm/worker/worker.py +345 -0
- vllm/worker/worker_base.py +146 -0
- vllm_npu-0.4.2.dist-info/LICENSE +201 -0
- vllm_npu-0.4.2.dist-info/METADATA +173 -0
- vllm_npu-0.4.2.dist-info/RECORD +219 -0
- vllm_npu-0.4.2.dist-info/WHEEL +5 -0
- vllm_npu-0.4.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,810 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
"""
|
3
|
+
Fused Attention
|
4
|
+
===============
|
5
|
+
|
6
|
+
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao
|
7
|
+
(https://tridao.me/publications/flash2/flash2.pdf)
|
8
|
+
Credits: OpenAI kernel team, AMD ML Frameworks Triton team
|
9
|
+
|
10
|
+
Features supported:
|
11
|
+
|
12
|
+
1) Fwd with causal masking
|
13
|
+
2) Any sequence lengths without padding (currently fwd kernel only)
|
14
|
+
3) Support for different sequence lengths for q and k
|
15
|
+
4) Nested tensor API currently does not support dropout or bias.
|
16
|
+
|
17
|
+
Not currently supported:
|
18
|
+
|
19
|
+
1) Non power of two head dims
|
20
|
+
|
21
|
+
"""
|
22
|
+
|
23
|
+
import torch
|
24
|
+
import triton
|
25
|
+
import triton.language as tl
|
26
|
+
|
27
|
+
torch_dtype: tl.constexpr = torch.float16
|
28
|
+
|
29
|
+
|
30
|
+
@triton.jit
|
31
|
+
def cdiv_fn(x, y):
|
32
|
+
return (x + y - 1) // y
|
33
|
+
|
34
|
+
|
35
|
+
@triton.jit
|
36
|
+
def max_fn(x, y):
|
37
|
+
return tl.math.max(x, y)
|
38
|
+
|
39
|
+
|
40
|
+
@triton.jit
|
41
|
+
def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride):
|
42
|
+
ms = tl.arange(0, m)
|
43
|
+
ns = tl.arange(0, n)
|
44
|
+
return philox_offset + ms[:, None] * stride + ns[None, :]
|
45
|
+
|
46
|
+
|
47
|
+
@triton.jit
|
48
|
+
def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride):
|
49
|
+
rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n,
|
50
|
+
stride).to(tl.uint32)
|
51
|
+
# TODO: use tl.randint for better performance
|
52
|
+
return tl.rand(philox_seed, rng_offsets)
|
53
|
+
|
54
|
+
|
55
|
+
@triton.jit
|
56
|
+
def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):
|
57
|
+
rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n,
|
58
|
+
stride)
|
59
|
+
rng_keep = rng_output > dropout_p
|
60
|
+
return rng_keep
|
61
|
+
|
62
|
+
|
63
|
+
@triton.jit
|
64
|
+
def load_fn(block_ptr, first, second, pad):
|
65
|
+
if first and second:
|
66
|
+
tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad)
|
67
|
+
elif first:
|
68
|
+
tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad)
|
69
|
+
elif second:
|
70
|
+
tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad)
|
71
|
+
else:
|
72
|
+
tensor = tl.load(block_ptr)
|
73
|
+
return tensor
|
74
|
+
|
75
|
+
|
76
|
+
@triton.jit
|
77
|
+
def _attn_fwd_inner(
|
78
|
+
acc,
|
79
|
+
l_i,
|
80
|
+
m_i,
|
81
|
+
q,
|
82
|
+
K_block_ptr,
|
83
|
+
V_block_ptr,
|
84
|
+
start_m,
|
85
|
+
actual_seqlen_k,
|
86
|
+
dropout_p,
|
87
|
+
philox_seed,
|
88
|
+
batch_philox_offset,
|
89
|
+
encoded_softmax_block_ptr,
|
90
|
+
block_min,
|
91
|
+
block_max,
|
92
|
+
offs_n_causal,
|
93
|
+
masked_blocks,
|
94
|
+
n_extra_tokens,
|
95
|
+
bias_ptr,
|
96
|
+
IS_CAUSAL: tl.constexpr,
|
97
|
+
BLOCK_M: tl.constexpr,
|
98
|
+
BLOCK_DMODEL: tl.constexpr,
|
99
|
+
BLOCK_N: tl.constexpr,
|
100
|
+
OFFS_M: tl.constexpr,
|
101
|
+
OFFS_N: tl.constexpr,
|
102
|
+
PRE_LOAD_V: tl.constexpr,
|
103
|
+
MASK_STEPS: tl.constexpr,
|
104
|
+
ENABLE_DROPOUT: tl.constexpr,
|
105
|
+
RETURN_ENCODED_SOFTMAX: tl.constexpr,
|
106
|
+
PADDED_HEAD: tl.constexpr,
|
107
|
+
):
|
108
|
+
# loop over k, v, and update accumulator
|
109
|
+
for start_n in range(block_min, block_max, BLOCK_N):
|
110
|
+
# For padded blocks, we will overrun the tensor size if
|
111
|
+
# we load all BLOCK_N. For others, the blocks are all within range.
|
112
|
+
k = load_fn(
|
113
|
+
K_block_ptr,
|
114
|
+
PADDED_HEAD,
|
115
|
+
MASK_STEPS and (n_extra_tokens != 0),
|
116
|
+
"zero",
|
117
|
+
)
|
118
|
+
if PRE_LOAD_V:
|
119
|
+
v = load_fn(
|
120
|
+
V_block_ptr,
|
121
|
+
MASK_STEPS and (n_extra_tokens != 0),
|
122
|
+
PADDED_HEAD,
|
123
|
+
"zero",
|
124
|
+
)
|
125
|
+
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
126
|
+
# We start from end of seqlen_k so only the first iteration would need
|
127
|
+
# to be checked for padding if it is not a multiple of block_n
|
128
|
+
# TODO: This can be optimized to only be true for the padded block.
|
129
|
+
if MASK_STEPS: # noqa: SIM102
|
130
|
+
# If this is the last block / iteration, we want to
|
131
|
+
# mask if the sequence length is not a multiple of block size
|
132
|
+
# a solution is to always do BLOCK_M // BLOCK_N + 1 steps
|
133
|
+
# if not is_modulo_mn. last step might get wasted but that is okay.
|
134
|
+
# check if this masking works for that case.
|
135
|
+
if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0):
|
136
|
+
boundary_m = tl.full([BLOCK_M],
|
137
|
+
actual_seqlen_k,
|
138
|
+
dtype=tl.int32)
|
139
|
+
size_n = start_n + OFFS_N[None, :]
|
140
|
+
mask = size_n < boundary_m[:, None]
|
141
|
+
qk = tl.where(mask, qk, float("-inf"))
|
142
|
+
if IS_CAUSAL:
|
143
|
+
causal_boundary = start_n + offs_n_causal
|
144
|
+
causal_mask = OFFS_M[:, None] >= causal_boundary[None, :]
|
145
|
+
qk = tl.where(causal_mask, qk, float("-inf"))
|
146
|
+
# -- compute qk ----
|
147
|
+
qk += tl.dot(q, k)
|
148
|
+
if bias_ptr is not None:
|
149
|
+
bias = load_fn(bias_ptr, False, MASK_STEPS
|
150
|
+
and (n_extra_tokens != 0), "zero")
|
151
|
+
# While bias is added after multiplying qk with sm_scale, our
|
152
|
+
# optimization to use 2^x instead of e^x results in an additional
|
153
|
+
# scale factor of log2(e) which we must also multiply the bias with.
|
154
|
+
qk += bias * 1.44269504089
|
155
|
+
m_ij = tl.maximum(m_i, tl.max(qk, 1))
|
156
|
+
qk = qk - m_ij[:, None]
|
157
|
+
p = tl.math.exp2(qk)
|
158
|
+
|
159
|
+
# CAVEAT: Must update l_ij before applying dropout
|
160
|
+
l_ij = tl.sum(p, 1)
|
161
|
+
if ENABLE_DROPOUT:
|
162
|
+
philox_offset = (batch_philox_offset +
|
163
|
+
start_m * BLOCK_M * actual_seqlen_k + start_n -
|
164
|
+
BLOCK_N)
|
165
|
+
keep = dropout_mask(
|
166
|
+
philox_seed,
|
167
|
+
philox_offset,
|
168
|
+
dropout_p,
|
169
|
+
BLOCK_M,
|
170
|
+
BLOCK_N,
|
171
|
+
actual_seqlen_k,
|
172
|
+
)
|
173
|
+
if RETURN_ENCODED_SOFTMAX:
|
174
|
+
tl.store(
|
175
|
+
encoded_softmax_block_ptr,
|
176
|
+
tl.where(keep, p,
|
177
|
+
-p).to(encoded_softmax_block_ptr.type.element_ty),
|
178
|
+
)
|
179
|
+
p = tl.where(keep, p, 0.0)
|
180
|
+
elif RETURN_ENCODED_SOFTMAX:
|
181
|
+
tl.store(
|
182
|
+
encoded_softmax_block_ptr,
|
183
|
+
p.to(encoded_softmax_block_ptr.type.element_ty),
|
184
|
+
)
|
185
|
+
# -- update output accumulator --
|
186
|
+
alpha = tl.math.exp2(m_i - m_ij)
|
187
|
+
acc = acc * alpha[:, None]
|
188
|
+
if not PRE_LOAD_V:
|
189
|
+
v = load_fn(
|
190
|
+
V_block_ptr,
|
191
|
+
MASK_STEPS and (n_extra_tokens != 0),
|
192
|
+
PADDED_HEAD,
|
193
|
+
"zero",
|
194
|
+
)
|
195
|
+
# -- update m_i and l_i
|
196
|
+
l_i = l_i * alpha + l_ij
|
197
|
+
# update m_i and l_i
|
198
|
+
m_i = m_ij
|
199
|
+
acc += tl.dot(p.to(V_block_ptr.type.element_ty), v)
|
200
|
+
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
|
201
|
+
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
|
202
|
+
if bias_ptr is not None:
|
203
|
+
bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N))
|
204
|
+
if RETURN_ENCODED_SOFTMAX:
|
205
|
+
encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr,
|
206
|
+
(0, BLOCK_N))
|
207
|
+
return acc, l_i, m_i
|
208
|
+
|
209
|
+
|
210
|
+
@triton.autotune(
|
211
|
+
configs=[
|
212
|
+
triton.Config(
|
213
|
+
{
|
214
|
+
"BLOCK_M": 256,
|
215
|
+
"BLOCK_N": 64,
|
216
|
+
"waves_per_eu": 2,
|
217
|
+
"PRE_LOAD_V": False,
|
218
|
+
},
|
219
|
+
num_stages=1,
|
220
|
+
num_warps=8,
|
221
|
+
),
|
222
|
+
triton.Config(
|
223
|
+
{
|
224
|
+
"BLOCK_M": 128,
|
225
|
+
"BLOCK_N": 128,
|
226
|
+
"waves_per_eu": 2,
|
227
|
+
"PRE_LOAD_V": False,
|
228
|
+
},
|
229
|
+
num_stages=1,
|
230
|
+
num_warps=4,
|
231
|
+
),
|
232
|
+
triton.Config(
|
233
|
+
{
|
234
|
+
"BLOCK_M": 256,
|
235
|
+
"BLOCK_N": 128,
|
236
|
+
"waves_per_eu": 2,
|
237
|
+
"PRE_LOAD_V": False,
|
238
|
+
},
|
239
|
+
num_stages=1,
|
240
|
+
num_warps=8,
|
241
|
+
),
|
242
|
+
triton.Config(
|
243
|
+
{
|
244
|
+
"BLOCK_M": 128,
|
245
|
+
"BLOCK_N": 64,
|
246
|
+
"waves_per_eu": 3,
|
247
|
+
"PRE_LOAD_V": True,
|
248
|
+
},
|
249
|
+
num_stages=1,
|
250
|
+
num_warps=4,
|
251
|
+
),
|
252
|
+
triton.Config(
|
253
|
+
{
|
254
|
+
"BLOCK_M": 128,
|
255
|
+
"BLOCK_N": 64,
|
256
|
+
"waves_per_eu": 3,
|
257
|
+
"PRE_LOAD_V": False,
|
258
|
+
},
|
259
|
+
num_stages=1,
|
260
|
+
num_warps=4,
|
261
|
+
),
|
262
|
+
triton.Config(
|
263
|
+
{
|
264
|
+
"BLOCK_M": 64,
|
265
|
+
"BLOCK_N": 64,
|
266
|
+
"waves_per_eu": 4,
|
267
|
+
"PRE_LOAD_V": False,
|
268
|
+
},
|
269
|
+
num_stages=1,
|
270
|
+
num_warps=8,
|
271
|
+
),
|
272
|
+
triton.Config(
|
273
|
+
{
|
274
|
+
"BLOCK_M": 32,
|
275
|
+
"BLOCK_N": 32,
|
276
|
+
"waves_per_eu": 4,
|
277
|
+
"PRE_LOAD_V": False,
|
278
|
+
},
|
279
|
+
num_stages=1,
|
280
|
+
num_warps=8,
|
281
|
+
),
|
282
|
+
# TODO: This config fails with head_size not pow2 with data mismatches.
|
283
|
+
# triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1,
|
284
|
+
# 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
|
285
|
+
triton.Config(
|
286
|
+
{
|
287
|
+
"BLOCK_M": 16,
|
288
|
+
"BLOCK_N": 16,
|
289
|
+
"waves_per_eu": 1,
|
290
|
+
"PRE_LOAD_V": False,
|
291
|
+
},
|
292
|
+
num_stages=1,
|
293
|
+
num_warps=4,
|
294
|
+
),
|
295
|
+
],
|
296
|
+
key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'],
|
297
|
+
)
|
298
|
+
@triton.jit
|
299
|
+
def attn_fwd(
|
300
|
+
Q,
|
301
|
+
K,
|
302
|
+
V,
|
303
|
+
bias,
|
304
|
+
sm_scale,
|
305
|
+
L,
|
306
|
+
Out,
|
307
|
+
stride_qz,
|
308
|
+
stride_qh,
|
309
|
+
stride_qm,
|
310
|
+
stride_qk,
|
311
|
+
stride_kz,
|
312
|
+
stride_kh,
|
313
|
+
stride_kn,
|
314
|
+
stride_kk,
|
315
|
+
stride_vz,
|
316
|
+
stride_vh,
|
317
|
+
stride_vk,
|
318
|
+
stride_vn,
|
319
|
+
stride_oz,
|
320
|
+
stride_oh,
|
321
|
+
stride_om,
|
322
|
+
stride_on,
|
323
|
+
stride_bz,
|
324
|
+
stride_bh,
|
325
|
+
stride_bm,
|
326
|
+
stride_bn,
|
327
|
+
cu_seqlens_q,
|
328
|
+
cu_seqlens_k,
|
329
|
+
dropout_p,
|
330
|
+
philox_seed,
|
331
|
+
philox_offset_base,
|
332
|
+
encoded_softmax,
|
333
|
+
HQ: tl.constexpr,
|
334
|
+
HK: tl.constexpr,
|
335
|
+
ACTUAL_BLOCK_DMODEL: tl.constexpr,
|
336
|
+
MAX_SEQLENS_Q: tl.constexpr,
|
337
|
+
MAX_SEQLENS_K: tl.constexpr,
|
338
|
+
VARLEN: tl.constexpr,
|
339
|
+
IS_CAUSAL: tl.constexpr,
|
340
|
+
BLOCK_M: tl.constexpr,
|
341
|
+
BLOCK_DMODEL: tl.constexpr,
|
342
|
+
BLOCK_N: tl.constexpr,
|
343
|
+
PRE_LOAD_V: tl.constexpr,
|
344
|
+
BIAS_TYPE: tl.constexpr,
|
345
|
+
ENABLE_DROPOUT: tl.constexpr,
|
346
|
+
RETURN_ENCODED_SOFTMAX: tl.constexpr,
|
347
|
+
):
|
348
|
+
start_m = tl.program_id(0)
|
349
|
+
off_h_q = tl.program_id(1)
|
350
|
+
off_z = tl.program_id(2)
|
351
|
+
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
352
|
+
offs_n = tl.arange(0, BLOCK_N)
|
353
|
+
if VARLEN:
|
354
|
+
cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
|
355
|
+
cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
|
356
|
+
seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start
|
357
|
+
# We have a one-size-fits-all grid in id(0). Some seqlens might be too
|
358
|
+
# small for all start_m so for those we return early.
|
359
|
+
if start_m * BLOCK_M > seqlen_q:
|
360
|
+
return
|
361
|
+
cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)
|
362
|
+
cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)
|
363
|
+
seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start
|
364
|
+
else:
|
365
|
+
cu_seqlens_q_start = 0
|
366
|
+
cu_seqlens_k_start = 0
|
367
|
+
seqlen_q = MAX_SEQLENS_Q
|
368
|
+
seqlen_k = MAX_SEQLENS_K
|
369
|
+
|
370
|
+
# Now we compute whether we need to exit early due to causal masking.
|
371
|
+
# This is because for seqlen_q > seqlen_k, M rows of the attn scores
|
372
|
+
# are completely masked, resulting in 0s written to the output, and
|
373
|
+
# inf written to LSE. We don't need to do any GEMMs in this case.
|
374
|
+
# This block of code determines what N is, and if this WG is operating
|
375
|
+
# on those M rows.
|
376
|
+
n_blocks = cdiv_fn(seqlen_k, BLOCK_N)
|
377
|
+
if IS_CAUSAL:
|
378
|
+
# If seqlen_q == seqlen_k, the attn scores are a square matrix.
|
379
|
+
# If seqlen_q != seqlen_k, attn scores are rectangular which means
|
380
|
+
# the causal mask boundary is bottom right aligned, and ends at either
|
381
|
+
# the top edge (seqlen_q < seqlen_k) or left edge.
|
382
|
+
# This captures the decrease in n_blocks if we have a rectangular attn
|
383
|
+
# matrix
|
384
|
+
n_blocks_seqlen = cdiv_fn(
|
385
|
+
(start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N)
|
386
|
+
# This is what adjusts the block_max for the current WG, only
|
387
|
+
# if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks
|
388
|
+
n_blocks = min(n_blocks, n_blocks_seqlen)
|
389
|
+
# If we have no blocks after adjusting for seqlen deltas, this WG is
|
390
|
+
# part of the blocks that are all 0. We exit early.
|
391
|
+
if n_blocks <= 0:
|
392
|
+
o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om +
|
393
|
+
off_h_q * stride_oh)
|
394
|
+
O_block_ptr = tl.make_block_ptr(
|
395
|
+
base=Out + o_offset,
|
396
|
+
shape=(seqlen_q, BLOCK_DMODEL),
|
397
|
+
strides=(stride_om, stride_on),
|
398
|
+
offsets=(start_m * BLOCK_M, 0),
|
399
|
+
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
400
|
+
order=(1, 0),
|
401
|
+
)
|
402
|
+
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty)
|
403
|
+
# We still need to write 0s to the result
|
404
|
+
# tl.store(O_block_ptr,
|
405
|
+
# acc.to(Out.type.element_ty), boundary_check=(0,1))
|
406
|
+
# l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q
|
407
|
+
# + offs_m
|
408
|
+
# We store inf to LSE, not -inf because in the bwd pass,
|
409
|
+
# we subtract this
|
410
|
+
# from qk which makes it -inf, such that exp(qk - inf) = 0
|
411
|
+
# for these masked blocks.
|
412
|
+
# l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32)
|
413
|
+
# tl.store(l_ptrs, l)
|
414
|
+
# TODO: Should dropout and return encoded softmax be handled here?
|
415
|
+
return
|
416
|
+
|
417
|
+
# If MQA / GQA, set the K and V head offsets appropriately.
|
418
|
+
GROUP_SIZE: tl.constexpr = HQ // HK
|
419
|
+
off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q
|
420
|
+
|
421
|
+
n_extra_tokens = 0
|
422
|
+
if seqlen_k < BLOCK_N:
|
423
|
+
n_extra_tokens = BLOCK_N - seqlen_k
|
424
|
+
elif seqlen_k % BLOCK_N:
|
425
|
+
n_extra_tokens = seqlen_k % BLOCK_N
|
426
|
+
padded_head = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL
|
427
|
+
|
428
|
+
# Compute pointers for all the tensors used in this kernel.
|
429
|
+
q_offset = (off_z * stride_qz + off_h_q * stride_qh +
|
430
|
+
cu_seqlens_q_start * stride_qm)
|
431
|
+
Q_block_ptr = tl.make_block_ptr(
|
432
|
+
base=Q + q_offset,
|
433
|
+
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
|
434
|
+
strides=(stride_qm, stride_qk),
|
435
|
+
offsets=(start_m * BLOCK_M, 0),
|
436
|
+
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
437
|
+
order=(1, 0),
|
438
|
+
)
|
439
|
+
k_offset = (off_z * stride_kz + off_h_k * stride_kh +
|
440
|
+
cu_seqlens_k_start * stride_kn)
|
441
|
+
K_block_ptr = tl.make_block_ptr(
|
442
|
+
base=K + k_offset,
|
443
|
+
shape=(ACTUAL_BLOCK_DMODEL, seqlen_k),
|
444
|
+
strides=(stride_kk, stride_kn),
|
445
|
+
offsets=(0, 0),
|
446
|
+
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
447
|
+
order=(0, 1),
|
448
|
+
)
|
449
|
+
v_offset = (off_z * stride_vz + off_h_k * stride_vh +
|
450
|
+
cu_seqlens_k_start * stride_vk)
|
451
|
+
V_block_ptr = tl.make_block_ptr(
|
452
|
+
base=V + v_offset,
|
453
|
+
shape=(seqlen_k, ACTUAL_BLOCK_DMODEL),
|
454
|
+
strides=(stride_vk, stride_vn),
|
455
|
+
offsets=(0, 0),
|
456
|
+
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
457
|
+
order=(1, 0),
|
458
|
+
)
|
459
|
+
if BIAS_TYPE != 0:
|
460
|
+
bias_ptr = tl.make_block_ptr(
|
461
|
+
base=bias + off_h_q * stride_bh,
|
462
|
+
shape=(seqlen_q, seqlen_k),
|
463
|
+
strides=(stride_bm, stride_bn),
|
464
|
+
offsets=(start_m * BLOCK_M, 0),
|
465
|
+
block_shape=(BLOCK_M, BLOCK_N),
|
466
|
+
order=(1, 0),
|
467
|
+
)
|
468
|
+
else:
|
469
|
+
bias_ptr = None
|
470
|
+
if ENABLE_DROPOUT:
|
471
|
+
batch_philox_offset = philox_offset_base \
|
472
|
+
+ (off_z * HQ + off_h_q) \
|
473
|
+
* seqlen_q * seqlen_k
|
474
|
+
else:
|
475
|
+
batch_philox_offset = 0
|
476
|
+
# We can ask to return the dropout mask without actually doing any dropout.
|
477
|
+
# In this case, we return an invalid pointer so indicate the mask is not i
|
478
|
+
# valid.
|
479
|
+
# TODO: Fix encoded softmax. It currently uses just h_q in the base offset.
|
480
|
+
if RETURN_ENCODED_SOFTMAX:
|
481
|
+
encoded_softmax_block_ptr = tl.make_block_ptr(
|
482
|
+
base=encoded_softmax + off_h_q * seqlen_q * seqlen_k,
|
483
|
+
shape=(seqlen_q, seqlen_k),
|
484
|
+
strides=(seqlen_k, 1),
|
485
|
+
offsets=(start_m * BLOCK_M, 0),
|
486
|
+
block_shape=(BLOCK_M, BLOCK_N),
|
487
|
+
order=(1, 0),
|
488
|
+
)
|
489
|
+
else:
|
490
|
+
encoded_softmax_block_ptr = 0
|
491
|
+
# initialize pointer to m and l
|
492
|
+
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
493
|
+
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
|
494
|
+
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
495
|
+
# scale sm_scale by log_2(e) and use 2^x in the loop as we do not
|
496
|
+
# have native e^x support in HW.
|
497
|
+
qk_scale = sm_scale * 1.44269504089
|
498
|
+
# Q is loaded once at the beginning and shared by all N blocks.
|
499
|
+
q = load_fn(Q_block_ptr, True, padded_head, "zero")
|
500
|
+
q = (q * qk_scale).to(Q_block_ptr.type.element_ty)
|
501
|
+
|
502
|
+
# Here we compute how many full and masked blocks we have.
|
503
|
+
padded_block_k = n_extra_tokens != 0
|
504
|
+
is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0)
|
505
|
+
if IS_CAUSAL:
|
506
|
+
# There are always at least BLOCK_M // BLOCK_N masked blocks.
|
507
|
+
# Additionally there might be one more due to dissimilar seqlens.
|
508
|
+
masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn)
|
509
|
+
else:
|
510
|
+
# Padding on Q does not need to be masked in the FA loop.
|
511
|
+
masked_blocks = padded_block_k
|
512
|
+
# if IS_CAUSAL, not is_modulo_mn does not always result in an additional
|
513
|
+
# block. In this case we might exceed n_blocks so pick the min.
|
514
|
+
masked_blocks = min(masked_blocks, n_blocks)
|
515
|
+
n_full_blocks = n_blocks - masked_blocks
|
516
|
+
block_min = 0
|
517
|
+
block_max = n_blocks * BLOCK_N
|
518
|
+
# Compute for full blocks. Here we set causal to false regardless of its
|
519
|
+
# value because there is no masking. Similarly we do not need padding.
|
520
|
+
if n_full_blocks > 0:
|
521
|
+
block_max = (n_blocks - masked_blocks) * BLOCK_N
|
522
|
+
acc, l_i, m_i = _attn_fwd_inner(
|
523
|
+
acc,
|
524
|
+
l_i,
|
525
|
+
m_i,
|
526
|
+
q,
|
527
|
+
K_block_ptr,
|
528
|
+
V_block_ptr,
|
529
|
+
start_m,
|
530
|
+
seqlen_k,
|
531
|
+
dropout_p,
|
532
|
+
philox_seed,
|
533
|
+
batch_philox_offset,
|
534
|
+
encoded_softmax_block_ptr,
|
535
|
+
# _, _, offs_n_causal, masked_blocks, n_extra_tokens, _
|
536
|
+
block_min,
|
537
|
+
block_max,
|
538
|
+
0,
|
539
|
+
0,
|
540
|
+
0,
|
541
|
+
bias_ptr,
|
542
|
+
# IS_CAUSAL, ....
|
543
|
+
False,
|
544
|
+
BLOCK_M,
|
545
|
+
BLOCK_DMODEL,
|
546
|
+
BLOCK_N,
|
547
|
+
offs_m,
|
548
|
+
offs_n,
|
549
|
+
# _, MASK_STEPS, ...
|
550
|
+
PRE_LOAD_V,
|
551
|
+
False,
|
552
|
+
ENABLE_DROPOUT,
|
553
|
+
RETURN_ENCODED_SOFTMAX,
|
554
|
+
padded_head,
|
555
|
+
)
|
556
|
+
block_min = block_max
|
557
|
+
block_max = n_blocks * BLOCK_N
|
558
|
+
|
559
|
+
tl.debug_barrier()
|
560
|
+
# Remaining blocks, if any, are full / not masked.
|
561
|
+
if masked_blocks > 0:
|
562
|
+
offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0
|
563
|
+
K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N))
|
564
|
+
V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0))
|
565
|
+
if bias_ptr is not None:
|
566
|
+
bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N))
|
567
|
+
if RETURN_ENCODED_SOFTMAX:
|
568
|
+
encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr,
|
569
|
+
(0, n_full_blocks))
|
570
|
+
acc, l_i, m_i = _attn_fwd_inner(
|
571
|
+
acc,
|
572
|
+
l_i,
|
573
|
+
m_i,
|
574
|
+
q,
|
575
|
+
K_block_ptr,
|
576
|
+
V_block_ptr,
|
577
|
+
start_m,
|
578
|
+
seqlen_k,
|
579
|
+
dropout_p,
|
580
|
+
philox_seed,
|
581
|
+
batch_philox_offset,
|
582
|
+
encoded_softmax_block_ptr,
|
583
|
+
block_min,
|
584
|
+
block_max,
|
585
|
+
offs_n_causal,
|
586
|
+
masked_blocks,
|
587
|
+
n_extra_tokens,
|
588
|
+
bias_ptr,
|
589
|
+
IS_CAUSAL,
|
590
|
+
BLOCK_M,
|
591
|
+
BLOCK_DMODEL,
|
592
|
+
BLOCK_N,
|
593
|
+
offs_m,
|
594
|
+
offs_n,
|
595
|
+
# _, MASK_STEPS, ...
|
596
|
+
PRE_LOAD_V,
|
597
|
+
True,
|
598
|
+
ENABLE_DROPOUT,
|
599
|
+
RETURN_ENCODED_SOFTMAX,
|
600
|
+
padded_head,
|
601
|
+
)
|
602
|
+
# epilogue
|
603
|
+
acc = acc / l_i[:, None]
|
604
|
+
if ENABLE_DROPOUT:
|
605
|
+
acc = acc / (1 - dropout_p)
|
606
|
+
# If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M,
|
607
|
+
# then we have one block with a row of all NaNs which come from computing
|
608
|
+
# softmax over a row of all -infs (-inf - inf = NaN). We check for that here
|
609
|
+
# and store 0s where there are NaNs as these rows should've been zeroed out.
|
610
|
+
end_m_idx = (start_m + 1) * BLOCK_M
|
611
|
+
start_m_idx = start_m * BLOCK_M
|
612
|
+
causal_start_idx = seqlen_q - seqlen_k
|
613
|
+
acc = acc.to(Out.type.element_ty)
|
614
|
+
if IS_CAUSAL: # noqa: SIM102
|
615
|
+
if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx:
|
616
|
+
out_mask_boundary = tl.full((BLOCK_DMODEL, ),
|
617
|
+
causal_start_idx,
|
618
|
+
dtype=tl.int32)
|
619
|
+
mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)
|
620
|
+
out_ptrs_mask = (mask_m_offsets[:, None] >=
|
621
|
+
out_mask_boundary[None, :])
|
622
|
+
z = 0.0
|
623
|
+
acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
|
624
|
+
# write back LSE
|
625
|
+
# l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
|
626
|
+
# If seqlen_q not multiple of BLOCK_M, we need to mask out the last
|
627
|
+
# few rows. This is only true for the last M block. For others,
|
628
|
+
# overflow_size will be -ve
|
629
|
+
# overflow_size = end_m_idx - seqlen_q
|
630
|
+
# if overflow_size > 0:
|
631
|
+
# boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32)
|
632
|
+
# # This is a > check because mask being 0 blocks the store.
|
633
|
+
# l_ptrs_mask = boundary > tl.arange(0, BLOCK_M)
|
634
|
+
# tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask)
|
635
|
+
# else:
|
636
|
+
# tl.store(l_ptrs, m_i + tl.math.log2(l_i))
|
637
|
+
|
638
|
+
# write back O
|
639
|
+
o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om +
|
640
|
+
off_h_q * stride_oh)
|
641
|
+
O_block_ptr = tl.make_block_ptr(
|
642
|
+
base=Out + o_offset,
|
643
|
+
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
|
644
|
+
strides=(stride_om, stride_on),
|
645
|
+
offsets=(start_m * BLOCK_M, 0),
|
646
|
+
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
647
|
+
order=(1, 0),
|
648
|
+
)
|
649
|
+
# Need boundary check on this to make sure the padding from the
|
650
|
+
# Q and KV tensors in both dims are not part of what we store back.
|
651
|
+
# TODO: Do the boundary check optionally.
|
652
|
+
tl.store(O_block_ptr, acc, boundary_check=(0, 1))
|
653
|
+
|
654
|
+
|
655
|
+
def check_args(
|
656
|
+
q,
|
657
|
+
k,
|
658
|
+
v,
|
659
|
+
o,
|
660
|
+
varlen=True,
|
661
|
+
max_seqlens=None,
|
662
|
+
cu_seqlens_q=None,
|
663
|
+
cu_seqlens_k=None,
|
664
|
+
):
|
665
|
+
assert q.dim() == k.dim() and q.dim() == v.dim()
|
666
|
+
if varlen:
|
667
|
+
assert q.dim() == 3
|
668
|
+
total_q, nheads_q, head_size = q.shape
|
669
|
+
total_k, nheads_k, _ = k.shape
|
670
|
+
assert cu_seqlens_q is not None
|
671
|
+
assert cu_seqlens_k is not None
|
672
|
+
assert len(cu_seqlens_q) == len(cu_seqlens_k)
|
673
|
+
else:
|
674
|
+
assert q.dim() == 4
|
675
|
+
batch, nheads_q, seqlen_q, head_size = q.shape
|
676
|
+
_, nheads_k, seqlen_k, _ = k.shape
|
677
|
+
assert max_seqlens > 0
|
678
|
+
assert k.shape == v.shape
|
679
|
+
assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1]
|
680
|
+
# TODO: Change assert if we support qkl f8 and v f16
|
681
|
+
assert q.dtype == k.dtype and q.dtype == v.dtype
|
682
|
+
assert head_size <= 256
|
683
|
+
assert o.shape == q.shape
|
684
|
+
assert (nheads_q % nheads_k) == 0
|
685
|
+
|
686
|
+
|
687
|
+
class _attention(torch.autograd.Function):
|
688
|
+
|
689
|
+
@staticmethod
|
690
|
+
def forward(
|
691
|
+
ctx,
|
692
|
+
q,
|
693
|
+
k,
|
694
|
+
v,
|
695
|
+
o,
|
696
|
+
cu_seqlens_q,
|
697
|
+
cu_seqlens_k,
|
698
|
+
max_seqlens_q,
|
699
|
+
max_seqlens_k,
|
700
|
+
causal=False,
|
701
|
+
sm_scale=1.0,
|
702
|
+
bias=None,
|
703
|
+
):
|
704
|
+
if o is None:
|
705
|
+
o = torch.empty_like(q, dtype=v.dtype)
|
706
|
+
|
707
|
+
check_args(
|
708
|
+
q,
|
709
|
+
k,
|
710
|
+
v,
|
711
|
+
o,
|
712
|
+
varlen=True,
|
713
|
+
cu_seqlens_q=cu_seqlens_q,
|
714
|
+
cu_seqlens_k=cu_seqlens_k,
|
715
|
+
)
|
716
|
+
if True: # varlen
|
717
|
+
total_q, nheads_q, head_size = q.shape
|
718
|
+
total_k, nheads_k, _ = k.shape
|
719
|
+
batch = len(cu_seqlens_q) - 1
|
720
|
+
q_strides = (0, q.stride(1), q.stride(0), q.stride(2))
|
721
|
+
k_strides = (0, k.stride(1), k.stride(0), k.stride(2))
|
722
|
+
v_strides = (0, v.stride(1), v.stride(0), v.stride(2))
|
723
|
+
o_strides = (0, o.stride(1), o.stride(0), o.stride(2))
|
724
|
+
else:
|
725
|
+
batch, seqlen_q, nheads_q, head_size = q.shape
|
726
|
+
_, seqlen_k, nheads_k, _ = k.shape
|
727
|
+
q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3))
|
728
|
+
k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3))
|
729
|
+
v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3))
|
730
|
+
o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3))
|
731
|
+
|
732
|
+
# Get closest power of 2 over or equal to 32.
|
733
|
+
unpadded_head_dims = {32, 64, 128, 256}
|
734
|
+
if head_size not in unpadded_head_dims:
|
735
|
+
padded_d_model = None
|
736
|
+
for i in unpadded_head_dims:
|
737
|
+
if i > head_size:
|
738
|
+
padded_d_model = i
|
739
|
+
break
|
740
|
+
assert padded_d_model is not None
|
741
|
+
else:
|
742
|
+
padded_d_model = head_size
|
743
|
+
|
744
|
+
grid = lambda META: (
|
745
|
+
triton.cdiv(max_seqlens_q, META["BLOCK_M"]),
|
746
|
+
nheads_q,
|
747
|
+
batch,
|
748
|
+
)
|
749
|
+
|
750
|
+
encoded_softmax = None
|
751
|
+
|
752
|
+
# Seed the RNG so we get reproducible results for testing.
|
753
|
+
philox_seed = 0x1BF52
|
754
|
+
philox_offset = 0x1D4B42
|
755
|
+
|
756
|
+
if bias is not None:
|
757
|
+
bias_strides = (
|
758
|
+
bias.stride(0),
|
759
|
+
bias.stride(1),
|
760
|
+
bias.stride(2),
|
761
|
+
bias.stride(3),
|
762
|
+
)
|
763
|
+
else:
|
764
|
+
bias_strides = (0, 0, 0, 0)
|
765
|
+
|
766
|
+
attn_fwd[grid](
|
767
|
+
q,
|
768
|
+
k,
|
769
|
+
v,
|
770
|
+
bias,
|
771
|
+
sm_scale,
|
772
|
+
None,
|
773
|
+
o,
|
774
|
+
*q_strides,
|
775
|
+
*k_strides,
|
776
|
+
*v_strides,
|
777
|
+
*o_strides,
|
778
|
+
*bias_strides,
|
779
|
+
cu_seqlens_q,
|
780
|
+
cu_seqlens_k,
|
781
|
+
dropout_p=0.0,
|
782
|
+
philox_seed=philox_seed,
|
783
|
+
philox_offset_base=philox_offset,
|
784
|
+
encoded_softmax=encoded_softmax,
|
785
|
+
HQ=nheads_q,
|
786
|
+
HK=nheads_k,
|
787
|
+
ACTUAL_BLOCK_DMODEL=head_size,
|
788
|
+
MAX_SEQLENS_Q=max_seqlens_q,
|
789
|
+
MAX_SEQLENS_K=max_seqlens_k,
|
790
|
+
IS_CAUSAL=causal,
|
791
|
+
VARLEN=True,
|
792
|
+
BLOCK_DMODEL=padded_d_model,
|
793
|
+
BIAS_TYPE=0 if bias is None else 1,
|
794
|
+
ENABLE_DROPOUT=False,
|
795
|
+
RETURN_ENCODED_SOFTMAX=False,
|
796
|
+
)
|
797
|
+
|
798
|
+
ctx.grid = grid
|
799
|
+
ctx.sm_scale = sm_scale
|
800
|
+
ctx.BLOCK_DMODEL = head_size
|
801
|
+
ctx.causal = causal
|
802
|
+
ctx.dropout_p = 0.0
|
803
|
+
ctx.philox_seed = philox_seed
|
804
|
+
ctx.philox_offset = philox_offset
|
805
|
+
ctx.encoded_softmax = encoded_softmax
|
806
|
+
ctx.return_encoded_softmax = False
|
807
|
+
return o, encoded_softmax
|
808
|
+
|
809
|
+
|
810
|
+
triton_attention = _attention.apply
|