vllm-npu 0.4.2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- vllm/__init__.py +23 -0
- vllm/_custom_ops.py +251 -0
- vllm/attention/__init__.py +13 -0
- vllm/attention/backends/__init__.py +0 -0
- vllm/attention/backends/abstract.py +127 -0
- vllm/attention/backends/flash_attn.py +271 -0
- vllm/attention/backends/flashinfer.py +220 -0
- vllm/attention/backends/rocm_flash_attn.py +374 -0
- vllm/attention/backends/torch_sdpa.py +250 -0
- vllm/attention/backends/xformers.py +393 -0
- vllm/attention/layer.py +56 -0
- vllm/attention/ops/__init__.py +0 -0
- vllm/attention/ops/paged_attn.py +216 -0
- vllm/attention/ops/prefix_prefill.py +792 -0
- vllm/attention/ops/triton_flash_attention.py +810 -0
- vllm/attention/selector.py +91 -0
- vllm/block.py +84 -0
- vllm/config.py +1225 -0
- vllm/core/__init__.py +0 -0
- vllm/core/block/__init__.py +0 -0
- vllm/core/block/block_table.py +295 -0
- vllm/core/block/common.py +199 -0
- vllm/core/block/cpu_gpu_block_allocator.py +228 -0
- vllm/core/block/interfaces.py +205 -0
- vllm/core/block/naive_block.py +318 -0
- vllm/core/block/prefix_caching_block.py +606 -0
- vllm/core/block_manager_v1.py +625 -0
- vllm/core/block_manager_v2.py +258 -0
- vllm/core/evictor_v1.py +105 -0
- vllm/core/evictor_v2.py +127 -0
- vllm/core/interfaces.py +113 -0
- vllm/core/policy.py +45 -0
- vllm/core/scheduler.py +1163 -0
- vllm/distributed/__init__.py +3 -0
- vllm/distributed/communication_op.py +237 -0
- vllm/distributed/device_communicators/__init__.py +0 -0
- vllm/distributed/device_communicators/custom_all_reduce.py +274 -0
- vllm/distributed/device_communicators/pynccl.py +287 -0
- vllm/distributed/device_communicators/pynccl_utils.py +66 -0
- vllm/distributed/parallel_state.py +339 -0
- vllm/distributed/utils.py +136 -0
- vllm/engine/__init__.py +0 -0
- vllm/engine/arg_utils.py +649 -0
- vllm/engine/async_llm_engine.py +737 -0
- vllm/engine/llm_engine.py +784 -0
- vllm/engine/metrics.py +368 -0
- vllm/engine/output_processor/__init__.py +0 -0
- vllm/engine/output_processor/interfaces.py +76 -0
- vllm/engine/output_processor/multi_step.py +142 -0
- vllm/engine/output_processor/single_step.py +284 -0
- vllm/engine/output_processor/stop_checker.py +101 -0
- vllm/engine/output_processor/util.py +19 -0
- vllm/entrypoints/__init__.py +0 -0
- vllm/entrypoints/api_server.py +119 -0
- vllm/entrypoints/llm.py +259 -0
- vllm/entrypoints/openai/__init__.py +0 -0
- vllm/entrypoints/openai/api_server.py +186 -0
- vllm/entrypoints/openai/cli_args.py +115 -0
- vllm/entrypoints/openai/protocol.py +460 -0
- vllm/entrypoints/openai/serving_chat.py +392 -0
- vllm/entrypoints/openai/serving_completion.py +347 -0
- vllm/entrypoints/openai/serving_engine.py +234 -0
- vllm/envs.py +217 -0
- vllm/executor/__init__.py +0 -0
- vllm/executor/cpu_executor.py +152 -0
- vllm/executor/distributed_gpu_executor.py +115 -0
- vllm/executor/executor_base.py +115 -0
- vllm/executor/gpu_executor.py +150 -0
- vllm/executor/multiproc_worker_utils.py +263 -0
- vllm/executor/neuron_executor.py +91 -0
- vllm/executor/ray_gpu_executor.py +327 -0
- vllm/executor/ray_utils.py +119 -0
- vllm/logger.py +153 -0
- vllm/logging/__init__.py +5 -0
- vllm/logging/formatter.py +15 -0
- vllm/lora/__init__.py +0 -0
- vllm/lora/fully_sharded_layers.py +262 -0
- vllm/lora/layers.py +1181 -0
- vllm/lora/lora.py +167 -0
- vllm/lora/models.py +645 -0
- vllm/lora/punica.py +213 -0
- vllm/lora/request.py +32 -0
- vllm/lora/utils.py +98 -0
- vllm/lora/worker_manager.py +251 -0
- vllm/model_executor/__init__.py +7 -0
- vllm/model_executor/guided_decoding/__init__.py +25 -0
- vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +70 -0
- vllm/model_executor/guided_decoding/outlines_decoding.py +130 -0
- vllm/model_executor/guided_decoding/outlines_logits_processors.py +184 -0
- vllm/model_executor/layers/__init__.py +0 -0
- vllm/model_executor/layers/activation.py +173 -0
- vllm/model_executor/layers/fused_moe/__init__.py +7 -0
- vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +140 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/fused_moe.py +479 -0
- vllm/model_executor/layers/layernorm.py +71 -0
- vllm/model_executor/layers/linear.py +709 -0
- vllm/model_executor/layers/logits_processor.py +115 -0
- vllm/model_executor/layers/ops/__init__.py +0 -0
- vllm/model_executor/layers/ops/rand.py +157 -0
- vllm/model_executor/layers/ops/sample.py +406 -0
- vllm/model_executor/layers/quantization/__init__.py +35 -0
- vllm/model_executor/layers/quantization/aqlm.py +376 -0
- vllm/model_executor/layers/quantization/awq.py +175 -0
- vllm/model_executor/layers/quantization/base_config.py +97 -0
- vllm/model_executor/layers/quantization/fp8.py +265 -0
- vllm/model_executor/layers/quantization/gptq.py +224 -0
- vllm/model_executor/layers/quantization/gptq_marlin.py +438 -0
- vllm/model_executor/layers/quantization/marlin.py +227 -0
- vllm/model_executor/layers/quantization/schema.py +84 -0
- vllm/model_executor/layers/quantization/squeezellm.py +137 -0
- vllm/model_executor/layers/rejection_sampler.py +405 -0
- vllm/model_executor/layers/rotary_embedding.py +525 -0
- vllm/model_executor/layers/sampler.py +1051 -0
- vllm/model_executor/layers/vocab_parallel_embedding.py +155 -0
- vllm/model_executor/model_loader/__init__.py +30 -0
- vllm/model_executor/model_loader/loader.py +362 -0
- vllm/model_executor/model_loader/neuron.py +136 -0
- vllm/model_executor/model_loader/tensorizer.py +368 -0
- vllm/model_executor/model_loader/utils.py +41 -0
- vllm/model_executor/model_loader/weight_utils.py +372 -0
- vllm/model_executor/models/__init__.py +119 -0
- vllm/model_executor/models/baichuan.py +410 -0
- vllm/model_executor/models/bloom.py +327 -0
- vllm/model_executor/models/chatglm.py +386 -0
- vllm/model_executor/models/commandr.py +373 -0
- vllm/model_executor/models/dbrx.py +413 -0
- vllm/model_executor/models/decilm.py +122 -0
- vllm/model_executor/models/deepseek.py +438 -0
- vllm/model_executor/models/falcon.py +444 -0
- vllm/model_executor/models/gemma.py +393 -0
- vllm/model_executor/models/gpt2.py +266 -0
- vllm/model_executor/models/gpt_bigcode.py +274 -0
- vllm/model_executor/models/gpt_j.py +281 -0
- vllm/model_executor/models/gpt_neox.py +295 -0
- vllm/model_executor/models/internlm2.py +323 -0
- vllm/model_executor/models/jais.py +333 -0
- vllm/model_executor/models/llama.py +442 -0
- vllm/model_executor/models/llava.py +239 -0
- vllm/model_executor/models/minicpm.py +531 -0
- vllm/model_executor/models/mixtral.py +583 -0
- vllm/model_executor/models/mixtral_quant.py +404 -0
- vllm/model_executor/models/mpt.py +295 -0
- vllm/model_executor/models/olmo.py +356 -0
- vllm/model_executor/models/opt.py +349 -0
- vllm/model_executor/models/orion.py +319 -0
- vllm/model_executor/models/phi.py +300 -0
- vllm/model_executor/models/qwen.py +284 -0
- vllm/model_executor/models/qwen2.py +367 -0
- vllm/model_executor/models/qwen2_moe.py +447 -0
- vllm/model_executor/models/stablelm.py +301 -0
- vllm/model_executor/models/starcoder2.py +302 -0
- vllm/model_executor/models/xverse.py +366 -0
- vllm/model_executor/sampling_metadata.py +588 -0
- vllm/model_executor/utils.py +35 -0
- vllm/outputs.py +150 -0
- vllm/py.typed +2 -0
- vllm/sampling_params.py +340 -0
- vllm/sequence.py +766 -0
- vllm/spec_decode/__init__.py +0 -0
- vllm/spec_decode/batch_expansion.py +397 -0
- vllm/spec_decode/interfaces.py +73 -0
- vllm/spec_decode/metrics.py +191 -0
- vllm/spec_decode/multi_step_worker.py +203 -0
- vllm/spec_decode/ngram_worker.py +176 -0
- vllm/spec_decode/spec_decode_worker.py +472 -0
- vllm/spec_decode/top1_proposer.py +200 -0
- vllm/spec_decode/util.py +228 -0
- vllm/test_utils.py +41 -0
- vllm/transformers_utils/__init__.py +0 -0
- vllm/transformers_utils/config.py +58 -0
- vllm/transformers_utils/configs/__init__.py +16 -0
- vllm/transformers_utils/configs/chatglm.py +68 -0
- vllm/transformers_utils/configs/dbrx.py +278 -0
- vllm/transformers_utils/configs/falcon.py +87 -0
- vllm/transformers_utils/configs/jais.py +236 -0
- vllm/transformers_utils/configs/mpt.py +178 -0
- vllm/transformers_utils/detokenizer.py +313 -0
- vllm/transformers_utils/tokenizer.py +149 -0
- vllm/transformers_utils/tokenizer_group/__init__.py +33 -0
- vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py +55 -0
- vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +169 -0
- vllm/transformers_utils/tokenizer_group/tokenizer_group.py +78 -0
- vllm/transformers_utils/tokenizers/__init__.py +5 -0
- vllm/transformers_utils/tokenizers/baichuan.py +255 -0
- vllm/usage/__init__.py +0 -0
- vllm/usage/usage_lib.py +209 -0
- vllm/utils.py +677 -0
- vllm/worker/__init__.py +0 -0
- vllm/worker/cache_engine.py +105 -0
- vllm/worker/cpu_model_runner.py +346 -0
- vllm/worker/cpu_worker.py +321 -0
- vllm/worker/model_runner.py +1168 -0
- vllm/worker/neuron_model_runner.py +196 -0
- vllm/worker/neuron_worker.py +98 -0
- vllm/worker/worker.py +345 -0
- vllm/worker/worker_base.py +146 -0
- vllm_npu-0.4.2.dist-info/LICENSE +201 -0
- vllm_npu-0.4.2.dist-info/METADATA +173 -0
- vllm_npu-0.4.2.dist-info/RECORD +219 -0
- vllm_npu-0.4.2.dist-info/WHEEL +5 -0
- vllm_npu-0.4.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,792 @@
|
|
1
|
+
# The kernels in this file are adapted from LightLLM's context_attention_fwd:
|
2
|
+
# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
|
3
|
+
|
4
|
+
import torch
|
5
|
+
import triton
|
6
|
+
import triton.language as tl
|
7
|
+
|
8
|
+
if triton.__version__ >= "2.1.0":
|
9
|
+
|
10
|
+
@triton.jit
|
11
|
+
def _fwd_kernel(
|
12
|
+
Q,
|
13
|
+
K,
|
14
|
+
V,
|
15
|
+
K_cache,
|
16
|
+
V_cache,
|
17
|
+
B_Loc,
|
18
|
+
sm_scale,
|
19
|
+
B_Start_Loc,
|
20
|
+
B_Seqlen,
|
21
|
+
B_Ctxlen,
|
22
|
+
block_size,
|
23
|
+
x,
|
24
|
+
Out,
|
25
|
+
stride_b_loc_b,
|
26
|
+
stride_b_loc_s,
|
27
|
+
stride_qbs,
|
28
|
+
stride_qh,
|
29
|
+
stride_qd,
|
30
|
+
stride_kbs,
|
31
|
+
stride_kh,
|
32
|
+
stride_kd,
|
33
|
+
stride_vbs,
|
34
|
+
stride_vh,
|
35
|
+
stride_vd,
|
36
|
+
stride_obs,
|
37
|
+
stride_oh,
|
38
|
+
stride_od,
|
39
|
+
stride_k_cache_bs,
|
40
|
+
stride_k_cache_h,
|
41
|
+
stride_k_cache_d,
|
42
|
+
stride_k_cache_bl,
|
43
|
+
stride_k_cache_x,
|
44
|
+
stride_v_cache_bs,
|
45
|
+
stride_v_cache_h,
|
46
|
+
stride_v_cache_d,
|
47
|
+
stride_v_cache_bl,
|
48
|
+
num_queries_per_kv: int,
|
49
|
+
BLOCK_M: tl.constexpr,
|
50
|
+
BLOCK_DMODEL: tl.constexpr, # head size
|
51
|
+
BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2
|
52
|
+
BLOCK_N: tl.constexpr,
|
53
|
+
SLIDING_WINDOW: tl.constexpr,
|
54
|
+
):
|
55
|
+
cur_batch = tl.program_id(0)
|
56
|
+
cur_head = tl.program_id(1)
|
57
|
+
start_m = tl.program_id(2)
|
58
|
+
|
59
|
+
cur_kv_head = cur_head // num_queries_per_kv
|
60
|
+
|
61
|
+
cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
|
62
|
+
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
63
|
+
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
64
|
+
cur_batch_query_len = cur_batch_seq_len - cur_batch_ctx_len
|
65
|
+
|
66
|
+
# start position inside of the query
|
67
|
+
# generally, N goes over kv, while M goes over query_len
|
68
|
+
block_start_loc = BLOCK_M * start_m
|
69
|
+
|
70
|
+
# initialize offsets
|
71
|
+
# [N]; starts at 0
|
72
|
+
offs_n = tl.arange(0, BLOCK_N)
|
73
|
+
# [D]; starts at 0
|
74
|
+
offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
|
75
|
+
# [M]; starts at current position in query
|
76
|
+
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
77
|
+
# [M,D]
|
78
|
+
off_q = (
|
79
|
+
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
|
80
|
+
cur_head * stride_qh + offs_d[None, :] * stride_qd)
|
81
|
+
|
82
|
+
dim_mask = tl.where(
|
83
|
+
tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1,
|
84
|
+
0).to(tl.int1) # [D]
|
85
|
+
|
86
|
+
q = tl.load(Q + off_q,
|
87
|
+
mask=dim_mask[None, :] &
|
88
|
+
(offs_m[:, None] < cur_batch_query_len),
|
89
|
+
other=0.0) # [M,D]
|
90
|
+
|
91
|
+
# initialize pointer to m and l
|
92
|
+
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # [M]
|
93
|
+
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # [M]
|
94
|
+
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED],
|
95
|
+
dtype=tl.float32) # [M,D]
|
96
|
+
|
97
|
+
# compute query against context (no causal mask here)
|
98
|
+
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
|
99
|
+
start_n = tl.multiple_of(start_n, BLOCK_N)
|
100
|
+
# -- compute qk ----
|
101
|
+
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
|
102
|
+
((start_n + offs_n) // block_size) * stride_b_loc_s,
|
103
|
+
mask=(start_n + offs_n) < cur_batch_ctx_len,
|
104
|
+
other=0) # [N]
|
105
|
+
# [D,N]
|
106
|
+
off_k = (bn[None, :] * stride_k_cache_bs +
|
107
|
+
cur_kv_head * stride_k_cache_h +
|
108
|
+
(offs_d[:, None] // x) * stride_k_cache_d +
|
109
|
+
((start_n + offs_n[None, :]) % block_size) *
|
110
|
+
stride_k_cache_bl +
|
111
|
+
(offs_d[:, None] % x) * stride_k_cache_x)
|
112
|
+
# [N,D]
|
113
|
+
off_v = (
|
114
|
+
bn[:, None] * stride_v_cache_bs +
|
115
|
+
cur_kv_head * stride_v_cache_h +
|
116
|
+
offs_d[None, :] * stride_v_cache_d +
|
117
|
+
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
|
118
|
+
k = tl.load(K_cache + off_k,
|
119
|
+
mask=dim_mask[:, None] &
|
120
|
+
((start_n + offs_n[None, :]) < cur_batch_ctx_len),
|
121
|
+
other=0.0) # [D,N]
|
122
|
+
|
123
|
+
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # [M,N]
|
124
|
+
qk += tl.dot(q, k)
|
125
|
+
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
|
126
|
+
float("-inf"))
|
127
|
+
qk *= sm_scale
|
128
|
+
if SLIDING_WINDOW > 0:
|
129
|
+
# (cur_batch_ctx_len + offs_m[:, None]) are the positions of
|
130
|
+
# Q entries in sequence
|
131
|
+
# (start_n + offs_n[None, :]) are the positions of
|
132
|
+
# KV entries in sequence
|
133
|
+
# So the condition makes sure each entry in Q only attends
|
134
|
+
# to KV entries not more than SLIDING_WINDOW away.
|
135
|
+
#
|
136
|
+
# We can't use -inf here, because the
|
137
|
+
# sliding window may lead to the entire row being masked.
|
138
|
+
# This then makes m_ij contain -inf, which causes NaNs in
|
139
|
+
# exp().
|
140
|
+
qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) -
|
141
|
+
(start_n + offs_n[None, :]) < SLIDING_WINDOW, qk,
|
142
|
+
-10000)
|
143
|
+
|
144
|
+
# -- compute m_ij, p, l_ij
|
145
|
+
m_ij = tl.max(qk, 1) # [M]
|
146
|
+
p = tl.exp(qk - m_ij[:, None]) # [M,N]
|
147
|
+
l_ij = tl.sum(p, 1) # [M]
|
148
|
+
# -- update m_i and l_i
|
149
|
+
m_i_new = tl.maximum(m_i, m_ij) # [M]
|
150
|
+
alpha = tl.exp(m_i - m_i_new) # [M]
|
151
|
+
beta = tl.exp(m_ij - m_i_new) # [M]
|
152
|
+
l_i_new = alpha * l_i + beta * l_ij # [M]
|
153
|
+
|
154
|
+
# -- update output accumulator --
|
155
|
+
# scale p
|
156
|
+
p_scale = beta / l_i_new
|
157
|
+
p = p * p_scale[:, None]
|
158
|
+
# scale acc
|
159
|
+
acc_scale = l_i / l_i_new * alpha
|
160
|
+
acc = acc * acc_scale[:, None]
|
161
|
+
# update acc
|
162
|
+
v = tl.load(V_cache + off_v,
|
163
|
+
mask=dim_mask[None, :] &
|
164
|
+
((start_n + offs_n[:, None]) < cur_batch_ctx_len),
|
165
|
+
other=0.0) # [N,D]
|
166
|
+
|
167
|
+
p = p.to(v.dtype)
|
168
|
+
acc += tl.dot(p, v)
|
169
|
+
# # update m_i and l_i
|
170
|
+
l_i = l_i_new
|
171
|
+
m_i = m_i_new
|
172
|
+
|
173
|
+
off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
|
174
|
+
offs_d[:, None] * stride_kd)
|
175
|
+
off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
|
176
|
+
offs_d[None, :] * stride_vd)
|
177
|
+
k_ptrs = K + off_k
|
178
|
+
v_ptrs = V + off_v
|
179
|
+
|
180
|
+
# block_mask is 0 when we're already past the current query length
|
181
|
+
block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0)
|
182
|
+
|
183
|
+
# compute query against itself (with causal mask)
|
184
|
+
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
|
185
|
+
start_n = tl.multiple_of(start_n, BLOCK_N)
|
186
|
+
# -- compute qk ----
|
187
|
+
k = tl.load(k_ptrs +
|
188
|
+
(cur_batch_in_all_start_index + start_n) * stride_kbs,
|
189
|
+
mask=dim_mask[:, None] &
|
190
|
+
((start_n + offs_n[None, :]) < cur_batch_query_len),
|
191
|
+
other=0.0)
|
192
|
+
|
193
|
+
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
194
|
+
qk += tl.dot(q, k)
|
195
|
+
qk *= sm_scale
|
196
|
+
# apply causal mask
|
197
|
+
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
|
198
|
+
float("-inf"))
|
199
|
+
if SLIDING_WINDOW > 0:
|
200
|
+
qk = tl.where(
|
201
|
+
offs_m[:, None] -
|
202
|
+
(start_n + offs_n[None, :]) < SLIDING_WINDOW, qk, -10000)
|
203
|
+
|
204
|
+
# -- compute m_ij, p, l_ij
|
205
|
+
m_ij = tl.max(qk, 1)
|
206
|
+
p = tl.exp(qk - m_ij[:, None])
|
207
|
+
l_ij = tl.sum(p, 1)
|
208
|
+
# -- update m_i and l_i
|
209
|
+
m_i_new = tl.maximum(m_i, m_ij)
|
210
|
+
alpha = tl.exp(m_i - m_i_new)
|
211
|
+
beta = tl.exp(m_ij - m_i_new)
|
212
|
+
l_i_new = alpha * l_i + beta * l_ij
|
213
|
+
# -- update output accumulator --
|
214
|
+
# scale p
|
215
|
+
p_scale = beta / l_i_new
|
216
|
+
p = p * p_scale[:, None]
|
217
|
+
# scale acc
|
218
|
+
acc_scale = l_i / l_i_new * alpha
|
219
|
+
acc = acc * acc_scale[:, None]
|
220
|
+
# update acc
|
221
|
+
v = tl.load(v_ptrs +
|
222
|
+
(cur_batch_in_all_start_index + start_n) * stride_vbs,
|
223
|
+
mask=dim_mask[None, :] &
|
224
|
+
((start_n + offs_n[:, None]) < cur_batch_query_len),
|
225
|
+
other=0.0)
|
226
|
+
|
227
|
+
p = p.to(v.dtype)
|
228
|
+
acc += tl.dot(p, v)
|
229
|
+
# update m_i and l_i
|
230
|
+
l_i = l_i_new
|
231
|
+
m_i = m_i_new
|
232
|
+
# initialize pointers to output
|
233
|
+
off_o = (
|
234
|
+
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
|
235
|
+
cur_head * stride_oh + offs_d[None, :] * stride_od)
|
236
|
+
out_ptrs = Out + off_o
|
237
|
+
tl.store(out_ptrs,
|
238
|
+
acc,
|
239
|
+
mask=dim_mask[None, :] &
|
240
|
+
(offs_m[:, None] < cur_batch_query_len))
|
241
|
+
return
|
242
|
+
|
243
|
+
@triton.jit
|
244
|
+
def _fwd_kernel_flash_attn_v2(
|
245
|
+
Q,
|
246
|
+
K,
|
247
|
+
V,
|
248
|
+
K_cache,
|
249
|
+
V_cache,
|
250
|
+
B_Loc,
|
251
|
+
sm_scale,
|
252
|
+
B_Start_Loc,
|
253
|
+
B_Seqlen,
|
254
|
+
B_Ctxlen,
|
255
|
+
block_size,
|
256
|
+
x,
|
257
|
+
Out,
|
258
|
+
stride_b_loc_b,
|
259
|
+
stride_b_loc_s,
|
260
|
+
stride_qbs,
|
261
|
+
stride_qh,
|
262
|
+
stride_qd,
|
263
|
+
stride_kbs,
|
264
|
+
stride_kh,
|
265
|
+
stride_kd,
|
266
|
+
stride_vbs,
|
267
|
+
stride_vh,
|
268
|
+
stride_vd,
|
269
|
+
stride_obs,
|
270
|
+
stride_oh,
|
271
|
+
stride_od,
|
272
|
+
stride_k_cache_bs,
|
273
|
+
stride_k_cache_h,
|
274
|
+
stride_k_cache_d,
|
275
|
+
stride_k_cache_bl,
|
276
|
+
stride_k_cache_x,
|
277
|
+
stride_v_cache_bs,
|
278
|
+
stride_v_cache_h,
|
279
|
+
stride_v_cache_d,
|
280
|
+
stride_v_cache_bl,
|
281
|
+
num_queries_per_kv: int,
|
282
|
+
BLOCK_M: tl.constexpr,
|
283
|
+
BLOCK_DMODEL: tl.constexpr,
|
284
|
+
BLOCK_N: tl.constexpr,
|
285
|
+
):
|
286
|
+
cur_batch = tl.program_id(0)
|
287
|
+
cur_head = tl.program_id(1)
|
288
|
+
start_m = tl.program_id(2)
|
289
|
+
|
290
|
+
cur_kv_head = cur_head // num_queries_per_kv
|
291
|
+
|
292
|
+
cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
|
293
|
+
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
294
|
+
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
295
|
+
|
296
|
+
block_start_loc = BLOCK_M * start_m
|
297
|
+
|
298
|
+
# initialize offsets
|
299
|
+
offs_n = tl.arange(0, BLOCK_N)
|
300
|
+
offs_d = tl.arange(0, BLOCK_DMODEL)
|
301
|
+
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
302
|
+
off_q = (
|
303
|
+
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
|
304
|
+
cur_head * stride_qh + offs_d[None, :] * stride_qd)
|
305
|
+
|
306
|
+
q = tl.load(
|
307
|
+
Q + off_q,
|
308
|
+
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
|
309
|
+
other=0.0)
|
310
|
+
|
311
|
+
# # initialize pointer to m and l
|
312
|
+
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
313
|
+
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
314
|
+
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
315
|
+
|
316
|
+
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
|
317
|
+
start_n = tl.multiple_of(start_n, BLOCK_N)
|
318
|
+
# -- compute qk ----
|
319
|
+
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
|
320
|
+
((start_n + offs_n) // block_size) * stride_b_loc_s,
|
321
|
+
mask=(start_n + offs_n) < cur_batch_ctx_len,
|
322
|
+
other=0)
|
323
|
+
off_k = (bn[None, :] * stride_k_cache_bs +
|
324
|
+
cur_kv_head * stride_k_cache_h +
|
325
|
+
(offs_d[:, None] // x) * stride_k_cache_d +
|
326
|
+
((start_n + offs_n[None, :]) % block_size) *
|
327
|
+
stride_k_cache_bl +
|
328
|
+
(offs_d[:, None] % x) * stride_k_cache_x)
|
329
|
+
off_v = (
|
330
|
+
bn[:, None] * stride_v_cache_bs +
|
331
|
+
cur_kv_head * stride_v_cache_h +
|
332
|
+
offs_d[None, :] * stride_v_cache_d +
|
333
|
+
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
|
334
|
+
k = tl.load(K_cache + off_k,
|
335
|
+
mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
|
336
|
+
other=0.0)
|
337
|
+
|
338
|
+
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
339
|
+
qk += tl.dot(q, k)
|
340
|
+
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
|
341
|
+
float("-inf"))
|
342
|
+
qk *= sm_scale
|
343
|
+
|
344
|
+
# -- compute m_ij, p, l_ij
|
345
|
+
m_ij = tl.max(qk, 1)
|
346
|
+
m_i_new = tl.maximum(m_i, m_ij)
|
347
|
+
p = tl.math.exp(qk - m_i_new[:, None])
|
348
|
+
l_ij = tl.sum(p, 1)
|
349
|
+
# -- update m_i and l_i
|
350
|
+
|
351
|
+
alpha = tl.math.exp(m_i - m_i_new)
|
352
|
+
l_i_new = alpha * l_i + l_ij
|
353
|
+
# -- update output accumulator --
|
354
|
+
# scale p
|
355
|
+
# scale acc
|
356
|
+
acc_scale = alpha
|
357
|
+
# acc_scale = l_i / l_i_new * alpha
|
358
|
+
acc = acc * acc_scale[:, None]
|
359
|
+
# update acc
|
360
|
+
v = tl.load(V_cache + off_v,
|
361
|
+
mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
|
362
|
+
other=0.0)
|
363
|
+
|
364
|
+
p = p.to(v.dtype)
|
365
|
+
acc += tl.dot(p, v)
|
366
|
+
# update m_i and l_i
|
367
|
+
l_i = l_i_new
|
368
|
+
m_i = m_i_new
|
369
|
+
|
370
|
+
off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
|
371
|
+
offs_d[:, None] * stride_kd)
|
372
|
+
off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
|
373
|
+
offs_d[None, :] * stride_vd)
|
374
|
+
k_ptrs = K + off_k
|
375
|
+
v_ptrs = V + off_v
|
376
|
+
|
377
|
+
block_mask = tl.where(
|
378
|
+
block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
|
379
|
+
|
380
|
+
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
|
381
|
+
start_n = tl.multiple_of(start_n, BLOCK_N)
|
382
|
+
# -- compute qk ----
|
383
|
+
k = tl.load(k_ptrs +
|
384
|
+
(cur_batch_in_all_start_index + start_n) * stride_kbs,
|
385
|
+
mask=(start_n + offs_n[None, :]) <
|
386
|
+
cur_batch_seq_len - cur_batch_ctx_len,
|
387
|
+
other=0.0)
|
388
|
+
|
389
|
+
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
390
|
+
qk += tl.dot(q, k)
|
391
|
+
qk *= sm_scale
|
392
|
+
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
|
393
|
+
float("-inf"))
|
394
|
+
|
395
|
+
# -- compute m_ij, p, l_ij
|
396
|
+
m_ij = tl.max(qk, 1)
|
397
|
+
m_i_new = tl.maximum(m_i, m_ij)
|
398
|
+
p = tl.math.exp(qk - m_i_new[:, None])
|
399
|
+
l_ij = tl.sum(p, 1)
|
400
|
+
# -- update m_i and l_i
|
401
|
+
|
402
|
+
alpha = tl.math.exp(m_i - m_i_new)
|
403
|
+
l_i_new = alpha * l_i + l_ij
|
404
|
+
# -- update output accumulator --
|
405
|
+
# scale p
|
406
|
+
# scale acc
|
407
|
+
acc_scale = alpha
|
408
|
+
# acc_scale = l_i / l_i_new * alpha
|
409
|
+
acc = acc * acc_scale[:, None]
|
410
|
+
# update acc
|
411
|
+
v = tl.load(v_ptrs +
|
412
|
+
(cur_batch_in_all_start_index + start_n) * stride_vbs,
|
413
|
+
mask=(start_n + offs_n[:, None]) <
|
414
|
+
cur_batch_seq_len - cur_batch_ctx_len,
|
415
|
+
other=0.0)
|
416
|
+
|
417
|
+
p = p.to(v.dtype)
|
418
|
+
acc += tl.dot(p, v)
|
419
|
+
# update m_i and l_i
|
420
|
+
l_i = l_i_new
|
421
|
+
m_i = m_i_new
|
422
|
+
|
423
|
+
# acc /= l_i[:, None]
|
424
|
+
# initialize pointers to output
|
425
|
+
off_o = (
|
426
|
+
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
|
427
|
+
cur_head * stride_oh + offs_d[None, :] * stride_od)
|
428
|
+
out_ptrs = Out + off_o
|
429
|
+
tl.store(out_ptrs,
|
430
|
+
acc,
|
431
|
+
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
|
432
|
+
return
|
433
|
+
|
434
|
+
@triton.jit
|
435
|
+
def _fwd_kernel_alibi(
|
436
|
+
Q,
|
437
|
+
K,
|
438
|
+
V,
|
439
|
+
K_cache,
|
440
|
+
V_cache,
|
441
|
+
B_Loc,
|
442
|
+
sm_scale,
|
443
|
+
B_Start_Loc,
|
444
|
+
B_Seqlen,
|
445
|
+
B_Ctxlen,
|
446
|
+
Alibi_slopes,
|
447
|
+
block_size,
|
448
|
+
x,
|
449
|
+
Out,
|
450
|
+
stride_b_loc_b,
|
451
|
+
stride_b_loc_s,
|
452
|
+
stride_qbs,
|
453
|
+
stride_qh,
|
454
|
+
stride_qd,
|
455
|
+
stride_kbs,
|
456
|
+
stride_kh,
|
457
|
+
stride_kd,
|
458
|
+
stride_vbs,
|
459
|
+
stride_vh,
|
460
|
+
stride_vd,
|
461
|
+
stride_obs,
|
462
|
+
stride_oh,
|
463
|
+
stride_od,
|
464
|
+
stride_k_cache_bs,
|
465
|
+
stride_k_cache_h,
|
466
|
+
stride_k_cache_d,
|
467
|
+
stride_k_cache_bl,
|
468
|
+
stride_k_cache_x,
|
469
|
+
stride_v_cache_bs,
|
470
|
+
stride_v_cache_h,
|
471
|
+
stride_v_cache_d,
|
472
|
+
stride_v_cache_bl,
|
473
|
+
num_queries_per_kv: int,
|
474
|
+
BLOCK_M: tl.constexpr,
|
475
|
+
BLOCK_DMODEL: tl.constexpr,
|
476
|
+
BLOCK_N: tl.constexpr,
|
477
|
+
):
|
478
|
+
# attn_bias[]
|
479
|
+
cur_batch = tl.program_id(0)
|
480
|
+
cur_head = tl.program_id(1)
|
481
|
+
start_m = tl.program_id(2)
|
482
|
+
|
483
|
+
cur_kv_head = cur_head // num_queries_per_kv
|
484
|
+
|
485
|
+
# cur_batch_seq_len: the length of prompts
|
486
|
+
# cur_batch_ctx_len: the length of prefix
|
487
|
+
# cur_batch_in_all_start_index: the start id of the dim=0
|
488
|
+
cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
|
489
|
+
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
490
|
+
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
491
|
+
|
492
|
+
block_start_loc = BLOCK_M * start_m
|
493
|
+
|
494
|
+
# initialize offsets
|
495
|
+
offs_n = tl.arange(0, BLOCK_N)
|
496
|
+
offs_d = tl.arange(0, BLOCK_DMODEL)
|
497
|
+
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
498
|
+
off_q = (
|
499
|
+
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
|
500
|
+
cur_head * stride_qh + offs_d[None, :] * stride_qd)
|
501
|
+
|
502
|
+
q = tl.load(
|
503
|
+
Q + off_q,
|
504
|
+
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
|
505
|
+
other=0.0)
|
506
|
+
|
507
|
+
# # initialize pointer to m and l
|
508
|
+
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
509
|
+
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
510
|
+
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
511
|
+
|
512
|
+
alibi_slope = tl.load(Alibi_slopes + cur_head)
|
513
|
+
alibi_start_q = tl.arange(
|
514
|
+
0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
|
515
|
+
alibi_start_k = 0
|
516
|
+
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
|
517
|
+
start_n = tl.multiple_of(start_n, BLOCK_N)
|
518
|
+
# -- compute qk ----
|
519
|
+
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
|
520
|
+
((start_n + offs_n) // block_size) * stride_b_loc_s,
|
521
|
+
mask=(start_n + offs_n) < cur_batch_ctx_len,
|
522
|
+
other=0)
|
523
|
+
off_k = (bn[None, :] * stride_k_cache_bs +
|
524
|
+
cur_kv_head * stride_k_cache_h +
|
525
|
+
(offs_d[:, None] // x) * stride_k_cache_d +
|
526
|
+
((start_n + offs_n[None, :]) % block_size) *
|
527
|
+
stride_k_cache_bl +
|
528
|
+
(offs_d[:, None] % x) * stride_k_cache_x)
|
529
|
+
off_v = (
|
530
|
+
bn[:, None] * stride_v_cache_bs +
|
531
|
+
cur_kv_head * stride_v_cache_h +
|
532
|
+
offs_d[None, :] * stride_v_cache_d +
|
533
|
+
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
|
534
|
+
k = tl.load(K_cache + off_k,
|
535
|
+
mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
|
536
|
+
other=0.0)
|
537
|
+
|
538
|
+
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
539
|
+
qk += tl.dot(q, k)
|
540
|
+
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
|
541
|
+
float("-inf"))
|
542
|
+
qk *= sm_scale
|
543
|
+
|
544
|
+
# load alibi
|
545
|
+
alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
|
546
|
+
alibi_start_q[:, None]) * alibi_slope
|
547
|
+
alibi = tl.where(
|
548
|
+
(alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len),
|
549
|
+
alibi, float("-inf"))
|
550
|
+
qk += alibi
|
551
|
+
alibi_start_k += BLOCK_N
|
552
|
+
|
553
|
+
# -- compute m_ij, p, l_ij
|
554
|
+
m_ij = tl.max(qk, 1)
|
555
|
+
m_i_new = tl.maximum(m_i, m_ij)
|
556
|
+
p = tl.math.exp(qk - m_i_new[:, None])
|
557
|
+
l_ij = tl.sum(p, 1)
|
558
|
+
# -- update m_i and l_i
|
559
|
+
|
560
|
+
alpha = tl.math.exp(m_i - m_i_new)
|
561
|
+
l_i_new = alpha * l_i + l_ij
|
562
|
+
# -- update output accumulator --
|
563
|
+
# scale p
|
564
|
+
# scale acc
|
565
|
+
acc_scale = alpha
|
566
|
+
# acc_scale = l_i / l_i_new * alpha
|
567
|
+
acc = acc * acc_scale[:, None]
|
568
|
+
# update acc
|
569
|
+
v = tl.load(V_cache + off_v,
|
570
|
+
mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
|
571
|
+
other=0.0)
|
572
|
+
|
573
|
+
p = p.to(v.dtype)
|
574
|
+
acc += tl.dot(p, v, allow_tf32=False)
|
575
|
+
# update m_i and l_i
|
576
|
+
l_i = l_i_new
|
577
|
+
m_i = m_i_new
|
578
|
+
|
579
|
+
off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
|
580
|
+
offs_d[:, None] * stride_kd)
|
581
|
+
off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
|
582
|
+
offs_d[None, :] * stride_vd)
|
583
|
+
k_ptrs = K + off_k
|
584
|
+
v_ptrs = V + off_v
|
585
|
+
|
586
|
+
block_mask = tl.where(
|
587
|
+
block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
|
588
|
+
|
589
|
+
# init alibi
|
590
|
+
alibi_slope = tl.load(Alibi_slopes + cur_head)
|
591
|
+
alibi_start_q = tl.arange(
|
592
|
+
0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
|
593
|
+
alibi_start_k = cur_batch_ctx_len
|
594
|
+
# # init debugger
|
595
|
+
# offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc
|
596
|
+
# offset_db_k = tl.arange(0, BLOCK_N)
|
597
|
+
# calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL]
|
598
|
+
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
|
599
|
+
start_n = tl.multiple_of(start_n, BLOCK_N)
|
600
|
+
# -- compute qk ----
|
601
|
+
k = tl.load(k_ptrs +
|
602
|
+
(cur_batch_in_all_start_index + start_n) * stride_kbs,
|
603
|
+
mask=(start_n + offs_n[None, :]) <
|
604
|
+
cur_batch_seq_len - cur_batch_ctx_len,
|
605
|
+
other=0.0)
|
606
|
+
|
607
|
+
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
608
|
+
qk += tl.dot(q, k, allow_tf32=False)
|
609
|
+
qk *= sm_scale
|
610
|
+
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
|
611
|
+
float("-inf"))
|
612
|
+
|
613
|
+
# load alibi
|
614
|
+
alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
|
615
|
+
alibi_start_q[:, None]) * alibi_slope
|
616
|
+
alibi = tl.where(
|
617
|
+
(alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len),
|
618
|
+
alibi, float("-inf"))
|
619
|
+
qk += alibi
|
620
|
+
alibi_start_k += BLOCK_N
|
621
|
+
|
622
|
+
# -- compute m_ij, p, l_ij
|
623
|
+
m_ij = tl.max(qk, 1)
|
624
|
+
m_i_new = tl.maximum(m_i, m_ij)
|
625
|
+
p = tl.math.exp(qk - m_i_new[:, None])
|
626
|
+
l_ij = tl.sum(p, 1)
|
627
|
+
# -- update m_i and l_i
|
628
|
+
|
629
|
+
alpha = tl.math.exp(m_i - m_i_new)
|
630
|
+
l_i_new = alpha * l_i + l_ij
|
631
|
+
# -- update output accumulator --
|
632
|
+
# scale p
|
633
|
+
# scale acc
|
634
|
+
acc_scale = alpha
|
635
|
+
# acc_scale = l_i / l_i_new * alpha
|
636
|
+
acc = acc * acc_scale[:, None]
|
637
|
+
# update acc
|
638
|
+
v = tl.load(v_ptrs +
|
639
|
+
(cur_batch_in_all_start_index + start_n) * stride_vbs,
|
640
|
+
mask=(start_n + offs_n[:, None]) <
|
641
|
+
cur_batch_seq_len - cur_batch_ctx_len,
|
642
|
+
other=0.0)
|
643
|
+
|
644
|
+
p = p.to(v.dtype)
|
645
|
+
acc += tl.dot(p, v, allow_tf32=False)
|
646
|
+
# update m_i and l_i
|
647
|
+
l_i = l_i_new
|
648
|
+
m_i = m_i_new
|
649
|
+
|
650
|
+
acc = acc / l_i[:, None]
|
651
|
+
|
652
|
+
# initialize pointers to output
|
653
|
+
off_o = (
|
654
|
+
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
|
655
|
+
cur_head * stride_oh + offs_d[None, :] * stride_od)
|
656
|
+
out_ptrs = Out + off_o
|
657
|
+
tl.store(out_ptrs,
|
658
|
+
acc,
|
659
|
+
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
|
660
|
+
return
|
661
|
+
|
662
|
+
@torch.inference_mode()
|
663
|
+
def context_attention_fwd(q,
|
664
|
+
k,
|
665
|
+
v,
|
666
|
+
o,
|
667
|
+
k_cache,
|
668
|
+
v_cache,
|
669
|
+
b_loc,
|
670
|
+
b_start_loc,
|
671
|
+
b_seq_len,
|
672
|
+
b_ctx_len,
|
673
|
+
max_input_len,
|
674
|
+
alibi_slopes=None,
|
675
|
+
sliding_window=None):
|
676
|
+
|
677
|
+
cap = torch.cuda.get_device_capability()
|
678
|
+
BLOCK = 128 if cap[0] >= 8 else 64
|
679
|
+
# shape constraints
|
680
|
+
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
681
|
+
assert Lq == Lk and Lk == Lv
|
682
|
+
# round up Lk to a power of 2 - this is required for Triton block size
|
683
|
+
Lk_padded = triton.next_power_of_2(Lk)
|
684
|
+
|
685
|
+
sm_scale = 1.0 / (Lq**0.5)
|
686
|
+
batch, head = b_seq_len.shape[0], q.shape[1]
|
687
|
+
num_queries_per_kv = q.shape[1] // k.shape[1]
|
688
|
+
|
689
|
+
grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head,
|
690
|
+
|
691
|
+
num_warps = 8 if Lk <= 64 else 8
|
692
|
+
if alibi_slopes is not None:
|
693
|
+
assert Lk == Lk_padded
|
694
|
+
_fwd_kernel_alibi[grid](
|
695
|
+
q,
|
696
|
+
k,
|
697
|
+
v,
|
698
|
+
k_cache,
|
699
|
+
v_cache,
|
700
|
+
b_loc,
|
701
|
+
sm_scale,
|
702
|
+
b_start_loc,
|
703
|
+
b_seq_len,
|
704
|
+
b_ctx_len,
|
705
|
+
alibi_slopes,
|
706
|
+
v_cache.shape[3],
|
707
|
+
8,
|
708
|
+
o,
|
709
|
+
b_loc.stride(0),
|
710
|
+
b_loc.stride(1),
|
711
|
+
q.stride(0),
|
712
|
+
q.stride(1),
|
713
|
+
q.stride(2),
|
714
|
+
k.stride(0),
|
715
|
+
k.stride(1),
|
716
|
+
k.stride(2),
|
717
|
+
v.stride(0),
|
718
|
+
v.stride(1),
|
719
|
+
v.stride(2),
|
720
|
+
o.stride(0),
|
721
|
+
o.stride(1),
|
722
|
+
o.stride(2),
|
723
|
+
k_cache.stride(0),
|
724
|
+
k_cache.stride(1),
|
725
|
+
k_cache.stride(2),
|
726
|
+
k_cache.stride(3),
|
727
|
+
k_cache.stride(
|
728
|
+
4
|
729
|
+
), #[num_blocks, num_kv_heads, head_size/x, block_size, x]
|
730
|
+
v_cache.stride(0),
|
731
|
+
v_cache.stride(1),
|
732
|
+
v_cache.stride(2),
|
733
|
+
v_cache.stride(
|
734
|
+
3), #[num_blocks, num_kv_heads, head_size, block_size]
|
735
|
+
num_queries_per_kv=num_queries_per_kv,
|
736
|
+
BLOCK_M=BLOCK,
|
737
|
+
BLOCK_DMODEL=Lk,
|
738
|
+
BLOCK_N=BLOCK,
|
739
|
+
num_warps=num_warps,
|
740
|
+
num_stages=1,
|
741
|
+
)
|
742
|
+
return
|
743
|
+
|
744
|
+
_fwd_kernel[grid](
|
745
|
+
q,
|
746
|
+
k,
|
747
|
+
v,
|
748
|
+
k_cache,
|
749
|
+
v_cache,
|
750
|
+
b_loc,
|
751
|
+
sm_scale,
|
752
|
+
b_start_loc,
|
753
|
+
b_seq_len,
|
754
|
+
b_ctx_len,
|
755
|
+
v_cache.shape[3],
|
756
|
+
8,
|
757
|
+
o,
|
758
|
+
b_loc.stride(0),
|
759
|
+
b_loc.stride(1),
|
760
|
+
q.stride(0),
|
761
|
+
q.stride(1),
|
762
|
+
q.stride(2),
|
763
|
+
k.stride(0),
|
764
|
+
k.stride(1),
|
765
|
+
k.stride(2),
|
766
|
+
v.stride(0),
|
767
|
+
v.stride(1),
|
768
|
+
v.stride(2),
|
769
|
+
o.stride(0),
|
770
|
+
o.stride(1),
|
771
|
+
o.stride(2),
|
772
|
+
k_cache.stride(0),
|
773
|
+
k_cache.stride(1),
|
774
|
+
k_cache.stride(2),
|
775
|
+
k_cache.stride(3),
|
776
|
+
k_cache.stride(
|
777
|
+
4), #[num_blocks, num_kv_heads, head_size/x, block_size, x]
|
778
|
+
v_cache.stride(0),
|
779
|
+
v_cache.stride(1),
|
780
|
+
v_cache.stride(2),
|
781
|
+
v_cache.stride(
|
782
|
+
3), #[num_blocks, num_kv_heads, head_size, block_size]
|
783
|
+
num_queries_per_kv=num_queries_per_kv,
|
784
|
+
BLOCK_M=BLOCK,
|
785
|
+
BLOCK_DMODEL=Lk,
|
786
|
+
BLOCK_DMODEL_PADDED=Lk_padded,
|
787
|
+
BLOCK_N=BLOCK,
|
788
|
+
SLIDING_WINDOW=sliding_window if sliding_window is not None else 0,
|
789
|
+
num_warps=num_warps,
|
790
|
+
num_stages=1,
|
791
|
+
)
|
792
|
+
return
|