sglang 0.4.2__py3-none-any.whl → 0.4.2.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. sglang/srt/constrained/outlines_backend.py +9 -1
  2. sglang/srt/custom_op.py +40 -0
  3. sglang/srt/entrypoints/engine.py +2 -2
  4. sglang/srt/layers/activation.py +10 -5
  5. sglang/srt/layers/attention/flashinfer_backend.py +284 -39
  6. sglang/srt/layers/attention/triton_backend.py +71 -7
  7. sglang/srt/layers/attention/triton_ops/decode_attention.py +53 -59
  8. sglang/srt/layers/attention/triton_ops/prefill_attention.py +6 -0
  9. sglang/srt/layers/attention/vision.py +243 -40
  10. sglang/srt/layers/layernorm.py +1 -5
  11. sglang/srt/layers/moe/ep_moe/layer.py +1 -3
  12. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  13. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  14. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +200 -0
  15. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +200 -0
  16. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +200 -0
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +178 -0
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +200 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +175 -0
  20. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -11
  21. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -3
  22. sglang/srt/layers/moe/topk.py +4 -0
  23. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  24. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  25. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  26. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  27. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  30. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  31. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  32. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  33. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  34. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  35. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  36. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  37. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  38. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  39. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  40. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  41. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  42. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  44. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/quantization/fp8.py +7 -0
  46. sglang/srt/layers/quantization/fp8_kernel.py +140 -2
  47. sglang/srt/layers/rotary_embedding.py +29 -15
  48. sglang/srt/layers/sampler.py +9 -6
  49. sglang/srt/lora/backend/__init__.py +8 -0
  50. sglang/srt/lora/backend/base_backend.py +95 -0
  51. sglang/srt/lora/backend/flashinfer_backend.py +91 -0
  52. sglang/srt/lora/backend/triton_backend.py +61 -0
  53. sglang/srt/lora/lora.py +127 -112
  54. sglang/srt/lora/lora_manager.py +50 -18
  55. sglang/srt/lora/triton_ops/__init__.py +5 -0
  56. sglang/srt/lora/triton_ops/qkv_lora_b.py +182 -0
  57. sglang/srt/lora/triton_ops/sgemm_lora_a.py +143 -0
  58. sglang/srt/lora/triton_ops/sgemm_lora_b.py +159 -0
  59. sglang/srt/managers/image_processor.py +77 -38
  60. sglang/srt/managers/scheduler.py +17 -3
  61. sglang/srt/mem_cache/base_prefix_cache.py +4 -0
  62. sglang/srt/mem_cache/chunk_cache.py +3 -0
  63. sglang/srt/mem_cache/radix_cache.py +30 -1
  64. sglang/srt/model_executor/cuda_graph_runner.py +77 -80
  65. sglang/srt/model_executor/forward_batch_info.py +58 -59
  66. sglang/srt/model_executor/model_runner.py +2 -2
  67. sglang/srt/models/minicpmv.py +129 -76
  68. sglang/srt/models/mllama.py +16 -56
  69. sglang/srt/models/qwen2.py +4 -1
  70. sglang/srt/models/qwen2_vl.py +19 -9
  71. sglang/srt/server_args.py +19 -2
  72. sglang/srt/speculative/build_eagle_tree.py +4 -2
  73. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +213 -0
  74. sglang/srt/speculative/eagle_utils.py +361 -372
  75. sglang/srt/speculative/eagle_worker.py +177 -45
  76. sglang/srt/utils.py +7 -2
  77. sglang/test/runners.py +2 -0
  78. sglang/utils.py +42 -0
  79. sglang/version.py +1 -1
  80. {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/METADATA +16 -7
  81. {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/RECORD +84 -45
  82. sglang/srt/layers/custom_op_util.py +0 -25
  83. {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/LICENSE +0 -0
  84. {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/WHEEL +0 -0
  85. {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/top_level.txt +0 -0
@@ -20,7 +20,6 @@ from typing import Dict, List, Optional, Tuple, Union
20
20
  import interegular
21
21
  import torch
22
22
  from outlines.fsm.guide import RegexGuide
23
- from outlines.fsm.json_schema import build_regex_from_schema
24
23
  from outlines.models.transformers import TransformerTokenizer
25
24
  from pydantic import BaseModel
26
25
 
@@ -29,6 +28,15 @@ from sglang.srt.constrained.base_grammar_backend import (
29
28
  BaseGrammarObject,
30
29
  )
31
30
  from sglang.srt.constrained.outlines_jump_forward import OutlinesJumpForwardMap
31
+ from sglang.srt.utils import is_hip
32
+
33
+ is_hip_ = is_hip()
34
+
35
+ if is_hip_:
36
+ from outlines_core.fsm.json_schema import build_regex_from_schema
37
+ else:
38
+ from outlines.fsm.json_schema import build_regex_from_schema
39
+
32
40
 
33
41
  logger = logging.getLogger(__name__)
34
42
 
@@ -0,0 +1,40 @@
1
+ import torch
2
+ from torch import nn
3
+
4
+ _is_cuda = torch.cuda.is_available() and torch.version.cuda
5
+ _is_rocm = torch.cuda.is_available() and torch.version.hip
6
+
7
+
8
+ class CustomOp(nn.Module):
9
+ def __init__(self):
10
+ super().__init__()
11
+ self._forward_method = self.dispatch_forward()
12
+
13
+ def forward(self, *args, **kwargs):
14
+ return self._forward_method(*args, **kwargs)
15
+
16
+ def forward_native(self, *args, **kwargs):
17
+ raise NotImplementedError
18
+
19
+ def forward_cuda(self, *args, **kwargs):
20
+ raise NotImplementedError
21
+
22
+ def forward_hip(self, *args, **kwargs):
23
+ return self.forward_cuda(*args, **kwargs)
24
+
25
+ def forward_xpu(self, *args, **kwargs):
26
+ return self.forward_native(*args, **kwargs)
27
+
28
+ def forward_hpu(self, *args, **kwargs):
29
+ return self.forward_native(*args, **kwargs)
30
+
31
+ def forward_cpu(self, *args, **kwargs):
32
+ return self.forward_native(*args, **kwargs)
33
+
34
+ def dispatch_forward(self):
35
+ if _is_cuda:
36
+ return self.forward_cuda
37
+ elif _is_rocm:
38
+ return self.forward_hip
39
+ else:
40
+ return self.forward_native
@@ -316,8 +316,8 @@ def _set_envs_and_config(server_args: ServerArgs):
316
316
  # Check flashinfer version
317
317
  if server_args.attention_backend == "flashinfer":
318
318
  assert_pkg_version(
319
- "flashinfer",
320
- "0.1.6",
319
+ "flashinfer_python",
320
+ "0.2.0.post2",
321
321
  "Please uninstall the old version and "
322
322
  "reinstall the latest version by following the instructions "
323
323
  "at https://docs.flashinfer.ai/installation.html.",
@@ -25,21 +25,18 @@ from sglang.srt.utils import is_cuda_available
25
25
  if is_cuda_available():
26
26
  from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
27
27
 
28
- from vllm.model_executor.custom_op import CustomOp
29
-
28
+ from sglang.srt.custom_op import CustomOp
30
29
  from sglang.srt.distributed import (
31
30
  divide,
32
31
  get_tensor_model_parallel_rank,
33
32
  get_tensor_model_parallel_world_size,
34
33
  )
35
- from sglang.srt.layers.custom_op_util import register_custom_op
36
34
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
37
35
  from sglang.srt.utils import set_weight_attrs
38
36
 
39
37
  logger = logging.getLogger(__name__)
40
38
 
41
39
 
42
- @register_custom_op("sglang_silu_and_mul")
43
40
  class SiluAndMul(CustomOp):
44
41
  def forward_native(self, x: torch.Tensor) -> torch.Tensor:
45
42
  d = x.shape[-1] // 2
@@ -53,7 +50,6 @@ class SiluAndMul(CustomOp):
53
50
  return out
54
51
 
55
52
 
56
- @register_custom_op("sglang_gelu_and_mul")
57
53
  class GeluAndMul(CustomOp):
58
54
  def __init__(self, approximate="tanh"):
59
55
  super().__init__()
@@ -76,6 +72,15 @@ class GeluAndMul(CustomOp):
76
72
  return out
77
73
 
78
74
 
75
+ class QuickGELU(CustomOp):
76
+ def forward_native(self, x: torch.Tensor) -> torch.Tensor:
77
+ return x * torch.sigmoid(1.702 * x)
78
+
79
+ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
80
+ # TODO(zhyncs): Implement the CUDA kernel for QuickGELU in sgl-kernel
81
+ return self.forward_native(x)
82
+
83
+
79
84
  class ScaledActivation(nn.Module):
80
85
  """An activation function with post-scale parameters.
81
86
 
@@ -10,6 +10,7 @@ Each backend supports two operators: extend (i.e. prefill with cached prefix) an
10
10
  import os
11
11
  from dataclasses import dataclass
12
12
  from enum import Enum, auto
13
+ from functools import partial
13
14
  from typing import TYPE_CHECKING, List, Optional, Union
14
15
 
15
16
  import torch
@@ -34,6 +35,7 @@ if is_flashinfer_available():
34
35
  BatchPrefillWithRaggedKVCacheWrapper,
35
36
  )
36
37
  from flashinfer.cascade import merge_state
38
+ from flashinfer.decode import PosEncodingMode
37
39
 
38
40
 
39
41
  class WrapperDispatch(Enum):
@@ -53,10 +55,19 @@ class PrefillMetadata:
53
55
  extend_no_prefix: bool
54
56
 
55
57
 
58
+ # Reuse this workspace buffer across all flashinfer wrappers
59
+ global_workspace_buffer = None
60
+
61
+
56
62
  class FlashInferAttnBackend(AttentionBackend):
57
63
  """Flashinfer attention kernels."""
58
64
 
59
- def __init__(self, model_runner: ModelRunner):
65
+ def __init__(
66
+ self,
67
+ model_runner: ModelRunner,
68
+ skip_prefill: bool = False,
69
+ kv_indptr_buf: Optional[torch.Tensor] = None,
70
+ ):
60
71
  super().__init__()
61
72
 
62
73
  # Parse constants
@@ -69,6 +80,7 @@ class FlashInferAttnBackend(AttentionBackend):
69
80
  ),
70
81
  )
71
82
  self.max_context_len = model_runner.model_config.context_len
83
+ self.skip_prefill = skip_prefill
72
84
 
73
85
  assert not (
74
86
  model_runner.sliding_window_size is not None
@@ -90,16 +102,26 @@ class FlashInferAttnBackend(AttentionBackend):
90
102
  global_config.flashinfer_workspace_size = 512 * 1024 * 1024
91
103
 
92
104
  # Allocate buffers
93
- self.workspace_buffer = torch.empty(
94
- global_config.flashinfer_workspace_size,
95
- dtype=torch.uint8,
96
- device=model_runner.device,
97
- )
105
+ global global_workspace_buffer
106
+ if global_workspace_buffer is None:
107
+ global_workspace_buffer = torch.empty(
108
+ global_config.flashinfer_workspace_size,
109
+ dtype=torch.uint8,
110
+ device=model_runner.device,
111
+ )
112
+ self.workspace_buffer = global_workspace_buffer
98
113
  max_bs = model_runner.req_to_token_pool.size
99
- self.kv_indptr = [
100
- torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device)
101
- for _ in range(self.num_wrappers)
102
- ]
114
+ if kv_indptr_buf is None:
115
+ self.kv_indptr = [
116
+ torch.zeros(
117
+ (max_bs + 1,), dtype=torch.int32, device=model_runner.device
118
+ )
119
+ for _ in range(self.num_wrappers)
120
+ ]
121
+ else:
122
+ assert self.num_wrappers == 1
123
+ self.kv_indptr = [kv_indptr_buf]
124
+
103
125
  self.kv_last_page_len = torch.ones(
104
126
  (max_bs,), dtype=torch.int32, device=model_runner.device
105
127
  )
@@ -122,12 +144,17 @@ class FlashInferAttnBackend(AttentionBackend):
122
144
  self.prefill_wrappers_verify = []
123
145
  self.decode_wrappers = []
124
146
  for _ in range(self.num_wrappers):
125
- self.prefill_wrappers_paged.append(
126
- BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
127
- )
128
- self.prefill_wrappers_verify.append(
129
- BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
130
- )
147
+ if not skip_prefill:
148
+ self.prefill_wrappers_paged.append(
149
+ BatchPrefillWithPagedKVCacheWrapper(
150
+ self.workspace_buffer,
151
+ "NHD",
152
+ backend="fa2",
153
+ )
154
+ )
155
+ self.prefill_wrappers_verify.append(
156
+ BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
157
+ )
131
158
  self.decode_wrappers.append(
132
159
  BatchDecodeWithPagedKVCacheWrapper(
133
160
  self.workspace_buffer,
@@ -137,10 +164,11 @@ class FlashInferAttnBackend(AttentionBackend):
137
164
  )
138
165
 
139
166
  # Create indices updater
167
+ if not skip_prefill:
168
+ self.indices_updater_prefill = FlashInferIndicesUpdaterPrefill(
169
+ model_runner, self
170
+ )
140
171
  self.indices_updater_decode = FlashInferIndicesUpdaterDecode(model_runner, self)
141
- self.indices_updater_prefill = FlashInferIndicesUpdaterPrefill(
142
- model_runner, self
143
- )
144
172
 
145
173
  # Other metadata
146
174
  self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
@@ -211,23 +239,30 @@ class FlashInferAttnBackend(AttentionBackend):
211
239
  self.prefill_wrappers_paged, use_ragged, extend_no_prefix
212
240
  )
213
241
 
214
- def init_cuda_graph_state(self, max_bs: int):
215
- cuda_graph_kv_indices = torch.zeros(
216
- (max_bs * self.max_context_len,),
217
- dtype=torch.int32,
218
- device="cuda",
219
- )
242
+ def init_cuda_graph_state(
243
+ self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
244
+ ):
245
+ if kv_indices_buf is None:
246
+ cuda_graph_kv_indices = torch.zeros(
247
+ (max_bs * self.max_context_len,),
248
+ dtype=torch.int32,
249
+ device="cuda",
250
+ )
251
+ else:
252
+ cuda_graph_kv_indices = kv_indices_buf
253
+
220
254
  self.cuda_graph_kv_indices = [cuda_graph_kv_indices] + [
221
255
  cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1)
222
256
  ]
223
257
 
224
- self.cuda_graph_custom_mask = torch.zeros(
225
- (max_bs * self.max_context_len),
226
- dtype=torch.uint8,
227
- device="cuda",
228
- )
229
- self.cuda_graph_qk_indptr = [x.clone() for x in self.kv_indptr]
230
- self.cuda_graph_qo_indptr = [x.clone() for x in self.kv_indptr]
258
+ if not self.skip_prefill:
259
+ self.cuda_graph_custom_mask = torch.zeros(
260
+ (max_bs * self.max_context_len),
261
+ dtype=torch.uint8,
262
+ device="cuda",
263
+ )
264
+ self.cuda_graph_qk_indptr = [x.clone() for x in self.kv_indptr]
265
+ self.cuda_graph_qo_indptr = [x.clone() for x in self.kv_indptr]
231
266
 
232
267
  def init_forward_metadata_capture_cuda_graph(
233
268
  self,
@@ -279,7 +314,7 @@ class FlashInferAttnBackend(AttentionBackend):
279
314
  paged_kv_indices_buf=self.cuda_graph_kv_indices[i],
280
315
  paged_kv_last_page_len_buf=self.kv_last_page_len[:bs],
281
316
  custom_mask_buf=self.cuda_graph_custom_mask,
282
- qk_indptr_buf=self.cuda_graph_qk_indptr[i][: bs + 1],
317
+ mask_indptr_buf=self.cuda_graph_qk_indptr[i][: bs + 1],
283
318
  )
284
319
  )
285
320
  seq_lens_sum = seq_lens.sum().item()
@@ -602,11 +637,8 @@ class FlashInferIndicesUpdaterDecode:
602
637
  self.req_to_token.shape[1],
603
638
  )
604
639
  else:
605
- bs, kv_indices, kv_indptr = spec_info.generate_attn_arg_decode(
606
- req_pool_indices,
607
- paged_kernel_lens,
608
- self.req_to_token,
609
- )
640
+ kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
641
+ bs = kv_indptr.shape[0] - 1
610
642
 
611
643
  wrapper.end_forward()
612
644
  wrapper.begin_forward(
@@ -800,7 +832,9 @@ class FlashInferIndicesUpdaterPrefill:
800
832
  kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
801
833
  kv_indptr = kv_indptr[: bs + 1]
802
834
  kv_indices = torch.empty(
803
- paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
835
+ paged_kernel_lens_sum + 256,
836
+ dtype=torch.int32,
837
+ device=req_pool_indices.device,
804
838
  )
805
839
  create_flashinfer_kv_indices_triton[(bs,)](
806
840
  self.req_to_token,
@@ -852,6 +886,132 @@ class FlashInferIndicesUpdaterPrefill:
852
886
  )
853
887
 
854
888
 
889
+ class FlashInferMultiStepDraftBackend:
890
+ """
891
+ Wrap multiple flashinfer attention backends as one for multiple consecutive
892
+ draft decoding steps.
893
+ """
894
+
895
+ def __init__(
896
+ self,
897
+ model_runner: ModelRunner,
898
+ topk: int,
899
+ speculative_num_steps: int,
900
+ ):
901
+ from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
902
+
903
+ self.topk = topk
904
+ self.speculative_num_steps = speculative_num_steps
905
+ self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
906
+ max_bs = model_runner.req_to_token_pool.size
907
+ self.kv_indptr = torch.zeros(
908
+ (
909
+ self.speculative_num_steps,
910
+ max_bs + 1,
911
+ ),
912
+ dtype=torch.int32,
913
+ device=model_runner.device,
914
+ )
915
+ self.attn_backends = []
916
+ for i in range(self.speculative_num_steps):
917
+ self.attn_backends.append(
918
+ FlashInferAttnBackend(
919
+ model_runner,
920
+ skip_prefill=True,
921
+ kv_indptr_buf=self.kv_indptr[i],
922
+ )
923
+ )
924
+ self.max_context_len = self.attn_backends[0].max_context_len
925
+ # Cached variables for generate_draft_decode_kv_indices
926
+ self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
927
+ self.kv_indptr_stride = self.kv_indptr.shape[1]
928
+
929
+ def common_template(self, forward_batch: ForwardBatch, call_fn: int):
930
+ num_seqs = forward_batch.batch_size
931
+ bs = self.topk * num_seqs
932
+ seq_lens_sum = forward_batch.seq_lens_sum
933
+ self.generate_draft_decode_kv_indices[
934
+ (self.speculative_num_steps, num_seqs, self.topk)
935
+ ](
936
+ forward_batch.req_pool_indices,
937
+ forward_batch.req_to_token_pool.req_to_token,
938
+ forward_batch.seq_lens,
939
+ self.cuda_graph_kv_indices,
940
+ self.kv_indptr,
941
+ forward_batch.positions,
942
+ num_seqs,
943
+ self.topk,
944
+ self.pool_len,
945
+ self.kv_indptr_stride,
946
+ self.kv_indptr.shape[1],
947
+ triton.next_power_of_2(num_seqs),
948
+ triton.next_power_of_2(self.speculative_num_steps),
949
+ triton.next_power_of_2(bs),
950
+ )
951
+ for i in range(self.speculative_num_steps):
952
+ forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
953
+ forward_batch.spec_info.kv_indices = self.cuda_graph_kv_indices[i][
954
+ : seq_lens_sum * self.topk + bs * (i + 1)
955
+ ]
956
+ call_fn(i, forward_batch)
957
+
958
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
959
+ def call_fn(i, forward_batch):
960
+ forward_batch.spec_info.kv_indptr = (
961
+ forward_batch.spec_info.kv_indptr.clone()
962
+ )
963
+ forward_batch.spec_info.kv_indices = (
964
+ forward_batch.spec_info.kv_indices.clone()
965
+ )
966
+ self.attn_backends[i].init_forward_metadata(forward_batch)
967
+
968
+ self.common_template(forward_batch, call_fn)
969
+
970
+ def init_cuda_graph_state(self, max_bs: int):
971
+ self.cuda_graph_kv_indices = torch.zeros(
972
+ (self.speculative_num_steps, max_bs * self.max_context_len),
973
+ dtype=torch.int32,
974
+ device="cuda",
975
+ )
976
+ self.kv_indptr_stride = self.cuda_graph_kv_indices.shape[1]
977
+ for i in range(self.speculative_num_steps):
978
+ self.attn_backends[i].init_cuda_graph_state(
979
+ max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
980
+ )
981
+
982
+ def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
983
+ def call_fn(i, forward_batch):
984
+ self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
985
+ forward_batch.batch_size,
986
+ forward_batch.batch_size * self.topk,
987
+ forward_batch.req_pool_indices,
988
+ forward_batch.seq_lens,
989
+ encoder_lens=None,
990
+ forward_mode=ForwardMode.DECODE,
991
+ spec_info=forward_batch.spec_info,
992
+ )
993
+ decode_wrapper = self.attn_backends[i].decode_cuda_graph_metadata[
994
+ forward_batch.batch_size
995
+ ][0]
996
+ decode_wrapper.begin_forward = partial(fast_decode_plan, decode_wrapper)
997
+
998
+ self.common_template(forward_batch, call_fn)
999
+
1000
+ def init_forward_metadata_replay_cuda_graph(self, forward_batch):
1001
+ def call_fn(i, forward_batch):
1002
+ self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
1003
+ forward_batch.batch_size,
1004
+ forward_batch.req_pool_indices,
1005
+ forward_batch.seq_lens,
1006
+ seq_lens_sum=-1,
1007
+ encoder_lens=None,
1008
+ forward_mode=ForwardMode.DECODE,
1009
+ spec_info=forward_batch.spec_info,
1010
+ )
1011
+
1012
+ self.common_template(forward_batch, call_fn)
1013
+
1014
+
855
1015
  @triton.jit
856
1016
  def create_flashinfer_kv_indices_triton(
857
1017
  req_to_token_ptr, # [max_batch, max_context_len]
@@ -935,3 +1095,88 @@ def should_use_tensor_core(
935
1095
  return gqa_group_size > 4
936
1096
  else:
937
1097
  return False
1098
+
1099
+
1100
+ def fast_decode_plan(
1101
+ self,
1102
+ indptr: torch.Tensor,
1103
+ indices: torch.Tensor,
1104
+ last_page_len: torch.Tensor,
1105
+ num_qo_heads: int,
1106
+ num_kv_heads: int,
1107
+ head_dim: int,
1108
+ page_size: int,
1109
+ pos_encoding_mode: str = "NONE",
1110
+ window_left: int = -1,
1111
+ logits_soft_cap: Optional[float] = None,
1112
+ data_type: Union[str, torch.dtype] = "float16",
1113
+ q_data_type: Optional[Union[str, torch.dtype]] = None,
1114
+ sm_scale: Optional[float] = None,
1115
+ rope_scale: Optional[float] = None,
1116
+ rope_theta: Optional[float] = None,
1117
+ ) -> None:
1118
+ """A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend."""
1119
+ batch_size = len(last_page_len)
1120
+ if logits_soft_cap is None:
1121
+ logits_soft_cap = 0.0
1122
+ if self.is_cuda_graph_enabled:
1123
+ if batch_size != self._fixed_batch_size:
1124
+ raise ValueError(
1125
+ "The batch size should be fixed in cudagraph mode, the runtime batch size {} "
1126
+ " mismatches the batch size set during initialization {}".format(
1127
+ batch_size, self._fixed_batch_size
1128
+ )
1129
+ )
1130
+ if len(indices) > len(self._paged_kv_indices_buf):
1131
+ raise ValueError(
1132
+ "The size of indices should be less than or equal to the allocated buffer"
1133
+ )
1134
+ else:
1135
+ self._paged_kv_indptr_buf = indptr
1136
+ self._paged_kv_indices_buf = indices
1137
+ self._paged_kv_last_page_len_buf = last_page_len
1138
+ # NOTE(Zihao): the following tensors acts as placeholder to pass dtype info
1139
+ if not q_data_type:
1140
+ q_data_type = data_type
1141
+ if not hasattr(self, "empty_q_data"):
1142
+ self.empty_q_data = torch.empty(
1143
+ 0,
1144
+ dtype=(
1145
+ getattr(torch, q_data_type)
1146
+ if isinstance(q_data_type, str)
1147
+ else q_data_type
1148
+ ),
1149
+ )
1150
+ self.empty_kv_cache = torch.empty(
1151
+ 0,
1152
+ dtype=(
1153
+ getattr(torch, data_type) if isinstance(data_type, str) else data_type
1154
+ ),
1155
+ )
1156
+ self.last_page_len = torch.ones(32768, dtype=torch.int32)
1157
+ empty_q_data = self.empty_q_data
1158
+ empty_kv_cache = self.empty_kv_cache
1159
+ stream = torch.cuda.current_stream()
1160
+ self._cached_module.plan(
1161
+ self._float_workspace_buffer,
1162
+ self._int_workspace_buffer,
1163
+ self._pin_memory_int_workspace_buffer,
1164
+ indptr.to("cpu"),
1165
+ batch_size,
1166
+ num_qo_heads,
1167
+ num_kv_heads,
1168
+ page_size,
1169
+ self.is_cuda_graph_enabled,
1170
+ window_left,
1171
+ logits_soft_cap,
1172
+ head_dim,
1173
+ empty_q_data,
1174
+ empty_kv_cache,
1175
+ stream.cuda_stream,
1176
+ )
1177
+ self._pos_encoding_mode = pos_encoding_mode
1178
+ self._window_left = window_left
1179
+ self._logits_soft_cap = logits_soft_cap
1180
+ self._sm_scale = sm_scale
1181
+ self._rope_scale = rope_scale
1182
+ self._rope_theta = rope_theta
@@ -5,6 +5,9 @@ from typing import TYPE_CHECKING, Optional
5
5
  import torch
6
6
 
7
7
  from sglang.srt.layers.attention import AttentionBackend
8
+ from sglang.srt.layers.attention.flashinfer_backend import (
9
+ create_flashinfer_kv_indices_triton,
10
+ )
8
11
  from sglang.srt.layers.dp_attention import get_attention_tp_size
9
12
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
10
13
 
@@ -29,6 +32,12 @@ class TritonAttnBackend(AttentionBackend):
29
32
  self.decode_attention_fwd = decode_attention_fwd
30
33
  self.extend_attention_fwd = extend_attention_fwd
31
34
 
35
+ max_bs = model_runner.req_to_token_pool.size
36
+ self.kv_indptr = torch.zeros(
37
+ (max_bs + 1,), dtype=torch.int32, device=model_runner.device
38
+ )
39
+ self.req_to_token = model_runner.req_to_token_pool.req_to_token
40
+
32
41
  self.num_head = (
33
42
  model_runner.model_config.num_attention_heads // get_attention_tp_size()
34
43
  )
@@ -58,11 +67,32 @@ class TritonAttnBackend(AttentionBackend):
58
67
  )
59
68
 
60
69
  max_extend_len = None
70
+
71
+ kv_indptr = self.kv_indptr
72
+ bs = len(forward_batch.req_pool_indices)
73
+ kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
74
+ kv_indptr = kv_indptr[: bs + 1]
75
+ kv_indices = torch.empty(
76
+ forward_batch.seq_lens_sum, dtype=torch.int32, device="cuda"
77
+ )
78
+ create_flashinfer_kv_indices_triton[(bs,)](
79
+ forward_batch.req_to_token_pool.req_to_token,
80
+ forward_batch.req_pool_indices,
81
+ forward_batch.seq_lens,
82
+ kv_indptr,
83
+ None,
84
+ kv_indices,
85
+ forward_batch.req_to_token_pool.req_to_token.stride(0),
86
+ )
87
+
61
88
  else:
62
89
  attn_logits = None
63
90
  max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
64
91
 
65
- self.forward_metadata = attn_logits, max_extend_len
92
+ kv_indptr = None
93
+ kv_indices = None
94
+
95
+ self.forward_metadata = attn_logits, max_extend_len, kv_indptr, kv_indices
66
96
 
67
97
  def init_cuda_graph_state(self, max_bs: int):
68
98
  self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
@@ -73,7 +103,12 @@ class TritonAttnBackend(AttentionBackend):
73
103
  self.cuda_graph_attn_logits = torch.empty(
74
104
  (max_bs, self.num_head, self.num_kv_splits, self.v_head_dim + 1),
75
105
  dtype=torch.float32,
76
- device="cuda",
106
+ device=self.device,
107
+ )
108
+ self.cuda_graph_kv_indices = torch.zeros(
109
+ (max_bs * self.cuda_graph_max_seq_len),
110
+ dtype=torch.int32,
111
+ device=self.device,
77
112
  )
78
113
 
79
114
  def init_forward_metadata_capture_cuda_graph(
@@ -90,9 +125,25 @@ class TritonAttnBackend(AttentionBackend):
90
125
  assert forward_mode.is_decode(), "Not supported"
91
126
  assert spec_info is None, "Not supported"
92
127
 
128
+ kv_indptr = self.kv_indptr
129
+ kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
130
+ kv_indptr = kv_indptr[: bs + 1]
131
+ kv_indices = self.cuda_graph_kv_indices
132
+ create_flashinfer_kv_indices_triton[(bs,)](
133
+ self.req_to_token,
134
+ req_pool_indices,
135
+ seq_lens,
136
+ kv_indptr,
137
+ None,
138
+ kv_indices,
139
+ self.req_to_token.stride(0),
140
+ )
141
+
93
142
  self.forward_metadata = (
94
143
  self.cuda_graph_attn_logits,
95
144
  None,
145
+ kv_indptr,
146
+ kv_indices,
96
147
  )
97
148
 
98
149
  def init_forward_metadata_replay_cuda_graph(
@@ -109,6 +160,20 @@ class TritonAttnBackend(AttentionBackend):
109
160
  self.cuda_graph_start_loc.zero_()
110
161
  self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
111
162
 
163
+ kv_indptr = self.kv_indptr
164
+ kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0)
165
+ kv_indptr = kv_indptr[: bs + 1]
166
+ kv_indices = self.cuda_graph_kv_indices
167
+ create_flashinfer_kv_indices_triton[(bs,)](
168
+ self.req_to_token,
169
+ req_pool_indices[:bs],
170
+ seq_lens[:bs],
171
+ kv_indptr,
172
+ None,
173
+ kv_indices,
174
+ self.req_to_token.stride(0),
175
+ )
176
+
112
177
  def get_cuda_graph_seq_len_fill_value(self):
113
178
  return 1
114
179
 
@@ -132,7 +197,7 @@ class TritonAttnBackend(AttentionBackend):
132
197
  layer, forward_batch.out_cache_loc, k, v
133
198
  )
134
199
 
135
- _, max_extend_len = self.forward_metadata
200
+ _, max_extend_len, _, _ = self.forward_metadata
136
201
  self.extend_attention_fwd(
137
202
  q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
138
203
  k.contiguous(),
@@ -170,7 +235,7 @@ class TritonAttnBackend(AttentionBackend):
170
235
  else:
171
236
  o = torch.empty_like(q)
172
237
 
173
- attn_logits, _ = self.forward_metadata
238
+ attn_logits, _, kv_indptr, kv_indices = self.forward_metadata
174
239
 
175
240
  if save_kv_cache:
176
241
  forward_batch.token_to_kv_pool.set_kv_buffer(
@@ -182,9 +247,8 @@ class TritonAttnBackend(AttentionBackend):
182
247
  forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
183
248
  forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
184
249
  o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
185
- forward_batch.req_to_token_pool.req_to_token,
186
- forward_batch.req_pool_indices,
187
- forward_batch.seq_lens,
250
+ kv_indptr,
251
+ kv_indices,
188
252
  attn_logits,
189
253
  self.num_kv_splits,
190
254
  layer.scaling,