sglang 0.4.2.post1__py3-none-any.whl → 0.4.2.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/srt/constrained/outlines_backend.py +9 -1
- sglang/srt/custom_op.py +40 -0
- sglang/srt/entrypoints/engine.py +2 -2
- sglang/srt/function_call_parser.py +96 -69
- sglang/srt/layers/activation.py +10 -5
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -3
- sglang/srt/layers/attention/flashinfer_backend.py +284 -39
- sglang/srt/layers/attention/triton_backend.py +124 -12
- sglang/srt/layers/attention/triton_ops/decode_attention.py +53 -59
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +337 -3
- sglang/srt/layers/attention/triton_ops/extend_attention.py +70 -42
- sglang/srt/layers/layernorm.py +1 -5
- sglang/srt/layers/moe/ep_moe/layer.py +1 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +178 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +175 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -13
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -3
- sglang/srt/layers/moe/topk.py +4 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8_kernel.py +173 -2
- sglang/srt/layers/rotary_embedding.py +1 -3
- sglang/srt/layers/sampler.py +4 -4
- sglang/srt/lora/backend/__init__.py +8 -0
- sglang/srt/lora/backend/base_backend.py +95 -0
- sglang/srt/lora/backend/flashinfer_backend.py +91 -0
- sglang/srt/lora/backend/triton_backend.py +61 -0
- sglang/srt/lora/lora.py +127 -112
- sglang/srt/lora/lora_manager.py +50 -18
- sglang/srt/lora/triton_ops/__init__.py +5 -0
- sglang/srt/lora/triton_ops/qkv_lora_b.py +182 -0
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +143 -0
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +159 -0
- sglang/srt/model_executor/cuda_graph_runner.py +77 -80
- sglang/srt/model_executor/forward_batch_info.py +58 -59
- sglang/srt/model_executor/model_runner.py +2 -2
- sglang/srt/models/llama.py +8 -3
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/server_args.py +13 -2
- sglang/srt/speculative/build_eagle_tree.py +486 -104
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +213 -0
- sglang/srt/speculative/eagle_utils.py +420 -401
- sglang/srt/speculative/eagle_worker.py +177 -45
- sglang/srt/utils.py +7 -0
- sglang/test/runners.py +2 -0
- sglang/version.py +1 -1
- {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/METADATA +15 -6
- {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/RECORD +77 -38
- sglang/srt/layers/custom_op_util.py +0 -25
- {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/LICENSE +0 -0
- {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/WHEEL +0 -0
- {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/top_level.txt +0 -0
@@ -10,6 +10,7 @@ Each backend supports two operators: extend (i.e. prefill with cached prefix) an
|
|
10
10
|
import os
|
11
11
|
from dataclasses import dataclass
|
12
12
|
from enum import Enum, auto
|
13
|
+
from functools import partial
|
13
14
|
from typing import TYPE_CHECKING, List, Optional, Union
|
14
15
|
|
15
16
|
import torch
|
@@ -34,6 +35,7 @@ if is_flashinfer_available():
|
|
34
35
|
BatchPrefillWithRaggedKVCacheWrapper,
|
35
36
|
)
|
36
37
|
from flashinfer.cascade import merge_state
|
38
|
+
from flashinfer.decode import PosEncodingMode
|
37
39
|
|
38
40
|
|
39
41
|
class WrapperDispatch(Enum):
|
@@ -53,10 +55,19 @@ class PrefillMetadata:
|
|
53
55
|
extend_no_prefix: bool
|
54
56
|
|
55
57
|
|
58
|
+
# Reuse this workspace buffer across all flashinfer wrappers
|
59
|
+
global_workspace_buffer = None
|
60
|
+
|
61
|
+
|
56
62
|
class FlashInferAttnBackend(AttentionBackend):
|
57
63
|
"""Flashinfer attention kernels."""
|
58
64
|
|
59
|
-
def __init__(
|
65
|
+
def __init__(
|
66
|
+
self,
|
67
|
+
model_runner: ModelRunner,
|
68
|
+
skip_prefill: bool = False,
|
69
|
+
kv_indptr_buf: Optional[torch.Tensor] = None,
|
70
|
+
):
|
60
71
|
super().__init__()
|
61
72
|
|
62
73
|
# Parse constants
|
@@ -69,6 +80,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
69
80
|
),
|
70
81
|
)
|
71
82
|
self.max_context_len = model_runner.model_config.context_len
|
83
|
+
self.skip_prefill = skip_prefill
|
72
84
|
|
73
85
|
assert not (
|
74
86
|
model_runner.sliding_window_size is not None
|
@@ -90,16 +102,26 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
90
102
|
global_config.flashinfer_workspace_size = 512 * 1024 * 1024
|
91
103
|
|
92
104
|
# Allocate buffers
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
105
|
+
global global_workspace_buffer
|
106
|
+
if global_workspace_buffer is None:
|
107
|
+
global_workspace_buffer = torch.empty(
|
108
|
+
global_config.flashinfer_workspace_size,
|
109
|
+
dtype=torch.uint8,
|
110
|
+
device=model_runner.device,
|
111
|
+
)
|
112
|
+
self.workspace_buffer = global_workspace_buffer
|
98
113
|
max_bs = model_runner.req_to_token_pool.size
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
114
|
+
if kv_indptr_buf is None:
|
115
|
+
self.kv_indptr = [
|
116
|
+
torch.zeros(
|
117
|
+
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
118
|
+
)
|
119
|
+
for _ in range(self.num_wrappers)
|
120
|
+
]
|
121
|
+
else:
|
122
|
+
assert self.num_wrappers == 1
|
123
|
+
self.kv_indptr = [kv_indptr_buf]
|
124
|
+
|
103
125
|
self.kv_last_page_len = torch.ones(
|
104
126
|
(max_bs,), dtype=torch.int32, device=model_runner.device
|
105
127
|
)
|
@@ -122,12 +144,17 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
122
144
|
self.prefill_wrappers_verify = []
|
123
145
|
self.decode_wrappers = []
|
124
146
|
for _ in range(self.num_wrappers):
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
147
|
+
if not skip_prefill:
|
148
|
+
self.prefill_wrappers_paged.append(
|
149
|
+
BatchPrefillWithPagedKVCacheWrapper(
|
150
|
+
self.workspace_buffer,
|
151
|
+
"NHD",
|
152
|
+
backend="fa2",
|
153
|
+
)
|
154
|
+
)
|
155
|
+
self.prefill_wrappers_verify.append(
|
156
|
+
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
|
157
|
+
)
|
131
158
|
self.decode_wrappers.append(
|
132
159
|
BatchDecodeWithPagedKVCacheWrapper(
|
133
160
|
self.workspace_buffer,
|
@@ -137,10 +164,11 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
137
164
|
)
|
138
165
|
|
139
166
|
# Create indices updater
|
167
|
+
if not skip_prefill:
|
168
|
+
self.indices_updater_prefill = FlashInferIndicesUpdaterPrefill(
|
169
|
+
model_runner, self
|
170
|
+
)
|
140
171
|
self.indices_updater_decode = FlashInferIndicesUpdaterDecode(model_runner, self)
|
141
|
-
self.indices_updater_prefill = FlashInferIndicesUpdaterPrefill(
|
142
|
-
model_runner, self
|
143
|
-
)
|
144
172
|
|
145
173
|
# Other metadata
|
146
174
|
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
|
@@ -211,23 +239,30 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
211
239
|
self.prefill_wrappers_paged, use_ragged, extend_no_prefix
|
212
240
|
)
|
213
241
|
|
214
|
-
def init_cuda_graph_state(
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
242
|
+
def init_cuda_graph_state(
|
243
|
+
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
|
244
|
+
):
|
245
|
+
if kv_indices_buf is None:
|
246
|
+
cuda_graph_kv_indices = torch.zeros(
|
247
|
+
(max_bs * self.max_context_len,),
|
248
|
+
dtype=torch.int32,
|
249
|
+
device="cuda",
|
250
|
+
)
|
251
|
+
else:
|
252
|
+
cuda_graph_kv_indices = kv_indices_buf
|
253
|
+
|
220
254
|
self.cuda_graph_kv_indices = [cuda_graph_kv_indices] + [
|
221
255
|
cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1)
|
222
256
|
]
|
223
257
|
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
258
|
+
if not self.skip_prefill:
|
259
|
+
self.cuda_graph_custom_mask = torch.zeros(
|
260
|
+
(max_bs * self.max_context_len),
|
261
|
+
dtype=torch.uint8,
|
262
|
+
device="cuda",
|
263
|
+
)
|
264
|
+
self.cuda_graph_qk_indptr = [x.clone() for x in self.kv_indptr]
|
265
|
+
self.cuda_graph_qo_indptr = [x.clone() for x in self.kv_indptr]
|
231
266
|
|
232
267
|
def init_forward_metadata_capture_cuda_graph(
|
233
268
|
self,
|
@@ -279,7 +314,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
279
314
|
paged_kv_indices_buf=self.cuda_graph_kv_indices[i],
|
280
315
|
paged_kv_last_page_len_buf=self.kv_last_page_len[:bs],
|
281
316
|
custom_mask_buf=self.cuda_graph_custom_mask,
|
282
|
-
|
317
|
+
mask_indptr_buf=self.cuda_graph_qk_indptr[i][: bs + 1],
|
283
318
|
)
|
284
319
|
)
|
285
320
|
seq_lens_sum = seq_lens.sum().item()
|
@@ -602,11 +637,8 @@ class FlashInferIndicesUpdaterDecode:
|
|
602
637
|
self.req_to_token.shape[1],
|
603
638
|
)
|
604
639
|
else:
|
605
|
-
|
606
|
-
|
607
|
-
paged_kernel_lens,
|
608
|
-
self.req_to_token,
|
609
|
-
)
|
640
|
+
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
641
|
+
bs = kv_indptr.shape[0] - 1
|
610
642
|
|
611
643
|
wrapper.end_forward()
|
612
644
|
wrapper.begin_forward(
|
@@ -800,7 +832,9 @@ class FlashInferIndicesUpdaterPrefill:
|
|
800
832
|
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
801
833
|
kv_indptr = kv_indptr[: bs + 1]
|
802
834
|
kv_indices = torch.empty(
|
803
|
-
paged_kernel_lens_sum
|
835
|
+
paged_kernel_lens_sum + 256,
|
836
|
+
dtype=torch.int32,
|
837
|
+
device=req_pool_indices.device,
|
804
838
|
)
|
805
839
|
create_flashinfer_kv_indices_triton[(bs,)](
|
806
840
|
self.req_to_token,
|
@@ -852,6 +886,132 @@ class FlashInferIndicesUpdaterPrefill:
|
|
852
886
|
)
|
853
887
|
|
854
888
|
|
889
|
+
class FlashInferMultiStepDraftBackend:
|
890
|
+
"""
|
891
|
+
Wrap multiple flashinfer attention backends as one for multiple consecutive
|
892
|
+
draft decoding steps.
|
893
|
+
"""
|
894
|
+
|
895
|
+
def __init__(
|
896
|
+
self,
|
897
|
+
model_runner: ModelRunner,
|
898
|
+
topk: int,
|
899
|
+
speculative_num_steps: int,
|
900
|
+
):
|
901
|
+
from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
|
902
|
+
|
903
|
+
self.topk = topk
|
904
|
+
self.speculative_num_steps = speculative_num_steps
|
905
|
+
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
|
906
|
+
max_bs = model_runner.req_to_token_pool.size
|
907
|
+
self.kv_indptr = torch.zeros(
|
908
|
+
(
|
909
|
+
self.speculative_num_steps,
|
910
|
+
max_bs + 1,
|
911
|
+
),
|
912
|
+
dtype=torch.int32,
|
913
|
+
device=model_runner.device,
|
914
|
+
)
|
915
|
+
self.attn_backends = []
|
916
|
+
for i in range(self.speculative_num_steps):
|
917
|
+
self.attn_backends.append(
|
918
|
+
FlashInferAttnBackend(
|
919
|
+
model_runner,
|
920
|
+
skip_prefill=True,
|
921
|
+
kv_indptr_buf=self.kv_indptr[i],
|
922
|
+
)
|
923
|
+
)
|
924
|
+
self.max_context_len = self.attn_backends[0].max_context_len
|
925
|
+
# Cached variables for generate_draft_decode_kv_indices
|
926
|
+
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
|
927
|
+
self.kv_indptr_stride = self.kv_indptr.shape[1]
|
928
|
+
|
929
|
+
def common_template(self, forward_batch: ForwardBatch, call_fn: int):
|
930
|
+
num_seqs = forward_batch.batch_size
|
931
|
+
bs = self.topk * num_seqs
|
932
|
+
seq_lens_sum = forward_batch.seq_lens_sum
|
933
|
+
self.generate_draft_decode_kv_indices[
|
934
|
+
(self.speculative_num_steps, num_seqs, self.topk)
|
935
|
+
](
|
936
|
+
forward_batch.req_pool_indices,
|
937
|
+
forward_batch.req_to_token_pool.req_to_token,
|
938
|
+
forward_batch.seq_lens,
|
939
|
+
self.cuda_graph_kv_indices,
|
940
|
+
self.kv_indptr,
|
941
|
+
forward_batch.positions,
|
942
|
+
num_seqs,
|
943
|
+
self.topk,
|
944
|
+
self.pool_len,
|
945
|
+
self.kv_indptr_stride,
|
946
|
+
self.kv_indptr.shape[1],
|
947
|
+
triton.next_power_of_2(num_seqs),
|
948
|
+
triton.next_power_of_2(self.speculative_num_steps),
|
949
|
+
triton.next_power_of_2(bs),
|
950
|
+
)
|
951
|
+
for i in range(self.speculative_num_steps):
|
952
|
+
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
|
953
|
+
forward_batch.spec_info.kv_indices = self.cuda_graph_kv_indices[i][
|
954
|
+
: seq_lens_sum * self.topk + bs * (i + 1)
|
955
|
+
]
|
956
|
+
call_fn(i, forward_batch)
|
957
|
+
|
958
|
+
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
959
|
+
def call_fn(i, forward_batch):
|
960
|
+
forward_batch.spec_info.kv_indptr = (
|
961
|
+
forward_batch.spec_info.kv_indptr.clone()
|
962
|
+
)
|
963
|
+
forward_batch.spec_info.kv_indices = (
|
964
|
+
forward_batch.spec_info.kv_indices.clone()
|
965
|
+
)
|
966
|
+
self.attn_backends[i].init_forward_metadata(forward_batch)
|
967
|
+
|
968
|
+
self.common_template(forward_batch, call_fn)
|
969
|
+
|
970
|
+
def init_cuda_graph_state(self, max_bs: int):
|
971
|
+
self.cuda_graph_kv_indices = torch.zeros(
|
972
|
+
(self.speculative_num_steps, max_bs * self.max_context_len),
|
973
|
+
dtype=torch.int32,
|
974
|
+
device="cuda",
|
975
|
+
)
|
976
|
+
self.kv_indptr_stride = self.cuda_graph_kv_indices.shape[1]
|
977
|
+
for i in range(self.speculative_num_steps):
|
978
|
+
self.attn_backends[i].init_cuda_graph_state(
|
979
|
+
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
980
|
+
)
|
981
|
+
|
982
|
+
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
|
983
|
+
def call_fn(i, forward_batch):
|
984
|
+
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
|
985
|
+
forward_batch.batch_size,
|
986
|
+
forward_batch.batch_size * self.topk,
|
987
|
+
forward_batch.req_pool_indices,
|
988
|
+
forward_batch.seq_lens,
|
989
|
+
encoder_lens=None,
|
990
|
+
forward_mode=ForwardMode.DECODE,
|
991
|
+
spec_info=forward_batch.spec_info,
|
992
|
+
)
|
993
|
+
decode_wrapper = self.attn_backends[i].decode_cuda_graph_metadata[
|
994
|
+
forward_batch.batch_size
|
995
|
+
][0]
|
996
|
+
decode_wrapper.begin_forward = partial(fast_decode_plan, decode_wrapper)
|
997
|
+
|
998
|
+
self.common_template(forward_batch, call_fn)
|
999
|
+
|
1000
|
+
def init_forward_metadata_replay_cuda_graph(self, forward_batch):
|
1001
|
+
def call_fn(i, forward_batch):
|
1002
|
+
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
|
1003
|
+
forward_batch.batch_size,
|
1004
|
+
forward_batch.req_pool_indices,
|
1005
|
+
forward_batch.seq_lens,
|
1006
|
+
seq_lens_sum=-1,
|
1007
|
+
encoder_lens=None,
|
1008
|
+
forward_mode=ForwardMode.DECODE,
|
1009
|
+
spec_info=forward_batch.spec_info,
|
1010
|
+
)
|
1011
|
+
|
1012
|
+
self.common_template(forward_batch, call_fn)
|
1013
|
+
|
1014
|
+
|
855
1015
|
@triton.jit
|
856
1016
|
def create_flashinfer_kv_indices_triton(
|
857
1017
|
req_to_token_ptr, # [max_batch, max_context_len]
|
@@ -935,3 +1095,88 @@ def should_use_tensor_core(
|
|
935
1095
|
return gqa_group_size > 4
|
936
1096
|
else:
|
937
1097
|
return False
|
1098
|
+
|
1099
|
+
|
1100
|
+
def fast_decode_plan(
|
1101
|
+
self,
|
1102
|
+
indptr: torch.Tensor,
|
1103
|
+
indices: torch.Tensor,
|
1104
|
+
last_page_len: torch.Tensor,
|
1105
|
+
num_qo_heads: int,
|
1106
|
+
num_kv_heads: int,
|
1107
|
+
head_dim: int,
|
1108
|
+
page_size: int,
|
1109
|
+
pos_encoding_mode: str = "NONE",
|
1110
|
+
window_left: int = -1,
|
1111
|
+
logits_soft_cap: Optional[float] = None,
|
1112
|
+
data_type: Union[str, torch.dtype] = "float16",
|
1113
|
+
q_data_type: Optional[Union[str, torch.dtype]] = None,
|
1114
|
+
sm_scale: Optional[float] = None,
|
1115
|
+
rope_scale: Optional[float] = None,
|
1116
|
+
rope_theta: Optional[float] = None,
|
1117
|
+
) -> None:
|
1118
|
+
"""A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend."""
|
1119
|
+
batch_size = len(last_page_len)
|
1120
|
+
if logits_soft_cap is None:
|
1121
|
+
logits_soft_cap = 0.0
|
1122
|
+
if self.is_cuda_graph_enabled:
|
1123
|
+
if batch_size != self._fixed_batch_size:
|
1124
|
+
raise ValueError(
|
1125
|
+
"The batch size should be fixed in cudagraph mode, the runtime batch size {} "
|
1126
|
+
" mismatches the batch size set during initialization {}".format(
|
1127
|
+
batch_size, self._fixed_batch_size
|
1128
|
+
)
|
1129
|
+
)
|
1130
|
+
if len(indices) > len(self._paged_kv_indices_buf):
|
1131
|
+
raise ValueError(
|
1132
|
+
"The size of indices should be less than or equal to the allocated buffer"
|
1133
|
+
)
|
1134
|
+
else:
|
1135
|
+
self._paged_kv_indptr_buf = indptr
|
1136
|
+
self._paged_kv_indices_buf = indices
|
1137
|
+
self._paged_kv_last_page_len_buf = last_page_len
|
1138
|
+
# NOTE(Zihao): the following tensors acts as placeholder to pass dtype info
|
1139
|
+
if not q_data_type:
|
1140
|
+
q_data_type = data_type
|
1141
|
+
if not hasattr(self, "empty_q_data"):
|
1142
|
+
self.empty_q_data = torch.empty(
|
1143
|
+
0,
|
1144
|
+
dtype=(
|
1145
|
+
getattr(torch, q_data_type)
|
1146
|
+
if isinstance(q_data_type, str)
|
1147
|
+
else q_data_type
|
1148
|
+
),
|
1149
|
+
)
|
1150
|
+
self.empty_kv_cache = torch.empty(
|
1151
|
+
0,
|
1152
|
+
dtype=(
|
1153
|
+
getattr(torch, data_type) if isinstance(data_type, str) else data_type
|
1154
|
+
),
|
1155
|
+
)
|
1156
|
+
self.last_page_len = torch.ones(32768, dtype=torch.int32)
|
1157
|
+
empty_q_data = self.empty_q_data
|
1158
|
+
empty_kv_cache = self.empty_kv_cache
|
1159
|
+
stream = torch.cuda.current_stream()
|
1160
|
+
self._cached_module.plan(
|
1161
|
+
self._float_workspace_buffer,
|
1162
|
+
self._int_workspace_buffer,
|
1163
|
+
self._pin_memory_int_workspace_buffer,
|
1164
|
+
indptr.to("cpu"),
|
1165
|
+
batch_size,
|
1166
|
+
num_qo_heads,
|
1167
|
+
num_kv_heads,
|
1168
|
+
page_size,
|
1169
|
+
self.is_cuda_graph_enabled,
|
1170
|
+
window_left,
|
1171
|
+
logits_soft_cap,
|
1172
|
+
head_dim,
|
1173
|
+
empty_q_data,
|
1174
|
+
empty_kv_cache,
|
1175
|
+
stream.cuda_stream,
|
1176
|
+
)
|
1177
|
+
self._pos_encoding_mode = pos_encoding_mode
|
1178
|
+
self._window_left = window_left
|
1179
|
+
self._logits_soft_cap = logits_soft_cap
|
1180
|
+
self._sm_scale = sm_scale
|
1181
|
+
self._rope_scale = rope_scale
|
1182
|
+
self._rope_theta = rope_theta
|
@@ -5,6 +5,9 @@ from typing import TYPE_CHECKING, Optional
|
|
5
5
|
import torch
|
6
6
|
|
7
7
|
from sglang.srt.layers.attention import AttentionBackend
|
8
|
+
from sglang.srt.layers.attention.flashinfer_backend import (
|
9
|
+
create_flashinfer_kv_indices_triton,
|
10
|
+
)
|
8
11
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
9
12
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
10
13
|
|
@@ -29,6 +32,15 @@ class TritonAttnBackend(AttentionBackend):
|
|
29
32
|
self.decode_attention_fwd = decode_attention_fwd
|
30
33
|
self.extend_attention_fwd = extend_attention_fwd
|
31
34
|
|
35
|
+
max_bs = model_runner.req_to_token_pool.size
|
36
|
+
self.kv_indptr = torch.zeros(
|
37
|
+
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
38
|
+
)
|
39
|
+
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
40
|
+
self.qo_indptr = torch.zeros(
|
41
|
+
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
42
|
+
)
|
43
|
+
|
32
44
|
self.num_head = (
|
33
45
|
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
34
46
|
)
|
@@ -45,6 +57,9 @@ class TritonAttnBackend(AttentionBackend):
|
|
45
57
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
46
58
|
"""Init auxiliary variables for triton attention backend."""
|
47
59
|
|
60
|
+
bs = forward_batch.batch_size
|
61
|
+
kv_indptr = self.kv_indptr
|
62
|
+
|
48
63
|
if forward_batch.forward_mode.is_decode():
|
49
64
|
attn_logits = torch.empty(
|
50
65
|
(
|
@@ -58,11 +73,63 @@ class TritonAttnBackend(AttentionBackend):
|
|
58
73
|
)
|
59
74
|
|
60
75
|
max_extend_len = None
|
76
|
+
|
77
|
+
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
|
78
|
+
kv_indptr = kv_indptr[: bs + 1]
|
79
|
+
kv_indices = torch.empty(
|
80
|
+
forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
|
81
|
+
)
|
82
|
+
create_flashinfer_kv_indices_triton[(bs,)](
|
83
|
+
self.req_to_token,
|
84
|
+
forward_batch.req_pool_indices,
|
85
|
+
forward_batch.seq_lens,
|
86
|
+
kv_indptr,
|
87
|
+
None,
|
88
|
+
kv_indices,
|
89
|
+
self.req_to_token.stride(0),
|
90
|
+
)
|
91
|
+
|
92
|
+
qo_indptr = None
|
93
|
+
custom_mask = None
|
94
|
+
mask_offsets = None
|
61
95
|
else:
|
96
|
+
kv_indptr[1 : bs + 1] = torch.cumsum(
|
97
|
+
forward_batch.extend_prefix_lens, dim=0
|
98
|
+
)
|
99
|
+
kv_indptr = kv_indptr[: bs + 1]
|
100
|
+
kv_indices = torch.empty(
|
101
|
+
forward_batch.extend_prefix_lens.sum().item(),
|
102
|
+
dtype=torch.int32,
|
103
|
+
device=self.device,
|
104
|
+
)
|
105
|
+
create_flashinfer_kv_indices_triton[(bs,)](
|
106
|
+
self.req_to_token,
|
107
|
+
forward_batch.req_pool_indices,
|
108
|
+
forward_batch.extend_prefix_lens,
|
109
|
+
kv_indptr,
|
110
|
+
None,
|
111
|
+
kv_indices,
|
112
|
+
self.req_to_token.stride(0),
|
113
|
+
)
|
114
|
+
|
115
|
+
qo_indptr = self.qo_indptr
|
116
|
+
qo_indptr[1 : bs + 1] = torch.cumsum(forward_batch.extend_seq_lens, dim=0)
|
117
|
+
qo_indptr = qo_indptr[: bs + 1]
|
118
|
+
custom_mask = None
|
119
|
+
mask_offsets = None
|
120
|
+
|
62
121
|
attn_logits = None
|
63
122
|
max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
|
64
123
|
|
65
|
-
self.forward_metadata =
|
124
|
+
self.forward_metadata = (
|
125
|
+
attn_logits,
|
126
|
+
max_extend_len,
|
127
|
+
kv_indptr,
|
128
|
+
kv_indices,
|
129
|
+
qo_indptr,
|
130
|
+
custom_mask,
|
131
|
+
mask_offsets,
|
132
|
+
)
|
66
133
|
|
67
134
|
def init_cuda_graph_state(self, max_bs: int):
|
68
135
|
self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
|
@@ -73,7 +140,12 @@ class TritonAttnBackend(AttentionBackend):
|
|
73
140
|
self.cuda_graph_attn_logits = torch.empty(
|
74
141
|
(max_bs, self.num_head, self.num_kv_splits, self.v_head_dim + 1),
|
75
142
|
dtype=torch.float32,
|
76
|
-
device=
|
143
|
+
device=self.device,
|
144
|
+
)
|
145
|
+
self.cuda_graph_kv_indices = torch.zeros(
|
146
|
+
(max_bs * self.cuda_graph_max_seq_len),
|
147
|
+
dtype=torch.int32,
|
148
|
+
device=self.device,
|
77
149
|
)
|
78
150
|
|
79
151
|
def init_forward_metadata_capture_cuda_graph(
|
@@ -90,9 +162,28 @@ class TritonAttnBackend(AttentionBackend):
|
|
90
162
|
assert forward_mode.is_decode(), "Not supported"
|
91
163
|
assert spec_info is None, "Not supported"
|
92
164
|
|
165
|
+
kv_indptr = self.kv_indptr
|
166
|
+
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
|
167
|
+
kv_indptr = kv_indptr[: bs + 1]
|
168
|
+
kv_indices = self.cuda_graph_kv_indices
|
169
|
+
create_flashinfer_kv_indices_triton[(bs,)](
|
170
|
+
self.req_to_token,
|
171
|
+
req_pool_indices,
|
172
|
+
seq_lens,
|
173
|
+
kv_indptr,
|
174
|
+
None,
|
175
|
+
kv_indices,
|
176
|
+
self.req_to_token.stride(0),
|
177
|
+
)
|
178
|
+
|
93
179
|
self.forward_metadata = (
|
94
180
|
self.cuda_graph_attn_logits,
|
95
181
|
None,
|
182
|
+
kv_indptr,
|
183
|
+
kv_indices,
|
184
|
+
None,
|
185
|
+
None,
|
186
|
+
None,
|
96
187
|
)
|
97
188
|
|
98
189
|
def init_forward_metadata_replay_cuda_graph(
|
@@ -109,6 +200,20 @@ class TritonAttnBackend(AttentionBackend):
|
|
109
200
|
self.cuda_graph_start_loc.zero_()
|
110
201
|
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
|
111
202
|
|
203
|
+
kv_indptr = self.kv_indptr
|
204
|
+
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0)
|
205
|
+
kv_indptr = kv_indptr[: bs + 1]
|
206
|
+
kv_indices = self.cuda_graph_kv_indices
|
207
|
+
create_flashinfer_kv_indices_triton[(bs,)](
|
208
|
+
self.req_to_token,
|
209
|
+
req_pool_indices[:bs],
|
210
|
+
seq_lens[:bs],
|
211
|
+
kv_indptr,
|
212
|
+
None,
|
213
|
+
kv_indices,
|
214
|
+
self.req_to_token.stride(0),
|
215
|
+
)
|
216
|
+
|
112
217
|
def get_cuda_graph_seq_len_fill_value(self):
|
113
218
|
return 1
|
114
219
|
|
@@ -132,7 +237,15 @@ class TritonAttnBackend(AttentionBackend):
|
|
132
237
|
layer, forward_batch.out_cache_loc, k, v
|
133
238
|
)
|
134
239
|
|
135
|
-
|
240
|
+
(
|
241
|
+
_,
|
242
|
+
max_extend_len,
|
243
|
+
kv_indptr,
|
244
|
+
kv_indices,
|
245
|
+
qo_indptr,
|
246
|
+
custom_mask,
|
247
|
+
mask_offsets,
|
248
|
+
) = self.forward_metadata
|
136
249
|
self.extend_attention_fwd(
|
137
250
|
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
138
251
|
k.contiguous(),
|
@@ -140,11 +253,11 @@ class TritonAttnBackend(AttentionBackend):
|
|
140
253
|
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
141
254
|
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
142
255
|
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
256
|
+
qo_indptr,
|
257
|
+
kv_indptr,
|
258
|
+
kv_indices,
|
259
|
+
custom_mask,
|
260
|
+
mask_offsets,
|
148
261
|
max_extend_len,
|
149
262
|
layer.scaling,
|
150
263
|
layer.logit_cap,
|
@@ -170,7 +283,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
170
283
|
else:
|
171
284
|
o = torch.empty_like(q)
|
172
285
|
|
173
|
-
attn_logits, _ = self.forward_metadata
|
286
|
+
attn_logits, _, kv_indptr, kv_indices, _, _, _ = self.forward_metadata
|
174
287
|
|
175
288
|
if save_kv_cache:
|
176
289
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
@@ -182,9 +295,8 @@ class TritonAttnBackend(AttentionBackend):
|
|
182
295
|
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
183
296
|
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
184
297
|
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
185
|
-
|
186
|
-
|
187
|
-
forward_batch.seq_lens,
|
298
|
+
kv_indptr,
|
299
|
+
kv_indices,
|
188
300
|
attn_logits,
|
189
301
|
self.num_kv_splits,
|
190
302
|
layer.scaling,
|