sglang 0.4.2.post1__py3-none-any.whl → 0.4.2.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. sglang/srt/constrained/outlines_backend.py +9 -1
  2. sglang/srt/custom_op.py +40 -0
  3. sglang/srt/entrypoints/engine.py +2 -2
  4. sglang/srt/function_call_parser.py +96 -69
  5. sglang/srt/layers/activation.py +10 -5
  6. sglang/srt/layers/attention/double_sparsity_backend.py +1 -3
  7. sglang/srt/layers/attention/flashinfer_backend.py +284 -39
  8. sglang/srt/layers/attention/triton_backend.py +124 -12
  9. sglang/srt/layers/attention/triton_ops/decode_attention.py +53 -59
  10. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +337 -3
  11. sglang/srt/layers/attention/triton_ops/extend_attention.py +70 -42
  12. sglang/srt/layers/layernorm.py +1 -5
  13. sglang/srt/layers/moe/ep_moe/layer.py +1 -3
  14. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  15. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  16. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +200 -0
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +200 -0
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +200 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +178 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +200 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +175 -0
  22. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -13
  23. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -3
  24. sglang/srt/layers/moe/topk.py +4 -0
  25. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  26. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  27. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  28. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  31. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  32. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  33. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  34. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  35. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  36. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  37. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  38. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  39. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  40. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  41. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  42. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  44. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  46. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/quantization/fp8_kernel.py +173 -2
  48. sglang/srt/layers/rotary_embedding.py +1 -3
  49. sglang/srt/layers/sampler.py +4 -4
  50. sglang/srt/lora/backend/__init__.py +8 -0
  51. sglang/srt/lora/backend/base_backend.py +95 -0
  52. sglang/srt/lora/backend/flashinfer_backend.py +91 -0
  53. sglang/srt/lora/backend/triton_backend.py +61 -0
  54. sglang/srt/lora/lora.py +127 -112
  55. sglang/srt/lora/lora_manager.py +50 -18
  56. sglang/srt/lora/triton_ops/__init__.py +5 -0
  57. sglang/srt/lora/triton_ops/qkv_lora_b.py +182 -0
  58. sglang/srt/lora/triton_ops/sgemm_lora_a.py +143 -0
  59. sglang/srt/lora/triton_ops/sgemm_lora_b.py +159 -0
  60. sglang/srt/model_executor/cuda_graph_runner.py +77 -80
  61. sglang/srt/model_executor/forward_batch_info.py +58 -59
  62. sglang/srt/model_executor/model_runner.py +2 -2
  63. sglang/srt/models/llama.py +8 -3
  64. sglang/srt/models/qwen2_vl.py +1 -1
  65. sglang/srt/server_args.py +13 -2
  66. sglang/srt/speculative/build_eagle_tree.py +486 -104
  67. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +213 -0
  68. sglang/srt/speculative/eagle_utils.py +420 -401
  69. sglang/srt/speculative/eagle_worker.py +177 -45
  70. sglang/srt/utils.py +7 -0
  71. sglang/test/runners.py +2 -0
  72. sglang/version.py +1 -1
  73. {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/METADATA +15 -6
  74. {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/RECORD +77 -38
  75. sglang/srt/layers/custom_op_util.py +0 -25
  76. {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/LICENSE +0 -0
  77. {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/WHEEL +0 -0
  78. {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/top_level.txt +0 -0
@@ -10,6 +10,7 @@ Each backend supports two operators: extend (i.e. prefill with cached prefix) an
10
10
  import os
11
11
  from dataclasses import dataclass
12
12
  from enum import Enum, auto
13
+ from functools import partial
13
14
  from typing import TYPE_CHECKING, List, Optional, Union
14
15
 
15
16
  import torch
@@ -34,6 +35,7 @@ if is_flashinfer_available():
34
35
  BatchPrefillWithRaggedKVCacheWrapper,
35
36
  )
36
37
  from flashinfer.cascade import merge_state
38
+ from flashinfer.decode import PosEncodingMode
37
39
 
38
40
 
39
41
  class WrapperDispatch(Enum):
@@ -53,10 +55,19 @@ class PrefillMetadata:
53
55
  extend_no_prefix: bool
54
56
 
55
57
 
58
+ # Reuse this workspace buffer across all flashinfer wrappers
59
+ global_workspace_buffer = None
60
+
61
+
56
62
  class FlashInferAttnBackend(AttentionBackend):
57
63
  """Flashinfer attention kernels."""
58
64
 
59
- def __init__(self, model_runner: ModelRunner):
65
+ def __init__(
66
+ self,
67
+ model_runner: ModelRunner,
68
+ skip_prefill: bool = False,
69
+ kv_indptr_buf: Optional[torch.Tensor] = None,
70
+ ):
60
71
  super().__init__()
61
72
 
62
73
  # Parse constants
@@ -69,6 +80,7 @@ class FlashInferAttnBackend(AttentionBackend):
69
80
  ),
70
81
  )
71
82
  self.max_context_len = model_runner.model_config.context_len
83
+ self.skip_prefill = skip_prefill
72
84
 
73
85
  assert not (
74
86
  model_runner.sliding_window_size is not None
@@ -90,16 +102,26 @@ class FlashInferAttnBackend(AttentionBackend):
90
102
  global_config.flashinfer_workspace_size = 512 * 1024 * 1024
91
103
 
92
104
  # Allocate buffers
93
- self.workspace_buffer = torch.empty(
94
- global_config.flashinfer_workspace_size,
95
- dtype=torch.uint8,
96
- device=model_runner.device,
97
- )
105
+ global global_workspace_buffer
106
+ if global_workspace_buffer is None:
107
+ global_workspace_buffer = torch.empty(
108
+ global_config.flashinfer_workspace_size,
109
+ dtype=torch.uint8,
110
+ device=model_runner.device,
111
+ )
112
+ self.workspace_buffer = global_workspace_buffer
98
113
  max_bs = model_runner.req_to_token_pool.size
99
- self.kv_indptr = [
100
- torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device)
101
- for _ in range(self.num_wrappers)
102
- ]
114
+ if kv_indptr_buf is None:
115
+ self.kv_indptr = [
116
+ torch.zeros(
117
+ (max_bs + 1,), dtype=torch.int32, device=model_runner.device
118
+ )
119
+ for _ in range(self.num_wrappers)
120
+ ]
121
+ else:
122
+ assert self.num_wrappers == 1
123
+ self.kv_indptr = [kv_indptr_buf]
124
+
103
125
  self.kv_last_page_len = torch.ones(
104
126
  (max_bs,), dtype=torch.int32, device=model_runner.device
105
127
  )
@@ -122,12 +144,17 @@ class FlashInferAttnBackend(AttentionBackend):
122
144
  self.prefill_wrappers_verify = []
123
145
  self.decode_wrappers = []
124
146
  for _ in range(self.num_wrappers):
125
- self.prefill_wrappers_paged.append(
126
- BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
127
- )
128
- self.prefill_wrappers_verify.append(
129
- BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
130
- )
147
+ if not skip_prefill:
148
+ self.prefill_wrappers_paged.append(
149
+ BatchPrefillWithPagedKVCacheWrapper(
150
+ self.workspace_buffer,
151
+ "NHD",
152
+ backend="fa2",
153
+ )
154
+ )
155
+ self.prefill_wrappers_verify.append(
156
+ BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
157
+ )
131
158
  self.decode_wrappers.append(
132
159
  BatchDecodeWithPagedKVCacheWrapper(
133
160
  self.workspace_buffer,
@@ -137,10 +164,11 @@ class FlashInferAttnBackend(AttentionBackend):
137
164
  )
138
165
 
139
166
  # Create indices updater
167
+ if not skip_prefill:
168
+ self.indices_updater_prefill = FlashInferIndicesUpdaterPrefill(
169
+ model_runner, self
170
+ )
140
171
  self.indices_updater_decode = FlashInferIndicesUpdaterDecode(model_runner, self)
141
- self.indices_updater_prefill = FlashInferIndicesUpdaterPrefill(
142
- model_runner, self
143
- )
144
172
 
145
173
  # Other metadata
146
174
  self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
@@ -211,23 +239,30 @@ class FlashInferAttnBackend(AttentionBackend):
211
239
  self.prefill_wrappers_paged, use_ragged, extend_no_prefix
212
240
  )
213
241
 
214
- def init_cuda_graph_state(self, max_bs: int):
215
- cuda_graph_kv_indices = torch.zeros(
216
- (max_bs * self.max_context_len,),
217
- dtype=torch.int32,
218
- device="cuda",
219
- )
242
+ def init_cuda_graph_state(
243
+ self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
244
+ ):
245
+ if kv_indices_buf is None:
246
+ cuda_graph_kv_indices = torch.zeros(
247
+ (max_bs * self.max_context_len,),
248
+ dtype=torch.int32,
249
+ device="cuda",
250
+ )
251
+ else:
252
+ cuda_graph_kv_indices = kv_indices_buf
253
+
220
254
  self.cuda_graph_kv_indices = [cuda_graph_kv_indices] + [
221
255
  cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1)
222
256
  ]
223
257
 
224
- self.cuda_graph_custom_mask = torch.zeros(
225
- (max_bs * self.max_context_len),
226
- dtype=torch.uint8,
227
- device="cuda",
228
- )
229
- self.cuda_graph_qk_indptr = [x.clone() for x in self.kv_indptr]
230
- self.cuda_graph_qo_indptr = [x.clone() for x in self.kv_indptr]
258
+ if not self.skip_prefill:
259
+ self.cuda_graph_custom_mask = torch.zeros(
260
+ (max_bs * self.max_context_len),
261
+ dtype=torch.uint8,
262
+ device="cuda",
263
+ )
264
+ self.cuda_graph_qk_indptr = [x.clone() for x in self.kv_indptr]
265
+ self.cuda_graph_qo_indptr = [x.clone() for x in self.kv_indptr]
231
266
 
232
267
  def init_forward_metadata_capture_cuda_graph(
233
268
  self,
@@ -279,7 +314,7 @@ class FlashInferAttnBackend(AttentionBackend):
279
314
  paged_kv_indices_buf=self.cuda_graph_kv_indices[i],
280
315
  paged_kv_last_page_len_buf=self.kv_last_page_len[:bs],
281
316
  custom_mask_buf=self.cuda_graph_custom_mask,
282
- qk_indptr_buf=self.cuda_graph_qk_indptr[i][: bs + 1],
317
+ mask_indptr_buf=self.cuda_graph_qk_indptr[i][: bs + 1],
283
318
  )
284
319
  )
