sglang 0.3.0__py3-none-any.whl → 0.3.1.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +17 -8
- sglang/bench_serving.py +33 -38
- sglang/global_config.py +5 -17
- sglang/lang/backend/runtime_endpoint.py +5 -2
- sglang/lang/interpreter.py +1 -4
- sglang/launch_server.py +3 -6
- sglang/launch_server_llavavid.py +7 -8
- sglang/srt/{model_config.py → configs/model_config.py} +5 -0
- sglang/srt/constrained/__init__.py +2 -0
- sglang/srt/constrained/fsm_cache.py +33 -38
- sglang/srt/constrained/jump_forward.py +0 -1
- sglang/srt/conversation.py +4 -1
- sglang/srt/hf_transformers_utils.py +1 -3
- sglang/srt/layers/activation.py +12 -0
- sglang/srt/layers/attention_backend.py +480 -0
- sglang/srt/layers/flashinfer_utils.py +235 -0
- sglang/srt/layers/fused_moe/layer.py +27 -7
- sglang/srt/layers/layernorm.py +12 -0
- sglang/srt/layers/logits_processor.py +64 -77
- sglang/srt/layers/radix_attention.py +11 -161
- sglang/srt/layers/sampler.py +38 -122
- sglang/srt/layers/torchao_utils.py +75 -0
- sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
- sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
- sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
- sglang/srt/lora/lora.py +403 -0
- sglang/srt/lora/lora_config.py +43 -0
- sglang/srt/lora/lora_manager.py +259 -0
- sglang/srt/managers/controller_multi.py +1 -5
- sglang/srt/managers/controller_single.py +0 -5
- sglang/srt/managers/io_struct.py +16 -1
- sglang/srt/managers/policy_scheduler.py +122 -5
- sglang/srt/managers/schedule_batch.py +105 -71
- sglang/srt/managers/tokenizer_manager.py +17 -8
- sglang/srt/managers/tp_worker.py +188 -121
- sglang/srt/model_executor/cuda_graph_runner.py +69 -133
- sglang/srt/model_executor/forward_batch_info.py +35 -312
- sglang/srt/model_executor/model_runner.py +123 -154
- sglang/srt/models/baichuan.py +416 -0
- sglang/srt/models/chatglm.py +1 -5
- sglang/srt/models/commandr.py +1 -5
- sglang/srt/models/dbrx.py +1 -5
- sglang/srt/models/deepseek.py +1 -5
- sglang/srt/models/deepseek_v2.py +7 -6
- sglang/srt/models/exaone.py +1 -5
- sglang/srt/models/gemma.py +1 -5
- sglang/srt/models/gemma2.py +1 -5
- sglang/srt/models/gpt_bigcode.py +1 -5
- sglang/srt/models/grok.py +1 -5
- sglang/srt/models/internlm2.py +1 -5
- sglang/srt/models/llama.py +51 -5
- sglang/srt/models/llama_classification.py +1 -20
- sglang/srt/models/llava.py +30 -5
- sglang/srt/models/llavavid.py +2 -2
- sglang/srt/models/minicpm.py +1 -5
- sglang/srt/models/minicpm3.py +669 -0
- sglang/srt/models/mixtral.py +6 -5
- sglang/srt/models/mixtral_quant.py +1 -5
- sglang/srt/models/olmoe.py +415 -0
- sglang/srt/models/qwen.py +1 -5
- sglang/srt/models/qwen2.py +1 -5
- sglang/srt/models/qwen2_moe.py +6 -5
- sglang/srt/models/stablelm.py +1 -5
- sglang/srt/models/xverse.py +375 -0
- sglang/srt/models/xverse_moe.py +445 -0
- sglang/srt/openai_api/adapter.py +65 -46
- sglang/srt/openai_api/protocol.py +11 -3
- sglang/srt/sampling/sampling_batch_info.py +46 -80
- sglang/srt/server.py +30 -15
- sglang/srt/server_args.py +163 -28
- sglang/srt/utils.py +19 -51
- sglang/test/few_shot_gsm8k.py +132 -0
- sglang/test/runners.py +114 -22
- sglang/test/test_programs.py +7 -5
- sglang/test/test_utils.py +85 -2
- sglang/utils.py +32 -37
- sglang/version.py +1 -1
- {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/METADATA +30 -18
- sglang-0.3.1.post1.dist-info/RECORD +130 -0
- {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/WHEEL +1 -1
- sglang-0.3.0.dist-info/RECORD +0 -118
- {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/LICENSE +0 -0
- {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/top_level.txt +0 -0
@@ -15,24 +15,15 @@ limitations under the License.
|
|
15
15
|
|
16
16
|
"""
|
17
17
|
Memory-efficient attention for decoding.
|
18
|
+
It supports page size = 1.
|
18
19
|
"""
|
19
20
|
|
20
21
|
# Adapted from
|
21
22
|
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
|
22
23
|
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py
|
23
|
-
import torch
|
24
24
|
import triton
|
25
25
|
import triton.language as tl
|
26
26
|
|
27
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
28
|
-
|
29
|
-
if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
|
30
|
-
REDUCE_TRITON_TYPE = tl.float32
|
31
|
-
REDUCE_TORCH_TYPE = torch.float32
|
32
|
-
else:
|
33
|
-
REDUCE_TRITON_TYPE = tl.float16
|
34
|
-
REDUCE_TORCH_TYPE = torch.float16
|
35
|
-
|
36
27
|
|
37
28
|
@triton.jit
|
38
29
|
def tanh(x):
|
@@ -60,11 +51,13 @@ def _fwd_kernel_stage1(
|
|
60
51
|
BLOCK_DMODEL: tl.constexpr,
|
61
52
|
BLOCK_N: tl.constexpr,
|
62
53
|
logit_cap: tl.constexpr,
|
54
|
+
Lk: tl.constexpr,
|
63
55
|
):
|
64
56
|
cur_batch = tl.program_id(0)
|
65
57
|
cur_head = tl.program_id(1)
|
66
58
|
start_n = tl.program_id(2)
|
67
59
|
|
60
|
+
reduce_dtype = Att_Out.dtype.element_ty
|
68
61
|
cur_kv_head = cur_head // kv_group_num
|
69
62
|
|
70
63
|
offs_d = tl.arange(0, BLOCK_DMODEL)
|
@@ -83,7 +76,7 @@ def _fwd_kernel_stage1(
|
|
83
76
|
block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0)
|
84
77
|
|
85
78
|
for start_mark in range(0, block_mask, 1):
|
86
|
-
q = tl.load(Q + off_q + start_mark).to(
|
79
|
+
q = tl.load(Q + off_q + start_mark).to(reduce_dtype)
|
87
80
|
offs_n_new = cur_batch_start_index + offs_n
|
88
81
|
k_loc = tl.load(
|
89
82
|
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new,
|
@@ -97,9 +90,9 @@ def _fwd_kernel_stage1(
|
|
97
90
|
)
|
98
91
|
k = tl.load(
|
99
92
|
K_Buffer + offs_buf_k,
|
100
|
-
mask=offs_n_new[:, None] < cur_batch_end_index,
|
93
|
+
mask=(offs_n_new[:, None] < cur_batch_end_index) & (offs_d[None, :] < Lk),
|
101
94
|
other=0.0,
|
102
|
-
).to(
|
95
|
+
).to(reduce_dtype)
|
103
96
|
att_value = tl.sum(q[None, :] * k, 1)
|
104
97
|
att_value *= sm_scale
|
105
98
|
|
@@ -112,7 +105,7 @@ def _fwd_kernel_stage1(
|
|
112
105
|
|
113
106
|
@triton.jit
|
114
107
|
def _fwd_kernel_stage2(
|
115
|
-
|
108
|
+
logits,
|
116
109
|
V_Buffer,
|
117
110
|
Out,
|
118
111
|
Req_to_tokens,
|
@@ -128,6 +121,7 @@ def _fwd_kernel_stage2(
|
|
128
121
|
kv_group_num: tl.constexpr,
|
129
122
|
BLOCK_DMODEL: tl.constexpr,
|
130
123
|
BLOCK_N: tl.constexpr,
|
124
|
+
Lv: tl.constexpr,
|
131
125
|
):
|
132
126
|
cur_batch = tl.program_id(0)
|
133
127
|
cur_head = tl.program_id(1)
|
@@ -159,7 +153,7 @@ def _fwd_kernel_stage2(
|
|
159
153
|
)
|
160
154
|
|
161
155
|
qk = tl.load(
|
162
|
-
|
156
|
+
logits
|
163
157
|
+ cur_head * stride_logic_h
|
164
158
|
+ (cur_batch_start_loc + start_n + offs_n),
|
165
159
|
mask=start_n + offs_n < cur_batch_seq_len,
|
@@ -170,14 +164,16 @@ def _fwd_kernel_stage2(
|
|
170
164
|
old_scale = tl.exp(e_max - n_e_max)
|
171
165
|
p = tl.exp(qk - n_e_max)
|
172
166
|
e_sum = e_sum * old_scale + tl.sum(p, 0)
|
173
|
-
v = tl.load(
|
167
|
+
v = tl.load(
|
168
|
+
v_ptrs + v_index[:, None] * stride_buf_vbs, mask=(offs_d[None, :] < Lv)
|
169
|
+
)
|
174
170
|
acc = acc * old_scale + tl.sum(p[:, None] * v, 0)
|
175
171
|
e_max = n_e_max
|
176
172
|
|
177
173
|
acc = acc / e_sum
|
178
174
|
off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d
|
179
175
|
out_ptrs = Out + off_o
|
180
|
-
tl.store(out_ptrs, acc)
|
176
|
+
tl.store(out_ptrs, acc, mask=(offs_d < Lv))
|
181
177
|
|
182
178
|
|
183
179
|
def _decode_att_m_fwd(
|
@@ -193,10 +189,7 @@ def _decode_att_m_fwd(
|
|
193
189
|
logit_cap,
|
194
190
|
):
|
195
191
|
BLOCK = 32
|
196
|
-
|
197
|
-
Lq, Lk = q.shape[-1], k_buffer.shape[-1]
|
198
|
-
assert Lq == Lk
|
199
|
-
assert Lk in {16, 32, 64, 128, 256}
|
192
|
+
Lk = k_buffer.shape[-1]
|
200
193
|
|
201
194
|
batch, head_num = B_req_idx.shape[0], q.shape[1]
|
202
195
|
|
@@ -208,6 +201,8 @@ def _decode_att_m_fwd(
|
|
208
201
|
else:
|
209
202
|
num_warps = 2
|
210
203
|
|
204
|
+
BLOCK_DMODEL = triton.next_power_of_2(Lk)
|
205
|
+
|
211
206
|
_fwd_kernel_stage1[grid](
|
212
207
|
q,
|
213
208
|
k_buffer,
|
@@ -224,16 +219,17 @@ def _decode_att_m_fwd(
|
|
224
219
|
k_buffer.stride(1),
|
225
220
|
att_out.stride(0),
|
226
221
|
kv_group_num=kv_group_num,
|
227
|
-
BLOCK_DMODEL=
|
222
|
+
BLOCK_DMODEL=BLOCK_DMODEL,
|
228
223
|
BLOCK_N=BLOCK,
|
229
224
|
logit_cap=logit_cap,
|
230
225
|
num_warps=num_warps,
|
231
226
|
num_stages=1,
|
227
|
+
Lk=Lk,
|
232
228
|
)
|
233
229
|
|
234
230
|
|
235
231
|
def _decode_softmax_reducev_fwd(
|
236
|
-
|
232
|
+
logits,
|
237
233
|
v_buffer,
|
238
234
|
o,
|
239
235
|
req_to_tokens,
|
@@ -242,31 +238,35 @@ def _decode_softmax_reducev_fwd(
|
|
242
238
|
b_seq_len,
|
243
239
|
):
|
244
240
|
BLOCK = 64
|
245
|
-
batch, head = b_seq_len.shape[0],
|
241
|
+
batch, head = b_seq_len.shape[0], logits.shape[0]
|
246
242
|
grid = (batch, head, 1)
|
247
|
-
kv_group_num =
|
243
|
+
kv_group_num = logits.shape[0] // v_buffer.shape[1]
|
248
244
|
|
249
245
|
num_warps = 1
|
250
246
|
|
247
|
+
Lv = v_buffer.shape[-1]
|
248
|
+
BLOCK_DMODEL = triton.next_power_of_2(Lv)
|
249
|
+
|
251
250
|
_fwd_kernel_stage2[grid](
|
252
|
-
|
251
|
+
logits,
|
253
252
|
v_buffer,
|
254
253
|
o,
|
255
254
|
req_to_tokens,
|
256
255
|
b_req_idx,
|
257
256
|
b_start_loc,
|
258
257
|
b_seq_len,
|
259
|
-
|
258
|
+
logits.stride(0),
|
260
259
|
v_buffer.stride(0),
|
261
260
|
v_buffer.stride(1),
|
262
261
|
o.stride(0),
|
263
262
|
o.stride(1),
|
264
263
|
req_to_tokens.stride(0),
|
265
264
|
kv_group_num=kv_group_num,
|
266
|
-
BLOCK_DMODEL=
|
265
|
+
BLOCK_DMODEL=BLOCK_DMODEL,
|
267
266
|
BLOCK_N=BLOCK,
|
268
267
|
num_warps=num_warps,
|
269
268
|
num_stages=3,
|
269
|
+
Lv=Lv,
|
270
270
|
)
|
271
271
|
|
272
272
|
|
@@ -293,11 +293,13 @@ def _fwd_grouped_kernel_stage1(
|
|
293
293
|
BLOCK_N: tl.constexpr,
|
294
294
|
BLOCK_H: tl.constexpr,
|
295
295
|
logit_cap: tl.constexpr,
|
296
|
+
Lk: tl.constexpr,
|
296
297
|
):
|
297
298
|
cur_batch = tl.program_id(0)
|
298
299
|
cur_kv_head = tl.program_id(1)
|
299
300
|
start_n = tl.program_id(2)
|
300
301
|
|
302
|
+
reduce_dtype = Att_Out.dtype.element_ty
|
301
303
|
cur_head = cur_kv_head * kv_group_num + tl.arange(0, BLOCK_H)
|
302
304
|
mask_h = cur_head < (cur_kv_head + 1) * kv_group_num
|
303
305
|
mask_h = mask_h & (cur_head < q_head_num)
|
@@ -324,9 +326,9 @@ def _fwd_grouped_kernel_stage1(
|
|
324
326
|
block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0)
|
325
327
|
|
326
328
|
for start_mark in range(0, block_mask, 1):
|
327
|
-
q = tl.load(
|
328
|
-
|
329
|
-
)
|
329
|
+
q = tl.load(
|
330
|
+
Q + offs_q + start_mark, mask=(mask_h[:, None]) & (offs_d[None, :] < Lk)
|
331
|
+
).to(reduce_dtype)
|
330
332
|
offs_n_new = cur_batch_start_index + offs_n
|
331
333
|
k_loc = tl.load(
|
332
334
|
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new,
|
@@ -340,13 +342,13 @@ def _fwd_grouped_kernel_stage1(
|
|
340
342
|
)
|
341
343
|
k = tl.load(
|
342
344
|
K_Buffer + offs_buf_k,
|
343
|
-
mask=offs_n_new[None, :] < cur_batch_end_index,
|
345
|
+
mask=(offs_n_new[None, :] < cur_batch_end_index) & (offs_d[:, None] < Lk),
|
344
346
|
other=0.0,
|
345
|
-
).to(
|
347
|
+
).to(reduce_dtype)
|
346
348
|
qk = tl.dot(q, k)
|
347
349
|
if BLOCK_DPE > 0:
|
348
350
|
qpe = tl.load(Q + off_qpe + start_mark, mask=mask_h[:, None]).to(
|
349
|
-
|
351
|
+
reduce_dtype
|
350
352
|
)
|
351
353
|
offs_buf_kpe = (
|
352
354
|
k_loc[None, :] * stride_buf_kbs
|
@@ -357,7 +359,7 @@ def _fwd_grouped_kernel_stage1(
|
|
357
359
|
K_Buffer + offs_buf_kpe,
|
358
360
|
mask=offs_n_new[None, :] < cur_batch_end_index,
|
359
361
|
other=0.0,
|
360
|
-
).to(
|
362
|
+
).to(reduce_dtype)
|
361
363
|
qk += tl.dot(qpe, kpe)
|
362
364
|
qk *= sm_scale
|
363
365
|
|
@@ -377,7 +379,7 @@ def _fwd_grouped_kernel_stage1(
|
|
377
379
|
|
378
380
|
@triton.jit
|
379
381
|
def _fwd_grouped_kernel_stage2(
|
380
|
-
|
382
|
+
logits,
|
381
383
|
V_Buffer,
|
382
384
|
Out,
|
383
385
|
Req_to_tokens,
|
@@ -395,6 +397,7 @@ def _fwd_grouped_kernel_stage2(
|
|
395
397
|
BLOCK_DMODEL: tl.constexpr,
|
396
398
|
BLOCK_N: tl.constexpr,
|
397
399
|
BLOCK_H: tl.constexpr,
|
400
|
+
Lv: tl.constexpr,
|
398
401
|
):
|
399
402
|
cur_batch = tl.program_id(0)
|
400
403
|
cur_kv_head = tl.program_id(1)
|
@@ -432,7 +435,7 @@ def _fwd_grouped_kernel_stage2(
|
|
432
435
|
)
|
433
436
|
|
434
437
|
qk = tl.load(
|
435
|
-
|
438
|
+
logits + offs_qk,
|
436
439
|
mask=mask_h[:, None] & (start_n + offs_n[None, :] < cur_batch_seq_len),
|
437
440
|
other=float("-inf"),
|
438
441
|
)
|
@@ -441,7 +444,9 @@ def _fwd_grouped_kernel_stage2(
|
|
441
444
|
old_scale = tl.exp(e_max - n_e_max)
|
442
445
|
p = tl.exp(qk - n_e_max[:, None])
|
443
446
|
e_sum = e_sum * old_scale + tl.sum(p, 1)
|
444
|
-
v = tl.load(
|
447
|
+
v = tl.load(
|
448
|
+
v_ptrs + v_index[:, None] * stride_buf_vbs, mask=(offs_d[None, :] < Lv)
|
449
|
+
)
|
445
450
|
p = p.to(v.dtype)
|
446
451
|
acc = acc * old_scale[:, None] + tl.dot(p, v)
|
447
452
|
e_max = n_e_max
|
@@ -449,7 +454,7 @@ def _fwd_grouped_kernel_stage2(
|
|
449
454
|
acc = acc / e_sum[:, None]
|
450
455
|
off_o = cur_batch * stride_obs + cur_head[:, None] * stride_oh + offs_d[None, :]
|
451
456
|
out_ptrs = Out + off_o
|
452
|
-
tl.store(out_ptrs, acc, mask=mask_h[:, None])
|
457
|
+
tl.store(out_ptrs, acc, mask=(mask_h[:, None]) & (offs_d[None, :] < Lv))
|
453
458
|
|
454
459
|
|
455
460
|
def _decode_grouped_att_m_fwd(
|
@@ -464,17 +469,17 @@ def _decode_grouped_att_m_fwd(
|
|
464
469
|
sm_scale,
|
465
470
|
logit_cap,
|
466
471
|
):
|
467
|
-
BLOCK =
|
468
|
-
|
469
|
-
Lq, Lk = q.shape[-1], k_buffer.shape[-1]
|
470
|
-
assert Lq == Lk
|
471
|
-
assert Lk in {16, 32, 64, 128, 256, 576}
|
472
|
+
BLOCK = 64
|
473
|
+
Lk = k_buffer.shape[-1]
|
472
474
|
|
473
475
|
if Lk == 576:
|
474
476
|
BLOCK_DMODEL = 512
|
475
477
|
BLOCK_DPE = 64
|
478
|
+
elif Lk == 288:
|
479
|
+
BLOCK_DMODEL = 256
|
480
|
+
BLOCK_DPE = 32
|
476
481
|
else:
|
477
|
-
BLOCK_DMODEL = Lk
|
482
|
+
BLOCK_DMODEL = triton.next_power_of_2(Lk)
|
478
483
|
BLOCK_DPE = 0
|
479
484
|
|
480
485
|
batch, head_num = B_req_idx.shape[0], q.shape[1]
|
@@ -513,11 +518,12 @@ def _decode_grouped_att_m_fwd(
|
|
513
518
|
logit_cap=logit_cap,
|
514
519
|
num_warps=num_warps,
|
515
520
|
num_stages=1,
|
521
|
+
Lk=Lk,
|
516
522
|
)
|
517
523
|
|
518
524
|
|
519
525
|
def _decode_grouped_softmax_reducev_fwd(
|
520
|
-
|
526
|
+
logits,
|
521
527
|
v_buffer,
|
522
528
|
o,
|
523
529
|
req_to_tokens,
|
@@ -526,22 +532,25 @@ def _decode_grouped_softmax_reducev_fwd(
|
|
526
532
|
b_seq_len,
|
527
533
|
):
|
528
534
|
BLOCK = 128
|
529
|
-
batch, head_num = b_seq_len.shape[0],
|
530
|
-
kv_group_num =
|
535
|
+
batch, head_num = b_seq_len.shape[0], logits.shape[0]
|
536
|
+
kv_group_num = logits.shape[0] // v_buffer.shape[1]
|
531
537
|
BLOCK_H = max(16, triton.next_power_of_2(kv_group_num))
|
532
538
|
grid = (batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), 1)
|
533
539
|
|
534
540
|
num_warps = 8
|
535
541
|
|
542
|
+
Lv = v_buffer.shape[-1]
|
543
|
+
BLOCK_DMODEL = triton.next_power_of_2(Lv)
|
544
|
+
|
536
545
|
_fwd_grouped_kernel_stage2[grid](
|
537
|
-
|
546
|
+
logits,
|
538
547
|
v_buffer,
|
539
548
|
o,
|
540
549
|
req_to_tokens,
|
541
550
|
b_req_idx,
|
542
551
|
b_start_loc,
|
543
552
|
b_seq_len,
|
544
|
-
|
553
|
+
logits.stride(0),
|
545
554
|
v_buffer.stride(0),
|
546
555
|
v_buffer.stride(1),
|
547
556
|
o.stride(0),
|
@@ -549,9 +558,10 @@ def _decode_grouped_softmax_reducev_fwd(
|
|
549
558
|
req_to_tokens.stride(0),
|
550
559
|
kv_group_num=kv_group_num,
|
551
560
|
q_head_num=head_num,
|
552
|
-
BLOCK_DMODEL=
|
561
|
+
BLOCK_DMODEL=BLOCK_DMODEL,
|
553
562
|
BLOCK_N=BLOCK,
|
554
563
|
BLOCK_H=BLOCK_H,
|
564
|
+
Lv=Lv,
|
555
565
|
num_warps=num_warps,
|
556
566
|
num_stages=1,
|
557
567
|
)
|
@@ -566,17 +576,11 @@ def decode_attention_fwd(
|
|
566
576
|
b_req_idx,
|
567
577
|
b_start_loc,
|
568
578
|
b_seq_len,
|
579
|
+
attn_logits,
|
569
580
|
max_len_in_batch,
|
570
|
-
total_num_tokens,
|
571
581
|
sm_scale,
|
572
|
-
logit_cap
|
573
|
-
att_m=None,
|
582
|
+
logit_cap=0.0,
|
574
583
|
):
|
575
|
-
if att_m is None:
|
576
|
-
att_m = torch.empty(
|
577
|
-
(q.shape[-2], total_num_tokens), dtype=REDUCE_TORCH_TYPE, device="cuda"
|
578
|
-
)
|
579
|
-
|
580
584
|
kv_group_num = q.shape[1] // v_buffer.shape[1]
|
581
585
|
|
582
586
|
if kv_group_num == 1:
|
@@ -584,7 +588,7 @@ def decode_attention_fwd(
|
|
584
588
|
_decode_att_m_fwd(
|
585
589
|
q,
|
586
590
|
k_buffer,
|
587
|
-
|
591
|
+
attn_logits,
|
588
592
|
req_to_token,
|
589
593
|
b_req_idx,
|
590
594
|
b_start_loc,
|
@@ -594,7 +598,7 @@ def decode_attention_fwd(
|
|
594
598
|
logit_cap,
|
595
599
|
)
|
596
600
|
_decode_softmax_reducev_fwd(
|
597
|
-
|
601
|
+
attn_logits,
|
598
602
|
v_buffer,
|
599
603
|
o,
|
600
604
|
req_to_token,
|
@@ -607,7 +611,7 @@ def decode_attention_fwd(
|
|
607
611
|
_decode_grouped_att_m_fwd(
|
608
612
|
q,
|
609
613
|
k_buffer,
|
610
|
-
|
614
|
+
attn_logits,
|
611
615
|
req_to_token,
|
612
616
|
b_req_idx,
|
613
617
|
b_start_loc,
|
@@ -617,7 +621,7 @@ def decode_attention_fwd(
|
|
617
621
|
logit_cap,
|
618
622
|
)
|
619
623
|
_decode_grouped_softmax_reducev_fwd(
|
620
|
-
|
624
|
+
attn_logits,
|
621
625
|
v_buffer,
|
622
626
|
o,
|
623
627
|
req_to_token,
|