sglang 0.4.2.post2__py3-none-any.whl → 0.4.2.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. sglang/check_env.py +1 -0
  2. sglang/srt/constrained/outlines_backend.py +4 -1
  3. sglang/srt/function_call_parser.py +96 -69
  4. sglang/srt/layers/attention/double_sparsity_backend.py +1 -3
  5. sglang/srt/layers/attention/flashinfer_backend.py +34 -41
  6. sglang/srt/layers/attention/triton_backend.py +64 -16
  7. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +337 -3
  8. sglang/srt/layers/attention/triton_ops/extend_attention.py +70 -42
  9. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +20 -5
  10. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  11. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  12. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  13. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  14. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  15. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  16. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  17. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  18. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  19. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  20. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  21. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  22. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  23. sglang/srt/layers/quantization/fp8_kernel.py +43 -10
  24. sglang/srt/lora/backend/__init__.py +25 -5
  25. sglang/srt/lora/backend/base_backend.py +31 -9
  26. sglang/srt/lora/backend/flashinfer_backend.py +41 -4
  27. sglang/srt/lora/backend/triton_backend.py +34 -4
  28. sglang/srt/lora/layers.py +293 -0
  29. sglang/srt/lora/lora.py +101 -326
  30. sglang/srt/lora/lora_manager.py +101 -269
  31. sglang/srt/lora/mem_pool.py +174 -0
  32. sglang/srt/lora/triton_ops/__init__.py +7 -1
  33. sglang/srt/lora/triton_ops/gate_up_lora_b.py +170 -0
  34. sglang/srt/lora/triton_ops/qkv_lora_b.py +5 -5
  35. sglang/srt/lora/triton_ops/sgemm_lora_a.py +2 -2
  36. sglang/srt/lora/triton_ops/sgemm_lora_b.py +2 -2
  37. sglang/srt/lora/utils.py +141 -0
  38. sglang/srt/model_executor/cuda_graph_runner.py +4 -0
  39. sglang/srt/models/llama.py +8 -3
  40. sglang/srt/speculative/build_eagle_tree.py +482 -102
  41. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
  42. sglang/srt/speculative/eagle_utils.py +134 -61
  43. sglang/srt/speculative/eagle_worker.py +1 -0
  44. sglang/version.py +1 -1
  45. {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post4.dist-info}/METADATA +4 -4
  46. {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post4.dist-info}/RECORD +49 -32
  47. {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post4.dist-info}/LICENSE +0 -0
  48. {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post4.dist-info}/WHEEL +0 -0
  49. {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post4.dist-info}/top_level.txt +0 -0
@@ -3,6 +3,13 @@ import triton
3
3
  import triton.language as tl
4
4
 
5
5
  from sglang.srt.managers.schedule_batch import global_server_args_dict
6
+ from sglang.srt.utils import is_hip
7
+
8
+ is_cuda_available = torch.cuda.is_available()
9
+ if is_cuda_available:
10
+ CUDA_CAPABILITY = torch.cuda.get_device_capability()
11
+
12
+ is_hip_ = is_hip()
6
13
 
7
14
  if global_server_args_dict.get("attention_reduce_in_fp32", False):
8
15
  REDUCE_TRITON_TYPE = tl.float32
@@ -274,9 +281,6 @@ def flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, O, block_seq):
274
281
  return
275
282
 
276
283
 
