sglang 0.4.2__py3-none-any.whl → 0.4.2.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/srt/constrained/outlines_backend.py +9 -1
- sglang/srt/custom_op.py +40 -0
- sglang/srt/entrypoints/engine.py +2 -2
- sglang/srt/layers/activation.py +10 -5
- sglang/srt/layers/attention/flashinfer_backend.py +284 -39
- sglang/srt/layers/attention/triton_backend.py +71 -7
- sglang/srt/layers/attention/triton_ops/decode_attention.py +53 -59
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +6 -0
- sglang/srt/layers/attention/vision.py +243 -40
- sglang/srt/layers/layernorm.py +1 -5
- sglang/srt/layers/moe/ep_moe/layer.py +1 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +178 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +175 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -11
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -3
- sglang/srt/layers/moe/topk.py +4 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +7 -0
- sglang/srt/layers/quantization/fp8_kernel.py +140 -2
- sglang/srt/layers/rotary_embedding.py +29 -15
- sglang/srt/layers/sampler.py +9 -6
- sglang/srt/lora/backend/__init__.py +8 -0
- sglang/srt/lora/backend/base_backend.py +95 -0
- sglang/srt/lora/backend/flashinfer_backend.py +91 -0
- sglang/srt/lora/backend/triton_backend.py +61 -0
- sglang/srt/lora/lora.py +127 -112
- sglang/srt/lora/lora_manager.py +50 -18
- sglang/srt/lora/triton_ops/__init__.py +5 -0
- sglang/srt/lora/triton_ops/qkv_lora_b.py +182 -0
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +143 -0
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +159 -0
- sglang/srt/managers/image_processor.py +77 -38
- sglang/srt/managers/scheduler.py +17 -3
- sglang/srt/mem_cache/base_prefix_cache.py +4 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -0
- sglang/srt/mem_cache/radix_cache.py +30 -1
- sglang/srt/model_executor/cuda_graph_runner.py +77 -80
- sglang/srt/model_executor/forward_batch_info.py +58 -59
- sglang/srt/model_executor/model_runner.py +2 -2
- sglang/srt/models/minicpmv.py +129 -76
- sglang/srt/models/mllama.py +16 -56
- sglang/srt/models/qwen2.py +4 -1
- sglang/srt/models/qwen2_vl.py +19 -9
- sglang/srt/server_args.py +19 -2
- sglang/srt/speculative/build_eagle_tree.py +4 -2
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +213 -0
- sglang/srt/speculative/eagle_utils.py +361 -372
- sglang/srt/speculative/eagle_worker.py +177 -45
- sglang/srt/utils.py +7 -2
- sglang/test/runners.py +2 -0
- sglang/utils.py +42 -0
- sglang/version.py +1 -1
- {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/METADATA +16 -7
- {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/RECORD +84 -45
- sglang/srt/layers/custom_op_util.py +0 -25
- {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/LICENSE +0 -0
- {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/top_level.txt +0 -0
@@ -20,7 +20,6 @@ from typing import Dict, List, Optional, Tuple, Union
|
|
20
20
|
import interegular
|
21
21
|
import torch
|
22
22
|
from outlines.fsm.guide import RegexGuide
|
23
|
-
from outlines.fsm.json_schema import build_regex_from_schema
|
24
23
|
from outlines.models.transformers import TransformerTokenizer
|
25
24
|
from pydantic import BaseModel
|
26
25
|
|
@@ -29,6 +28,15 @@ from sglang.srt.constrained.base_grammar_backend import (
|
|
29
28
|
BaseGrammarObject,
|
30
29
|
)
|
31
30
|
from sglang.srt.constrained.outlines_jump_forward import OutlinesJumpForwardMap
|
31
|
+
from sglang.srt.utils import is_hip
|
32
|
+
|
33
|
+
is_hip_ = is_hip()
|
34
|
+
|
35
|
+
if is_hip_:
|
36
|
+
from outlines_core.fsm.json_schema import build_regex_from_schema
|
37
|
+
else:
|
38
|
+
from outlines.fsm.json_schema import build_regex_from_schema
|
39
|
+
|
32
40
|
|
33
41
|
logger = logging.getLogger(__name__)
|
34
42
|
|
sglang/srt/custom_op.py
ADDED
@@ -0,0 +1,40 @@
|
|
1
|
+
import torch
|
2
|
+
from torch import nn
|
3
|
+
|
4
|
+
_is_cuda = torch.cuda.is_available() and torch.version.cuda
|
5
|
+
_is_rocm = torch.cuda.is_available() and torch.version.hip
|
6
|
+
|
7
|
+
|
8
|
+
class CustomOp(nn.Module):
|
9
|
+
def __init__(self):
|
10
|
+
super().__init__()
|
11
|
+
self._forward_method = self.dispatch_forward()
|
12
|
+
|
13
|
+
def forward(self, *args, **kwargs):
|
14
|
+
return self._forward_method(*args, **kwargs)
|
15
|
+
|
16
|
+
def forward_native(self, *args, **kwargs):
|
17
|
+
raise NotImplementedError
|
18
|
+
|
19
|
+
def forward_cuda(self, *args, **kwargs):
|
20
|
+
raise NotImplementedError
|
21
|
+
|
22
|
+
def forward_hip(self, *args, **kwargs):
|
23
|
+
return self.forward_cuda(*args, **kwargs)
|
24
|
+
|
25
|
+
def forward_xpu(self, *args, **kwargs):
|
26
|
+
return self.forward_native(*args, **kwargs)
|
27
|
+
|
28
|
+
def forward_hpu(self, *args, **kwargs):
|
29
|
+
return self.forward_native(*args, **kwargs)
|
30
|
+
|
31
|
+
def forward_cpu(self, *args, **kwargs):
|
32
|
+
return self.forward_native(*args, **kwargs)
|
33
|
+
|
34
|
+
def dispatch_forward(self):
|
35
|
+
if _is_cuda:
|
36
|
+
return self.forward_cuda
|
37
|
+
elif _is_rocm:
|
38
|
+
return self.forward_hip
|
39
|
+
else:
|
40
|
+
return self.forward_native
|
sglang/srt/entrypoints/engine.py
CHANGED
@@ -316,8 +316,8 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
316
316
|
# Check flashinfer version
|
317
317
|
if server_args.attention_backend == "flashinfer":
|
318
318
|
assert_pkg_version(
|
319
|
-
"
|
320
|
-
"0.
|
319
|
+
"flashinfer_python",
|
320
|
+
"0.2.0.post2",
|
321
321
|
"Please uninstall the old version and "
|
322
322
|
"reinstall the latest version by following the instructions "
|
323
323
|
"at https://docs.flashinfer.ai/installation.html.",
|
sglang/srt/layers/activation.py
CHANGED
@@ -25,21 +25,18 @@ from sglang.srt.utils import is_cuda_available
|
|
25
25
|
if is_cuda_available():
|
26
26
|
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
27
27
|
|
28
|
-
from
|
29
|
-
|
28
|
+
from sglang.srt.custom_op import CustomOp
|
30
29
|
from sglang.srt.distributed import (
|
31
30
|
divide,
|
32
31
|
get_tensor_model_parallel_rank,
|
33
32
|
get_tensor_model_parallel_world_size,
|
34
33
|
)
|
35
|
-
from sglang.srt.layers.custom_op_util import register_custom_op
|
36
34
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
37
35
|
from sglang.srt.utils import set_weight_attrs
|
38
36
|
|
39
37
|
logger = logging.getLogger(__name__)
|
40
38
|
|
41
39
|
|
42
|
-
@register_custom_op("sglang_silu_and_mul")
|
43
40
|
class SiluAndMul(CustomOp):
|
44
41
|
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
45
42
|
d = x.shape[-1] // 2
|
@@ -53,7 +50,6 @@ class SiluAndMul(CustomOp):
|
|
53
50
|
return out
|
54
51
|
|
55
52
|
|
56
|
-
@register_custom_op("sglang_gelu_and_mul")
|
57
53
|
class GeluAndMul(CustomOp):
|
58
54
|
def __init__(self, approximate="tanh"):
|
59
55
|
super().__init__()
|
@@ -76,6 +72,15 @@ class GeluAndMul(CustomOp):
|
|
76
72
|
return out
|
77
73
|
|
78
74
|
|
75
|
+
class QuickGELU(CustomOp):
|
76
|
+
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
77
|
+
return x * torch.sigmoid(1.702 * x)
|
78
|
+
|
79
|
+
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
80
|
+
# TODO(zhyncs): Implement the CUDA kernel for QuickGELU in sgl-kernel
|
81
|
+
return self.forward_native(x)
|
82
|
+
|
83
|
+
|
79
84
|
class ScaledActivation(nn.Module):
|
80
85
|
"""An activation function with post-scale parameters.
|
81
86
|
|
@@ -10,6 +10,7 @@ Each backend supports two operators: extend (i.e. prefill with cached prefix) an
|
|
10
10
|
import os
|
11
11
|
from dataclasses import dataclass
|
12
12
|
from enum import Enum, auto
|
13
|
+
from functools import partial
|
13
14
|
from typing import TYPE_CHECKING, List, Optional, Union
|
14
15
|
|
15
16
|
import torch
|
@@ -34,6 +35,7 @@ if is_flashinfer_available():
|
|
34
35
|
BatchPrefillWithRaggedKVCacheWrapper,
|
35
36
|
)
|
36
37
|
from flashinfer.cascade import merge_state
|
38
|
+
from flashinfer.decode import PosEncodingMode
|
37
39
|
|
38
40
|
|
39
41
|
class WrapperDispatch(Enum):
|
@@ -53,10 +55,19 @@ class PrefillMetadata:
|
|
53
55
|
extend_no_prefix: bool
|
54
56
|
|
55
57
|
|
58
|
+
# Reuse this workspace buffer across all flashinfer wrappers
|
59
|
+
global_workspace_buffer = None
|
60
|
+
|
61
|
+
|
56
62
|
class FlashInferAttnBackend(AttentionBackend):
|
57
63
|
"""Flashinfer attention kernels."""
|
58
64
|
|
59
|
-
def __init__(
|
65
|
+
def __init__(
|
66
|
+
self,
|
67
|
+
model_runner: ModelRunner,
|
68
|
+
skip_prefill: bool = False,
|
69
|
+
kv_indptr_buf: Optional[torch.Tensor] = None,
|
70
|
+
):
|
60
71
|
super().__init__()
|
61
72
|
|
62
73
|
# Parse constants
|
@@ -69,6 +80,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
69
80
|
),
|
70
81
|
)
|
71
82
|
self.max_context_len = model_runner.model_config.context_len
|
83
|
+
self.skip_prefill = skip_prefill
|
72
84
|
|
73
85
|
assert not (
|
74
86
|
model_runner.sliding_window_size is not None
|
@@ -90,16 +102,26 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
90
102
|
global_config.flashinfer_workspace_size = 512 * 1024 * 1024
|
91
103
|
|
92
104
|
# Allocate buffers
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
105
|
+
global global_workspace_buffer
|
106
|
+
if global_workspace_buffer is None:
|
107
|
+
global_workspace_buffer = torch.empty(
|
108
|
+
global_config.flashinfer_workspace_size,
|
109
|
+
dtype=torch.uint8,
|
110
|
+
device=model_runner.device,
|
111
|
+
)
|
112
|
+
self.workspace_buffer = global_workspace_buffer
|
98
113
|
max_bs = model_runner.req_to_token_pool.size
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
114
|
+
if kv_indptr_buf is None:
|
115
|
+
self.kv_indptr = [
|
116
|
+
torch.zeros(
|
117
|
+
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
118
|
+
)
|
119
|
+
for _ in range(self.num_wrappers)
|
120
|
+
]
|
121
|
+
else:
|
122
|
+
assert self.num_wrappers == 1
|
123
|
+
self.kv_indptr = [kv_indptr_buf]
|
124
|
+
|
103
125
|
self.kv_last_page_len = torch.ones(
|
104
126
|
(max_bs,), dtype=torch.int32, device=model_runner.device
|
105
127
|
)
|
@@ -122,12 +144,17 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
122
144
|
self.prefill_wrappers_verify = []
|
123
145
|
self.decode_wrappers = []
|
124
146
|
for _ in range(self.num_wrappers):
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
147
|
+
if not skip_prefill:
|
148
|
+
self.prefill_wrappers_paged.append(
|
149
|
+
BatchPrefillWithPagedKVCacheWrapper(
|
150
|
+
self.workspace_buffer,
|
151
|
+
"NHD",
|
152
|
+
backend="fa2",
|
153
|
+
)
|
154
|
+
)
|
155
|
+
self.prefill_wrappers_verify.append(
|
156
|
+
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
|
157
|
+
)
|
131
158
|
self.decode_wrappers.append(
|
132
159
|
BatchDecodeWithPagedKVCacheWrapper(
|
133
160
|
self.workspace_buffer,
|
@@ -137,10 +164,11 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
137
164
|
)
|
138
165
|
|
139
166
|
# Create indices updater
|
167
|
+
if not skip_prefill:
|
168
|
+
self.indices_updater_prefill = FlashInferIndicesUpdaterPrefill(
|
169
|
+
model_runner, self
|
170
|
+
)
|
140
171
|
self.indices_updater_decode = FlashInferIndicesUpdaterDecode(model_runner, self)
|
141
|
-
self.indices_updater_prefill = FlashInferIndicesUpdaterPrefill(
|
142
|
-
model_runner, self
|
143
|
-
)
|
144
172
|
|
145
173
|
# Other metadata
|
146
174
|
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
|
@@ -211,23 +239,30 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
211
239
|
self.prefill_wrappers_paged, use_ragged, extend_no_prefix
|
212
240
|
)
|
213
241
|
|
214
|
-
def init_cuda_graph_state(
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
242
|
+
def init_cuda_graph_state(
|
243
|
+
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
|
244
|
+
):
|
245
|
+
if kv_indices_buf is None:
|
246
|
+
cuda_graph_kv_indices = torch.zeros(
|
247
|
+
(max_bs * self.max_context_len,),
|
248
|
+
dtype=torch.int32,
|
249
|
+
device="cuda",
|
250
|
+
)
|
251
|
+
else:
|
252
|
+
cuda_graph_kv_indices = kv_indices_buf
|
253
|
+
|
220
254
|
self.cuda_graph_kv_indices = [cuda_graph_kv_indices] + [
|
221
255
|
cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1)
|
222
256
|
]
|
223
257
|
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
258
|
+
if not self.skip_prefill:
|
259
|
+
self.cuda_graph_custom_mask = torch.zeros(
|
260
|
+
(max_bs * self.max_context_len),
|
261
|
+
dtype=torch.uint8,
|
262
|
+
device="cuda",
|
263
|
+
)
|
264
|
+
self.cuda_graph_qk_indptr = [x.clone() for x in self.kv_indptr]
|
265
|
+
self.cuda_graph_qo_indptr = [x.clone() for x in self.kv_indptr]
|
231
266
|
|
232
267
|
def init_forward_metadata_capture_cuda_graph(
|
233
268
|
self,
|
@@ -279,7 +314,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
279
314
|
paged_kv_indices_buf=self.cuda_graph_kv_indices[i],
|
280
315
|
paged_kv_last_page_len_buf=self.kv_last_page_len[:bs],
|
281
316
|
custom_mask_buf=self.cuda_graph_custom_mask,
|
282
|
-
|
317
|
+
mask_indptr_buf=self.cuda_graph_qk_indptr[i][: bs + 1],
|
283
318
|
)
|
284
319
|
)
|
285
320
|
seq_lens_sum = seq_lens.sum().item()
|
@@ -602,11 +637,8 @@ class FlashInferIndicesUpdaterDecode:
|
|
602
637
|
self.req_to_token.shape[1],
|
603
638
|
)
|
604
639
|
else:
|
605
|
-
|
606
|
-
|
607
|
-
paged_kernel_lens,
|
608
|
-
self.req_to_token,
|
609
|
-
)
|
640
|
+
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
641
|
+
bs = kv_indptr.shape[0] - 1
|
610
642
|
|
611
643
|
wrapper.end_forward()
|
612
644
|
wrapper.begin_forward(
|
@@ -800,7 +832,9 @@ class FlashInferIndicesUpdaterPrefill:
|
|
800
832
|
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
801
833
|
kv_indptr = kv_indptr[: bs + 1]
|
802
834
|
kv_indices = torch.empty(
|
803
|
-
paged_kernel_lens_sum
|
835
|
+
paged_kernel_lens_sum + 256,
|
836
|
+
dtype=torch.int32,
|
837
|
+
device=req_pool_indices.device,
|
804
838
|
)
|
805
839
|
create_flashinfer_kv_indices_triton[(bs,)](
|
806
840
|
self.req_to_token,
|
@@ -852,6 +886,132 @@ class FlashInferIndicesUpdaterPrefill:
|
|
852
886
|
)
|
853
887
|
|
854
888
|
|
889
|
+
class FlashInferMultiStepDraftBackend:
|
890
|
+
"""
|
891
|
+
Wrap multiple flashinfer attention backends as one for multiple consecutive
|
892
|
+
draft decoding steps.
|
893
|
+
"""
|
894
|
+
|
895
|
+
def __init__(
|
896
|
+
self,
|
897
|
+
model_runner: ModelRunner,
|
898
|
+
topk: int,
|
899
|
+
speculative_num_steps: int,
|
900
|
+
):
|
901
|
+
from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
|
902
|
+
|
903
|
+
self.topk = topk
|
904
|
+
self.speculative_num_steps = speculative_num_steps
|
905
|
+
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
|
906
|
+
max_bs = model_runner.req_to_token_pool.size
|
907
|
+
self.kv_indptr = torch.zeros(
|
908
|
+
(
|
909
|
+
self.speculative_num_steps,
|
910
|
+
max_bs + 1,
|
911
|
+
),
|
912
|
+
dtype=torch.int32,
|
913
|
+
device=model_runner.device,
|
914
|
+
)
|
915
|
+
self.attn_backends = []
|
916
|
+
for i in range(self.speculative_num_steps):
|
917
|
+
self.attn_backends.append(
|
918
|
+
FlashInferAttnBackend(
|
919
|
+
model_runner,
|
920
|
+
skip_prefill=True,
|
921
|
+
kv_indptr_buf=self.kv_indptr[i],
|
922
|
+
)
|
923
|
+
)
|
924
|
+
self.max_context_len = self.attn_backends[0].max_context_len
|
925
|
+
# Cached variables for generate_draft_decode_kv_indices
|
926
|
+
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
|
927
|
+
self.kv_indptr_stride = self.kv_indptr.shape[1]
|
928
|
+
|
929
|
+
def common_template(self, forward_batch: ForwardBatch, call_fn: int):
|
930
|
+
num_seqs = forward_batch.batch_size
|
931
|
+
bs = self.topk * num_seqs
|
932
|
+
seq_lens_sum = forward_batch.seq_lens_sum
|
933
|
+
self.generate_draft_decode_kv_indices[
|
934
|
+
(self.speculative_num_steps, num_seqs, self.topk)
|
935
|
+
](
|
936
|
+
forward_batch.req_pool_indices,
|
937
|
+
forward_batch.req_to_token_pool.req_to_token,
|
938
|
+
forward_batch.seq_lens,
|
939
|
+
self.cuda_graph_kv_indices,
|
940
|
+
self.kv_indptr,
|
941
|
+
forward_batch.positions,
|
942
|
+
num_seqs,
|
943
|
+
self.topk,
|
944
|
+
self.pool_len,
|
945
|
+
self.kv_indptr_stride,
|
946
|
+
self.kv_indptr.shape[1],
|
947
|
+
triton.next_power_of_2(num_seqs),
|
948
|
+
triton.next_power_of_2(self.speculative_num_steps),
|
949
|
+
triton.next_power_of_2(bs),
|
950
|
+
)
|
951
|
+
for i in range(self.speculative_num_steps):
|
952
|
+
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
|
953
|
+
forward_batch.spec_info.kv_indices = self.cuda_graph_kv_indices[i][
|
954
|
+
: seq_lens_sum * self.topk + bs * (i + 1)
|
955
|
+
]
|
956
|
+
call_fn(i, forward_batch)
|
957
|
+
|
958
|
+
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
959
|
+
def call_fn(i, forward_batch):
|
960
|
+
forward_batch.spec_info.kv_indptr = (
|
961
|
+
forward_batch.spec_info.kv_indptr.clone()
|
962
|
+
)
|
963
|
+
forward_batch.spec_info.kv_indices = (
|
964
|
+
forward_batch.spec_info.kv_indices.clone()
|
965
|
+
)
|
966
|
+
self.attn_backends[i].init_forward_metadata(forward_batch)
|
967
|
+
|
968
|
+
self.common_template(forward_batch, call_fn)
|
969
|
+
|
970
|
+
def init_cuda_graph_state(self, max_bs: int):
|
971
|
+
self.cuda_graph_kv_indices = torch.zeros(
|
972
|
+
(self.speculative_num_steps, max_bs * self.max_context_len),
|
973
|
+
dtype=torch.int32,
|
974
|
+
device="cuda",
|
975
|
+
)
|
976
|
+
self.kv_indptr_stride = self.cuda_graph_kv_indices.shape[1]
|
977
|
+
for i in range(self.speculative_num_steps):
|
978
|
+
self.attn_backends[i].init_cuda_graph_state(
|
979
|
+
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
980
|
+
)
|
981
|
+
|
982
|
+
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
|
983
|
+
def call_fn(i, forward_batch):
|
984
|
+
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
|
985
|
+
forward_batch.batch_size,
|
986
|
+
forward_batch.batch_size * self.topk,
|
987
|
+
forward_batch.req_pool_indices,
|
988
|
+
forward_batch.seq_lens,
|
989
|
+
encoder_lens=None,
|
990
|
+
forward_mode=ForwardMode.DECODE,
|
991
|
+
spec_info=forward_batch.spec_info,
|
992
|
+
)
|
993
|
+
decode_wrapper = self.attn_backends[i].decode_cuda_graph_metadata[
|
994
|
+
forward_batch.batch_size
|
995
|
+
][0]
|
996
|
+
decode_wrapper.begin_forward = partial(fast_decode_plan, decode_wrapper)
|
997
|
+
|
998
|
+
self.common_template(forward_batch, call_fn)
|
999
|
+
|
1000
|
+
def init_forward_metadata_replay_cuda_graph(self, forward_batch):
|
1001
|
+
def call_fn(i, forward_batch):
|
1002
|
+
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
|
1003
|
+
forward_batch.batch_size,
|
1004
|
+
forward_batch.req_pool_indices,
|
1005
|
+
forward_batch.seq_lens,
|
1006
|
+
seq_lens_sum=-1,
|
1007
|
+
encoder_lens=None,
|
1008
|
+
forward_mode=ForwardMode.DECODE,
|
1009
|
+
spec_info=forward_batch.spec_info,
|
1010
|
+
)
|
1011
|
+
|
1012
|
+
self.common_template(forward_batch, call_fn)
|
1013
|
+
|
1014
|
+
|
855
1015
|
@triton.jit
|
856
1016
|
def create_flashinfer_kv_indices_triton(
|
857
1017
|
req_to_token_ptr, # [max_batch, max_context_len]
|
@@ -935,3 +1095,88 @@ def should_use_tensor_core(
|
|
935
1095
|
return gqa_group_size > 4
|
936
1096
|
else:
|
937
1097
|
return False
|
1098
|
+
|
1099
|
+
|
1100
|
+
def fast_decode_plan(
|
1101
|
+
self,
|
1102
|
+
indptr: torch.Tensor,
|
1103
|
+
indices: torch.Tensor,
|
1104
|
+
last_page_len: torch.Tensor,
|
1105
|
+
num_qo_heads: int,
|
1106
|
+
num_kv_heads: int,
|
1107
|
+
head_dim: int,
|
1108
|
+
page_size: int,
|
1109
|
+
pos_encoding_mode: str = "NONE",
|
1110
|
+
window_left: int = -1,
|
1111
|
+
logits_soft_cap: Optional[float] = None,
|
1112
|
+
data_type: Union[str, torch.dtype] = "float16",
|
1113
|
+
q_data_type: Optional[Union[str, torch.dtype]] = None,
|
1114
|
+
sm_scale: Optional[float] = None,
|
1115
|
+
rope_scale: Optional[float] = None,
|
1116
|
+
rope_theta: Optional[float] = None,
|
1117
|
+
) -> None:
|
1118
|
+
"""A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend."""
|
1119
|
+
batch_size = len(last_page_len)
|
1120
|
+
if logits_soft_cap is None:
|
1121
|
+
logits_soft_cap = 0.0
|
1122
|
+
if self.is_cuda_graph_enabled:
|
1123
|
+
if batch_size != self._fixed_batch_size:
|
1124
|
+
raise ValueError(
|
1125
|
+
"The batch size should be fixed in cudagraph mode, the runtime batch size {} "
|
1126
|
+
" mismatches the batch size set during initialization {}".format(
|
1127
|
+
batch_size, self._fixed_batch_size
|
1128
|
+
)
|
1129
|
+
)
|
1130
|
+
if len(indices) > len(self._paged_kv_indices_buf):
|
1131
|
+
raise ValueError(
|
1132
|
+
"The size of indices should be less than or equal to the allocated buffer"
|
1133
|
+
)
|
1134
|
+
else:
|
1135
|
+
self._paged_kv_indptr_buf = indptr
|
1136
|
+
self._paged_kv_indices_buf = indices
|
1137
|
+
self._paged_kv_last_page_len_buf = last_page_len
|
1138
|
+
# NOTE(Zihao): the following tensors acts as placeholder to pass dtype info
|
1139
|
+
if not q_data_type:
|
1140
|
+
q_data_type = data_type
|
1141
|
+
if not hasattr(self, "empty_q_data"):
|
1142
|
+
self.empty_q_data = torch.empty(
|
1143
|
+
0,
|
1144
|
+
dtype=(
|
1145
|
+
getattr(torch, q_data_type)
|
1146
|
+
if isinstance(q_data_type, str)
|
1147
|
+
else q_data_type
|
1148
|
+
),
|
1149
|
+
)
|
1150
|
+
self.empty_kv_cache = torch.empty(
|
1151
|
+
0,
|
1152
|
+
dtype=(
|
1153
|
+
getattr(torch, data_type) if isinstance(data_type, str) else data_type
|
1154
|
+
),
|
1155
|
+
)
|
1156
|
+
self.last_page_len = torch.ones(32768, dtype=torch.int32)
|
1157
|
+
empty_q_data = self.empty_q_data
|
1158
|
+
empty_kv_cache = self.empty_kv_cache
|
1159
|
+
stream = torch.cuda.current_stream()
|
1160
|
+
self._cached_module.plan(
|
1161
|
+
self._float_workspace_buffer,
|
1162
|
+
self._int_workspace_buffer,
|
1163
|
+
self._pin_memory_int_workspace_buffer,
|
1164
|
+
indptr.to("cpu"),
|
1165
|
+
batch_size,
|
1166
|
+
num_qo_heads,
|
1167
|
+
num_kv_heads,
|
1168
|
+
page_size,
|
1169
|
+
self.is_cuda_graph_enabled,
|
1170
|
+
window_left,
|
1171
|
+
logits_soft_cap,
|
1172
|
+
head_dim,
|
1173
|
+
empty_q_data,
|
1174
|
+
empty_kv_cache,
|
1175
|
+
stream.cuda_stream,
|
1176
|
+
)
|
1177
|
+
self._pos_encoding_mode = pos_encoding_mode
|
1178
|
+
self._window_left = window_left
|
1179
|
+
self._logits_soft_cap = logits_soft_cap
|
1180
|
+
self._sm_scale = sm_scale
|
1181
|
+
self._rope_scale = rope_scale
|
1182
|
+
self._rope_theta = rope_theta
|
@@ -5,6 +5,9 @@ from typing import TYPE_CHECKING, Optional
|
|
5
5
|
import torch
|
6
6
|
|
7
7
|
from sglang.srt.layers.attention import AttentionBackend
|
8
|
+
from sglang.srt.layers.attention.flashinfer_backend import (
|
9
|
+
create_flashinfer_kv_indices_triton,
|
10
|
+
)
|
8
11
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
9
12
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
10
13
|
|
@@ -29,6 +32,12 @@ class TritonAttnBackend(AttentionBackend):
|
|
29
32
|
self.decode_attention_fwd = decode_attention_fwd
|
30
33
|
self.extend_attention_fwd = extend_attention_fwd
|
31
34
|
|
35
|
+
max_bs = model_runner.req_to_token_pool.size
|
36
|
+
self.kv_indptr = torch.zeros(
|
37
|
+
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
38
|
+
)
|
39
|
+
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
40
|
+
|
32
41
|
self.num_head = (
|
33
42
|
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
34
43
|
)
|
@@ -58,11 +67,32 @@ class TritonAttnBackend(AttentionBackend):
|
|
58
67
|
)
|
59
68
|
|
60
69
|
max_extend_len = None
|
70
|
+
|
71
|
+
kv_indptr = self.kv_indptr
|
72
|
+
bs = len(forward_batch.req_pool_indices)
|
73
|
+
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
|
74
|
+
kv_indptr = kv_indptr[: bs + 1]
|
75
|
+
kv_indices = torch.empty(
|
76
|
+
forward_batch.seq_lens_sum, dtype=torch.int32, device="cuda"
|
77
|
+
)
|
78
|
+
create_flashinfer_kv_indices_triton[(bs,)](
|
79
|
+
forward_batch.req_to_token_pool.req_to_token,
|
80
|
+
forward_batch.req_pool_indices,
|
81
|
+
forward_batch.seq_lens,
|
82
|
+
kv_indptr,
|
83
|
+
None,
|
84
|
+
kv_indices,
|
85
|
+
forward_batch.req_to_token_pool.req_to_token.stride(0),
|
86
|
+
)
|
87
|
+
|
61
88
|
else:
|
62
89
|
attn_logits = None
|
63
90
|
max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
|
64
91
|
|
65
|
-
|
92
|
+
kv_indptr = None
|
93
|
+
kv_indices = None
|
94
|
+
|
95
|
+
self.forward_metadata = attn_logits, max_extend_len, kv_indptr, kv_indices
|
66
96
|
|
67
97
|
def init_cuda_graph_state(self, max_bs: int):
|
68
98
|
self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
|
@@ -73,7 +103,12 @@ class TritonAttnBackend(AttentionBackend):
|
|
73
103
|
self.cuda_graph_attn_logits = torch.empty(
|
74
104
|
(max_bs, self.num_head, self.num_kv_splits, self.v_head_dim + 1),
|
75
105
|
dtype=torch.float32,
|
76
|
-
device=
|
106
|
+
device=self.device,
|
107
|
+
)
|
108
|
+
self.cuda_graph_kv_indices = torch.zeros(
|
109
|
+
(max_bs * self.cuda_graph_max_seq_len),
|
110
|
+
dtype=torch.int32,
|
111
|
+
device=self.device,
|
77
112
|
)
|
78
113
|
|
79
114
|
def init_forward_metadata_capture_cuda_graph(
|
@@ -90,9 +125,25 @@ class TritonAttnBackend(AttentionBackend):
|
|
90
125
|
assert forward_mode.is_decode(), "Not supported"
|
91
126
|
assert spec_info is None, "Not supported"
|
92
127
|
|
128
|
+
kv_indptr = self.kv_indptr
|
129
|
+
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
|
130
|
+
kv_indptr = kv_indptr[: bs + 1]
|
131
|
+
kv_indices = self.cuda_graph_kv_indices
|
132
|
+
create_flashinfer_kv_indices_triton[(bs,)](
|
133
|
+
self.req_to_token,
|
134
|
+
req_pool_indices,
|
135
|
+
seq_lens,
|
136
|
+
kv_indptr,
|
137
|
+
None,
|
138
|
+
kv_indices,
|
139
|
+
self.req_to_token.stride(0),
|
140
|
+
)
|
141
|
+
|
93
142
|
self.forward_metadata = (
|
94
143
|
self.cuda_graph_attn_logits,
|
95
144
|
None,
|
145
|
+
kv_indptr,
|
146
|
+
kv_indices,
|
96
147
|
)
|
97
148
|
|
98
149
|
def init_forward_metadata_replay_cuda_graph(
|
@@ -109,6 +160,20 @@ class TritonAttnBackend(AttentionBackend):
|
|
109
160
|
self.cuda_graph_start_loc.zero_()
|
110
161
|
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
|
111
162
|
|
163
|
+
kv_indptr = self.kv_indptr
|
164
|
+
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0)
|
165
|
+
kv_indptr = kv_indptr[: bs + 1]
|
166
|
+
kv_indices = self.cuda_graph_kv_indices
|
167
|
+
create_flashinfer_kv_indices_triton[(bs,)](
|
168
|
+
self.req_to_token,
|
169
|
+
req_pool_indices[:bs],
|
170
|
+
seq_lens[:bs],
|
171
|
+
kv_indptr,
|
172
|
+
None,
|
173
|
+
kv_indices,
|
174
|
+
self.req_to_token.stride(0),
|
175
|
+
)
|
176
|
+
|
112
177
|
def get_cuda_graph_seq_len_fill_value(self):
|
113
178
|
return 1
|
114
179
|
|
@@ -132,7 +197,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
132
197
|
layer, forward_batch.out_cache_loc, k, v
|
133
198
|
)
|
134
199
|
|
135
|
-
_, max_extend_len = self.forward_metadata
|
200
|
+
_, max_extend_len, _, _ = self.forward_metadata
|
136
201
|
self.extend_attention_fwd(
|
137
202
|
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
138
203
|
k.contiguous(),
|
@@ -170,7 +235,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
170
235
|
else:
|
171
236
|
o = torch.empty_like(q)
|
172
237
|
|
173
|
-
attn_logits, _ = self.forward_metadata
|
238
|
+
attn_logits, _, kv_indptr, kv_indices = self.forward_metadata
|
174
239
|
|
175
240
|
if save_kv_cache:
|
176
241
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
@@ -182,9 +247,8 @@ class TritonAttnBackend(AttentionBackend):
|
|
182
247
|
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
183
248
|
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
184
249
|
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
185
|
-
|
186
|
-
|
187
|
-
forward_batch.seq_lens,
|
250
|
+
kv_indptr,
|
251
|
+
kv_indices,
|
188
252
|
attn_logits,
|
189
253
|
self.num_kv_splits,
|
190
254
|
layer.scaling,
|