sglang 0.4.0__py3-none-any.whl → 0.4.0.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +1 -1
- sglang/bench_offline_throughput.py +18 -6
- sglang/bench_one_batch.py +13 -0
- sglang/bench_serving.py +8 -1
- sglang/check_env.py +140 -48
- sglang/lang/backend/runtime_endpoint.py +1 -0
- sglang/lang/chat_template.py +32 -0
- sglang/llama3_eval.py +316 -0
- sglang/srt/constrained/outlines_backend.py +5 -0
- sglang/srt/constrained/xgrammar_backend.py +9 -6
- sglang/srt/layers/attention/__init__.py +5 -2
- sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
- sglang/srt/layers/attention/flashinfer_backend.py +22 -5
- sglang/srt/layers/attention/torch_native_backend.py +22 -8
- sglang/srt/layers/attention/triton_backend.py +38 -33
- sglang/srt/layers/attention/triton_ops/decode_attention.py +305 -350
- sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
- sglang/srt/layers/ep_moe/__init__.py +0 -0
- sglang/srt/layers/ep_moe/kernels.py +349 -0
- sglang/srt/layers/ep_moe/layer.py +665 -0
- sglang/srt/layers/fused_moe_triton/fused_moe.py +64 -21
- sglang/srt/layers/fused_moe_triton/layer.py +1 -1
- sglang/srt/layers/logits_processor.py +133 -95
- sglang/srt/layers/quantization/__init__.py +2 -47
- sglang/srt/layers/quantization/fp8.py +607 -0
- sglang/srt/layers/quantization/fp8_utils.py +27 -0
- sglang/srt/layers/radix_attention.py +11 -2
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/layers/torchao_utils.py +58 -45
- sglang/srt/managers/detokenizer_manager.py +37 -17
- sglang/srt/managers/io_struct.py +39 -10
- sglang/srt/managers/schedule_batch.py +39 -24
- sglang/srt/managers/schedule_policy.py +64 -5
- sglang/srt/managers/scheduler.py +236 -197
- sglang/srt/managers/tokenizer_manager.py +99 -58
- sglang/srt/managers/tp_worker_overlap_thread.py +7 -5
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +2 -2
- sglang/srt/mem_cache/memory_pool.py +5 -1
- sglang/srt/mem_cache/radix_cache.py +12 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -11
- sglang/srt/model_executor/model_runner.py +24 -9
- sglang/srt/model_parallel.py +67 -10
- sglang/srt/models/commandr.py +2 -2
- sglang/srt/models/deepseek_v2.py +87 -7
- sglang/srt/models/gemma2.py +34 -0
- sglang/srt/models/gemma2_reward.py +0 -1
- sglang/srt/models/granite.py +517 -0
- sglang/srt/models/grok.py +72 -13
- sglang/srt/models/llama.py +22 -5
- sglang/srt/models/llama_classification.py +11 -23
- sglang/srt/models/llama_reward.py +0 -2
- sglang/srt/models/llava.py +37 -14
- sglang/srt/models/mixtral.py +12 -9
- sglang/srt/models/phi3_small.py +0 -5
- sglang/srt/models/qwen2.py +20 -0
- sglang/srt/models/qwen2_moe.py +0 -5
- sglang/srt/models/torch_native_llama.py +0 -5
- sglang/srt/openai_api/adapter.py +4 -0
- sglang/srt/openai_api/protocol.py +9 -4
- sglang/srt/sampling/sampling_batch_info.py +9 -8
- sglang/srt/server.py +4 -4
- sglang/srt/server_args.py +62 -13
- sglang/srt/utils.py +57 -10
- sglang/test/test_utils.py +3 -2
- sglang/utils.py +10 -3
- sglang/version.py +1 -1
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/METADATA +15 -9
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/RECORD +72 -65
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/LICENSE +0 -0
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/top_level.txt +0 -0
@@ -17,8 +17,11 @@ It supports page size = 1.
|
|
17
17
|
"""
|
18
18
|
|
19
19
|
# Adapted from
|
20
|
-
# https://github.com/ModelTC/lightllm/blob/
|
21
|
-
# https://github.com/ModelTC/lightllm/blob/
|
20
|
+
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py
|
21
|
+
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py
|
22
|
+
|
23
|
+
import logging
|
24
|
+
|
22
25
|
import triton
|
23
26
|
import triton.language as tl
|
24
27
|
|
@@ -26,6 +29,13 @@ from sglang.srt.utils import is_hip
|
|
26
29
|
|
27
30
|
is_hip_ = is_hip()
|
28
31
|
|
32
|
+
logger = logging.getLogger(__name__)
|
33
|
+
|
34
|
+
# TODO: Remove this when triton>=3.2.0. This issue will not affect performance and accuracy.
|
35
|
+
logger.warning(
|
36
|
+
"The following error message 'operation scheduled before its operands' can be ignored."
|
37
|
+
)
|
38
|
+
|
29
39
|
|
30
40
|
@triton.jit
|
31
41
|
def tanh(x):
|
@@ -37,10 +47,10 @@ def tanh(x):
|
|
37
47
|
def _fwd_kernel_stage1(
|
38
48
|
Q,
|
39
49
|
K_Buffer,
|
50
|
+
V_Buffer,
|
40
51
|
sm_scale,
|
41
52
|
Req_to_tokens,
|
42
53
|
B_req_idx,
|
43
|
-
B_Start_Loc,
|
44
54
|
B_Seqlen,
|
45
55
|
Att_Out,
|
46
56
|
stride_req_to_tokens_b,
|
@@ -48,152 +58,136 @@ def _fwd_kernel_stage1(
|
|
48
58
|
stride_qh,
|
49
59
|
stride_buf_kbs,
|
50
60
|
stride_buf_kh,
|
51
|
-
|
61
|
+
stride_buf_vbs,
|
62
|
+
stride_buf_vh,
|
63
|
+
stride_mid_ob,
|
64
|
+
stride_mid_oh,
|
65
|
+
stride_mid_os,
|
52
66
|
kv_group_num: tl.constexpr,
|
53
67
|
BLOCK_DMODEL: tl.constexpr,
|
68
|
+
BLOCK_DV: tl.constexpr,
|
54
69
|
BLOCK_N: tl.constexpr,
|
55
|
-
|
70
|
+
NUM_KV_SPLITS: tl.constexpr,
|
56
71
|
logit_cap: tl.constexpr,
|
57
72
|
Lk: tl.constexpr,
|
73
|
+
Lv: tl.constexpr,
|
58
74
|
):
|
59
75
|
cur_batch = tl.program_id(0)
|
60
76
|
cur_head = tl.program_id(1)
|
61
|
-
|
77
|
+
split_kv_id = tl.program_id(2)
|
62
78
|
|
63
|
-
reduce_dtype = Att_Out.dtype.element_ty
|
64
79
|
cur_kv_head = cur_head // kv_group_num
|
65
80
|
|
66
81
|
offs_d = tl.arange(0, BLOCK_DMODEL)
|
82
|
+
offs_dv = tl.arange(0, BLOCK_DV)
|
83
|
+
mask_d = offs_d < Lk
|
84
|
+
mask_dv = offs_dv < Lv
|
67
85
|
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
68
|
-
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
69
86
|
cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
|
70
87
|
|
71
88
|
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
|
72
|
-
q = tl.load(Q + off_q
|
73
|
-
|
74
|
-
kv_len_per_split = tl.cdiv(cur_batch_seq_len, SPLIT_K)
|
75
|
-
split_k_start = kv_len_per_split * split_k_id
|
76
|
-
split_k_end = tl.minimum(split_k_start + kv_len_per_split, cur_batch_seq_len)
|
77
|
-
|
78
|
-
for start_n in range(split_k_start, split_k_end, BLOCK_N):
|
79
|
-
offs_n = start_n + tl.arange(0, BLOCK_N)
|
80
|
-
k_loc = tl.load(
|
81
|
-
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n,
|
82
|
-
mask=offs_n < split_k_end,
|
83
|
-
other=0,
|
84
|
-
)
|
85
|
-
offs_buf_k = (
|
86
|
-
k_loc[:, None] * stride_buf_kbs
|
87
|
-
+ cur_kv_head * stride_buf_kh
|
88
|
-
+ offs_d[None, :]
|
89
|
-
)
|
90
|
-
k = tl.load(
|
91
|
-
K_Buffer + offs_buf_k,
|
92
|
-
mask=(offs_n[:, None] < split_k_end) & (offs_d[None, :] < Lk),
|
93
|
-
other=0.0,
|
94
|
-
).to(reduce_dtype)
|
95
|
-
att_value = tl.sum(q[None, :] * k, 1)
|
96
|
-
att_value *= sm_scale
|
89
|
+
q = tl.load(Q + off_q, mask=mask_d, other=0.0)
|
97
90
|
|
98
|
-
|
99
|
-
|
91
|
+
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
|
92
|
+
split_kv_start = kv_len_per_split * split_kv_id
|
93
|
+
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
|
100
94
|
|
101
|
-
|
102
|
-
|
95
|
+
e_max = -float("inf")
|
96
|
+
e_sum = 0.0
|
97
|
+
acc = tl.zeros([BLOCK_DV], dtype=tl.float32)
|
98
|
+
|
99
|
+
if split_kv_end > split_kv_start:
|
100
|
+
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
|
101
|
+
offs_n = start_n + tl.arange(0, BLOCK_N)
|
102
|
+
kv_loc = tl.load(
|
103
|
+
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n,
|
104
|
+
mask=offs_n < split_kv_end,
|
105
|
+
other=0,
|
106
|
+
)
|
107
|
+
offs_buf_k = (
|
108
|
+
kv_loc[:, None] * stride_buf_kbs
|
109
|
+
+ cur_kv_head * stride_buf_kh
|
110
|
+
+ offs_d[None, :]
|
111
|
+
)
|
112
|
+
k = tl.load(
|
113
|
+
K_Buffer + offs_buf_k,
|
114
|
+
mask=(offs_n[:, None] < split_kv_end) & (mask_d[None, :]),
|
115
|
+
other=0.0,
|
116
|
+
)
|
117
|
+
qk = tl.sum(q[None, :] * k, 1)
|
118
|
+
qk *= sm_scale
|
103
119
|
|
120
|
+
if logit_cap > 0:
|
121
|
+
qk = logit_cap * tanh(qk / logit_cap)
|
104
122
|
|
105
|
-
|
106
|
-
def _fwd_kernel_stage2(
|
107
|
-
logits,
|
108
|
-
V_Buffer,
|
109
|
-
Out,
|
110
|
-
Req_to_tokens,
|
111
|
-
B_req_idx,
|
112
|
-
B_Start_Loc,
|
113
|
-
B_Seqlen,
|
114
|
-
stride_logic_h,
|
115
|
-
stride_buf_vbs,
|
116
|
-
stride_buf_vh,
|
117
|
-
stride_obs,
|
118
|
-
stride_oh,
|
119
|
-
stride_req_to_token_b,
|
120
|
-
kv_group_num: tl.constexpr,
|
121
|
-
BLOCK_DMODEL: tl.constexpr,
|
122
|
-
BLOCK_N: tl.constexpr,
|
123
|
-
Lv: tl.constexpr,
|
124
|
-
):
|
125
|
-
cur_batch = tl.program_id(0)
|
126
|
-
cur_head = tl.program_id(1)
|
123
|
+
qk = tl.where(offs_n < split_kv_end, qk, float("-inf"))
|
127
124
|
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
125
|
+
offs_buf_v = (
|
126
|
+
kv_loc[:, None] * stride_buf_vbs
|
127
|
+
+ cur_kv_head * stride_buf_vh
|
128
|
+
+ offs_dv[None, :]
|
129
|
+
)
|
130
|
+
v = tl.load(
|
131
|
+
V_Buffer + offs_buf_v,
|
132
|
+
mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),
|
133
|
+
other=0.0,
|
134
|
+
)
|
133
135
|
|
134
|
-
|
135
|
-
|
136
|
+
n_e_max = tl.maximum(tl.max(qk, 0), e_max)
|
137
|
+
re_scale = tl.exp(e_max - n_e_max)
|
138
|
+
p = tl.exp(qk - n_e_max)
|
139
|
+
acc *= re_scale
|
140
|
+
acc += tl.sum(p[:, None] * v, 0)
|
136
141
|
|
137
|
-
|
138
|
-
|
142
|
+
e_sum = e_sum * re_scale + tl.sum(p, 0)
|
143
|
+
e_max = n_e_max
|
139
144
|
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
start_n = tl.multiple_of(start_n, BLOCK_N)
|
146
|
-
v_index = tl.load(
|
147
|
-
Req_to_tokens
|
148
|
-
+ cur_batch_req_idx * stride_req_to_token_b
|
149
|
-
+ (start_n + offs_n),
|
150
|
-
mask=(start_n + offs_n) < cur_batch_seq_len,
|
151
|
-
other=0,
|
145
|
+
offs_mid_o = (
|
146
|
+
cur_batch * stride_mid_ob
|
147
|
+
+ cur_head * stride_mid_oh
|
148
|
+
+ split_kv_id * stride_mid_os
|
149
|
+
+ offs_dv
|
152
150
|
)
|
153
151
|
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
mask=start_n + offs_n < cur_batch_seq_len,
|
159
|
-
other=float("-inf"),
|
152
|
+
tl.store(
|
153
|
+
Att_Out + offs_mid_o,
|
154
|
+
acc / e_sum,
|
155
|
+
mask=(mask_dv),
|
160
156
|
)
|
161
157
|
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
v_ptrs + v_index[:, None] * stride_buf_vbs, mask=(offs_d[None, :] < Lv)
|
158
|
+
offs_mid_o_1 = (
|
159
|
+
cur_batch * stride_mid_ob
|
160
|
+
+ cur_head * stride_mid_oh
|
161
|
+
+ split_kv_id * stride_mid_os
|
162
|
+
+ Lv
|
168
163
|
)
|
169
|
-
acc = acc * old_scale + tl.sum(p[:, None] * v, 0)
|
170
|
-
e_max = n_e_max
|
171
164
|
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
165
|
+
tl.store(
|
166
|
+
Att_Out + offs_mid_o_1,
|
167
|
+
e_max + tl.log(e_sum),
|
168
|
+
)
|
176
169
|
|
177
170
|
|
178
171
|
def _decode_att_m_fwd(
|
179
172
|
q,
|
180
173
|
k_buffer,
|
174
|
+
v_buffer,
|
181
175
|
att_out,
|
182
176
|
Req_to_tokens,
|
183
177
|
B_req_idx,
|
184
|
-
B_Start_Loc,
|
185
178
|
B_Seqlen,
|
186
|
-
|
179
|
+
num_kv_splits,
|
187
180
|
sm_scale,
|
188
181
|
logit_cap,
|
189
182
|
):
|
190
|
-
BLOCK =
|
191
|
-
|
183
|
+
BLOCK = 64
|
184
|
+
NUM_KV_SPLITS = num_kv_splits
|
192
185
|
Lk = k_buffer.shape[-1]
|
186
|
+
Lv = v_buffer.shape[-1]
|
193
187
|
|
194
188
|
batch, head_num = B_req_idx.shape[0], q.shape[1]
|
195
189
|
|
196
|
-
grid = (batch, head_num,
|
190
|
+
grid = (batch, head_num, NUM_KV_SPLITS)
|
197
191
|
kv_group_num = q.shape[1] // k_buffer.shape[1]
|
198
192
|
|
199
193
|
if kv_group_num == 1:
|
@@ -202,14 +196,15 @@ def _decode_att_m_fwd(
|
|
202
196
|
num_warps = 2
|
203
197
|
|
204
198
|
BLOCK_DMODEL = triton.next_power_of_2(Lk)
|
199
|
+
BLOCK_DV = triton.next_power_of_2(Lv)
|
205
200
|
|
206
201
|
_fwd_kernel_stage1[grid](
|
207
202
|
q,
|
208
203
|
k_buffer,
|
204
|
+
v_buffer,
|
209
205
|
sm_scale,
|
210
206
|
Req_to_tokens,
|
211
207
|
B_req_idx,
|
212
|
-
B_Start_Loc,
|
213
208
|
B_Seqlen,
|
214
209
|
att_out,
|
215
210
|
Req_to_tokens.stride(0),
|
@@ -217,56 +212,20 @@ def _decode_att_m_fwd(
|
|
217
212
|
q.stride(1),
|
218
213
|
k_buffer.stride(0),
|
219
214
|
k_buffer.stride(1),
|
215
|
+
v_buffer.stride(0),
|
216
|
+
v_buffer.stride(1),
|
220
217
|
att_out.stride(0),
|
218
|
+
att_out.stride(1),
|
219
|
+
att_out.stride(2),
|
221
220
|
kv_group_num=kv_group_num,
|
222
221
|
BLOCK_DMODEL=BLOCK_DMODEL,
|
222
|
+
BLOCK_DV=BLOCK_DV,
|
223
223
|
BLOCK_N=BLOCK,
|
224
|
-
|
224
|
+
NUM_KV_SPLITS=NUM_KV_SPLITS,
|
225
225
|
logit_cap=logit_cap,
|
226
226
|
num_warps=num_warps,
|
227
|
-
num_stages=
|
227
|
+
num_stages=2,
|
228
228
|
Lk=Lk,
|
229
|
-
)
|
230
|
-
|
231
|
-
|
232
|
-
def _decode_softmax_reducev_fwd(
|
233
|
-
logits,
|
234
|
-
v_buffer,
|
235
|
-
o,
|
236
|
-
req_to_tokens,
|
237
|
-
b_req_idx,
|
238
|
-
b_start_loc,
|
239
|
-
b_seq_len,
|
240
|
-
):
|
241
|
-
BLOCK = 64
|
242
|
-
batch, head = b_seq_len.shape[0], logits.shape[0]
|
243
|
-
grid = (batch, head, 1)
|
244
|
-
kv_group_num = logits.shape[0] // v_buffer.shape[1]
|
245
|
-
|
246
|
-
num_warps = 1
|
247
|
-
|
248
|
-
Lv = v_buffer.shape[-1]
|
249
|
-
BLOCK_DMODEL = triton.next_power_of_2(Lv)
|
250
|
-
|
251
|
-
_fwd_kernel_stage2[grid](
|
252
|
-
logits,
|
253
|
-
v_buffer,
|
254
|
-
o,
|
255
|
-
req_to_tokens,
|
256
|
-
b_req_idx,
|
257
|
-
b_start_loc,
|
258
|
-
b_seq_len,
|
259
|
-
logits.stride(0),
|
260
|
-
v_buffer.stride(0),
|
261
|
-
v_buffer.stride(1),
|
262
|
-
o.stride(0),
|
263
|
-
o.stride(1),
|
264
|
-
req_to_tokens.stride(0),
|
265
|
-
kv_group_num=kv_group_num,
|
266
|
-
BLOCK_DMODEL=BLOCK_DMODEL,
|
267
|
-
BLOCK_N=BLOCK,
|
268
|
-
num_warps=num_warps,
|
269
|
-
num_stages=3,
|
270
229
|
Lv=Lv,
|
271
230
|
)
|
272
231
|
|
@@ -275,10 +234,10 @@ def _decode_softmax_reducev_fwd(
|
|
275
234
|
def _fwd_grouped_kernel_stage1(
|
276
235
|
Q,
|
277
236
|
K_Buffer,
|
237
|
+
V_Buffer,
|
278
238
|
sm_scale,
|
279
239
|
Req_to_tokens,
|
280
240
|
B_req_idx,
|
281
|
-
B_Start_Loc,
|
282
241
|
B_Seqlen,
|
283
242
|
Att_Out,
|
284
243
|
stride_req_to_tokens_b,
|
@@ -286,23 +245,27 @@ def _fwd_grouped_kernel_stage1(
|
|
286
245
|
stride_qh,
|
287
246
|
stride_buf_kbs,
|
288
247
|
stride_buf_kh,
|
289
|
-
|
248
|
+
stride_buf_vbs,
|
249
|
+
stride_buf_vh,
|
250
|
+
stride_mid_ob,
|
251
|
+
stride_mid_oh,
|
252
|
+
stride_mid_os,
|
290
253
|
kv_group_num: tl.constexpr,
|
291
254
|
q_head_num: tl.constexpr,
|
292
255
|
BLOCK_DMODEL: tl.constexpr,
|
293
256
|
BLOCK_DPE: tl.constexpr,
|
257
|
+
BLOCK_DV: tl.constexpr,
|
294
258
|
BLOCK_N: tl.constexpr,
|
295
259
|
BLOCK_H: tl.constexpr,
|
296
|
-
|
260
|
+
NUM_KV_SPLITS: tl.constexpr,
|
297
261
|
logit_cap: tl.constexpr,
|
298
262
|
Lk: tl.constexpr,
|
263
|
+
Lv: tl.constexpr,
|
299
264
|
):
|
300
265
|
cur_batch = tl.program_id(0)
|
301
266
|
cur_head_id = tl.program_id(1)
|
302
267
|
cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H)
|
303
|
-
|
304
|
-
|
305
|
-
reduce_dtype = Att_Out.dtype.element_ty
|
268
|
+
split_kv_id = tl.program_id(2)
|
306
269
|
|
307
270
|
if BLOCK_H < kv_group_num:
|
308
271
|
VALID_BLOCK_H: tl.constexpr = BLOCK_H
|
@@ -313,171 +276,135 @@ def _fwd_grouped_kernel_stage1(
|
|
313
276
|
mask_h = mask_h & (cur_head < q_head_num)
|
314
277
|
|
315
278
|
offs_d = tl.arange(0, BLOCK_DMODEL)
|
279
|
+
offs_dv = tl.arange(0, BLOCK_DV)
|
280
|
+
mask_d = offs_d < Lk
|
281
|
+
mask_dv = offs_dv < Lv
|
316
282
|
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
317
|
-
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
318
283
|
cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
|
319
284
|
|
320
285
|
offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
|
321
|
-
q = tl.load(
|
322
|
-
Q + offs_q, mask=(mask_h[:, None]) & (offs_d[None, :] < Lk), other=0.0
|
323
|
-
).to(reduce_dtype)
|
286
|
+
q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0)
|
324
287
|
|
325
288
|
if BLOCK_DPE > 0:
|
326
289
|
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
|
290
|
+
mask_dpe = offs_dpe < Lk
|
327
291
|
off_qpe = (
|
328
292
|
cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
|
329
293
|
)
|
330
|
-
qpe = tl.load(
|
331
|
-
|
332
|
-
kv_len_per_split = tl.cdiv(cur_batch_seq_len, SPLIT_K)
|
333
|
-
split_k_start = kv_len_per_split * split_k_id
|
334
|
-
split_k_end = tl.minimum(split_k_start + kv_len_per_split, cur_batch_seq_len)
|
335
|
-
|
336
|
-
for start_n in range(split_k_start, split_k_end, BLOCK_N):
|
337
|
-
offs_n = start_n + tl.arange(0, BLOCK_N)
|
338
|
-
k_loc = tl.load(
|
339
|
-
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n,
|
340
|
-
mask=offs_n < split_k_end,
|
341
|
-
other=0,
|
342
|
-
)
|
343
|
-
offs_buf_k = (
|
344
|
-
k_loc[None, :] * stride_buf_kbs
|
345
|
-
+ cur_kv_head * stride_buf_kh
|
346
|
-
+ offs_d[:, None]
|
294
|
+
qpe = tl.load(
|
295
|
+
Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0
|
347
296
|
)
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
297
|
+
|
298
|
+
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
|
299
|
+
split_kv_start = kv_len_per_split * split_kv_id
|
300
|
+
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
|
301
|
+
|
302
|
+
e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf")
|
303
|
+
e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
|
304
|
+
acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32)
|
305
|
+
|
306
|
+
if split_kv_end > split_kv_start:
|
307
|
+
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
|
308
|
+
offs_n = start_n + tl.arange(0, BLOCK_N)
|
309
|
+
kv_loc = tl.load(
|
310
|
+
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n,
|
311
|
+
mask=offs_n < split_kv_end,
|
312
|
+
other=0,
|
313
|
+
)
|
314
|
+
offs_buf_k = (
|
315
|
+
kv_loc[None, :] * stride_buf_kbs
|
357
316
|
+ cur_kv_head * stride_buf_kh
|
358
|
-
+
|
317
|
+
+ offs_d[:, None]
|
359
318
|
)
|
360
|
-
|
361
|
-
K_Buffer +
|
362
|
-
mask=offs_n[None, :] <
|
319
|
+
k = tl.load(
|
320
|
+
K_Buffer + offs_buf_k,
|
321
|
+
mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]),
|
363
322
|
other=0.0,
|
364
|
-
)
|
365
|
-
qk
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
Out,
|
387
|
-
Req_to_tokens,
|
388
|
-
B_req_idx,
|
389
|
-
B_Start_Loc,
|
390
|
-
B_Seqlen,
|
391
|
-
stride_logic_h,
|
392
|
-
stride_buf_vbs,
|
393
|
-
stride_buf_vh,
|
394
|
-
stride_obs,
|
395
|
-
stride_oh,
|
396
|
-
stride_req_to_token_b,
|
397
|
-
kv_group_num: tl.constexpr,
|
398
|
-
q_head_num: tl.constexpr,
|
399
|
-
BLOCK_DMODEL: tl.constexpr,
|
400
|
-
BLOCK_N: tl.constexpr,
|
401
|
-
BLOCK_H: tl.constexpr,
|
402
|
-
Lv: tl.constexpr,
|
403
|
-
):
|
404
|
-
cur_batch = tl.program_id(0)
|
405
|
-
cur_head_id = tl.program_id(1)
|
406
|
-
cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H)
|
407
|
-
|
408
|
-
if BLOCK_H < kv_group_num:
|
409
|
-
VALID_BLOCK_H: tl.constexpr = BLOCK_H
|
410
|
-
else:
|
411
|
-
VALID_BLOCK_H: tl.constexpr = kv_group_num
|
412
|
-
cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)
|
413
|
-
mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H
|
414
|
-
mask_h = mask_h & (cur_head < q_head_num)
|
323
|
+
)
|
324
|
+
qk = tl.dot(q, k.to(q.dtype))
|
325
|
+
if BLOCK_DPE > 0:
|
326
|
+
offs_buf_kpe = (
|
327
|
+
kv_loc[None, :] * stride_buf_kbs
|
328
|
+
+ cur_kv_head * stride_buf_kh
|
329
|
+
+ offs_dpe[:, None]
|
330
|
+
)
|
331
|
+
kpe = tl.load(
|
332
|
+
K_Buffer + offs_buf_kpe,
|
333
|
+
mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]),
|
334
|
+
other=0.0,
|
335
|
+
)
|
336
|
+
qk += tl.dot(qpe, kpe.to(qpe.dtype))
|
337
|
+
qk *= sm_scale
|
338
|
+
|
339
|
+
if logit_cap > 0:
|
340
|
+
qk = logit_cap * tanh(qk / logit_cap)
|
341
|
+
|
342
|
+
qk = tl.where(
|
343
|
+
mask_h[:, None] & (offs_n[None, :] < split_kv_end), qk, float("-inf")
|
344
|
+
)
|
415
345
|
|
416
|
-
|
417
|
-
|
418
|
-
|
346
|
+
offs_buf_v = (
|
347
|
+
kv_loc[:, None] * stride_buf_vbs
|
348
|
+
+ cur_kv_head * stride_buf_vh
|
349
|
+
+ offs_dv[None, :]
|
350
|
+
)
|
351
|
+
v = tl.load(
|
352
|
+
V_Buffer + offs_buf_v,
|
353
|
+
mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),
|
354
|
+
other=0.0,
|
355
|
+
)
|
419
356
|
|
420
|
-
|
421
|
-
|
357
|
+
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
358
|
+
re_scale = tl.exp(e_max - n_e_max)
|
359
|
+
p = tl.exp(qk - n_e_max[:, None])
|
360
|
+
acc *= re_scale[:, None]
|
361
|
+
acc += tl.dot(p.to(v.dtype), v)
|
422
362
|
|
423
|
-
|
424
|
-
|
363
|
+
e_sum = e_sum * re_scale + tl.sum(p, 1)
|
364
|
+
e_max = n_e_max
|
425
365
|
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
start_n = tl.multiple_of(start_n, BLOCK_N)
|
432
|
-
v_index = tl.load(
|
433
|
-
Req_to_tokens
|
434
|
-
+ cur_batch_req_idx * stride_req_to_token_b
|
435
|
-
+ (start_n + offs_n),
|
436
|
-
mask=(start_n + offs_n) < cur_batch_seq_len,
|
437
|
-
other=0,
|
366
|
+
offs_mid_o = (
|
367
|
+
cur_batch * stride_mid_ob
|
368
|
+
+ cur_head[:, None] * stride_mid_oh
|
369
|
+
+ split_kv_id * stride_mid_os
|
370
|
+
+ offs_dv[None, :]
|
438
371
|
)
|
439
372
|
|
440
|
-
|
441
|
-
|
373
|
+
tl.store(
|
374
|
+
Att_Out + offs_mid_o,
|
375
|
+
acc / e_sum[:, None],
|
376
|
+
mask=(mask_h[:, None]) & (mask_dv[None, :]),
|
442
377
|
)
|
443
378
|
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
379
|
+
offs_mid_o_1 = (
|
380
|
+
cur_batch * stride_mid_ob
|
381
|
+
+ cur_head * stride_mid_oh
|
382
|
+
+ split_kv_id * stride_mid_os
|
383
|
+
+ Lv
|
448
384
|
)
|
449
385
|
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
v = tl.load(
|
455
|
-
v_ptrs + v_index[:, None] * stride_buf_vbs, mask=(offs_d[None, :] < Lv)
|
386
|
+
tl.store(
|
387
|
+
Att_Out + offs_mid_o_1,
|
388
|
+
e_max + tl.log(e_sum),
|
389
|
+
mask=mask_h,
|
456
390
|
)
|
457
|
-
p = p.to(v.dtype)
|
458
|
-
acc = acc * old_scale[:, None] + tl.dot(p, v)
|
459
|
-
e_max = n_e_max
|
460
|
-
|
461
|
-
acc = acc / e_sum[:, None]
|
462
|
-
off_o = cur_batch * stride_obs + cur_head[:, None] * stride_oh + offs_d[None, :]
|
463
|
-
out_ptrs = Out + off_o
|
464
|
-
tl.store(out_ptrs, acc, mask=(mask_h[:, None]) & (offs_d[None, :] < Lv))
|
465
391
|
|
466
392
|
|
467
393
|
def _decode_grouped_att_m_fwd(
|
468
394
|
q,
|
469
395
|
k_buffer,
|
396
|
+
v_buffer,
|
470
397
|
att_out,
|
471
398
|
Req_to_tokens,
|
472
399
|
B_req_idx,
|
473
|
-
B_Start_Loc,
|
474
400
|
B_Seqlen,
|
475
|
-
|
401
|
+
num_kv_splits,
|
476
402
|
sm_scale,
|
477
403
|
logit_cap,
|
478
404
|
):
|
479
|
-
BLOCK =
|
405
|
+
BLOCK = 32
|
480
406
|
Lk = k_buffer.shape[-1]
|
407
|
+
Lv = v_buffer.shape[-1]
|
481
408
|
|
482
409
|
if Lk == 576:
|
483
410
|
BLOCK_DMODEL = 512
|
@@ -488,20 +415,19 @@ def _decode_grouped_att_m_fwd(
|
|
488
415
|
else:
|
489
416
|
BLOCK_DMODEL = triton.next_power_of_2(Lk)
|
490
417
|
BLOCK_DPE = 0
|
418
|
+
BLOCK_DV = triton.next_power_of_2(Lv)
|
491
419
|
|
492
420
|
batch, head_num = B_req_idx.shape[0], q.shape[1]
|
493
421
|
kv_group_num = q.shape[1] // k_buffer.shape[1]
|
494
422
|
|
495
|
-
BLOCK_H =
|
496
|
-
|
423
|
+
BLOCK_H = 16
|
424
|
+
NUM_KV_SPLITS = num_kv_splits
|
497
425
|
grid = (
|
498
426
|
batch,
|
499
427
|
triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),
|
500
|
-
|
428
|
+
NUM_KV_SPLITS,
|
501
429
|
)
|
502
430
|
|
503
|
-
num_warps = 4
|
504
|
-
|
505
431
|
extra_kargs = {}
|
506
432
|
if is_hip_:
|
507
433
|
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
|
@@ -511,10 +437,10 @@ def _decode_grouped_att_m_fwd(
|
|
511
437
|
_fwd_grouped_kernel_stage1[grid](
|
512
438
|
q,
|
513
439
|
k_buffer,
|
440
|
+
v_buffer,
|
514
441
|
sm_scale,
|
515
442
|
Req_to_tokens,
|
516
443
|
B_req_idx,
|
517
|
-
B_Start_Loc,
|
518
444
|
B_Seqlen,
|
519
445
|
att_out,
|
520
446
|
Req_to_tokens.stride(0),
|
@@ -522,41 +448,97 @@ def _decode_grouped_att_m_fwd(
|
|
522
448
|
q.stride(1),
|
523
449
|
k_buffer.stride(0),
|
524
450
|
k_buffer.stride(1),
|
451
|
+
v_buffer.stride(0),
|
452
|
+
v_buffer.stride(1),
|
525
453
|
att_out.stride(0),
|
454
|
+
att_out.stride(1),
|
455
|
+
att_out.stride(2),
|
526
456
|
kv_group_num=kv_group_num,
|
527
457
|
q_head_num=head_num,
|
528
458
|
BLOCK_DMODEL=BLOCK_DMODEL,
|
529
459
|
BLOCK_DPE=BLOCK_DPE,
|
460
|
+
BLOCK_DV=BLOCK_DV,
|
530
461
|
BLOCK_N=BLOCK,
|
531
462
|
BLOCK_H=BLOCK_H,
|
532
|
-
|
463
|
+
NUM_KV_SPLITS=NUM_KV_SPLITS,
|
533
464
|
logit_cap=logit_cap,
|
534
|
-
num_warps=
|
535
|
-
num_stages=
|
465
|
+
num_warps=4,
|
466
|
+
num_stages=2,
|
536
467
|
Lk=Lk,
|
468
|
+
Lv=Lv,
|
537
469
|
**extra_kargs,
|
538
470
|
)
|
539
471
|
|
540
472
|
|
541
|
-
|
473
|
+
@triton.jit
|
474
|
+
def _fwd_kernel_stage2(
|
475
|
+
Mid_O,
|
476
|
+
O,
|
477
|
+
B_Seqlen,
|
478
|
+
stride_mid_ob,
|
479
|
+
stride_mid_oh,
|
480
|
+
stride_mid_os,
|
481
|
+
stride_obs,
|
482
|
+
stride_oh,
|
483
|
+
NUM_KV_SPLITS: tl.constexpr,
|
484
|
+
BLOCK_DV: tl.constexpr,
|
485
|
+
Lv: tl.constexpr,
|
486
|
+
):
|
487
|
+
cur_batch = tl.program_id(0)
|
488
|
+
cur_head = tl.program_id(1)
|
489
|
+
|
490
|
+
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
491
|
+
|
492
|
+
offs_d = tl.arange(0, BLOCK_DV)
|
493
|
+
mask_d = offs_d < Lv
|
494
|
+
|
495
|
+
e_sum = 0.0
|
496
|
+
e_max = -float("inf")
|
497
|
+
acc = tl.zeros([BLOCK_DV], dtype=tl.float32)
|
498
|
+
|
499
|
+
offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d
|
500
|
+
offs_logic = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + Lv
|
501
|
+
|
502
|
+
for split_kv_id in range(0, NUM_KV_SPLITS):
|
503
|
+
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
|
504
|
+
split_kv_start = kv_len_per_split * split_kv_id
|
505
|
+
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
|
506
|
+
|
507
|
+
if split_kv_end > split_kv_start:
|
508
|
+
tv = tl.load(
|
509
|
+
Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0
|
510
|
+
)
|
511
|
+
tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os)
|
512
|
+
n_e_max = tl.maximum(tlogic, e_max)
|
513
|
+
|
514
|
+
old_scale = tl.exp(e_max - n_e_max)
|
515
|
+
acc *= old_scale
|
516
|
+
exp_logic = tl.exp(tlogic - n_e_max)
|
517
|
+
acc += exp_logic * tv
|
518
|
+
|
519
|
+
e_sum = e_sum * old_scale + exp_logic
|
520
|
+
e_max = n_e_max
|
521
|
+
|
522
|
+
tl.store(
|
523
|
+
O + cur_batch * stride_obs + cur_head * stride_oh + offs_d,
|
524
|
+
acc / e_sum,
|
525
|
+
mask=mask_d,
|
526
|
+
)
|
527
|
+
|
528
|
+
|
529
|
+
def _decode_softmax_reducev_fwd(
|
542
530
|
logits,
|
543
|
-
|
531
|
+
q,
|
544
532
|
o,
|
545
|
-
|
546
|
-
b_req_idx,
|
547
|
-
b_start_loc,
|
533
|
+
v_buffer,
|
548
534
|
b_seq_len,
|
535
|
+
num_kv_splits,
|
549
536
|
):
|
550
|
-
|
551
|
-
batch, head_num = b_seq_len.shape[0], logits.shape[0]
|
552
|
-
kv_group_num = logits.shape[0] // v_buffer.shape[1]
|
553
|
-
BLOCK_H = max(16, min(64, triton.next_power_of_2(kv_group_num)))
|
554
|
-
grid = (batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), 1)
|
555
|
-
|
556
|
-
num_warps = 8
|
557
|
-
|
537
|
+
batch, head_num = q.shape[0], q.shape[1]
|
558
538
|
Lv = v_buffer.shape[-1]
|
559
|
-
|
539
|
+
BLOCK_DV = triton.next_power_of_2(Lv)
|
540
|
+
|
541
|
+
NUM_KV_SPLITS = num_kv_splits
|
560
542
|
|
561
543
|
extra_kargs = {}
|
562
544
|
if is_hip_:
|
@@ -564,28 +546,21 @@ def _decode_grouped_softmax_reducev_fwd(
|
|
564
546
|
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
|
565
547
|
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
|
566
548
|
|
567
|
-
|
549
|
+
grid = (batch, head_num)
|
550
|
+
_fwd_kernel_stage2[grid](
|
568
551
|
logits,
|
569
|
-
v_buffer,
|
570
552
|
o,
|
571
|
-
req_to_tokens,
|
572
|
-
b_req_idx,
|
573
|
-
b_start_loc,
|
574
553
|
b_seq_len,
|
575
554
|
logits.stride(0),
|
576
|
-
|
577
|
-
|
555
|
+
logits.stride(1),
|
556
|
+
logits.stride(2),
|
578
557
|
o.stride(0),
|
579
558
|
o.stride(1),
|
580
|
-
|
581
|
-
|
582
|
-
q_head_num=head_num,
|
583
|
-
BLOCK_DMODEL=BLOCK_DMODEL,
|
584
|
-
BLOCK_N=BLOCK,
|
585
|
-
BLOCK_H=BLOCK_H,
|
559
|
+
NUM_KV_SPLITS=NUM_KV_SPLITS,
|
560
|
+
BLOCK_DV=BLOCK_DV,
|
586
561
|
Lv=Lv,
|
587
|
-
num_warps=
|
588
|
-
num_stages=
|
562
|
+
num_warps=4,
|
563
|
+
num_stages=2,
|
589
564
|
**extra_kargs,
|
590
565
|
)
|
591
566
|
|
@@ -597,34 +572,25 @@ def decode_attention_fwd_normal(
|
|
597
572
|
o,
|
598
573
|
req_to_token,
|
599
574
|
b_req_idx,
|
600
|
-
b_start_loc,
|
601
575
|
b_seq_len,
|
602
576
|
attn_logits,
|
603
|
-
|
577
|
+
num_kv_splits,
|
604
578
|
sm_scale,
|
605
579
|
logit_cap=0.0,
|
606
580
|
):
|
607
581
|
_decode_att_m_fwd(
|
608
582
|
q,
|
609
583
|
k_buffer,
|
584
|
+
v_buffer,
|
610
585
|
attn_logits,
|
611
586
|
req_to_token,
|
612
587
|
b_req_idx,
|
613
|
-
b_start_loc,
|
614
588
|
b_seq_len,
|
615
|
-
|
589
|
+
num_kv_splits,
|
616
590
|
sm_scale,
|
617
591
|
logit_cap,
|
618
592
|
)
|
619
|
-
_decode_softmax_reducev_fwd(
|
620
|
-
attn_logits,
|
621
|
-
v_buffer,
|
622
|
-
o,
|
623
|
-
req_to_token,
|
624
|
-
b_req_idx,
|
625
|
-
b_start_loc,
|
626
|
-
b_seq_len,
|
627
|
-
)
|
593
|
+
_decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits)
|
628
594
|
|
629
595
|
|
630
596
|
def decode_attention_fwd_grouped(
|
@@ -634,34 +600,25 @@ def decode_attention_fwd_grouped(
|
|
634
600
|
o,
|
635
601
|
req_to_token,
|
636
602
|
b_req_idx,
|
637
|
-
b_start_loc,
|
638
603
|
b_seq_len,
|
639
604
|
attn_logits,
|
640
|
-
|
605
|
+
num_kv_splits,
|
641
606
|
sm_scale,
|
642
607
|
logit_cap=0.0,
|
643
608
|
):
|
644
609
|
_decode_grouped_att_m_fwd(
|
645
610
|
q,
|
646
611
|
k_buffer,
|
612
|
+
v_buffer,
|
647
613
|
attn_logits,
|
648
614
|
req_to_token,
|
649
615
|
b_req_idx,
|
650
|
-
b_start_loc,
|
651
616
|
b_seq_len,
|
652
|
-
|
617
|
+
num_kv_splits,
|
653
618
|
sm_scale,
|
654
619
|
logit_cap,
|
655
620
|
)
|
656
|
-
|
657
|
-
attn_logits,
|
658
|
-
v_buffer,
|
659
|
-
o,
|
660
|
-
req_to_token,
|
661
|
-
b_req_idx,
|
662
|
-
b_start_loc,
|
663
|
-
b_seq_len,
|
664
|
-
)
|
621
|
+
_decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits)
|
665
622
|
|
666
623
|
|
667
624
|
def decode_attention_fwd(
|
@@ -671,13 +628,13 @@ def decode_attention_fwd(
|
|
671
628
|
o,
|
672
629
|
req_to_token,
|
673
630
|
b_req_idx,
|
674
|
-
b_start_loc,
|
675
631
|
b_seq_len,
|
676
632
|
attn_logits,
|
677
|
-
|
633
|
+
num_kv_splits,
|
678
634
|
sm_scale,
|
679
635
|
logit_cap=0.0,
|
680
636
|
):
|
637
|
+
assert num_kv_splits == attn_logits.shape[2]
|
681
638
|
kv_group_num = q.shape[1] // v_buffer.shape[1]
|
682
639
|
|
683
640
|
if kv_group_num == 1:
|
@@ -689,10 +646,9 @@ def decode_attention_fwd(
|
|
689
646
|
o,
|
690
647
|
req_to_token,
|
691
648
|
b_req_idx,
|
692
|
-
b_start_loc,
|
693
649
|
b_seq_len,
|
694
650
|
attn_logits,
|
695
|
-
|
651
|
+
num_kv_splits,
|
696
652
|
sm_scale,
|
697
653
|
logit_cap,
|
698
654
|
)
|
@@ -705,10 +661,9 @@ def decode_attention_fwd(
|
|
705
661
|
o,
|
706
662
|
req_to_token,
|
707
663
|
b_req_idx,
|
708
|
-
b_start_loc,
|
709
664
|
b_seq_len,
|
710
665
|
attn_logits,
|
711
|
-
|
666
|
+
num_kv_splits,
|
712
667
|
sm_scale,
|
713
668
|
logit_cap,
|
714
669
|
)
|