277
- import torch
278
-
279
-
280
284
  def flash_decode_attention_fwd(
281
285
  q,
282
286
  k_buffer,
@@ -770,3 +774,333 @@ def flash_decode_sparse_attention_fwd(
770
774
  )
771
775
 
772
776
  sparse_flash_decode_stage3(heavy_token_num, mid_out, mid_o_logexpsum, o, BLOCK_SEQ)
777
+
778
+
779
+ # Extend attention kernel for Double Sparsity
780
+ # Moved from https://github.com/sgl-project/sglang/blob/v0.4.2.post1/python/sglang/srt/layers/attention/triton_ops/extend_attention.py
781
+ @triton.jit
782
+ def _fwd_kernel(
783
+ Q_Extend,
784
+ K_Extend,
785
+ V_Extend,
786
+ O_Extend,
787
+ K_Buffer,
788
+ V_Buffer,
789
+ Req_to_tokens,
790
+ B_req_idx,
791
+ B_Seq_Len,
792
+ B_Start_Loc_Extend,
793
+ B_Seq_Len_Extend,
794
+ sm_scale,
795
+ kv_group_num,
796
+ stride_qbs,
797
+ stride_qh,
798
+ stride_kbs,
799
+ stride_kh,
800
+ stride_vbs,
801
+ stride_vh,
802
+ stride_obs,
803
+ stride_oh,
804
+ stride_buf_kbs,
805
+ stride_buf_kh,
806
+ stride_buf_vbs,
807
+ stride_buf_vh,
808
+ stride_req_to_tokens_b,
809
+ logit_cap: tl.constexpr,
810
+ Lq: tl.constexpr,
811
+ Lv: tl.constexpr,
812
+ BLOCK_DMODEL: tl.constexpr,
813
+ BLOCK_DPE: tl.constexpr,
814
+ BLOCK_DV: tl.constexpr,
815
+ BLOCK_M: tl.constexpr,
816
+ BLOCK_N: tl.constexpr,
817
+ ):
818
+ cur_seq = tl.program_id(0)
819
+ cur_head = tl.program_id(1)
820
+ cur_block_m = tl.program_id(2)
821
+ cur_kv_head = cur_head // kv_group_num
822
+
823
+ cur_seq_len = tl.load(B_Seq_Len + cur_seq)
824
+ cur_seq_len_extend = tl.load(B_Seq_Len_Extend + cur_seq)
825
+ cur_seq_len_prefix = cur_seq_len - cur_seq_len_extend
826
+
827
+ cur_seq_prefix_start_in_loc = 0
828
+ cur_seq_extend_start_contiguous = tl.load(B_Start_Loc_Extend + cur_seq)
829
+ cur_batch_req_idx = tl.load(B_req_idx + cur_seq)
830
+
831
+ offs_d = tl.arange(0, BLOCK_DMODEL)
832
+ offs_dv = tl.arange(0, BLOCK_DV)
833
+ offs_m = tl.arange(0, BLOCK_M)
834
+ mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_len_extend
835
+
836
+ mask_d = offs_d < Lq
837
+ mask_dv = offs_dv < Lv
838
+
839
+ offs_q = (
840
+ (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
841
+ * stride_qbs
842
+ + cur_head * stride_qh
843
+ + offs_d[None, :]
844
+ )
845
+ q = tl.load(
846
+ Q_Extend + offs_q, mask=(mask_m[:, None]) & (mask_d[None, :]), other=0.0
847
+ )
848
+
849
+ if BLOCK_DPE > 0:
850
+ offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
851
+ offs_qpe = (
852
+ (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
853
+ * stride_qbs
854
+ + cur_head * stride_qh
855
+ + offs_dpe[None, :]
856
+ )
857
+ qpe = tl.load(Q_Extend + offs_qpe, mask=mask_m[:, None], other=0.0)
858
+
859
+ # stage 1: compute scores with prefix
860
+ offs_n = tl.arange(0, BLOCK_N)
861
+
862
+ acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32)
863
+ deno = tl.zeros([BLOCK_M], dtype=tl.float32)
864
+ e_max = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
865
+
866
+ for start_n in range(0, cur_seq_len_prefix, BLOCK_N):
867
+ start_n = tl.multiple_of(start_n, BLOCK_N)
868
+ mask_n = (start_n + offs_n) < cur_seq_len_prefix
869
+ offs_b_loc_prefix = cur_batch_req_idx * stride_req_to_tokens_b + (
870
+ cur_seq_prefix_start_in_loc + start_n + offs_n
871
+ )
872
+ offs_kv_loc = tl.load(Req_to_tokens + offs_b_loc_prefix, mask=mask_n, other=0)
873
+
874
+ # load k in transposed way
875
+ offs_buf_k = (
876
+ offs_kv_loc[None, :] * stride_buf_kbs
877
+ + cur_kv_head * stride_buf_kh
878
+ + offs_d[:, None]
879
+ )
880
+ k = tl.load(
881
+ K_Buffer + offs_buf_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0
882
+ )
883
+
884
+ qk = tl.dot(q.to(k.dtype), k)
885
+ if BLOCK_DPE > 0:
886
+ offs_kpe = (
887
+ offs_kv_loc[None, :] * stride_buf_kbs
888
+ + cur_kv_head * stride_buf_kh
889
+ + offs_dpe[:, None]
890
+ )
891
+ kpe = tl.load(
892
+ K_Buffer + offs_kpe,
893
+ mask=mask_n[None, :],
894
+ other=0.0,
895
+ )
896
+ qk += tl.dot(qpe.to(kpe.dtype), kpe)
897
+ qk *= sm_scale
898
+
899
+ if logit_cap > 0:
900
+ qk = logit_cap * tanh(qk / logit_cap)
901
+
902
+ qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float("-inf"))
903
+
904
+ n_e_max = tl.maximum(tl.max(qk, 1), e_max)
905
+ re_scale = tl.exp(e_max - n_e_max)
906
+ p = tl.exp(qk - n_e_max[:, None])
907
+ deno = deno * re_scale + tl.sum(p, 1)
908
+
909
+ offs_buf_v = (
910
+ offs_kv_loc[:, None] * stride_buf_vbs
911
+ + cur_kv_head * stride_buf_vh
912
+ + offs_dv[None, :]
913
+ )
914
+ v = tl.load(
915
+ V_Buffer + offs_buf_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0
916
+ )
917
+ p = p.to(v.dtype)
918
+ acc = acc * re_scale[:, None] + tl.dot(p, v)
919
+
920
+ e_max = n_e_max
921
+
922
+ # stage 2: compute the trianlge part
923
+
924
+ cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
925
+ for start_n in range(0, cur_block_m_end, BLOCK_N):
926
+ start_n = tl.multiple_of(start_n, BLOCK_N)
927
+ mask_n = (start_n + offs_n) < cur_block_m_end
928
+
929
+ # load k in transposed way
930
+ offs_k = (
931
+ (cur_seq_extend_start_contiguous + start_n + offs_n[None, :]) * stride_kbs
932
+ + cur_kv_head * stride_kh
933
+ + offs_d[:, None]
934
+ )
935
+ k = tl.load(
936
+ K_Extend + offs_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0
937
+ )
938
+
939
+ qk = tl.dot(q, k, out_dtype=tl.float32)
940
+ if BLOCK_DPE > 0:
941
+ offs_kpe = (
942
+ (cur_seq_extend_start_contiguous + start_n + offs_n[None, :])
943
+ * stride_kbs
944
+ + cur_kv_head * stride_kh
945
+ + offs_dpe[:, None]
946
+ )
947
+ kpe = tl.load(
948
+ K_Extend + offs_kpe,
949
+ mask=mask_n[None, :],
950
+ other=0.0,
951
+ )
952
+ qk += tl.dot(qpe, kpe)
953
+
954
+ qk *= sm_scale
955
+
956
+ if logit_cap > 0:
957
+ qk = logit_cap * tanh(qk / logit_cap)
958
+
959
+ mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
960
+ start_n + offs_n[None, :]
961
+ )
962
+ mask_causual &= mask_m[:, None] & mask_n[None, :]
963
+ qk = tl.where(mask_causual, qk, float("-inf"))
964
+
965
+ n_e_max = tl.maximum(tl.max(qk, 1), e_max)
966
+ re_scale = tl.exp(e_max - n_e_max)
967
+ p = tl.exp(qk - n_e_max[:, None])
968
+ deno = deno * re_scale + tl.sum(p, 1)
969
+
970
+ offs_v = (
971
+ (cur_seq_extend_start_contiguous + start_n + offs_n[:, None]) * stride_vbs
972
+ + cur_kv_head * stride_vh
973
+ + offs_dv[None, :]
974
+ )
975
+ v = tl.load(
976
+ V_Extend + offs_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0
977
+ )
978
+ p = p.to(v.dtype)
979
+ acc = acc * re_scale[:, None] + tl.dot(p, v)
980
+
981
+ e_max = n_e_max
982
+
983
+ offs_o = (
984
+ (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
985
+ * stride_obs
986
+ + cur_head * stride_oh
987
+ + offs_dv[None, :]
988
+ )
989
+ tl.store(
990
+ O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None] & mask_dv[None, :]
991
+ )
992
+
993
+
994
+ def extend_attention_fwd(
995
+ q_extend,
996
+ k_extend,
997
+ v_extend,
998
+ o_extend,
999
+ k_buffer,
1000
+ v_buffer,
1001
+ req_to_tokens,
1002
+ b_req_idx,
1003
+ b_seq_len,
1004
+ b_seq_len_extend,
1005
+ b_start_loc_extend,
1006
+ max_len_extend,
1007
+ sm_scale=None,
1008
+ logit_cap=0.0,
1009
+ ):
1010
+ """
1011
+ q_extend, k_extend, v_extend, o_extend: contiguous tensors
1012
+
1013
+ k_buffer, v_buffer: (prefix + extend) tensors in mem_manager
1014
+ """
1015
+ Lq, Lk, Lv = (
1016
+ q_extend.shape[-1],
1017
+ k_extend.shape[-1],
1018
+ v_extend.shape[-1],
1019
+ )
1020
+
1021
+ if Lq == 576:
1022
+ BLOCK_DMODEL = 512
1023
+ BLOCK_DPE = 64
1024
+ elif Lq == 288:
1025
+ BLOCK_DMODEL = 256
1026
+ BLOCK_DPE = 32
1027
+ elif Lq == 192:
1028
+ BLOCK_DMODEL = 128
1029
+ BLOCK_DPE = 64
1030
+ else:
1031
+ BLOCK_DMODEL = triton.next_power_of_2(Lq)
1032
+ BLOCK_DPE = 0
1033
+ BLOCK_DV = triton.next_power_of_2(Lv)
1034
+
1035
+ if is_hip_:
1036
+ BLOCK_M, BLOCK_N = (64, 64)
1037
+ num_warps = 4
1038
+
1039
+ else:
1040
+ if is_cuda_available and CUDA_CAPABILITY[0] >= 9:
1041
+ if Lq <= 256:
1042
+ BLOCK_M, BLOCK_N = (128, 64)
1043
+ else:
1044
+ BLOCK_M, BLOCK_N = (32, 64)
1045
+ elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
1046
+ if Lq <= 128:
1047
+ BLOCK_M, BLOCK_N = (128, 128)
1048
+ elif Lq <= 256:
1049
+ BLOCK_M, BLOCK_N = (64, 64)
1050
+ else:
1051
+ BLOCK_M, BLOCK_N = (32, 64)
1052
+ else:
1053
+ BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)
1054
+
1055
+ num_warps = 4 if Lk <= 64 else 8
1056
+
1057
+ sm_scale = sm_scale or 1.0 / (Lq**0.5)
1058
+ batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1]
1059
+ kv_group_num = q_extend.shape[1] // k_extend.shape[1]
1060
+
1061
+ grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
1062
+ num_stages = 1
1063
+
1064
+ extra_kargs = {}
1065
+ if is_hip_:
1066
+ extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
1067
+
1068
+ _fwd_kernel[grid](
1069
+ q_extend,
1070
+ k_extend,
1071
+ v_extend,
1072
+ o_extend,
1073
+ k_buffer,
1074
+ v_buffer,
1075
+ req_to_tokens,
1076
+ b_req_idx,
1077
+ b_seq_len,
1078
+ b_start_loc_extend,
1079
+ b_seq_len_extend,
1080
+ sm_scale,
1081
+ kv_group_num,
1082
+ q_extend.stride(0),
1083
+ q_extend.stride(1),
1084
+ k_extend.stride(0),
1085
+ k_extend.stride(1),
1086
+ v_extend.stride(0),
1087
+ v_extend.stride(1),
1088
+ o_extend.stride(0),
1089
+ o_extend.stride(1),
1090
+ k_buffer.stride(0),
1091
+ k_buffer.stride(1),
1092
+ v_buffer.stride(0),
1093
+ v_buffer.stride(1),
1094
+ req_to_tokens.stride(0),
1095
+ logit_cap=logit_cap,
1096
+ BLOCK_DMODEL=BLOCK_DMODEL,
1097
+ BLOCK_DPE=BLOCK_DPE,
1098
+ BLOCK_DV=BLOCK_DV,
1099
+ BLOCK_M=BLOCK_M,
1100
+ BLOCK_N=BLOCK_N,
1101
+ Lq=Lq,
1102
+ Lv=Lv,
1103
+ num_warps=num_warps,
1104
+ num_stages=num_stages,
1105
+ **extra_kargs,
1106
+ )
@@ -46,11 +46,11 @@ def _fwd_kernel(
46
46
  O_Extend,
47
47
  K_Buffer,
48
48
  V_Buffer,
49
- Req_to_tokens,
50
- B_req_idx,
51
- B_Seq_Len,
52
- B_Start_Loc_Extend,
53
- B_Seq_Len_Extend,
49
+ qo_indptr,
50
+ kv_indptr,
51
+ kv_indices,
52
+ mask_ptr,
53
+ mask_offsets,
54
54
  sm_scale,
55
55
  kv_group_num,
56
56
  stride_qbs,
@@ -65,7 +65,6 @@ def _fwd_kernel(
65
65
  stride_buf_kh,
66
66
  stride_buf_vbs,
67
67
  stride_buf_vh,
68
- stride_req_to_tokens_b,
69
68
  logit_cap: tl.constexpr,
70
69
  Lq: tl.constexpr,
71
70
  Lv: tl.constexpr,
@@ -74,19 +73,21 @@ def _fwd_kernel(
74
73
  BLOCK_DV: tl.constexpr,
75
74
  BLOCK_M: tl.constexpr,
76
75
  BLOCK_N: tl.constexpr,
76
+ USE_CUSTOM_MASK: tl.constexpr,
77
77
  ):
78
78
  cur_seq = tl.program_id(0)
79
79
  cur_head = tl.program_id(1)
80
80
  cur_block_m = tl.program_id(2)
81
81
  cur_kv_head = cur_head // kv_group_num
82
82
 
83
- cur_seq_len = tl.load(B_Seq_Len + cur_seq)
84
- cur_seq_len_extend = tl.load(B_Seq_Len_Extend + cur_seq)
85
- cur_seq_len_prefix = cur_seq_len - cur_seq_len_extend
83
+ cur_seq_extend_start_idx = tl.load(qo_indptr + cur_seq)
84
+ cur_seq_len_extend = tl.load(qo_indptr + cur_seq + 1) - cur_seq_extend_start_idx
85
+ cur_seq_kv_start_idx = tl.load(kv_indptr + cur_seq)
86
+ cur_seq_len_prefix = tl.load(kv_indptr + cur_seq + 1) - cur_seq_kv_start_idx
87
+ cur_seq_len = cur_seq_len_prefix + cur_seq_len_extend
86
88
 
87
- cur_seq_prefix_start_in_loc = 0
88
- cur_seq_extend_start_contiguous = tl.load(B_Start_Loc_Extend + cur_seq)
89
- cur_batch_req_idx = tl.load(B_req_idx + cur_seq)
89
+ if USE_CUSTOM_MASK:
90
+ cur_seq_mask_start_idx = tl.load(mask_offsets + cur_seq)
90
91
 
91
92
  offs_d = tl.arange(0, BLOCK_DMODEL)
92
93
  offs_dv = tl.arange(0, BLOCK_DV)
@@ -97,7 +98,7 @@ def _fwd_kernel(
97
98
  mask_dv = offs_dv < Lv
98
99
 
99
100
  offs_q = (
100
- (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
101
+ (cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
101
102
  * stride_qbs
102
103
  + cur_head * stride_qh
103
104
  + offs_d[None, :]
@@ -109,7 +110,7 @@ def _fwd_kernel(
109
110
  if BLOCK_DPE > 0:
110
111
  offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
111
112
  offs_qpe = (
112
- (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
113
+ (cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
113
114
  * stride_qbs
114
115
  + cur_head * stride_qh
115
116
  + offs_dpe[None, :]
@@ -126,10 +127,9 @@ def _fwd_kernel(
126
127
  for start_n in range(0, cur_seq_len_prefix, BLOCK_N):
127
128
  start_n = tl.multiple_of(start_n, BLOCK_N)
128
129
  mask_n = (start_n + offs_n) < cur_seq_len_prefix
129
- offs_b_loc_prefix = cur_batch_req_idx * stride_req_to_tokens_b + (
130
- cur_seq_prefix_start_in_loc + start_n + offs_n
130
+ offs_kv_loc = tl.load(
131
+ kv_indices + cur_seq_kv_start_idx + start_n + offs_n, mask=mask_n, other=0
131
132
  )
132
- offs_kv_loc = tl.load(Req_to_tokens + offs_b_loc_prefix, mask=mask_n, other=0)
133
133
 
134
134
  # load k in transposed way
135
135
  offs_buf_k = (
@@ -159,7 +159,20 @@ def _fwd_kernel(
159
159
  if logit_cap > 0:
160
160
  qk = logit_cap * tanh(qk / logit_cap)
161
161
 
162
- qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float("-inf"))
162
+ if USE_CUSTOM_MASK:
163
+ custom_mask = tl.load(
164
+ mask_ptr
165
+ + cur_seq_mask_start_idx
166
+ + (cur_block_m * BLOCK_M + offs_m[:, None]) * cur_seq_len
167
+ + start_n
168
+ + offs_n[None, :],
169
+ mask=(mask_m[:, None] & mask_n[None, :]),
170
+ other=0,
171
+ )
172
+ custom_mask &= mask_m[:, None] & mask_n[None, :]
173
+ qk = tl.where(custom_mask, qk, float("-inf"))
174
+ else:
175
+ qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float("-inf"))
163
176
 
164
177
  n_e_max = tl.maximum(tl.max(qk, 1), e_max)
165
178
  re_scale = tl.exp(e_max - n_e_max)
@@ -179,7 +192,7 @@ def _fwd_kernel(
179
192
 
180
193
  e_max = n_e_max
181
194
 
182
- # stage 2: compute the trianlge part
195
+ # stage 2: compute the triangle part
183
196
 
184
197
  cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
185
198
  for start_n in range(0, cur_block_m_end, BLOCK_N):
@@ -188,7 +201,7 @@ def _fwd_kernel(
188
201
 
189
202
  # load k in transposed way
190
203
  offs_k = (
191
- (cur_seq_extend_start_contiguous + start_n + offs_n[None, :]) * stride_kbs
204
+ (cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs
192
205
  + cur_kv_head * stride_kh
193
206
  + offs_d[:, None]
194
207
  )
@@ -199,8 +212,7 @@ def _fwd_kernel(
199
212
  qk = tl.dot(q, k, out_dtype=tl.float32)
200
213
  if BLOCK_DPE > 0:
201
214
  offs_kpe = (
202
- (cur_seq_extend_start_contiguous + start_n + offs_n[None, :])
203
- * stride_kbs
215
+ (cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs
204
216
  + cur_kv_head * stride_kh
205
217
  + offs_dpe[:, None]
206
218
  )
@@ -216,11 +228,25 @@ def _fwd_kernel(
216
228
  if logit_cap > 0:
217
229
  qk = logit_cap * tanh(qk / logit_cap)
218
230
 
219
- mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
220
- start_n + offs_n[None, :]
221
- )
222
- mask_causual &= mask_m[:, None] & mask_n[None, :]
223
- qk = tl.where(mask_causual, qk, float("-inf"))
231
+ if USE_CUSTOM_MASK:
232
+ custom_mask = tl.load(
233
+ mask_ptr
234
+ + cur_seq_mask_start_idx
235
+ + (cur_block_m * BLOCK_M + offs_m[:, None]) * cur_seq_len
236
+ + cur_seq_len_prefix
237
+ + start_n
238
+ + offs_n[None, :],
239
+ mask=(mask_m[:, None] & mask_n[None, :]),
240
+ other=0,
241
+ )
242
+ custom_mask &= mask_m[:, None] & mask_n[None, :]
243
+ qk = tl.where(custom_mask, qk, float("-inf"))
244
+ else:
245
+ mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
246
+ start_n + offs_n[None, :]
247
+ )
248
+ mask_causual &= mask_m[:, None] & mask_n[None, :]
249
+ qk = tl.where(mask_causual, qk, float("-inf"))
224
250
 
225
251
  n_e_max = tl.maximum(tl.max(qk, 1), e_max)
226
252
  re_scale = tl.exp(e_max - n_e_max)
@@ -228,7 +254,7 @@ def _fwd_kernel(
228
254
  deno = deno * re_scale + tl.sum(p, 1)
229
255
 
230
256
  offs_v = (
231
- (cur_seq_extend_start_contiguous + start_n + offs_n[:, None]) * stride_vbs
257
+ (cur_seq_extend_start_idx + start_n + offs_n[:, None]) * stride_vbs
232
258
  + cur_kv_head * stride_vh
233
259
  + offs_dv[None, :]
234
260
  )
@@ -241,7 +267,7 @@ def _fwd_kernel(
241
267
  e_max = n_e_max
242
268
 
243
269
  offs_o = (
244
- (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
270
+ (cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
245
271
  * stride_obs
246
272
  + cur_head * stride_oh
247
273
  + offs_dv[None, :]
@@ -258,11 +284,11 @@ def extend_attention_fwd(
258
284
  o_extend,
259
285
  k_buffer,
260
286
  v_buffer,
261
- req_to_tokens,
262
- b_req_idx,
263
- b_seq_len,
264
- b_seq_len_extend,
265
- b_start_loc_extend,
287
+ qo_indptr,
288
+ kv_indptr,
289
+ kv_indices,
290
+ custom_mask,
291
+ mask_offsets,
266
292
  max_len_extend,
267
293
  sm_scale=None,
268
294
  logit_cap=0.0,
@@ -315,15 +341,17 @@ def extend_attention_fwd(
315
341
  num_warps = 4 if Lk <= 64 else 8
316
342
 
317
343
  sm_scale = sm_scale or 1.0 / (Lq**0.5)
318
- batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1]
344
+ batch_size, head_num = qo_indptr.shape[0] - 1, q_extend.shape[1]
319
345
  kv_group_num = q_extend.shape[1] // k_extend.shape[1]
320
346
 
347
+ USE_CUSTOM_MASK = custom_mask is not None
348
+
321
349
  grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
322
350
  num_stages = 1
323
351
 
324
352
  extra_kargs = {}
325
353
  if is_hip_:
326
- extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
354
+ extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2}
327
355
 
328
356
  _fwd_kernel[grid](
329
357
  q_extend,
@@ -332,11 +360,11 @@ def extend_attention_fwd(
332
360
  o_extend,
333
361
  k_buffer,
334
362
  v_buffer,
335
- req_to_tokens,
336
- b_req_idx,
337
- b_seq_len,
338
- b_start_loc_extend,
339
- b_seq_len_extend,
363
+ qo_indptr,
364
+ kv_indptr,
365
+ kv_indices,
366
+ custom_mask,
367
+ mask_offsets,
340
368
  sm_scale,
341
369
  kv_group_num,
342
370
  q_extend.stride(0),
@@ -351,7 +379,6 @@ def extend_attention_fwd(
351
379
  k_buffer.stride(1),
352
380
  v_buffer.stride(0),
353
381
  v_buffer.stride(1),
354
- req_to_tokens.stride(0),
355
382
  logit_cap=logit_cap,
356
383
  BLOCK_DMODEL=BLOCK_DMODEL,
357
384
  BLOCK_DPE=BLOCK_DPE,
@@ -360,6 +387,7 @@ def extend_attention_fwd(
360
387
  BLOCK_N=BLOCK_N,
361
388
  Lq=Lq,
362
389
  Lv=Lv,
390
+ USE_CUSTOM_MASK=USE_CUSTOM_MASK,
363
391
  num_warps=num_warps,
364
392
  num_stages=num_stages,
365
393
  **extra_kargs,
@@ -18,7 +18,7 @@ from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
18
18
  from sglang.srt.utils import direct_register_custom_op, get_device_name, is_hip
19
19
 
20
20
  is_hip_flag = is_hip()
21
- from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
21
+
22
22
 
23
23
  logger = logging.getLogger(__name__)
24
24
  padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
@@ -27,6 +27,15 @@ enable_moe_align_block_size_triton = bool(
27
27
  int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
28
28
  )
29
29
 
30
+ _is_cuda = torch.cuda.is_available() and torch.version.cuda
31
+ _is_rocm = torch.cuda.is_available() and torch.version.hip
32
+
33
+ if _is_cuda:
34
+ from sgl_kernel import gelu_and_mul, silu_and_mul
35
+
36
+ if _is_cuda or _is_rocm:
37
+ from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
38
+
30
39
 
31
40
  @triton.jit
32
41
  def fused_moe_kernel(
@@ -417,12 +426,12 @@ def moe_align_block_size(
417
426
  num_tokens_post_pad,
418
427
  )
419
428
  else:
420
- token_cnts_buffer = torch.empty(
429
+ token_cnts_buffer = torch.zeros(
421
430
  (num_experts + 1) * num_experts,
422
431
  dtype=torch.int32,
423
432
  device=topk_ids.device,
424
433
  )
425
- cumsum_buffer = torch.empty(
434
+ cumsum_buffer = torch.zeros(
426
435
  num_experts + 1, dtype=torch.int32, device=topk_ids.device
427
436
  )
428
437
 
@@ -989,9 +998,15 @@ def fused_experts_impl(
989
998
  )
990
999
 
991
1000
  if activation == "silu":
992
- ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
1001
+ if _is_cuda:
1002
+ silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
1003
+ else:
1004
+ ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
993
1005
  elif activation == "gelu":
994
- ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
1006
+ if _is_cuda:
1007
+ gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
1008
+ else:
1009
+ ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
995
1010
  else:
996
1011
  raise ValueError(f"Unsupported activation: {activation=}")
997
1012