sglang 0.4.2.post1__py3-none-any.whl → 0.4.2.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/srt/constrained/outlines_backend.py +9 -1
- sglang/srt/custom_op.py +40 -0
- sglang/srt/entrypoints/engine.py +2 -2
- sglang/srt/function_call_parser.py +96 -69
- sglang/srt/layers/activation.py +10 -5
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -3
- sglang/srt/layers/attention/flashinfer_backend.py +284 -39
- sglang/srt/layers/attention/triton_backend.py +124 -12
- sglang/srt/layers/attention/triton_ops/decode_attention.py +53 -59
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +337 -3
- sglang/srt/layers/attention/triton_ops/extend_attention.py +70 -42
- sglang/srt/layers/layernorm.py +1 -5
- sglang/srt/layers/moe/ep_moe/layer.py +1 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +178 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +175 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -13
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -3
- sglang/srt/layers/moe/topk.py +4 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8_kernel.py +173 -2
- sglang/srt/layers/rotary_embedding.py +1 -3
- sglang/srt/layers/sampler.py +4 -4
- sglang/srt/lora/backend/__init__.py +8 -0
- sglang/srt/lora/backend/base_backend.py +95 -0
- sglang/srt/lora/backend/flashinfer_backend.py +91 -0
- sglang/srt/lora/backend/triton_backend.py +61 -0
- sglang/srt/lora/lora.py +127 -112
- sglang/srt/lora/lora_manager.py +50 -18
- sglang/srt/lora/triton_ops/__init__.py +5 -0
- sglang/srt/lora/triton_ops/qkv_lora_b.py +182 -0
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +143 -0
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +159 -0
- sglang/srt/model_executor/cuda_graph_runner.py +77 -80
- sglang/srt/model_executor/forward_batch_info.py +58 -59
- sglang/srt/model_executor/model_runner.py +2 -2
- sglang/srt/models/llama.py +8 -3
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/server_args.py +13 -2
- sglang/srt/speculative/build_eagle_tree.py +486 -104
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +213 -0
- sglang/srt/speculative/eagle_utils.py +420 -401
- sglang/srt/speculative/eagle_worker.py +177 -45
- sglang/srt/utils.py +7 -0
- sglang/test/runners.py +2 -0
- sglang/version.py +1 -1
- {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/METADATA +15 -6
- {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/RECORD +77 -38
- sglang/srt/layers/custom_op_util.py +0 -25
- {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/LICENSE +0 -0
- {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/WHEEL +0 -0
- {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/top_level.txt +0 -0
@@ -49,11 +49,9 @@ def _fwd_kernel_stage1(
|
|
49
49
|
K_Buffer,
|
50
50
|
V_Buffer,
|
51
51
|
sm_scale,
|
52
|
-
|
53
|
-
|
54
|
-
B_Seqlen,
|
52
|
+
kv_indptr,
|
53
|
+
kv_indices,
|
55
54
|
Att_Out,
|
56
|
-
stride_req_to_tokens_b,
|
57
55
|
stride_qbs,
|
58
56
|
stride_qh,
|
59
57
|
stride_buf_kbs,
|
@@ -82,8 +80,9 @@ def _fwd_kernel_stage1(
|
|
82
80
|
offs_dv = tl.arange(0, BLOCK_DV)
|
83
81
|
mask_d = offs_d < Lk
|
84
82
|
mask_dv = offs_dv < Lv
|
85
|
-
|
86
|
-
|
83
|
+
|
84
|
+
cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch)
|
85
|
+
cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx
|
87
86
|
|
88
87
|
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
|
89
88
|
q = tl.load(Q + off_q, mask=mask_d, other=0.0)
|
@@ -100,7 +99,7 @@ def _fwd_kernel_stage1(
|
|
100
99
|
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
|
101
100
|
offs_n = start_n + tl.arange(0, BLOCK_N)
|
102
101
|
kv_loc = tl.load(
|
103
|
-
|
102
|
+
kv_indices + cur_batch_kv_start_idx + offs_n,
|
104
103
|
mask=offs_n < split_kv_end,
|
105
104
|
other=0,
|
106
105
|
)
|
@@ -173,19 +172,21 @@ def _decode_att_m_fwd(
|
|
173
172
|
k_buffer,
|
174
173
|
v_buffer,
|
175
174
|
att_out,
|
176
|
-
|
177
|
-
|
178
|
-
B_Seqlen,
|
175
|
+
kv_indptr,
|
176
|
+
kv_indices,
|
179
177
|
num_kv_splits,
|
180
178
|
sm_scale,
|
181
179
|
logit_cap,
|
182
180
|
):
|
183
181
|
BLOCK = 64
|
182
|
+
# [TODO] work around SGPR limit on MI3xx
|
183
|
+
if is_hip_:
|
184
|
+
BLOCK = 8
|
184
185
|
NUM_KV_SPLITS = num_kv_splits
|
185
186
|
Lk = k_buffer.shape[-1]
|
186
187
|
Lv = v_buffer.shape[-1]
|
187
188
|
|
188
|
-
batch, head_num =
|
189
|
+
batch, head_num = kv_indptr.shape[0] - 1, q.shape[1]
|
189
190
|
|
190
191
|
grid = (batch, head_num, NUM_KV_SPLITS)
|
191
192
|
kv_group_num = q.shape[1] // k_buffer.shape[1]
|
@@ -194,6 +195,8 @@ def _decode_att_m_fwd(
|
|
194
195
|
num_warps = 4
|
195
196
|
else:
|
196
197
|
num_warps = 2
|
198
|
+
if is_hip_:
|
199
|
+
num_warps = 1
|
197
200
|
|
198
201
|
BLOCK_DMODEL = triton.next_power_of_2(Lk)
|
199
202
|
BLOCK_DV = triton.next_power_of_2(Lv)
|
@@ -203,11 +206,9 @@ def _decode_att_m_fwd(
|
|
203
206
|
k_buffer,
|
204
207
|
v_buffer,
|
205
208
|
sm_scale,
|
206
|
-
|
207
|
-
|
208
|
-
B_Seqlen,
|
209
|
+
kv_indptr,
|
210
|
+
kv_indices,
|
209
211
|
att_out,
|
210
|
-
Req_to_tokens.stride(0),
|
211
212
|
q.stride(0),
|
212
213
|
q.stride(1),
|
213
214
|
k_buffer.stride(0),
|
@@ -236,11 +237,9 @@ def _fwd_grouped_kernel_stage1(
|
|
236
237
|
K_Buffer,
|
237
238
|
V_Buffer,
|
238
239
|
sm_scale,
|
239
|
-
|
240
|
-
|
241
|
-
B_Seqlen,
|
240
|
+
kv_indptr,
|
241
|
+
kv_indices,
|
242
242
|
Att_Out,
|
243
|
-
stride_req_to_tokens_b,
|
244
243
|
stride_qbs,
|
245
244
|
stride_qh,
|
246
245
|
stride_buf_kbs,
|
@@ -279,8 +278,9 @@ def _fwd_grouped_kernel_stage1(
|
|
279
278
|
offs_dv = tl.arange(0, BLOCK_DV)
|
280
279
|
mask_d = offs_d < Lk
|
281
280
|
mask_dv = offs_dv < Lv
|
282
|
-
|
283
|
-
|
281
|
+
|
282
|
+
cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch)
|
283
|
+
cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx
|
284
284
|
|
285
285
|
offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
|
286
286
|
q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0)
|
@@ -307,7 +307,7 @@ def _fwd_grouped_kernel_stage1(
|
|
307
307
|
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
|
308
308
|
offs_n = start_n + tl.arange(0, BLOCK_N)
|
309
309
|
kv_loc = tl.load(
|
310
|
-
|
310
|
+
kv_indices + cur_batch_kv_start_idx + offs_n,
|
311
311
|
mask=offs_n < split_kv_end,
|
312
312
|
other=0,
|
313
313
|
)
|
@@ -395,9 +395,8 @@ def _decode_grouped_att_m_fwd(
|
|
395
395
|
k_buffer,
|
396
396
|
v_buffer,
|
397
397
|
att_out,
|
398
|
-
|
399
|
-
|
400
|
-
B_Seqlen,
|
398
|
+
kv_indptr,
|
399
|
+
kv_indices,
|
401
400
|
num_kv_splits,
|
402
401
|
sm_scale,
|
403
402
|
logit_cap,
|
@@ -421,7 +420,7 @@ def _decode_grouped_att_m_fwd(
|
|
421
420
|
BLOCK_DPE = 0
|
422
421
|
BLOCK_DV = triton.next_power_of_2(Lv)
|
423
422
|
|
424
|
-
batch, head_num =
|
423
|
+
batch, head_num = kv_indptr.shape[0] - 1, q.shape[1]
|
425
424
|
kv_group_num = q.shape[1] // k_buffer.shape[1]
|
426
425
|
|
427
426
|
BLOCK_H = 16
|
@@ -433,21 +432,21 @@ def _decode_grouped_att_m_fwd(
|
|
433
432
|
)
|
434
433
|
|
435
434
|
extra_kargs = {}
|
435
|
+
num_stages = 2
|
436
436
|
if is_hip_:
|
437
437
|
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
|
438
438
|
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
|
439
|
-
extra_kargs = {"waves_per_eu":
|
439
|
+
extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2}
|
440
|
+
num_stages = 1
|
440
441
|
|
441
442
|
_fwd_grouped_kernel_stage1[grid](
|
442
443
|
q,
|
443
444
|
k_buffer,
|
444
445
|
v_buffer,
|
445
446
|
sm_scale,
|
446
|
-
|
447
|
-
|
448
|
-
B_Seqlen,
|
447
|
+
kv_indptr,
|
448
|
+
kv_indices,
|
449
449
|
att_out,
|
450
|
-
Req_to_tokens.stride(0),
|
451
450
|
q.stride(0),
|
452
451
|
q.stride(1),
|
453
452
|
k_buffer.stride(0),
|
@@ -467,7 +466,7 @@ def _decode_grouped_att_m_fwd(
|
|
467
466
|
NUM_KV_SPLITS=NUM_KV_SPLITS,
|
468
467
|
logit_cap=logit_cap,
|
469
468
|
num_warps=4,
|
470
|
-
num_stages=
|
469
|
+
num_stages=num_stages,
|
471
470
|
Lk=Lk,
|
472
471
|
Lv=Lv,
|
473
472
|
**extra_kargs,
|
@@ -478,7 +477,7 @@ def _decode_grouped_att_m_fwd(
|
|
478
477
|
def _fwd_kernel_stage2(
|
479
478
|
Mid_O,
|
480
479
|
O,
|
481
|
-
|
480
|
+
kv_indptr,
|
482
481
|
stride_mid_ob,
|
483
482
|
stride_mid_oh,
|
484
483
|
stride_mid_os,
|
@@ -491,7 +490,9 @@ def _fwd_kernel_stage2(
|
|
491
490
|
cur_batch = tl.program_id(0)
|
492
491
|
cur_head = tl.program_id(1)
|
493
492
|
|
494
|
-
cur_batch_seq_len = tl.load(
|
493
|
+
cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - tl.load(
|
494
|
+
kv_indptr + cur_batch
|
495
|
+
)
|
495
496
|
|
496
497
|
offs_d = tl.arange(0, BLOCK_DV)
|
497
498
|
mask_d = offs_d < Lv
|
@@ -535,7 +536,7 @@ def _decode_softmax_reducev_fwd(
|
|
535
536
|
q,
|
536
537
|
o,
|
537
538
|
v_buffer,
|
538
|
-
|
539
|
+
kv_indptr,
|
539
540
|
num_kv_splits,
|
540
541
|
):
|
541
542
|
batch, head_num = q.shape[0], q.shape[1]
|
@@ -554,7 +555,7 @@ def _decode_softmax_reducev_fwd(
|
|
554
555
|
_fwd_kernel_stage2[grid](
|
555
556
|
logits,
|
556
557
|
o,
|
557
|
-
|
558
|
+
kv_indptr,
|
558
559
|
logits.stride(0),
|
559
560
|
logits.stride(1),
|
560
561
|
logits.stride(2),
|
@@ -574,9 +575,8 @@ def decode_attention_fwd_normal(
|
|
574
575
|
k_buffer,
|
575
576
|
v_buffer,
|
576
577
|
o,
|
577
|
-
|
578
|
-
|
579
|
-
b_seq_len,
|
578
|
+
kv_indptr,
|
579
|
+
kv_indices,
|
580
580
|
attn_logits,
|
581
581
|
num_kv_splits,
|
582
582
|
sm_scale,
|
@@ -587,14 +587,13 @@ def decode_attention_fwd_normal(
|
|
587
587
|
k_buffer,
|
588
588
|
v_buffer,
|
589
589
|
attn_logits,
|
590
|
-
|
591
|
-
|
592
|
-
b_seq_len,
|
590
|
+
kv_indptr,
|
591
|
+
kv_indices,
|
593
592
|
num_kv_splits,
|
594
593
|
sm_scale,
|
595
594
|
logit_cap,
|
596
595
|
)
|
597
|
-
_decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer,
|
596
|
+
_decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, kv_indptr, num_kv_splits)
|
598
597
|
|
599
598
|
|
600
599
|
def decode_attention_fwd_grouped(
|
@@ -602,9 +601,8 @@ def decode_attention_fwd_grouped(
|
|
602
601
|
k_buffer,
|
603
602
|
v_buffer,
|
604
603
|
o,
|
605
|
-
|
606
|
-
|
607
|
-
b_seq_len,
|
604
|
+
kv_indptr,
|
605
|
+
kv_indices,
|
608
606
|
attn_logits,
|
609
607
|
num_kv_splits,
|
610
608
|
sm_scale,
|
@@ -615,14 +613,13 @@ def decode_attention_fwd_grouped(
|
|
615
613
|
k_buffer,
|
616
614
|
v_buffer,
|
617
615
|
attn_logits,
|
618
|
-
|
619
|
-
|
620
|
-
b_seq_len,
|
616
|
+
kv_indptr,
|
617
|
+
kv_indices,
|
621
618
|
num_kv_splits,
|
622
619
|
sm_scale,
|
623
620
|
logit_cap,
|
624
621
|
)
|
625
|
-
_decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer,
|
622
|
+
_decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, kv_indptr, num_kv_splits)
|
626
623
|
|
627
624
|
|
628
625
|
def decode_attention_fwd(
|
@@ -630,9 +627,8 @@ def decode_attention_fwd(
|
|
630
627
|
k_buffer,
|
631
628
|
v_buffer,
|
632
629
|
o,
|
633
|
-
|
634
|
-
|
635
|
-
b_seq_len,
|
630
|
+
kv_indptr,
|
631
|
+
kv_indices,
|
636
632
|
attn_logits,
|
637
633
|
num_kv_splits,
|
638
634
|
sm_scale,
|
@@ -648,9 +644,8 @@ def decode_attention_fwd(
|
|
648
644
|
k_buffer,
|
649
645
|
v_buffer,
|
650
646
|
o,
|
651
|
-
|
652
|
-
|
653
|
-
b_seq_len,
|
647
|
+
kv_indptr,
|
648
|
+
kv_indices,
|
654
649
|
attn_logits,
|
655
650
|
num_kv_splits,
|
656
651
|
sm_scale,
|
@@ -663,9 +658,8 @@ def decode_attention_fwd(
|
|
663
658
|
k_buffer,
|
664
659
|
v_buffer,
|
665
660
|
o,
|
666
|
-
|
667
|
-
|
668
|
-
b_seq_len,
|
661
|
+
kv_indptr,
|
662
|
+
kv_indices,
|
669
663
|
attn_logits,
|
670
664
|
num_kv_splits,
|
671
665
|
sm_scale,
|
@@ -3,6 +3,13 @@ import triton
|
|
3
3
|
import triton.language as tl
|
4
4
|
|
5
5
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
6
|
+
from sglang.srt.utils import is_hip
|
7
|
+
|
8
|
+
is_cuda_available = torch.cuda.is_available()
|
9
|
+
if is_cuda_available:
|
10
|
+
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
11
|
+
|
12
|
+
is_hip_ = is_hip()
|
6
13
|
|
7
14
|
if global_server_args_dict.get("attention_reduce_in_fp32", False):
|
8
15
|
REDUCE_TRITON_TYPE = tl.float32
|
@@ -274,9 +281,6 @@ def flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, O, block_seq):
|
|
274
281
|
return
|
275
282
|
|
276
283
|
|
277
|
-
import torch
|
278
|
-
|
279
|
-
|
280
284
|
def flash_decode_attention_fwd(
|
281
285
|
q,
|
282
286
|
k_buffer,
|
@@ -770,3 +774,333 @@ def flash_decode_sparse_attention_fwd(
|
|
770
774
|
)
|
771
775
|
|
772
776
|
sparse_flash_decode_stage3(heavy_token_num, mid_out, mid_o_logexpsum, o, BLOCK_SEQ)
|
777
|
+
|
778
|
+
|
779
|
+
# Extend attention kernel for Double Sparsity
|
780
|
+
# Moved from https://github.com/sgl-project/sglang/blob/v0.4.2.post1/python/sglang/srt/layers/attention/triton_ops/extend_attention.py
|
781
|
+
@triton.jit
|
782
|
+
def _fwd_kernel(
|
783
|
+
Q_Extend,
|
784
|
+
K_Extend,
|
785
|
+
V_Extend,
|
786
|
+
O_Extend,
|
787
|
+
K_Buffer,
|
788
|
+
V_Buffer,
|
789
|
+
Req_to_tokens,
|
790
|
+
B_req_idx,
|
791
|
+
B_Seq_Len,
|
792
|
+
B_Start_Loc_Extend,
|
793
|
+
B_Seq_Len_Extend,
|
794
|
+
sm_scale,
|
795
|
+
kv_group_num,
|
796
|
+
stride_qbs,
|
797
|
+
stride_qh,
|
798
|
+
stride_kbs,
|
799
|
+
stride_kh,
|
800
|
+
stride_vbs,
|
801
|
+
stride_vh,
|
802
|
+
stride_obs,
|
803
|
+
stride_oh,
|
804
|
+
stride_buf_kbs,
|
805
|
+
stride_buf_kh,
|
806
|
+
stride_buf_vbs,
|
807
|
+
stride_buf_vh,
|
808
|
+
stride_req_to_tokens_b,
|
809
|
+
logit_cap: tl.constexpr,
|
810
|
+
Lq: tl.constexpr,
|
811
|
+
Lv: tl.constexpr,
|
812
|
+
BLOCK_DMODEL: tl.constexpr,
|
813
|
+
BLOCK_DPE: tl.constexpr,
|
814
|
+
BLOCK_DV: tl.constexpr,
|
815
|
+
BLOCK_M: tl.constexpr,
|
816
|
+
BLOCK_N: tl.constexpr,
|
817
|
+
):
|
818
|
+
cur_seq = tl.program_id(0)
|
819
|
+
cur_head = tl.program_id(1)
|
820
|
+
cur_block_m = tl.program_id(2)
|
821
|
+
cur_kv_head = cur_head // kv_group_num
|
822
|
+
|
823
|
+
cur_seq_len = tl.load(B_Seq_Len + cur_seq)
|
824
|
+
cur_seq_len_extend = tl.load(B_Seq_Len_Extend + cur_seq)
|
825
|
+
cur_seq_len_prefix = cur_seq_len - cur_seq_len_extend
|
826
|
+
|
827
|
+
cur_seq_prefix_start_in_loc = 0
|
828
|
+
cur_seq_extend_start_contiguous = tl.load(B_Start_Loc_Extend + cur_seq)
|
829
|
+
cur_batch_req_idx = tl.load(B_req_idx + cur_seq)
|
830
|
+
|
831
|
+
offs_d = tl.arange(0, BLOCK_DMODEL)
|
832
|
+
offs_dv = tl.arange(0, BLOCK_DV)
|
833
|
+
offs_m = tl.arange(0, BLOCK_M)
|
834
|
+
mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_len_extend
|
835
|
+
|
836
|
+
mask_d = offs_d < Lq
|
837
|
+
mask_dv = offs_dv < Lv
|
838
|
+
|
839
|
+
offs_q = (
|
840
|
+
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
|
841
|
+
* stride_qbs
|
842
|
+
+ cur_head * stride_qh
|
843
|
+
+ offs_d[None, :]
|
844
|
+
)
|
845
|
+
q = tl.load(
|
846
|
+
Q_Extend + offs_q, mask=(mask_m[:, None]) & (mask_d[None, :]), other=0.0
|
847
|
+
)
|
848
|
+
|
849
|
+
if BLOCK_DPE > 0:
|
850
|
+
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
|
851
|
+
offs_qpe = (
|
852
|
+
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
|
853
|
+
* stride_qbs
|
854
|
+
+ cur_head * stride_qh
|
855
|
+
+ offs_dpe[None, :]
|
856
|
+
)
|
857
|
+
qpe = tl.load(Q_Extend + offs_qpe, mask=mask_m[:, None], other=0.0)
|
858
|
+
|
859
|
+
# stage 1: compute scores with prefix
|
860
|
+
offs_n = tl.arange(0, BLOCK_N)
|
861
|
+
|
862
|
+
acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32)
|
863
|
+
deno = tl.zeros([BLOCK_M], dtype=tl.float32)
|
864
|
+
e_max = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
865
|
+
|
866
|
+
for start_n in range(0, cur_seq_len_prefix, BLOCK_N):
|
867
|
+
start_n = tl.multiple_of(start_n, BLOCK_N)
|
868
|
+
mask_n = (start_n + offs_n) < cur_seq_len_prefix
|
869
|
+
offs_b_loc_prefix = cur_batch_req_idx * stride_req_to_tokens_b + (
|
870
|
+
cur_seq_prefix_start_in_loc + start_n + offs_n
|
871
|
+
)
|
872
|
+
offs_kv_loc = tl.load(Req_to_tokens + offs_b_loc_prefix, mask=mask_n, other=0)
|
873
|
+
|
874
|
+
# load k in transposed way
|
875
|
+
offs_buf_k = (
|
876
|
+
offs_kv_loc[None, :] * stride_buf_kbs
|
877
|
+
+ cur_kv_head * stride_buf_kh
|
878
|
+
+ offs_d[:, None]
|
879
|
+
)
|
880
|
+
k = tl.load(
|
881
|
+
K_Buffer + offs_buf_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0
|
882
|
+
)
|
883
|
+
|
884
|
+
qk = tl.dot(q.to(k.dtype), k)
|
885
|
+
if BLOCK_DPE > 0:
|
886
|
+
offs_kpe = (
|
887
|
+
offs_kv_loc[None, :] * stride_buf_kbs
|
888
|
+
+ cur_kv_head * stride_buf_kh
|
889
|
+
+ offs_dpe[:, None]
|
890
|
+
)
|
891
|
+
kpe = tl.load(
|
892
|
+
K_Buffer + offs_kpe,
|
893
|
+
mask=mask_n[None, :],
|
894
|
+
other=0.0,
|
895
|
+
)
|
896
|
+
qk += tl.dot(qpe.to(kpe.dtype), kpe)
|
897
|
+
qk *= sm_scale
|
898
|
+
|
899
|
+
if logit_cap > 0:
|
900
|
+
qk = logit_cap * tanh(qk / logit_cap)
|
901
|
+
|
902
|
+
qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float("-inf"))
|
903
|
+
|
904
|
+
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
905
|
+
re_scale = tl.exp(e_max - n_e_max)
|
906
|
+
p = tl.exp(qk - n_e_max[:, None])
|
907
|
+
deno = deno * re_scale + tl.sum(p, 1)
|
908
|
+
|
909
|
+
offs_buf_v = (
|
910
|
+
offs_kv_loc[:, None] * stride_buf_vbs
|
911
|
+
+ cur_kv_head * stride_buf_vh
|
912
|
+
+ offs_dv[None, :]
|
913
|
+
)
|
914
|
+
v = tl.load(
|
915
|
+
V_Buffer + offs_buf_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0
|
916
|
+
)
|
917
|
+
p = p.to(v.dtype)
|
918
|
+
acc = acc * re_scale[:, None] + tl.dot(p, v)
|
919
|
+
|
920
|
+
e_max = n_e_max
|
921
|
+
|
922
|
+
# stage 2: compute the trianlge part
|
923
|
+
|
924
|
+
cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
|
925
|
+
for start_n in range(0, cur_block_m_end, BLOCK_N):
|
926
|
+
start_n = tl.multiple_of(start_n, BLOCK_N)
|
927
|
+
mask_n = (start_n + offs_n) < cur_block_m_end
|
928
|
+
|
929
|
+
# load k in transposed way
|
930
|
+
offs_k = (
|
931
|
+
(cur_seq_extend_start_contiguous + start_n + offs_n[None, :]) * stride_kbs
|
932
|
+
+ cur_kv_head * stride_kh
|
933
|
+
+ offs_d[:, None]
|
934
|
+
)
|
935
|
+
k = tl.load(
|
936
|
+
K_Extend + offs_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0
|
937
|
+
)
|
938
|
+
|
939
|
+
qk = tl.dot(q, k, out_dtype=tl.float32)
|
940
|
+
if BLOCK_DPE > 0:
|
941
|
+
offs_kpe = (
|
942
|
+
(cur_seq_extend_start_contiguous + start_n + offs_n[None, :])
|
943
|
+
* stride_kbs
|
944
|
+
+ cur_kv_head * stride_kh
|
945
|
+
+ offs_dpe[:, None]
|
946
|
+
)
|
947
|
+
kpe = tl.load(
|
948
|
+
K_Extend + offs_kpe,
|
949
|
+
mask=mask_n[None, :],
|
950
|
+
other=0.0,
|
951
|
+
)
|
952
|
+
qk += tl.dot(qpe, kpe)
|
953
|
+
|
954
|
+
qk *= sm_scale
|
955
|
+
|
956
|
+
if logit_cap > 0:
|
957
|
+
qk = logit_cap * tanh(qk / logit_cap)
|
958
|
+
|
959
|
+
mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
|
960
|
+
start_n + offs_n[None, :]
|
961
|
+
)
|
962
|
+
mask_causual &= mask_m[:, None] & mask_n[None, :]
|
963
|
+
qk = tl.where(mask_causual, qk, float("-inf"))
|
964
|
+
|
965
|
+
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
966
|
+
re_scale = tl.exp(e_max - n_e_max)
|
967
|
+
p = tl.exp(qk - n_e_max[:, None])
|
968
|
+
deno = deno * re_scale + tl.sum(p, 1)
|
969
|
+
|
970
|
+
offs_v = (
|
971
|
+
(cur_seq_extend_start_contiguous + start_n + offs_n[:, None]) * stride_vbs
|
972
|
+
+ cur_kv_head * stride_vh
|
973
|
+
+ offs_dv[None, :]
|
974
|
+
)
|
975
|
+
v = tl.load(
|
976
|
+
V_Extend + offs_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0
|
977
|
+
)
|
978
|
+
p = p.to(v.dtype)
|
979
|
+
acc = acc * re_scale[:, None] + tl.dot(p, v)
|
980
|
+
|
981
|
+
e_max = n_e_max
|
982
|
+
|
983
|
+
offs_o = (
|
984
|
+
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
|
985
|
+
* stride_obs
|
986
|
+
+ cur_head * stride_oh
|
987
|
+
+ offs_dv[None, :]
|
988
|
+
)
|
989
|
+
tl.store(
|
990
|
+
O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None] & mask_dv[None, :]
|
991
|
+
)
|
992
|
+
|
993
|
+
|
994
|
+
def extend_attention_fwd(
|
995
|
+
q_extend,
|
996
|
+
k_extend,
|
997
|
+
v_extend,
|
998
|
+
o_extend,
|
999
|
+
k_buffer,
|
1000
|
+
v_buffer,
|
1001
|
+
req_to_tokens,
|
1002
|
+
b_req_idx,
|
1003
|
+
b_seq_len,
|
1004
|
+
b_seq_len_extend,
|
1005
|
+
b_start_loc_extend,
|
1006
|
+
max_len_extend,
|
1007
|
+
sm_scale=None,
|
1008
|
+
logit_cap=0.0,
|
1009
|
+
):
|
1010
|
+
"""
|
1011
|
+
q_extend, k_extend, v_extend, o_extend: contiguous tensors
|
1012
|
+
|
1013
|
+
k_buffer, v_buffer: (prefix + extend) tensors in mem_manager
|
1014
|
+
"""
|
1015
|
+
Lq, Lk, Lv = (
|
1016
|
+
q_extend.shape[-1],
|
1017
|
+
k_extend.shape[-1],
|
1018
|
+
v_extend.shape[-1],
|
1019
|
+
)
|
1020
|
+
|
1021
|
+
if Lq == 576:
|
1022
|
+
BLOCK_DMODEL = 512
|
1023
|
+
BLOCK_DPE = 64
|
1024
|
+
elif Lq == 288:
|
1025
|
+
BLOCK_DMODEL = 256
|
1026
|
+
BLOCK_DPE = 32
|
1027
|
+
elif Lq == 192:
|
1028
|
+
BLOCK_DMODEL = 128
|
1029
|
+
BLOCK_DPE = 64
|
1030
|
+
else:
|
1031
|
+
BLOCK_DMODEL = triton.next_power_of_2(Lq)
|
1032
|
+
BLOCK_DPE = 0
|
1033
|
+
BLOCK_DV = triton.next_power_of_2(Lv)
|
1034
|
+
|
1035
|
+
if is_hip_:
|
1036
|
+
BLOCK_M, BLOCK_N = (64, 64)
|
1037
|
+
num_warps = 4
|
1038
|
+
|
1039
|
+
else:
|
1040
|
+
if is_cuda_available and CUDA_CAPABILITY[0] >= 9:
|
1041
|
+
if Lq <= 256:
|
1042
|
+
BLOCK_M, BLOCK_N = (128, 64)
|
1043
|
+
else:
|
1044
|
+
BLOCK_M, BLOCK_N = (32, 64)
|
1045
|
+
elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
|
1046
|
+
if Lq <= 128:
|
1047
|
+
BLOCK_M, BLOCK_N = (128, 128)
|
1048
|
+
elif Lq <= 256:
|
1049
|
+
BLOCK_M, BLOCK_N = (64, 64)
|
1050
|
+
else:
|
1051
|
+
BLOCK_M, BLOCK_N = (32, 64)
|
1052
|
+
else:
|
1053
|
+
BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)
|
1054
|
+
|
1055
|
+
num_warps = 4 if Lk <= 64 else 8
|
1056
|
+
|
1057
|
+
sm_scale = sm_scale or 1.0 / (Lq**0.5)
|
1058
|
+
batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1]
|
1059
|
+
kv_group_num = q_extend.shape[1] // k_extend.shape[1]
|
1060
|
+
|
1061
|
+
grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
|
1062
|
+
num_stages = 1
|
1063
|
+
|
1064
|
+
extra_kargs = {}
|
1065
|
+
if is_hip_:
|
1066
|
+
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
|
1067
|
+
|
1068
|
+
_fwd_kernel[grid](
|
1069
|
+
q_extend,
|
1070
|
+
k_extend,
|
1071
|
+
v_extend,
|
1072
|
+
o_extend,
|
1073
|
+
k_buffer,
|
1074
|
+
v_buffer,
|
1075
|
+
req_to_tokens,
|
1076
|
+
b_req_idx,
|
1077
|
+
b_seq_len,
|
1078
|
+
b_start_loc_extend,
|
1079
|
+
b_seq_len_extend,
|
1080
|
+
sm_scale,
|
1081
|
+
kv_group_num,
|
1082
|
+
q_extend.stride(0),
|
1083
|
+
q_extend.stride(1),
|
1084
|
+
k_extend.stride(0),
|
1085
|
+
k_extend.stride(1),
|
1086
|
+
v_extend.stride(0),
|
1087
|
+
v_extend.stride(1),
|
1088
|
+
o_extend.stride(0),
|
1089
|
+
o_extend.stride(1),
|
1090
|
+
k_buffer.stride(0),
|
1091
|
+
k_buffer.stride(1),
|
1092
|
+
v_buffer.stride(0),
|
1093
|
+
v_buffer.stride(1),
|
1094
|
+
req_to_tokens.stride(0),
|
1095
|
+
logit_cap=logit_cap,
|
1096
|
+
BLOCK_DMODEL=BLOCK_DMODEL,
|
1097
|
+
BLOCK_DPE=BLOCK_DPE,
|
1098
|
+
BLOCK_DV=BLOCK_DV,
|
1099
|
+
BLOCK_M=BLOCK_M,
|
1100
|
+
BLOCK_N=BLOCK_N,
|
1101
|
+
Lq=Lq,
|
1102
|
+
Lv=Lv,
|
1103
|
+
num_warps=num_warps,
|
1104
|
+
num_stages=num_stages,
|
1105
|
+
**extra_kargs,
|
1106
|
+
)
|