285
320
  seq_lens_sum = seq_lens.sum().item()
@@ -602,11 +637,8 @@ class FlashInferIndicesUpdaterDecode:
602
637
  self.req_to_token.shape[1],
603
638
  )
604
639
  else:
605
- bs, kv_indices, kv_indptr = spec_info.generate_attn_arg_decode(
606
- req_pool_indices,
607
- paged_kernel_lens,
608
- self.req_to_token,
609
- )
640
+ kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
641
+ bs = kv_indptr.shape[0] - 1
610
642
 
611
643
  wrapper.end_forward()
612
644
  wrapper.begin_forward(
@@ -800,7 +832,9 @@ class FlashInferIndicesUpdaterPrefill:
800
832
  kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
801
833
  kv_indptr = kv_indptr[: bs + 1]
802
834
  kv_indices = torch.empty(
803
- paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
835
+ paged_kernel_lens_sum + 256,
836
+ dtype=torch.int32,
837
+ device=req_pool_indices.device,
804
838
  )
805
839
  create_flashinfer_kv_indices_triton[(bs,)](
806
840
  self.req_to_token,
@@ -852,6 +886,132 @@ class FlashInferIndicesUpdaterPrefill:
852
886
  )
853
887
 
854
888
 
889
+ class FlashInferMultiStepDraftBackend:
890
+ """
891
+ Wrap multiple flashinfer attention backends as one for multiple consecutive
892
+ draft decoding steps.
893
+ """
894
+
895
+ def __init__(
896
+ self,
897
+ model_runner: ModelRunner,
898
+ topk: int,
899
+ speculative_num_steps: int,
900
+ ):
901
+ from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
902
+
903
+ self.topk = topk
904
+ self.speculative_num_steps = speculative_num_steps
905
+ self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
906
+ max_bs = model_runner.req_to_token_pool.size
907
+ self.kv_indptr = torch.zeros(
908
+ (
909
+ self.speculative_num_steps,
910
+ max_bs + 1,
911
+ ),
912
+ dtype=torch.int32,
913
+ device=model_runner.device,
914
+ )
915
+ self.attn_backends = []
916
+ for i in range(self.speculative_num_steps):
917
+ self.attn_backends.append(
918
+ FlashInferAttnBackend(
919
+ model_runner,
920
+ skip_prefill=True,
921
+ kv_indptr_buf=self.kv_indptr[i],
922
+ )
923
+ )
924
+ self.max_context_len = self.attn_backends[0].max_context_len
925
+ # Cached variables for generate_draft_decode_kv_indices
926
+ self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
927
+ self.kv_indptr_stride = self.kv_indptr.shape[1]
928
+
929
+ def common_template(self, forward_batch: ForwardBatch, call_fn: int):
930
+ num_seqs = forward_batch.batch_size
931
+ bs = self.topk * num_seqs
932
+ seq_lens_sum = forward_batch.seq_lens_sum
933
+ self.generate_draft_decode_kv_indices[
934
+ (self.speculative_num_steps, num_seqs, self.topk)
935
+ ](
936
+ forward_batch.req_pool_indices,
937
+ forward_batch.req_to_token_pool.req_to_token,
938
+ forward_batch.seq_lens,
939
+ self.cuda_graph_kv_indices,
940
+ self.kv_indptr,
941
+ forward_batch.positions,
942
+ num_seqs,
943
+ self.topk,
944
+ self.pool_len,
945
+ self.kv_indptr_stride,
946
+ self.kv_indptr.shape[1],
947
+ triton.next_power_of_2(num_seqs),
948
+ triton.next_power_of_2(self.speculative_num_steps),
949
+ triton.next_power_of_2(bs),
950
+ )
951
+ for i in range(self.speculative_num_steps):
952
+ forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
953
+ forward_batch.spec_info.kv_indices = self.cuda_graph_kv_indices[i][
954
+ : seq_lens_sum * self.topk + bs * (i + 1)
955
+ ]
956
+ call_fn(i, forward_batch)
957
+
958
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
959
+ def call_fn(i, forward_batch):
960
+ forward_batch.spec_info.kv_indptr = (
961
+ forward_batch.spec_info.kv_indptr.clone()
962
+ )
963
+ forward_batch.spec_info.kv_indices = (
964
+ forward_batch.spec_info.kv_indices.clone()
965
+ )
966
+ self.attn_backends[i].init_forward_metadata(forward_batch)
967
+
968
+ self.common_template(forward_batch, call_fn)
969
+
970
+ def init_cuda_graph_state(self, max_bs: int):
971
+ self.cuda_graph_kv_indices = torch.zeros(
972
+ (self.speculative_num_steps, max_bs * self.max_context_len),
973
+ dtype=torch.int32,
974
+ device="cuda",
975
+ )
976
+ self.kv_indptr_stride = self.cuda_graph_kv_indices.shape[1]
977
+ for i in range(self.speculative_num_steps):
978
+ self.attn_backends[i].init_cuda_graph_state(
979
+ max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
980
+ )
981
+
982
+ def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
983
+ def call_fn(i, forward_batch):
984
+ self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
985
+ forward_batch.batch_size,
986
+ forward_batch.batch_size * self.topk,
987
+ forward_batch.req_pool_indices,
988
+ forward_batch.seq_lens,
989
+ encoder_lens=None,
990
+ forward_mode=ForwardMode.DECODE,
991
+ spec_info=forward_batch.spec_info,
992
+ )
993
+ decode_wrapper = self.attn_backends[i].decode_cuda_graph_metadata[
994
+ forward_batch.batch_size
995
+ ][0]
996
+ decode_wrapper.begin_forward = partial(fast_decode_plan, decode_wrapper)
997
+
998
+ self.common_template(forward_batch, call_fn)
999
+
1000
+ def init_forward_metadata_replay_cuda_graph(self, forward_batch):
1001
+ def call_fn(i, forward_batch):
1002
+ self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
1003
+ forward_batch.batch_size,
1004
+ forward_batch.req_pool_indices,
1005
+ forward_batch.seq_lens,
1006
+ seq_lens_sum=-1,
1007
+ encoder_lens=None,
1008
+ forward_mode=ForwardMode.DECODE,
1009
+ spec_info=forward_batch.spec_info,
1010
+ )
1011
+
1012
+ self.common_template(forward_batch, call_fn)
1013
+
1014
+
855
1015
  @triton.jit
