sglang 0.4.2.post2__py3-none-any.whl → 0.4.2.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/check_env.py +1 -0
- sglang/srt/constrained/outlines_backend.py +4 -1
- sglang/srt/function_call_parser.py +96 -69
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -3
- sglang/srt/layers/attention/flashinfer_backend.py +34 -41
- sglang/srt/layers/attention/triton_backend.py +64 -16
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +337 -3
- sglang/srt/layers/attention/triton_ops/extend_attention.py +70 -42
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +20 -5
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8_kernel.py +43 -10
- sglang/srt/lora/backend/__init__.py +25 -5
- sglang/srt/lora/backend/base_backend.py +31 -9
- sglang/srt/lora/backend/flashinfer_backend.py +41 -4
- sglang/srt/lora/backend/triton_backend.py +34 -4
- sglang/srt/lora/layers.py +293 -0
- sglang/srt/lora/lora.py +101 -326
- sglang/srt/lora/lora_manager.py +101 -269
- sglang/srt/lora/mem_pool.py +174 -0
- sglang/srt/lora/triton_ops/__init__.py +7 -1
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +170 -0
- sglang/srt/lora/triton_ops/qkv_lora_b.py +5 -5
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +2 -2
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +2 -2
- sglang/srt/lora/utils.py +141 -0
- sglang/srt/model_executor/cuda_graph_runner.py +4 -0
- sglang/srt/models/llama.py +8 -3
- sglang/srt/speculative/build_eagle_tree.py +482 -102
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
- sglang/srt/speculative/eagle_utils.py +134 -61
- sglang/srt/speculative/eagle_worker.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post4.dist-info}/METADATA +4 -4
- {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post4.dist-info}/RECORD +49 -32
- {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post4.dist-info}/LICENSE +0 -0
- {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post4.dist-info}/top_level.txt +0 -0
@@ -3,6 +3,13 @@ import triton
|
|
3
3
|
import triton.language as tl
|
4
4
|
|
5
5
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
6
|
+
from sglang.srt.utils import is_hip
|
7
|
+
|
8
|
+
is_cuda_available = torch.cuda.is_available()
|
9
|
+
if is_cuda_available:
|
10
|
+
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
11
|
+
|
12
|
+
is_hip_ = is_hip()
|
6
13
|
|
7
14
|
if global_server_args_dict.get("attention_reduce_in_fp32", False):
|
8
15
|
REDUCE_TRITON_TYPE = tl.float32
|
@@ -274,9 +281,6 @@ def flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, O, block_seq):
|
|
274
281
|
return
|
275
282
|
|
276
283
|
|
277
|
-
import torch
|
278
|
-
|
279
|
-
|
280
284
|
def flash_decode_attention_fwd(
|
281
285
|
q,
|
282
286
|
k_buffer,
|
@@ -770,3 +774,333 @@ def flash_decode_sparse_attention_fwd(
|
|
770
774
|
)
|
771
775
|
|
772
776
|
sparse_flash_decode_stage3(heavy_token_num, mid_out, mid_o_logexpsum, o, BLOCK_SEQ)
|
777
|
+
|
778
|
+
|
779
|
+
# Extend attention kernel for Double Sparsity
|
780
|
+
# Moved from https://github.com/sgl-project/sglang/blob/v0.4.2.post1/python/sglang/srt/layers/attention/triton_ops/extend_attention.py
|
781
|
+
@triton.jit
|
782
|
+
def _fwd_kernel(
|
783
|
+
Q_Extend,
|
784
|
+
K_Extend,
|
785
|
+
V_Extend,
|
786
|
+
O_Extend,
|
787
|
+
K_Buffer,
|
788
|
+
V_Buffer,
|
789
|
+
Req_to_tokens,
|
790
|
+
B_req_idx,
|
791
|
+
B_Seq_Len,
|
792
|
+
B_Start_Loc_Extend,
|
793
|
+
B_Seq_Len_Extend,
|
794
|
+
sm_scale,
|
795
|
+
kv_group_num,
|
796
|
+
stride_qbs,
|
797
|
+
stride_qh,
|
798
|
+
stride_kbs,
|
799
|
+
stride_kh,
|
800
|
+
stride_vbs,
|
801
|
+
stride_vh,
|
802
|
+
stride_obs,
|
803
|
+
stride_oh,
|
804
|
+
stride_buf_kbs,
|
805
|
+
stride_buf_kh,
|
806
|
+
stride_buf_vbs,
|
807
|
+
stride_buf_vh,
|
808
|
+
stride_req_to_tokens_b,
|
809
|
+
logit_cap: tl.constexpr,
|
810
|
+
Lq: tl.constexpr,
|
811
|
+
Lv: tl.constexpr,
|
812
|
+
BLOCK_DMODEL: tl.constexpr,
|
813
|
+
BLOCK_DPE: tl.constexpr,
|
814
|
+
BLOCK_DV: tl.constexpr,
|
815
|
+
BLOCK_M: tl.constexpr,
|
816
|
+
BLOCK_N: tl.constexpr,
|
817
|
+
):
|
818
|
+
cur_seq = tl.program_id(0)
|
819
|
+
cur_head = tl.program_id(1)
|
820
|
+
cur_block_m = tl.program_id(2)
|
821
|
+
cur_kv_head = cur_head // kv_group_num
|
822
|
+
|
823
|
+
cur_seq_len = tl.load(B_Seq_Len + cur_seq)
|
824
|
+
cur_seq_len_extend = tl.load(B_Seq_Len_Extend + cur_seq)
|
825
|
+
cur_seq_len_prefix = cur_seq_len - cur_seq_len_extend
|
826
|
+
|
827
|
+
cur_seq_prefix_start_in_loc = 0
|
828
|
+
cur_seq_extend_start_contiguous = tl.load(B_Start_Loc_Extend + cur_seq)
|
829
|
+
cur_batch_req_idx = tl.load(B_req_idx + cur_seq)
|
830
|
+
|
831
|
+
offs_d = tl.arange(0, BLOCK_DMODEL)
|
832
|
+
offs_dv = tl.arange(0, BLOCK_DV)
|
833
|
+
offs_m = tl.arange(0, BLOCK_M)
|
834
|
+
mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_len_extend
|
835
|
+
|
836
|
+
mask_d = offs_d < Lq
|
837
|
+
mask_dv = offs_dv < Lv
|
838
|
+
|
839
|
+
offs_q = (
|
840
|
+
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
|
841
|
+
* stride_qbs
|
842
|
+
+ cur_head * stride_qh
|
843
|
+
+ offs_d[None, :]
|
844
|
+
)
|
845
|
+
q = tl.load(
|
846
|
+
Q_Extend + offs_q, mask=(mask_m[:, None]) & (mask_d[None, :]), other=0.0
|
847
|
+
)
|
848
|
+
|
849
|
+
if BLOCK_DPE > 0:
|
850
|
+
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
|
851
|
+
offs_qpe = (
|
852
|
+
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
|
853
|
+
* stride_qbs
|
854
|
+
+ cur_head * stride_qh
|
855
|
+
+ offs_dpe[None, :]
|
856
|
+
)
|
857
|
+
qpe = tl.load(Q_Extend + offs_qpe, mask=mask_m[:, None], other=0.0)
|
858
|
+
|
859
|
+
# stage 1: compute scores with prefix
|
860
|
+
offs_n = tl.arange(0, BLOCK_N)
|
861
|
+
|
862
|
+
acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32)
|
863
|
+
deno = tl.zeros([BLOCK_M], dtype=tl.float32)
|
864
|
+
e_max = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
865
|
+
|
866
|
+
for start_n in range(0, cur_seq_len_prefix, BLOCK_N):
|
867
|
+
start_n = tl.multiple_of(start_n, BLOCK_N)
|
868
|
+
mask_n = (start_n + offs_n) < cur_seq_len_prefix
|
869
|
+
offs_b_loc_prefix = cur_batch_req_idx * stride_req_to_tokens_b + (
|
870
|
+
cur_seq_prefix_start_in_loc + start_n + offs_n
|
871
|
+
)
|
872
|
+
offs_kv_loc = tl.load(Req_to_tokens + offs_b_loc_prefix, mask=mask_n, other=0)
|
873
|
+
|
874
|
+
# load k in transposed way
|
875
|
+
offs_buf_k = (
|
876
|
+
offs_kv_loc[None, :] * stride_buf_kbs
|
877
|
+
+ cur_kv_head * stride_buf_kh
|
878
|
+
+ offs_d[:, None]
|
879
|
+
)
|
880
|
+
k = tl.load(
|
881
|
+
K_Buffer + offs_buf_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0
|
882
|
+
)
|
883
|
+
|
884
|
+
qk = tl.dot(q.to(k.dtype), k)
|
885
|
+
if BLOCK_DPE > 0:
|
886
|
+
offs_kpe = (
|
887
|
+
offs_kv_loc[None, :] * stride_buf_kbs
|
888
|
+
+ cur_kv_head * stride_buf_kh
|
889
|
+
+ offs_dpe[:, None]
|
890
|
+
)
|
891
|
+
kpe = tl.load(
|
892
|
+
K_Buffer + offs_kpe,
|
893
|
+
mask=mask_n[None, :],
|
894
|
+
other=0.0,
|
895
|
+
)
|
896
|
+
qk += tl.dot(qpe.to(kpe.dtype), kpe)
|
897
|
+
qk *= sm_scale
|
898
|
+
|
899
|
+
if logit_cap > 0:
|
900
|
+
qk = logit_cap * tanh(qk / logit_cap)
|
901
|
+
|
902
|
+
qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float("-inf"))
|
903
|
+
|
904
|
+
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
905
|
+
re_scale = tl.exp(e_max - n_e_max)
|
906
|
+
p = tl.exp(qk - n_e_max[:, None])
|
907
|
+
deno = deno * re_scale + tl.sum(p, 1)
|
908
|
+
|
909
|
+
offs_buf_v = (
|
910
|
+
offs_kv_loc[:, None] * stride_buf_vbs
|
911
|
+
+ cur_kv_head * stride_buf_vh
|
912
|
+
+ offs_dv[None, :]
|
913
|
+
)
|
914
|
+
v = tl.load(
|
915
|
+
V_Buffer + offs_buf_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0
|
916
|
+
)
|
917
|
+
p = p.to(v.dtype)
|
918
|
+
acc = acc * re_scale[:, None] + tl.dot(p, v)
|
919
|
+
|
920
|
+
e_max = n_e_max
|
921
|
+
|
922
|
+
# stage 2: compute the trianlge part
|
923
|
+
|
924
|
+
cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
|
925
|
+
for start_n in range(0, cur_block_m_end, BLOCK_N):
|
926
|
+
start_n = tl.multiple_of(start_n, BLOCK_N)
|
927
|
+
mask_n = (start_n + offs_n) < cur_block_m_end
|
928
|
+
|
929
|
+
# load k in transposed way
|
930
|
+
offs_k = (
|
931
|
+
(cur_seq_extend_start_contiguous + start_n + offs_n[None, :]) * stride_kbs
|
932
|
+
+ cur_kv_head * stride_kh
|
933
|
+
+ offs_d[:, None]
|
934
|
+
)
|
935
|
+
k = tl.load(
|
936
|
+
K_Extend + offs_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0
|
937
|
+
)
|
938
|
+
|
939
|
+
qk = tl.dot(q, k, out_dtype=tl.float32)
|
940
|
+
if BLOCK_DPE > 0:
|
941
|
+
offs_kpe = (
|
942
|
+
(cur_seq_extend_start_contiguous + start_n + offs_n[None, :])
|
943
|
+
* stride_kbs
|
944
|
+
+ cur_kv_head * stride_kh
|
945
|
+
+ offs_dpe[:, None]
|
946
|
+
)
|
947
|
+
kpe = tl.load(
|
948
|
+
K_Extend + offs_kpe,
|
949
|
+
mask=mask_n[None, :],
|
950
|
+
other=0.0,
|
951
|
+
)
|
952
|
+
qk += tl.dot(qpe, kpe)
|
953
|
+
|
954
|
+
qk *= sm_scale
|
955
|
+
|
956
|
+
if logit_cap > 0:
|
957
|
+
qk = logit_cap * tanh(qk / logit_cap)
|
958
|
+
|
959
|
+
mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
|
960
|
+
start_n + offs_n[None, :]
|
961
|
+
)
|
962
|
+
mask_causual &= mask_m[:, None] & mask_n[None, :]
|
963
|
+
qk = tl.where(mask_causual, qk, float("-inf"))
|
964
|
+
|
965
|
+
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
966
|
+
re_scale = tl.exp(e_max - n_e_max)
|
967
|
+
p = tl.exp(qk - n_e_max[:, None])
|
968
|
+
deno = deno * re_scale + tl.sum(p, 1)
|
969
|
+
|
970
|
+
offs_v = (
|
971
|
+
(cur_seq_extend_start_contiguous + start_n + offs_n[:, None]) * stride_vbs
|
972
|
+
+ cur_kv_head * stride_vh
|
973
|
+
+ offs_dv[None, :]
|
974
|
+
)
|
975
|
+
v = tl.load(
|
976
|
+
V_Extend + offs_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0
|
977
|
+
)
|
978
|
+
p = p.to(v.dtype)
|
979
|
+
acc = acc * re_scale[:, None] + tl.dot(p, v)
|
980
|
+
|
981
|
+
e_max = n_e_max
|
982
|
+
|
983
|
+
offs_o = (
|
984
|
+
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
|
985
|
+
* stride_obs
|
986
|
+
+ cur_head * stride_oh
|
987
|
+
+ offs_dv[None, :]
|
988
|
+
)
|
989
|
+
tl.store(
|
990
|
+
O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None] & mask_dv[None, :]
|
991
|
+
)
|
992
|
+
|
993
|
+
|
994
|
+
def extend_attention_fwd(
|
995
|
+
q_extend,
|
996
|
+
k_extend,
|
997
|
+
v_extend,
|
998
|
+
o_extend,
|
999
|
+
k_buffer,
|
1000
|
+
v_buffer,
|
1001
|
+
req_to_tokens,
|
1002
|
+
b_req_idx,
|
1003
|
+
b_seq_len,
|
1004
|
+
b_seq_len_extend,
|
1005
|
+
b_start_loc_extend,
|
1006
|
+
max_len_extend,
|
1007
|
+
sm_scale=None,
|
1008
|
+
logit_cap=0.0,
|
1009
|
+
):
|
1010
|
+
"""
|
1011
|
+
q_extend, k_extend, v_extend, o_extend: contiguous tensors
|
1012
|
+
|
1013
|
+
k_buffer, v_buffer: (prefix + extend) tensors in mem_manager
|
1014
|
+
"""
|
1015
|
+
Lq, Lk, Lv = (
|
1016
|
+
q_extend.shape[-1],
|
1017
|
+
k_extend.shape[-1],
|
1018
|
+
v_extend.shape[-1],
|
1019
|
+
)
|
1020
|
+
|
1021
|
+
if Lq == 576:
|
1022
|
+
BLOCK_DMODEL = 512
|
1023
|
+
BLOCK_DPE = 64
|
1024
|
+
elif Lq == 288:
|
1025
|
+
BLOCK_DMODEL = 256
|
1026
|
+
BLOCK_DPE = 32
|
1027
|
+
elif Lq == 192:
|
1028
|
+
BLOCK_DMODEL = 128
|
1029
|
+
BLOCK_DPE = 64
|
1030
|
+
else:
|
1031
|
+
BLOCK_DMODEL = triton.next_power_of_2(Lq)
|
1032
|
+
BLOCK_DPE = 0
|
1033
|
+
BLOCK_DV = triton.next_power_of_2(Lv)
|
1034
|
+
|
1035
|
+
if is_hip_:
|
1036
|
+
BLOCK_M, BLOCK_N = (64, 64)
|
1037
|
+
num_warps = 4
|
1038
|
+
|
1039
|
+
else:
|
1040
|
+
if is_cuda_available and CUDA_CAPABILITY[0] >= 9:
|
1041
|
+
if Lq <= 256:
|
1042
|
+
BLOCK_M, BLOCK_N = (128, 64)
|
1043
|
+
else:
|
1044
|
+
BLOCK_M, BLOCK_N = (32, 64)
|
1045
|
+
elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
|
1046
|
+
if Lq <= 128:
|
1047
|
+
BLOCK_M, BLOCK_N = (128, 128)
|
1048
|
+
elif Lq <= 256:
|
1049
|
+
BLOCK_M, BLOCK_N = (64, 64)
|
1050
|
+
else:
|
1051
|
+
BLOCK_M, BLOCK_N = (32, 64)
|
1052
|
+
else:
|
1053
|
+
BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)
|
1054
|
+
|
1055
|
+
num_warps = 4 if Lk <= 64 else 8
|
1056
|
+
|
1057
|
+
sm_scale = sm_scale or 1.0 / (Lq**0.5)
|
1058
|
+
batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1]
|
1059
|
+
kv_group_num = q_extend.shape[1] // k_extend.shape[1]
|
1060
|
+
|
1061
|
+
grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
|
1062
|
+
num_stages = 1
|
1063
|
+
|
1064
|
+
extra_kargs = {}
|
1065
|
+
if is_hip_:
|
1066
|
+
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
|
1067
|
+
|
1068
|
+
_fwd_kernel[grid](
|
1069
|
+
q_extend,
|
1070
|
+
k_extend,
|
1071
|
+
v_extend,
|
1072
|
+
o_extend,
|
1073
|
+
k_buffer,
|
1074
|
+
v_buffer,
|
1075
|
+
req_to_tokens,
|
1076
|
+
b_req_idx,
|
1077
|
+
b_seq_len,
|
1078
|
+
b_start_loc_extend,
|
1079
|
+
b_seq_len_extend,
|
1080
|
+
sm_scale,
|
1081
|
+
kv_group_num,
|
1082
|
+
q_extend.stride(0),
|
1083
|
+
q_extend.stride(1),
|
1084
|
+
k_extend.stride(0),
|
1085
|
+
k_extend.stride(1),
|
1086
|
+
v_extend.stride(0),
|
1087
|
+
v_extend.stride(1),
|
1088
|
+
o_extend.stride(0),
|
1089
|
+
o_extend.stride(1),
|
1090
|
+
k_buffer.stride(0),
|
1091
|
+
k_buffer.stride(1),
|
1092
|
+
v_buffer.stride(0),
|
1093
|
+
v_buffer.stride(1),
|
1094
|
+
req_to_tokens.stride(0),
|
1095
|
+
logit_cap=logit_cap,
|
1096
|
+
BLOCK_DMODEL=BLOCK_DMODEL,
|
1097
|
+
BLOCK_DPE=BLOCK_DPE,
|
1098
|
+
BLOCK_DV=BLOCK_DV,
|
1099
|
+
BLOCK_M=BLOCK_M,
|
1100
|
+
BLOCK_N=BLOCK_N,
|
1101
|
+
Lq=Lq,
|
1102
|
+
Lv=Lv,
|
1103
|
+
num_warps=num_warps,
|
1104
|
+
num_stages=num_stages,
|
1105
|
+
**extra_kargs,
|
1106
|
+
)
|
@@ -46,11 +46,11 @@ def _fwd_kernel(
|
|
46
46
|
O_Extend,
|
47
47
|
K_Buffer,
|
48
48
|
V_Buffer,
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
49
|
+
qo_indptr,
|
50
|
+
kv_indptr,
|
51
|
+
kv_indices,
|
52
|
+
mask_ptr,
|
53
|
+
mask_offsets,
|
54
54
|
sm_scale,
|
55
55
|
kv_group_num,
|
56
56
|
stride_qbs,
|
@@ -65,7 +65,6 @@ def _fwd_kernel(
|
|
65
65
|
stride_buf_kh,
|
66
66
|
stride_buf_vbs,
|
67
67
|
stride_buf_vh,
|
68
|
-
stride_req_to_tokens_b,
|
69
68
|
logit_cap: tl.constexpr,
|
70
69
|
Lq: tl.constexpr,
|
71
70
|
Lv: tl.constexpr,
|
@@ -74,19 +73,21 @@ def _fwd_kernel(
|
|
74
73
|
BLOCK_DV: tl.constexpr,
|
75
74
|
BLOCK_M: tl.constexpr,
|
76
75
|
BLOCK_N: tl.constexpr,
|
76
|
+
USE_CUSTOM_MASK: tl.constexpr,
|
77
77
|
):
|
78
78
|
cur_seq = tl.program_id(0)
|
79
79
|
cur_head = tl.program_id(1)
|
80
80
|
cur_block_m = tl.program_id(2)
|
81
81
|
cur_kv_head = cur_head // kv_group_num
|
82
82
|
|
83
|
-
|
84
|
-
cur_seq_len_extend = tl.load(
|
85
|
-
|
83
|
+
cur_seq_extend_start_idx = tl.load(qo_indptr + cur_seq)
|
84
|
+
cur_seq_len_extend = tl.load(qo_indptr + cur_seq + 1) - cur_seq_extend_start_idx
|
85
|
+
cur_seq_kv_start_idx = tl.load(kv_indptr + cur_seq)
|
86
|
+
cur_seq_len_prefix = tl.load(kv_indptr + cur_seq + 1) - cur_seq_kv_start_idx
|
87
|
+
cur_seq_len = cur_seq_len_prefix + cur_seq_len_extend
|
86
88
|
|
87
|
-
|
88
|
-
|
89
|
-
cur_batch_req_idx = tl.load(B_req_idx + cur_seq)
|
89
|
+
if USE_CUSTOM_MASK:
|
90
|
+
cur_seq_mask_start_idx = tl.load(mask_offsets + cur_seq)
|
90
91
|
|
91
92
|
offs_d = tl.arange(0, BLOCK_DMODEL)
|
92
93
|
offs_dv = tl.arange(0, BLOCK_DV)
|
@@ -97,7 +98,7 @@ def _fwd_kernel(
|
|
97
98
|
mask_dv = offs_dv < Lv
|
98
99
|
|
99
100
|
offs_q = (
|
100
|
-
(
|
101
|
+
(cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
|
101
102
|
* stride_qbs
|
102
103
|
+ cur_head * stride_qh
|
103
104
|
+ offs_d[None, :]
|
@@ -109,7 +110,7 @@ def _fwd_kernel(
|
|
109
110
|
if BLOCK_DPE > 0:
|
110
111
|
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
|
111
112
|
offs_qpe = (
|
112
|
-
(
|
113
|
+
(cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
|
113
114
|
* stride_qbs
|
114
115
|
+ cur_head * stride_qh
|
115
116
|
+ offs_dpe[None, :]
|
@@ -126,10 +127,9 @@ def _fwd_kernel(
|
|
126
127
|
for start_n in range(0, cur_seq_len_prefix, BLOCK_N):
|
127
128
|
start_n = tl.multiple_of(start_n, BLOCK_N)
|
128
129
|
mask_n = (start_n + offs_n) < cur_seq_len_prefix
|
129
|
-
|
130
|
-
|
130
|
+
offs_kv_loc = tl.load(
|
131
|
+
kv_indices + cur_seq_kv_start_idx + start_n + offs_n, mask=mask_n, other=0
|
131
132
|
)
|
132
|
-
offs_kv_loc = tl.load(Req_to_tokens + offs_b_loc_prefix, mask=mask_n, other=0)
|
133
133
|
|
134
134
|
# load k in transposed way
|
135
135
|
offs_buf_k = (
|
@@ -159,7 +159,20 @@ def _fwd_kernel(
|
|
159
159
|
if logit_cap > 0:
|
160
160
|
qk = logit_cap * tanh(qk / logit_cap)
|
161
161
|
|
162
|
-
|
162
|
+
if USE_CUSTOM_MASK:
|
163
|
+
custom_mask = tl.load(
|
164
|
+
mask_ptr
|
165
|
+
+ cur_seq_mask_start_idx
|
166
|
+
+ (cur_block_m * BLOCK_M + offs_m[:, None]) * cur_seq_len
|
167
|
+
+ start_n
|
168
|
+
+ offs_n[None, :],
|
169
|
+
mask=(mask_m[:, None] & mask_n[None, :]),
|
170
|
+
other=0,
|
171
|
+
)
|
172
|
+
custom_mask &= mask_m[:, None] & mask_n[None, :]
|
173
|
+
qk = tl.where(custom_mask, qk, float("-inf"))
|
174
|
+
else:
|
175
|
+
qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float("-inf"))
|
163
176
|
|
164
177
|
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
165
178
|
re_scale = tl.exp(e_max - n_e_max)
|
@@ -179,7 +192,7 @@ def _fwd_kernel(
|
|
179
192
|
|
180
193
|
e_max = n_e_max
|
181
194
|
|
182
|
-
# stage 2: compute the
|
195
|
+
# stage 2: compute the triangle part
|
183
196
|
|
184
197
|
cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
|
185
198
|
for start_n in range(0, cur_block_m_end, BLOCK_N):
|
@@ -188,7 +201,7 @@ def _fwd_kernel(
|
|
188
201
|
|
189
202
|
# load k in transposed way
|
190
203
|
offs_k = (
|
191
|
-
(
|
204
|
+
(cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs
|
192
205
|
+ cur_kv_head * stride_kh
|
193
206
|
+ offs_d[:, None]
|
194
207
|
)
|
@@ -199,8 +212,7 @@ def _fwd_kernel(
|
|
199
212
|
qk = tl.dot(q, k, out_dtype=tl.float32)
|
200
213
|
if BLOCK_DPE > 0:
|
201
214
|
offs_kpe = (
|
202
|
-
(
|
203
|
-
* stride_kbs
|
215
|
+
(cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs
|
204
216
|
+ cur_kv_head * stride_kh
|
205
217
|
+ offs_dpe[:, None]
|
206
218
|
)
|
@@ -216,11 +228,25 @@ def _fwd_kernel(
|
|
216
228
|
if logit_cap > 0:
|
217
229
|
qk = logit_cap * tanh(qk / logit_cap)
|
218
230
|
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
231
|
+
if USE_CUSTOM_MASK:
|
232
|
+
custom_mask = tl.load(
|
233
|
+
mask_ptr
|
234
|
+
+ cur_seq_mask_start_idx
|
235
|
+
+ (cur_block_m * BLOCK_M + offs_m[:, None]) * cur_seq_len
|
236
|
+
+ cur_seq_len_prefix
|
237
|
+
+ start_n
|
238
|
+
+ offs_n[None, :],
|
239
|
+
mask=(mask_m[:, None] & mask_n[None, :]),
|
240
|
+
other=0,
|
241
|
+
)
|
242
|
+
custom_mask &= mask_m[:, None] & mask_n[None, :]
|
243
|
+
qk = tl.where(custom_mask, qk, float("-inf"))
|
244
|
+
else:
|
245
|
+
mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
|
246
|
+
start_n + offs_n[None, :]
|
247
|
+
)
|
248
|
+
mask_causual &= mask_m[:, None] & mask_n[None, :]
|
249
|
+
qk = tl.where(mask_causual, qk, float("-inf"))
|
224
250
|
|
225
251
|
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
226
252
|
re_scale = tl.exp(e_max - n_e_max)
|
@@ -228,7 +254,7 @@ def _fwd_kernel(
|
|
228
254
|
deno = deno * re_scale + tl.sum(p, 1)
|
229
255
|
|
230
256
|
offs_v = (
|
231
|
-
(
|
257
|
+
(cur_seq_extend_start_idx + start_n + offs_n[:, None]) * stride_vbs
|
232
258
|
+ cur_kv_head * stride_vh
|
233
259
|
+ offs_dv[None, :]
|
234
260
|
)
|
@@ -241,7 +267,7 @@ def _fwd_kernel(
|
|
241
267
|
e_max = n_e_max
|
242
268
|
|
243
269
|
offs_o = (
|
244
|
-
(
|
270
|
+
(cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
|
245
271
|
* stride_obs
|
246
272
|
+ cur_head * stride_oh
|
247
273
|
+ offs_dv[None, :]
|
@@ -258,11 +284,11 @@ def extend_attention_fwd(
|
|
258
284
|
o_extend,
|
259
285
|
k_buffer,
|
260
286
|
v_buffer,
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
287
|
+
qo_indptr,
|
288
|
+
kv_indptr,
|
289
|
+
kv_indices,
|
290
|
+
custom_mask,
|
291
|
+
mask_offsets,
|
266
292
|
max_len_extend,
|
267
293
|
sm_scale=None,
|
268
294
|
logit_cap=0.0,
|
@@ -315,15 +341,17 @@ def extend_attention_fwd(
|
|
315
341
|
num_warps = 4 if Lk <= 64 else 8
|
316
342
|
|
317
343
|
sm_scale = sm_scale or 1.0 / (Lq**0.5)
|
318
|
-
batch_size, head_num =
|
344
|
+
batch_size, head_num = qo_indptr.shape[0] - 1, q_extend.shape[1]
|
319
345
|
kv_group_num = q_extend.shape[1] // k_extend.shape[1]
|
320
346
|
|
347
|
+
USE_CUSTOM_MASK = custom_mask is not None
|
348
|
+
|
321
349
|
grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
|
322
350
|
num_stages = 1
|
323
351
|
|
324
352
|
extra_kargs = {}
|
325
353
|
if is_hip_:
|
326
|
-
extra_kargs = {"waves_per_eu":
|
354
|
+
extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2}
|
327
355
|
|
328
356
|
_fwd_kernel[grid](
|
329
357
|
q_extend,
|
@@ -332,11 +360,11 @@ def extend_attention_fwd(
|
|
332
360
|
o_extend,
|
333
361
|
k_buffer,
|
334
362
|
v_buffer,
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
363
|
+
qo_indptr,
|
364
|
+
kv_indptr,
|
365
|
+
kv_indices,
|
366
|
+
custom_mask,
|
367
|
+
mask_offsets,
|
340
368
|
sm_scale,
|
341
369
|
kv_group_num,
|
342
370
|
q_extend.stride(0),
|
@@ -351,7 +379,6 @@ def extend_attention_fwd(
|
|
351
379
|
k_buffer.stride(1),
|
352
380
|
v_buffer.stride(0),
|
353
381
|
v_buffer.stride(1),
|
354
|
-
req_to_tokens.stride(0),
|
355
382
|
logit_cap=logit_cap,
|
356
383
|
BLOCK_DMODEL=BLOCK_DMODEL,
|
357
384
|
BLOCK_DPE=BLOCK_DPE,
|
@@ -360,6 +387,7 @@ def extend_attention_fwd(
|
|
360
387
|
BLOCK_N=BLOCK_N,
|
361
388
|
Lq=Lq,
|
362
389
|
Lv=Lv,
|
390
|
+
USE_CUSTOM_MASK=USE_CUSTOM_MASK,
|
363
391
|
num_warps=num_warps,
|
364
392
|
num_stages=num_stages,
|
365
393
|
**extra_kargs,
|
@@ -18,7 +18,7 @@ from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
|
18
18
|
from sglang.srt.utils import direct_register_custom_op, get_device_name, is_hip
|
19
19
|
|
20
20
|
is_hip_flag = is_hip()
|
21
|
-
|
21
|
+
|
22
22
|
|
23
23
|
logger = logging.getLogger(__name__)
|
24
24
|
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
|
@@ -27,6 +27,15 @@ enable_moe_align_block_size_triton = bool(
|
|
27
27
|
int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
|
28
28
|
)
|
29
29
|
|
30
|
+
_is_cuda = torch.cuda.is_available() and torch.version.cuda
|
31
|
+
_is_rocm = torch.cuda.is_available() and torch.version.hip
|
32
|
+
|
33
|
+
if _is_cuda:
|
34
|
+
from sgl_kernel import gelu_and_mul, silu_and_mul
|
35
|
+
|
36
|
+
if _is_cuda or _is_rocm:
|
37
|
+
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
38
|
+
|
30
39
|
|
31
40
|
@triton.jit
|
32
41
|
def fused_moe_kernel(
|
@@ -417,12 +426,12 @@ def moe_align_block_size(
|
|
417
426
|
num_tokens_post_pad,
|
418
427
|
)
|
419
428
|
else:
|
420
|
-
token_cnts_buffer = torch.
|
429
|
+
token_cnts_buffer = torch.zeros(
|
421
430
|
(num_experts + 1) * num_experts,
|
422
431
|
dtype=torch.int32,
|
423
432
|
device=topk_ids.device,
|
424
433
|
)
|
425
|
-
cumsum_buffer = torch.
|
434
|
+
cumsum_buffer = torch.zeros(
|
426
435
|
num_experts + 1, dtype=torch.int32, device=topk_ids.device
|
427
436
|
)
|
428
437
|
|
@@ -989,9 +998,15 @@ def fused_experts_impl(
|
|
989
998
|
)
|
990
999
|
|
991
1000
|
if activation == "silu":
|
992
|
-
|
1001
|
+
if _is_cuda:
|
1002
|
+
silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
|
1003
|
+
else:
|
1004
|
+
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
993
1005
|
elif activation == "gelu":
|
994
|
-
|
1006
|
+
if _is_cuda:
|
1007
|
+
gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
|
1008
|
+
else:
|
1009
|
+
ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
995
1010
|
else:
|
996
1011
|
raise ValueError(f"Unsupported activation: {activation=}")
|
997
1012
|
|