sglang 0.4.2.post2__py3-none-any.whl → 0.4.2.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/srt/function_call_parser.py +96 -69
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -3
- sglang/srt/layers/attention/triton_backend.py +64 -16
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +337 -3
- sglang/srt/layers/attention/triton_ops/extend_attention.py +70 -42
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +2 -2
- sglang/srt/layers/quantization/fp8_kernel.py +43 -10
- sglang/srt/models/llama.py +8 -3
- sglang/srt/speculative/build_eagle_tree.py +482 -102
- sglang/srt/speculative/eagle_utils.py +80 -50
- sglang/version.py +1 -1
- {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post3.dist-info}/METADATA +2 -2
- {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post3.dist-info}/RECORD +16 -16
- {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post3.dist-info}/LICENSE +0 -0
- {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post3.dist-info}/WHEEL +0 -0
- {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post3.dist-info}/top_level.txt +0 -0
@@ -3,6 +3,13 @@ import triton
|
|
3
3
|
import triton.language as tl
|
4
4
|
|
5
5
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
6
|
+
from sglang.srt.utils import is_hip
|
7
|
+
|
8
|
+
is_cuda_available = torch.cuda.is_available()
|
9
|
+
if is_cuda_available:
|
10
|
+
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
11
|
+
|
12
|
+
is_hip_ = is_hip()
|
6
13
|
|
7
14
|
if global_server_args_dict.get("attention_reduce_in_fp32", False):
|
8
15
|
REDUCE_TRITON_TYPE = tl.float32
|
@@ -274,9 +281,6 @@ def flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, O, block_seq):
|
|
274
281
|
return
|
275
282
|
|
276
283
|
|
277
|
-
import torch
|
278
|
-
|
279
|
-
|
280
284
|
def flash_decode_attention_fwd(
|
281
285
|
q,
|
282
286
|
k_buffer,
|
@@ -770,3 +774,333 @@ def flash_decode_sparse_attention_fwd(
|
|
770
774
|
)
|
771
775
|
|
772
776
|
sparse_flash_decode_stage3(heavy_token_num, mid_out, mid_o_logexpsum, o, BLOCK_SEQ)
|
777
|
+
|
778
|
+
|
779
|
+
# Extend attention kernel for Double Sparsity
|
780
|
+
# Moved from https://github.com/sgl-project/sglang/blob/v0.4.2.post1/python/sglang/srt/layers/attention/triton_ops/extend_attention.py
|
781
|
+
@triton.jit
|
782
|
+
def _fwd_kernel(
|
783
|
+
Q_Extend,
|
784
|
+
K_Extend,
|
785
|
+
V_Extend,
|
786
|
+
O_Extend,
|
787
|
+
K_Buffer,
|
788
|
+
V_Buffer,
|
789
|
+
Req_to_tokens,
|
790
|
+
B_req_idx,
|
791
|
+
B_Seq_Len,
|
792
|
+
B_Start_Loc_Extend,
|
793
|
+
B_Seq_Len_Extend,
|
794
|
+
sm_scale,
|
795
|
+
kv_group_num,
|
796
|
+
stride_qbs,
|
797
|
+
stride_qh,
|
798
|
+
stride_kbs,
|
799
|
+
stride_kh,
|
800
|
+
stride_vbs,
|
801
|
+
stride_vh,
|
802
|
+
stride_obs,
|
803
|
+
stride_oh,
|
804
|
+
stride_buf_kbs,
|
805
|
+
stride_buf_kh,
|
806
|
+
stride_buf_vbs,
|
807
|
+
stride_buf_vh,
|
808
|
+
stride_req_to_tokens_b,
|
809
|
+
logit_cap: tl.constexpr,
|
810
|
+
Lq: tl.constexpr,
|
811
|
+
Lv: tl.constexpr,
|
812
|
+
BLOCK_DMODEL: tl.constexpr,
|
813
|
+
BLOCK_DPE: tl.constexpr,
|
814
|
+
BLOCK_DV: tl.constexpr,
|
815
|
+
BLOCK_M: tl.constexpr,
|
816
|
+
BLOCK_N: tl.constexpr,
|
817
|
+
):
|
818
|
+
cur_seq = tl.program_id(0)
|
819
|
+
cur_head = tl.program_id(1)
|
820
|
+
cur_block_m = tl.program_id(2)
|
821
|
+
cur_kv_head = cur_head // kv_group_num
|
822
|
+
|
823
|
+
cur_seq_len = tl.load(B_Seq_Len + cur_seq)
|
824
|
+
cur_seq_len_extend = tl.load(B_Seq_Len_Extend + cur_seq)
|
825
|
+
cur_seq_len_prefix = cur_seq_len - cur_seq_len_extend
|
826
|
+
|
827
|
+
cur_seq_prefix_start_in_loc = 0
|
828
|
+
cur_seq_extend_start_contiguous = tl.load(B_Start_Loc_Extend + cur_seq)
|
829
|
+
cur_batch_req_idx = tl.load(B_req_idx + cur_seq)
|
830
|
+
|
831
|
+
offs_d = tl.arange(0, BLOCK_DMODEL)
|
832
|
+
offs_dv = tl.arange(0, BLOCK_DV)
|
833
|
+
offs_m = tl.arange(0, BLOCK_M)
|
834
|
+
mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_len_extend
|
835
|
+
|
836
|
+
mask_d = offs_d < Lq
|
837
|
+
mask_dv = offs_dv < Lv
|
838
|
+
|
839
|
+
offs_q = (
|
840
|
+
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
|
841
|
+
* stride_qbs
|
842
|
+
+ cur_head * stride_qh
|
843
|
+
+ offs_d[None, :]
|
844
|
+
)
|
845
|
+
q = tl.load(
|
846
|
+
Q_Extend + offs_q, mask=(mask_m[:, None]) & (mask_d[None, :]), other=0.0
|
847
|
+
)
|
848
|
+
|
849
|
+
if BLOCK_DPE > 0:
|
850
|
+
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
|
851
|
+
offs_qpe = (
|
852
|
+
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
|
853
|
+
* stride_qbs
|
854
|
+
+ cur_head * stride_qh
|
855
|
+
+ offs_dpe[None, :]
|
856
|
+
)
|
857
|
+
qpe = tl.load(Q_Extend + offs_qpe, mask=mask_m[:, None], other=0.0)
|
858
|
+
|
859
|
+
# stage 1: compute scores with prefix
|
860
|
+
offs_n = tl.arange(0, BLOCK_N)
|
861
|
+
|
862
|
+
acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32)
|
863
|
+
deno = tl.zeros([BLOCK_M], dtype=tl.float32)
|
864
|
+
e_max = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
865
|
+
|
866
|
+
for start_n in range(0, cur_seq_len_prefix, BLOCK_N):
|
867
|
+
start_n = tl.multiple_of(start_n, BLOCK_N)
|
868
|
+
mask_n = (start_n + offs_n) < cur_seq_len_prefix
|
869
|
+
offs_b_loc_prefix = cur_batch_req_idx * stride_req_to_tokens_b + (
|
870
|
+
cur_seq_prefix_start_in_loc + start_n + offs_n
|
871
|
+
)
|
872
|
+
offs_kv_loc = tl.load(Req_to_tokens + offs_b_loc_prefix, mask=mask_n, other=0)
|
873
|
+
|
874
|
+
# load k in transposed way
|
875
|
+
offs_buf_k = (
|
876
|
+
offs_kv_loc[None, :] * stride_buf_kbs
|
877
|
+
+ cur_kv_head * stride_buf_kh
|
878
|
+
+ offs_d[:, None]
|
879
|
+
)
|
880
|
+
k = tl.load(
|
881
|
+
K_Buffer + offs_buf_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0
|
882
|
+
)
|
883
|
+
|
884
|
+
qk = tl.dot(q.to(k.dtype), k)
|
885
|
+
if BLOCK_DPE > 0:
|
886
|
+
offs_kpe = (
|
887
|
+
offs_kv_loc[None, :] * stride_buf_kbs
|
888
|
+
+ cur_kv_head * stride_buf_kh
|
889
|
+
+ offs_dpe[:, None]
|
890
|
+
)
|
891
|
+
kpe = tl.load(
|
892
|
+
K_Buffer + offs_kpe,
|
893
|
+
mask=mask_n[None, :],
|
894
|
+
other=0.0,
|
895
|
+
)
|
896
|
+
qk += tl.dot(qpe.to(kpe.dtype), kpe)
|
897
|
+
qk *= sm_scale
|
898
|
+
|
899
|
+
if logit_cap > 0:
|
900
|
+
qk = logit_cap * tanh(qk / logit_cap)
|
901
|
+
|
902
|
+
qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float("-inf"))
|
903
|
+
|
904
|
+
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
905
|
+
re_scale = tl.exp(e_max - n_e_max)
|
906
|
+
p = tl.exp(qk - n_e_max[:, None])
|
907
|
+
deno = deno * re_scale + tl.sum(p, 1)
|
908
|
+
|
909
|
+
offs_buf_v = (
|
910
|
+
offs_kv_loc[:, None] * stride_buf_vbs
|
911
|
+
+ cur_kv_head * stride_buf_vh
|
912
|
+
+ offs_dv[None, :]
|
913
|
+
)
|
914
|
+
v = tl.load(
|
915
|
+
V_Buffer + offs_buf_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0
|
916
|
+
)
|
917
|
+
p = p.to(v.dtype)
|
918
|
+
acc = acc * re_scale[:, None] + tl.dot(p, v)
|
919
|
+
|
920
|
+
e_max = n_e_max
|
921
|
+
|
922
|
+
# stage 2: compute the trianlge part
|
923
|
+
|
924
|
+
cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
|
925
|
+
for start_n in range(0, cur_block_m_end, BLOCK_N):
|
926
|
+
start_n = tl.multiple_of(start_n, BLOCK_N)
|
927
|
+
mask_n = (start_n + offs_n) < cur_block_m_end
|
928
|
+
|
929
|
+
# load k in transposed way
|
930
|
+
offs_k = (
|
931
|
+
(cur_seq_extend_start_contiguous + start_n + offs_n[None, :]) * stride_kbs
|
932
|
+
+ cur_kv_head * stride_kh
|
933
|
+
+ offs_d[:, None]
|
934
|
+
)
|
935
|
+
k = tl.load(
|
936
|
+
K_Extend + offs_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0
|
937
|
+
)
|
938
|
+
|
939
|
+
qk = tl.dot(q, k, out_dtype=tl.float32)
|
940
|
+
if BLOCK_DPE > 0:
|
941
|
+
offs_kpe = (
|
942
|
+
(cur_seq_extend_start_contiguous + start_n + offs_n[None, :])
|
943
|
+
* stride_kbs
|
944
|
+
+ cur_kv_head * stride_kh
|
945
|
+
+ offs_dpe[:, None]
|
946
|
+
)
|
947
|
+
kpe = tl.load(
|
948
|
+
K_Extend + offs_kpe,
|
949
|
+
mask=mask_n[None, :],
|
950
|
+
other=0.0,
|
951
|
+
)
|
952
|
+
qk += tl.dot(qpe, kpe)
|
953
|
+
|
954
|
+
qk *= sm_scale
|
955
|
+
|
956
|
+
if logit_cap > 0:
|
957
|
+
qk = logit_cap * tanh(qk / logit_cap)
|
958
|
+
|
959
|
+
mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
|
960
|
+
start_n + offs_n[None, :]
|
961
|
+
)
|
962
|
+
mask_causual &= mask_m[:, None] & mask_n[None, :]
|
963
|
+
qk = tl.where(mask_causual, qk, float("-inf"))
|
964
|
+
|
965
|
+
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
966
|
+
re_scale = tl.exp(e_max - n_e_max)
|
967
|
+
p = tl.exp(qk - n_e_max[:, None])
|
968
|
+
deno = deno * re_scale + tl.sum(p, 1)
|
969
|
+
|
970
|
+
offs_v = (
|
971
|
+
(cur_seq_extend_start_contiguous + start_n + offs_n[:, None]) * stride_vbs
|
972
|
+
+ cur_kv_head * stride_vh
|
973
|
+
+ offs_dv[None, :]
|
974
|
+
)
|
975
|
+
v = tl.load(
|
976
|
+
V_Extend + offs_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0
|
977
|
+
)
|
978
|
+
p = p.to(v.dtype)
|
979
|
+
acc = acc * re_scale[:, None] + tl.dot(p, v)
|
980
|
+
|
981
|
+
e_max = n_e_max
|
982
|
+
|
983
|
+
offs_o = (
|
984
|
+
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
|
985
|
+
* stride_obs
|
986
|
+
+ cur_head * stride_oh
|
987
|
+
+ offs_dv[None, :]
|
988
|
+
)
|
989
|
+
tl.store(
|
990
|
+
O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None] & mask_dv[None, :]
|
991
|
+
)
|
992
|
+
|
993
|
+
|
994
|
+
def extend_attention_fwd(
|
995
|
+
q_extend,
|
996
|
+
k_extend,
|
997
|
+
v_extend,
|
998
|
+
o_extend,
|
999
|
+
k_buffer,
|
1000
|
+
v_buffer,
|
1001
|
+
req_to_tokens,
|
1002
|
+
b_req_idx,
|
1003
|
+
b_seq_len,
|
1004
|
+
b_seq_len_extend,
|
1005
|
+
b_start_loc_extend,
|
1006
|
+
max_len_extend,
|
1007
|
+
sm_scale=None,
|
1008
|
+
logit_cap=0.0,
|
1009
|
+
):
|
1010
|
+
"""
|
1011
|
+
q_extend, k_extend, v_extend, o_extend: contiguous tensors
|
1012
|
+
|
1013
|
+
k_buffer, v_buffer: (prefix + extend) tensors in mem_manager
|
1014
|
+
"""
|
1015
|
+
Lq, Lk, Lv = (
|
1016
|
+
q_extend.shape[-1],
|
1017
|
+
k_extend.shape[-1],
|
1018
|
+
v_extend.shape[-1],
|
1019
|
+
)
|
1020
|
+
|
1021
|
+
if Lq == 576:
|
1022
|
+
BLOCK_DMODEL = 512
|
1023
|
+
BLOCK_DPE = 64
|
1024
|
+
elif Lq == 288:
|
1025
|
+
BLOCK_DMODEL = 256
|
1026
|
+
BLOCK_DPE = 32
|
1027
|
+
elif Lq == 192:
|
1028
|
+
BLOCK_DMODEL = 128
|
1029
|
+
BLOCK_DPE = 64
|
1030
|
+
else:
|
1031
|
+
BLOCK_DMODEL = triton.next_power_of_2(Lq)
|
1032
|
+
BLOCK_DPE = 0
|
1033
|
+
BLOCK_DV = triton.next_power_of_2(Lv)
|
1034
|
+
|
1035
|
+
if is_hip_:
|
1036
|
+
BLOCK_M, BLOCK_N = (64, 64)
|
1037
|
+
num_warps = 4
|
1038
|
+
|
1039
|
+
else:
|
1040
|
+
if is_cuda_available and CUDA_CAPABILITY[0] >= 9:
|
1041
|
+
if Lq <= 256:
|
1042
|
+
BLOCK_M, BLOCK_N = (128, 64)
|
1043
|
+
else:
|
1044
|
+
BLOCK_M, BLOCK_N = (32, 64)
|
1045
|
+
elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
|
1046
|
+
if Lq <= 128:
|
1047
|
+
BLOCK_M, BLOCK_N = (128, 128)
|
1048
|
+
elif Lq <= 256:
|
1049
|
+
BLOCK_M, BLOCK_N = (64, 64)
|
1050
|
+
else:
|
1051
|
+
BLOCK_M, BLOCK_N = (32, 64)
|
1052
|
+
else:
|
1053
|
+
BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)
|
1054
|
+
|
1055
|
+
num_warps = 4 if Lk <= 64 else 8
|
1056
|
+
|
1057
|
+
sm_scale = sm_scale or 1.0 / (Lq**0.5)
|
1058
|
+
batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1]
|
1059
|
+
kv_group_num = q_extend.shape[1] // k_extend.shape[1]
|
1060
|
+
|
1061
|
+
grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
|
1062
|
+
num_stages = 1
|
1063
|
+
|
1064
|
+
extra_kargs = {}
|
1065
|
+
if is_hip_:
|
1066
|
+
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
|
1067
|
+
|
1068
|
+
_fwd_kernel[grid](
|
1069
|
+
q_extend,
|
1070
|
+
k_extend,
|
1071
|
+
v_extend,
|
1072
|
+
o_extend,
|
1073
|
+
k_buffer,
|
1074
|
+
v_buffer,
|
1075
|
+
req_to_tokens,
|
1076
|
+
b_req_idx,
|
1077
|
+
b_seq_len,
|
1078
|
+
b_start_loc_extend,
|
1079
|
+
b_seq_len_extend,
|
1080
|
+
sm_scale,
|
1081
|
+
kv_group_num,
|
1082
|
+
q_extend.stride(0),
|
1083
|
+
q_extend.stride(1),
|
1084
|
+
k_extend.stride(0),
|
1085
|
+
k_extend.stride(1),
|
1086
|
+
v_extend.stride(0),
|
1087
|
+
v_extend.stride(1),
|
1088
|
+
o_extend.stride(0),
|
1089
|
+
o_extend.stride(1),
|
1090
|
+
k_buffer.stride(0),
|
1091
|
+
k_buffer.stride(1),
|
1092
|
+
v_buffer.stride(0),
|
1093
|
+
v_buffer.stride(1),
|
1094
|
+
req_to_tokens.stride(0),
|
1095
|
+
logit_cap=logit_cap,
|
1096
|
+
BLOCK_DMODEL=BLOCK_DMODEL,
|
1097
|
+
BLOCK_DPE=BLOCK_DPE,
|
1098
|
+
BLOCK_DV=BLOCK_DV,
|
1099
|
+
BLOCK_M=BLOCK_M,
|
1100
|
+
BLOCK_N=BLOCK_N,
|
1101
|
+
Lq=Lq,
|
1102
|
+
Lv=Lv,
|
1103
|
+
num_warps=num_warps,
|
1104
|
+
num_stages=num_stages,
|
1105
|
+
**extra_kargs,
|
1106
|
+
)
|
@@ -46,11 +46,11 @@ def _fwd_kernel(
|
|
46
46
|
O_Extend,
|
47
47
|
K_Buffer,
|
48
48
|
V_Buffer,
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
49
|
+
qo_indptr,
|
50
|
+
kv_indptr,
|
51
|
+
kv_indices,
|
52
|
+
mask_ptr,
|
53
|
+
mask_offsets,
|
54
54
|
sm_scale,
|
55
55
|
kv_group_num,
|
56
56
|
stride_qbs,
|
@@ -65,7 +65,6 @@ def _fwd_kernel(
|
|
65
65
|
stride_buf_kh,
|
66
66
|
stride_buf_vbs,
|
67
67
|
stride_buf_vh,
|
68
|
-
stride_req_to_tokens_b,
|
69
68
|
logit_cap: tl.constexpr,
|
70
69
|
Lq: tl.constexpr,
|
71
70
|
Lv: tl.constexpr,
|
@@ -74,19 +73,21 @@ def _fwd_kernel(
|
|
74
73
|
BLOCK_DV: tl.constexpr,
|
75
74
|
BLOCK_M: tl.constexpr,
|
76
75
|
BLOCK_N: tl.constexpr,
|
76
|
+
USE_CUSTOM_MASK: tl.constexpr,
|
77
77
|
):
|
78
78
|
cur_seq = tl.program_id(0)
|
79
79
|
cur_head = tl.program_id(1)
|
80
80
|
cur_block_m = tl.program_id(2)
|
81
81
|
cur_kv_head = cur_head // kv_group_num
|
82
82
|
|
83
|
-
|
84
|
-
cur_seq_len_extend = tl.load(
|
85
|
-
|
83
|
+
cur_seq_extend_start_idx = tl.load(qo_indptr + cur_seq)
|
84
|
+
cur_seq_len_extend = tl.load(qo_indptr + cur_seq + 1) - cur_seq_extend_start_idx
|
85
|
+
cur_seq_kv_start_idx = tl.load(kv_indptr + cur_seq)
|
86
|
+
cur_seq_len_prefix = tl.load(kv_indptr + cur_seq + 1) - cur_seq_kv_start_idx
|
87
|
+
cur_seq_len = cur_seq_len_prefix + cur_seq_len_extend
|
86
88
|
|
87
|
-
|
88
|
-
|
89
|
-
cur_batch_req_idx = tl.load(B_req_idx + cur_seq)
|
89
|
+
if USE_CUSTOM_MASK:
|
90
|
+
cur_seq_mask_start_idx = tl.load(mask_offsets + cur_seq)
|
90
91
|
|
91
92
|
offs_d = tl.arange(0, BLOCK_DMODEL)
|
92
93
|
offs_dv = tl.arange(0, BLOCK_DV)
|
@@ -97,7 +98,7 @@ def _fwd_kernel(
|
|
97
98
|
mask_dv = offs_dv < Lv
|
98
99
|
|
99
100
|
offs_q = (
|
100
|
-
(
|
101
|
+
(cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
|
101
102
|
* stride_qbs
|
102
103
|
+ cur_head * stride_qh
|
103
104
|
+ offs_d[None, :]
|
@@ -109,7 +110,7 @@ def _fwd_kernel(
|
|
109
110
|
if BLOCK_DPE > 0:
|
110
111
|
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
|
111
112
|
offs_qpe = (
|
112
|
-
(
|
113
|
+
(cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
|
113
114
|
* stride_qbs
|
114
115
|
+ cur_head * stride_qh
|
115
116
|
+ offs_dpe[None, :]
|
@@ -126,10 +127,9 @@ def _fwd_kernel(
|
|
126
127
|
for start_n in range(0, cur_seq_len_prefix, BLOCK_N):
|
127
128
|
start_n = tl.multiple_of(start_n, BLOCK_N)
|
128
129
|
mask_n = (start_n + offs_n) < cur_seq_len_prefix
|
129
|
-
|
130
|
-
|
130
|
+
offs_kv_loc = tl.load(
|
131
|
+
kv_indices + cur_seq_kv_start_idx + start_n + offs_n, mask=mask_n, other=0
|
131
132
|
)
|
132
|
-
offs_kv_loc = tl.load(Req_to_tokens + offs_b_loc_prefix, mask=mask_n, other=0)
|
133
133
|
|
134
134
|
# load k in transposed way
|
135
135
|
offs_buf_k = (
|
@@ -159,7 +159,20 @@ def _fwd_kernel(
|
|
159
159
|
if logit_cap > 0:
|
160
160
|
qk = logit_cap * tanh(qk / logit_cap)
|
161
161
|
|
162
|
-
|
162
|
+
if USE_CUSTOM_MASK:
|
163
|
+
custom_mask = tl.load(
|
164
|
+
mask_ptr
|
165
|
+
+ cur_seq_mask_start_idx
|
166
|
+
+ (cur_block_m * BLOCK_M + offs_m[:, None]) * cur_seq_len
|
167
|
+
+ start_n
|
168
|
+
+ offs_n[None, :],
|
169
|
+
mask=(mask_m[:, None] & mask_n[None, :]),
|
170
|
+
other=0,
|
171
|
+
)
|
172
|
+
custom_mask &= mask_m[:, None] & mask_n[None, :]
|
173
|
+
qk = tl.where(custom_mask, qk, float("-inf"))
|
174
|
+
else:
|
175
|
+
qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float("-inf"))
|
163
176
|
|
164
177
|
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
165
178
|
re_scale = tl.exp(e_max - n_e_max)
|
@@ -179,7 +192,7 @@ def _fwd_kernel(
|
|
179
192
|
|
180
193
|
e_max = n_e_max
|
181
194
|
|
182
|
-
# stage 2: compute the
|
195
|
+
# stage 2: compute the triangle part
|
183
196
|
|
184
197
|
cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
|
185
198
|
for start_n in range(0, cur_block_m_end, BLOCK_N):
|
@@ -188,7 +201,7 @@ def _fwd_kernel(
|
|
188
201
|
|
189
202
|
# load k in transposed way
|
190
203
|
offs_k = (
|
191
|
-
(
|
204
|
+
(cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs
|
192
205
|
+ cur_kv_head * stride_kh
|
193
206
|
+ offs_d[:, None]
|
194
207
|
)
|
@@ -199,8 +212,7 @@ def _fwd_kernel(
|
|
199
212
|
qk = tl.dot(q, k, out_dtype=tl.float32)
|
200
213
|
if BLOCK_DPE > 0:
|
201
214
|
offs_kpe = (
|
202
|
-
(
|
203
|
-
* stride_kbs
|
215
|
+
(cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs
|
204
216
|
+ cur_kv_head * stride_kh
|
205
217
|
+ offs_dpe[:, None]
|
206
218
|
)
|
@@ -216,11 +228,25 @@ def _fwd_kernel(
|
|
216
228
|
if logit_cap > 0:
|
217
229
|
qk = logit_cap * tanh(qk / logit_cap)
|
218
230
|
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
231
|
+
if USE_CUSTOM_MASK:
|
232
|
+
custom_mask = tl.load(
|
233
|
+
mask_ptr
|
234
|
+
+ cur_seq_mask_start_idx
|
235
|
+
+ (cur_block_m * BLOCK_M + offs_m[:, None]) * cur_seq_len
|
236
|
+
+ cur_seq_len_prefix
|
237
|
+
+ start_n
|
238
|
+
+ offs_n[None, :],
|
239
|
+
mask=(mask_m[:, None] & mask_n[None, :]),
|
240
|
+
other=0,
|
241
|
+
)
|
242
|
+
custom_mask &= mask_m[:, None] & mask_n[None, :]
|
243
|
+
qk = tl.where(custom_mask, qk, float("-inf"))
|
244
|
+
else:
|
245
|
+
mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
|
246
|
+
start_n + offs_n[None, :]
|
247
|
+
)
|
248
|
+
mask_causual &= mask_m[:, None] & mask_n[None, :]
|
249
|
+
qk = tl.where(mask_causual, qk, float("-inf"))
|
224
250
|
|
225
251
|
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
226
252
|
re_scale = tl.exp(e_max - n_e_max)
|
@@ -228,7 +254,7 @@ def _fwd_kernel(
|
|
228
254
|
deno = deno * re_scale + tl.sum(p, 1)
|
229
255
|
|
230
256
|
offs_v = (
|
231
|
-
(
|
257
|
+
(cur_seq_extend_start_idx + start_n + offs_n[:, None]) * stride_vbs
|
232
258
|
+ cur_kv_head * stride_vh
|
233
259
|
+ offs_dv[None, :]
|
234
260
|
)
|
@@ -241,7 +267,7 @@ def _fwd_kernel(
|
|
241
267
|
e_max = n_e_max
|
242
268
|
|
243
269
|
offs_o = (
|
244
|
-
(
|
270
|
+
(cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
|
245
271
|
* stride_obs
|
246
272
|
+ cur_head * stride_oh
|
247
273
|
+ offs_dv[None, :]
|
@@ -258,11 +284,11 @@ def extend_attention_fwd(
|
|
258
284
|
o_extend,
|
259
285
|
k_buffer,
|
260
286
|
v_buffer,
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
287
|
+
qo_indptr,
|
288
|
+
kv_indptr,
|
289
|
+
kv_indices,
|
290
|
+
custom_mask,
|
291
|
+
mask_offsets,
|
266
292
|
max_len_extend,
|
267
293
|
sm_scale=None,
|
268
294
|
logit_cap=0.0,
|
@@ -315,15 +341,17 @@ def extend_attention_fwd(
|
|
315
341
|
num_warps = 4 if Lk <= 64 else 8
|
316
342
|
|
317
343
|
sm_scale = sm_scale or 1.0 / (Lq**0.5)
|
318
|
-
batch_size, head_num =
|
344
|
+
batch_size, head_num = qo_indptr.shape[0] - 1, q_extend.shape[1]
|
319
345
|
kv_group_num = q_extend.shape[1] // k_extend.shape[1]
|
320
346
|
|
347
|
+
USE_CUSTOM_MASK = custom_mask is not None
|
348
|
+
|
321
349
|
grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
|
322
350
|
num_stages = 1
|
323
351
|
|
324
352
|
extra_kargs = {}
|
325
353
|
if is_hip_:
|
326
|
-
extra_kargs = {"waves_per_eu":
|
354
|
+
extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2}
|
327
355
|
|
328
356
|
_fwd_kernel[grid](
|
329
357
|
q_extend,
|
@@ -332,11 +360,11 @@ def extend_attention_fwd(
|
|
332
360
|
o_extend,
|
333
361
|
k_buffer,
|
334
362
|
v_buffer,
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
363
|
+
qo_indptr,
|
364
|
+
kv_indptr,
|
365
|
+
kv_indices,
|
366
|
+
custom_mask,
|
367
|
+
mask_offsets,
|
340
368
|
sm_scale,
|
341
369
|
kv_group_num,
|
342
370
|
q_extend.stride(0),
|
@@ -351,7 +379,6 @@ def extend_attention_fwd(
|
|
351
379
|
k_buffer.stride(1),
|
352
380
|
v_buffer.stride(0),
|
353
381
|
v_buffer.stride(1),
|
354
|
-
req_to_tokens.stride(0),
|
355
382
|
logit_cap=logit_cap,
|
356
383
|
BLOCK_DMODEL=BLOCK_DMODEL,
|
357
384
|
BLOCK_DPE=BLOCK_DPE,
|
@@ -360,6 +387,7 @@ def extend_attention_fwd(
|
|
360
387
|
BLOCK_N=BLOCK_N,
|
361
388
|
Lq=Lq,
|
362
389
|
Lv=Lv,
|
390
|
+
USE_CUSTOM_MASK=USE_CUSTOM_MASK,
|
363
391
|
num_warps=num_warps,
|
364
392
|
num_stages=num_stages,
|
365
393
|
**extra_kargs,
|
@@ -417,12 +417,12 @@ def moe_align_block_size(
|
|
417
417
|
num_tokens_post_pad,
|
418
418
|
)
|
419
419
|
else:
|
420
|
-
token_cnts_buffer = torch.
|
420
|
+
token_cnts_buffer = torch.zeros(
|
421
421
|
(num_experts + 1) * num_experts,
|
422
422
|
dtype=torch.int32,
|
423
423
|
device=topk_ids.device,
|
424
424
|
)
|
425
|
-
cumsum_buffer = torch.
|
425
|
+
cumsum_buffer = torch.zeros(
|
426
426
|
num_experts + 1, dtype=torch.int32, device=topk_ids.device
|
427
427
|
)
|
428
428
|
|
@@ -279,12 +279,21 @@ def _w8a8_block_fp8_matmul_unrolledx4(
|
|
279
279
|
|
280
280
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
281
281
|
# manually unroll to 4 iterations
|
282
|
-
|
282
|
+
UNROLL_FACTOR = 4
|
283
|
+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * UNROLL_FACTOR)):
|
283
284
|
# 1st iteration
|
284
|
-
a = tl.load(
|
285
|
-
|
285
|
+
a = tl.load(
|
286
|
+
a_ptrs,
|
287
|
+
mask=offs_k[None, :] < K - (k * UNROLL_FACTOR) * BLOCK_SIZE_K,
|
288
|
+
other=0.0,
|
289
|
+
)
|
290
|
+
b = tl.load(
|
291
|
+
b_ptrs,
|
292
|
+
mask=offs_k[:, None] < K - (k * UNROLL_FACTOR) * BLOCK_SIZE_K,
|
293
|
+
other=0.0,
|
294
|
+
)
|
286
295
|
|
287
|
-
k_start = k * BLOCK_SIZE_K
|
296
|
+
k_start = (k * UNROLL_FACTOR) * BLOCK_SIZE_K
|
288
297
|
offs_ks = k_start // group_k
|
289
298
|
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
|
290
299
|
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
|
@@ -294,8 +303,16 @@ def _w8a8_block_fp8_matmul_unrolledx4(
|
|
294
303
|
b_ptrs += BLOCK_SIZE_K * stride_bk
|
295
304
|
|
296
305
|
# 2nd iteration
|
297
|
-
a = tl.load(
|
298
|
-
|
306
|
+
a = tl.load(
|
307
|
+
a_ptrs,
|
308
|
+
mask=offs_k[None, :] < K - (k * UNROLL_FACTOR + 1) * BLOCK_SIZE_K,
|
309
|
+
other=0.0,
|
310
|
+
)
|
311
|
+
b = tl.load(
|
312
|
+
b_ptrs,
|
313
|
+
mask=offs_k[:, None] < K - (k * UNROLL_FACTOR + 1) * BLOCK_SIZE_K,
|
314
|
+
other=0.0,
|
315
|
+
)
|
299
316
|
|
300
317
|
k_start = k_start + BLOCK_SIZE_K
|
301
318
|
offs_ks = k_start // group_k
|
@@ -307,8 +324,16 @@ def _w8a8_block_fp8_matmul_unrolledx4(
|
|
307
324
|
b_ptrs += BLOCK_SIZE_K * stride_bk
|
308
325
|
|
309
326
|
# 3rd iteration
|
310
|
-
a = tl.load(
|
311
|
-
|
327
|
+
a = tl.load(
|
328
|
+
a_ptrs,
|
329
|
+
mask=offs_k[None, :] < K - (k * UNROLL_FACTOR + 2) * BLOCK_SIZE_K,
|
330
|
+
other=0.0,
|
331
|
+
)
|
332
|
+
b = tl.load(
|
333
|
+
b_ptrs,
|
334
|
+
mask=offs_k[:, None] < K - (k * UNROLL_FACTOR + 2) * BLOCK_SIZE_K,
|
335
|
+
other=0.0,
|
336
|
+
)
|
312
337
|
|
313
338
|
k_start = k_start + BLOCK_SIZE_K
|
314
339
|
offs_ks = k_start // group_k
|
@@ -320,8 +345,16 @@ def _w8a8_block_fp8_matmul_unrolledx4(
|
|
320
345
|
b_ptrs += BLOCK_SIZE_K * stride_bk
|
321
346
|
|
322
347
|
# 4th iteration
|
323
|
-
a = tl.load(
|
324
|
-
|
348
|
+
a = tl.load(
|
349
|
+
a_ptrs,
|
350
|
+
mask=offs_k[None, :] < K - (k * UNROLL_FACTOR + 3) * BLOCK_SIZE_K,
|
351
|
+
other=0.0,
|
352
|
+
)
|
353
|
+
b = tl.load(
|
354
|
+
b_ptrs,
|
355
|
+
mask=offs_k[:, None] < K - (k * UNROLL_FACTOR + 3) * BLOCK_SIZE_K,
|
356
|
+
other=0.0,
|
357
|
+
)
|
325
358
|
|
326
359
|
k_start = k_start + BLOCK_SIZE_K
|
327
360
|
offs_ks = k_start // group_k
|