856
1016
  def create_flashinfer_kv_indices_triton(
857
1017
  req_to_token_ptr, # [max_batch, max_context_len]
@@ -935,3 +1095,88 @@ def should_use_tensor_core(
935
1095
  return gqa_group_size > 4
936
1096
  else:
937
1097
  return False
1098
+
1099
+
1100
+ def fast_decode_plan(
1101
+ self,
1102
+ indptr: torch.Tensor,
1103
+ indices: torch.Tensor,
1104
+ last_page_len: torch.Tensor,
1105
+ num_qo_heads: int,
1106
+ num_kv_heads: int,
1107
+ head_dim: int,
1108
+ page_size: int,
1109
+ pos_encoding_mode: str = "NONE",
1110
+ window_left: int = -1,
1111
+ logits_soft_cap: Optional[float] = None,
1112
+ data_type: Union[str, torch.dtype] = "float16",
1113
+ q_data_type: Optional[Union[str, torch.dtype]] = None,
1114
+ sm_scale: Optional[float] = None,
1115
+ rope_scale: Optional[float] = None,
1116
+ rope_theta: Optional[float] = None,
1117
+ ) -> None:
1118
+ """A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend."""
1119
+ batch_size = len(last_page_len)
1120
+ if logits_soft_cap is None:
1121
+ logits_soft_cap = 0.0
1122
+ if self.is_cuda_graph_enabled:
1123
+ if batch_size != self._fixed_batch_size:
1124
+ raise ValueError(
1125
+ "The batch size should be fixed in cudagraph mode, the runtime batch size {} "
1126
+ " mismatches the batch size set during initialization {}".format(
1127
+ batch_size, self._fixed_batch_size
1128
+ )
1129
+ )
1130
+ if len(indices) > len(self._paged_kv_indices_buf):
1131
+ raise ValueError(
1132
+ "The size of indices should be less than or equal to the allocated buffer"
1133
+ )
1134
+ else:
1135
+ self._paged_kv_indptr_buf = indptr
1136
+ self._paged_kv_indices_buf = indices
1137
+ self._paged_kv_last_page_len_buf = last_page_len
1138
+ # NOTE(Zihao): the following tensors acts as placeholder to pass dtype info
1139
+ if not q_data_type:
1140
+ q_data_type = data_type
1141
+ if not hasattr(self, "empty_q_data"):
1142
+ self.empty_q_data = torch.empty(
1143
+ 0,
1144
+ dtype=(
1145
+ getattr(torch, q_data_type)
1146
+ if isinstance(q_data_type, str)
1147
+ else q_data_type
1148
+ ),
1149
+ )
1150
+ self.empty_kv_cache = torch.empty(
1151
+ 0,
1152
+ dtype=(
1153
+ getattr(torch, data_type) if isinstance(data_type, str) else data_type
1154
+ ),
1155
+ )
1156
+ self.last_page_len = torch.ones(32768, dtype=torch.int32)
1157
+ empty_q_data = self.empty_q_data
1158
+ empty_kv_cache = self.empty_kv_cache
1159
+ stream = torch.cuda.current_stream()
1160
+ self._cached_module.plan(
1161
+ self._float_workspace_buffer,
1162
+ self._int_workspace_buffer,
1163
+ self._pin_memory_int_workspace_buffer,
1164
+ indptr.to("cpu"),
1165
+ batch_size,
1166
+ num_qo_heads,
1167
+ num_kv_heads,
1168
+ page_size,
1169
+ self.is_cuda_graph_enabled,
1170
+ window_left,
1171
+ logits_soft_cap,
1172
+ head_dim,
1173
+ empty_q_data,
1174
+ empty_kv_cache,
1175
+ stream.cuda_stream,
1176
+ )
1177
+ self._pos_encoding_mode = pos_encoding_mode
1178
+ self._window_left = window_left
1179
+ self._logits_soft_cap = logits_soft_cap
1180
+ self._sm_scale = sm_scale
1181
+ self._rope_scale = rope_scale
1182
+ self._rope_theta = rope_theta
@@ -5,6 +5,9 @@ from typing import TYPE_CHECKING, Optional
5
5
  import torch
6
6
 
7
7
  from sglang.srt.layers.attention import AttentionBackend
8
+ from sglang.srt.layers.attention.flashinfer_backend import (
9
+ create_flashinfer_kv_indices_triton,
10
+ )
8
11
  from sglang.srt.layers.dp_attention import get_attention_tp_size
9
12
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
10
13
 
@@ -29,6 +32,15 @@ class TritonAttnBackend(AttentionBackend):
29
32
  self.decode_attention_fwd = decode_attention_fwd
30
33
  self.extend_attention_fwd = extend_attention_fwd
31
34
 
35
+ max_bs = model_runner.req_to_token_pool.size
36
+ self.kv_indptr = torch.zeros(
37
+ (max_bs + 1,), dtype=torch.int32, device=model_runner.device
38
+ )
39
+ self.req_to_token = model_runner.req_to_token_pool.req_to_token
40
+ self.qo_indptr = torch.zeros(
41
+ (max_bs + 1,), dtype=torch.int32, device=model_runner.device
42
+ )
43
+
32
44
  self.num_head = (
33
45
  model_runner.model_config.num_attention_heads // get_attention_tp_size()
34
46
  )
@@ -45,6 +57,9 @@ class TritonAttnBackend(AttentionBackend):
45
57
  def init_forward_metadata(self, forward_batch: ForwardBatch):
46
58
  """Init auxiliary variables for triton attention backend."""
47
59
 
60
+ bs = forward_batch.batch_size
61
+ kv_indptr = self.kv_indptr
62
+
48
63
  if forward_batch.forward_mode.is_decode():
49
64
  attn_logits = torch.empty(
50
65
  (
@@ -58,11 +73,63 @@ class TritonAttnBackend(AttentionBackend):
58
73
  )
59
74
 
60
75
  max_extend_len = None
76
+
77
+ kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
78
+ kv_indptr = kv_indptr[: bs + 1]
79
+ kv_indices = torch.empty(
80
+ forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
81
+ )
82
+ create_flashinfer_kv_indices_triton[(bs,)](
83
+ self.req_to_token,
84
+ forward_batch.req_pool_indices,
85
+ forward_batch.seq_lens,
86
+ kv_indptr,
87
+ None,
88
+ kv_indices,
89
+ self.req_to_token.stride(0),
90
+ )
91
+
92
+ qo_indptr = None
93
+ custom_mask = None
94
+ mask_offsets = None
61
95
  else:
96
+ kv_indptr[1 : bs + 1] = torch.cumsum(
97
+ forward_batch.extend_prefix_lens, dim=0
98
+ )
99
+ kv_indptr = kv_indptr[: bs + 1]
100
+ kv_indices = torch.empty(
101
+ forward_batch.extend_prefix_lens.sum().item(),
102
+ dtype=torch.int32,
103
+ device=self.device,
104
+ )
105
+ create_flashinfer_kv_indices_triton[(bs,)](
106
+ self.req_to_token,
107
+ forward_batch.req_pool_indices,
108
+ forward_batch.extend_prefix_lens,
109
+ kv_indptr,
110
+ None,
111
+ kv_indices,
112
+ self.req_to_token.stride(0),
113
+ )
114
+
115
+ qo_indptr = self.qo_indptr
116
+ qo_indptr[1 : bs + 1] = torch.cumsum(forward_batch.extend_seq_lens, dim=0)
117
+ qo_indptr = qo_indptr[: bs + 1]
118
+ custom_mask = None
119
+ mask_offsets = None
120
+
62
121
  attn_logits = None
63
122
  max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
64
123
 
65
- self.forward_metadata = attn_logits, max_extend_len
124
+ self.forward_metadata = (
125
+ attn_logits,
126
+ max_extend_len,
127
+ kv_indptr,
128
+ kv_indices,
129
+ qo_indptr,
130
+ custom_mask,
131
+ mask_offsets,
132
+ )
66
133
 
67
134
  def init_cuda_graph_state(self, max_bs: int):
68
135
  self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
@@ -73,7 +140,12 @@ class TritonAttnBackend(AttentionBackend):
73
140
  self.cuda_graph_attn_logits = torch.empty(
74
141
  (max_bs, self.num_head, self.num_kv_splits, self.v_head_dim + 1),
75
142
  dtype=torch.float32,
76
- device="cuda",
143
+ device=self.device,
144
+ )
145
+ self.cuda_graph_kv_indices = torch.zeros(
146
+ (max_bs * self.cuda_graph_max_seq_len),
147
+ dtype=torch.int32,
148
+ device=self.device,
77
149
  )
78
150
 
79
151
  def init_forward_metadata_capture_cuda_graph(
@@ -90,9 +162,28 @@ class TritonAttnBackend(AttentionBackend):
90
162
  assert forward_mode.is_decode(), "Not supported"
91
163
  assert spec_info is None, "Not supported"
92
164
 
165
+ kv_indptr = self.kv_indptr
166
+ kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
167
+ kv_indptr = kv_indptr[: bs + 1]
168
+ kv_indices = self.cuda_graph_kv_indices
169
+ create_flashinfer_kv_indices_triton[(bs,)](
170
+ self.req_to_token,
171
+ req_pool_indices,
172
+ seq_lens,
173
+ kv_indptr,
174
+ None,
175
+ kv_indices,
176
+ self.req_to_token.stride(0),
177
+ )
178
+
93
179
  self.forward_metadata = (
94
180
  self.cuda_graph_attn_logits,
95
181
  None,
182
+ kv_indptr,
183
+ kv_indices,
184
+ None,
185
+ None,
186
+ None,
96
187
  )
97
188
 
98
189
  def init_forward_metadata_replay_cuda_graph(
@@ -109,6 +200,20 @@ class TritonAttnBackend(AttentionBackend):
109
200
  self.cuda_graph_start_loc.zero_()
110
201
  self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
111
202
 
203
+ kv_indptr = self.kv_indptr
204
+ kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0)
205
+ kv_indptr = kv_indptr[: bs + 1]
206
+ kv_indices = self.cuda_graph_kv_indices
207
+ create_flashinfer_kv_indices_triton[(bs,)](
208
+ self.req_to_token,
209
+ req_pool_indices[:bs],
210
+ seq_lens[:bs],
211
+ kv_indptr,
212
+ None,
213
+ kv_indices,
214
+ self.req_to_token.stride(0),
215
+ )
216
+
112
217
  def get_cuda_graph_seq_len_fill_value(self):
113
218
  return 1
114
219
 
@@ -132,7 +237,15 @@ class TritonAttnBackend(AttentionBackend):
132
237
  layer, forward_batch.out_cache_loc, k, v
133
238
  )
134
239
 
135
- _, max_extend_len = self.forward_metadata
240
+ (
241
+ _,
242
+ max_extend_len,
243
+ kv_indptr,
244
+ kv_indices,
245
+ qo_indptr,
246
+ custom_mask,
247
+ mask_offsets,
248
+ ) = self.forward_metadata
136
249
  self.extend_attention_fwd(
137
250
  q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
138
251
  k.contiguous(),
@@ -140,11 +253,11 @@ class TritonAttnBackend(AttentionBackend):
140
253
  o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
141
254
  forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
142
255
  forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
143
- forward_batch.req_to_token_pool.req_to_token,
144
- forward_batch.req_pool_indices,
145
- forward_batch.seq_lens,
146
- forward_batch.extend_seq_lens,
147
- forward_batch.extend_start_loc,
256
+ qo_indptr,
257
+ kv_indptr,
258
+ kv_indices,
259
+ custom_mask,
260
+ mask_offsets,
148
261
  max_extend_len,
149
262
  layer.scaling,
150
263
  layer.logit_cap,
@@ -170,7 +283,7 @@ class TritonAttnBackend(AttentionBackend):
170
283
  else:
171
284
  o = torch.empty_like(q)
172
285
 
173
- attn_logits, _ = self.forward_metadata
286
+ attn_logits, _, kv_indptr, kv_indices, _, _, _ = self.forward_metadata
174
287
 
175
288
  if save_kv_cache:
176
289
  forward_batch.token_to_kv_pool.set_kv_buffer(
@@ -182,9 +295,8 @@ class TritonAttnBackend(AttentionBackend):
182
295
  forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
183
296
  forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
184
297
  o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
185
- forward_batch.req_to_token_pool.req_to_token,
186
- forward_batch.req_pool_indices,
187
- forward_batch.seq_lens,
298
+ kv_indptr,
299
+ kv_indices,
188
300
  attn_logits,
189
301
  self.num_kv_splits,
190
302
  layer.scaling,