sglang 0.4.3.post3__py3-none-any.whl → 0.4.3.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +94 -48
- sglang/srt/layers/attention/triton_backend.py +4 -2
- sglang/srt/managers/io_struct.py +1 -0
- sglang/srt/managers/scheduler.py +144 -127
- sglang/srt/managers/tokenizer_manager.py +1 -0
- sglang/srt/mem_cache/memory_pool.py +34 -29
- sglang/srt/metrics/collector.py +8 -0
- sglang/srt/model_executor/cuda_graph_runner.py +1 -7
- sglang/srt/model_executor/model_runner.py +97 -78
- sglang/srt/server_args.py +3 -12
- sglang/srt/speculative/build_eagle_tree.py +6 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -11
- sglang/srt/speculative/eagle_utils.py +2 -1
- sglang/srt/speculative/eagle_worker.py +67 -32
- sglang/version.py +1 -1
- {sglang-0.4.3.post3.dist-info → sglang-0.4.3.post4.dist-info}/METADATA +2 -1
- {sglang-0.4.3.post3.dist-info → sglang-0.4.3.post4.dist-info}/RECORD +21 -21
- {sglang-0.4.3.post3.dist-info → sglang-0.4.3.post4.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post3.dist-info → sglang-0.4.3.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.3.post3.dist-info → sglang-0.4.3.post4.dist-info}/top_level.txt +0 -0
sglang/bench_serving.py
CHANGED
@@ -220,7 +220,7 @@ async def async_request_openai_completions(
|
|
220
220
|
|
221
221
|
most_recent_timestamp = timestamp
|
222
222
|
generated_text += data["choices"][0]["text"]
|
223
|
-
output_len = data.get("usage"
|
223
|
+
output_len = (data.get("usage") or {}).get(
|
224
224
|
"completion_tokens", output_len
|
225
225
|
)
|
226
226
|
|
@@ -7,16 +7,14 @@ FlashInfer is faster and Triton is easier to customize.
|
|
7
7
|
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
|
8
8
|
"""
|
9
9
|
|
10
|
-
import math
|
11
10
|
import os
|
12
11
|
from dataclasses import dataclass
|
13
12
|
from enum import Enum, auto
|
14
13
|
from functools import partial
|
15
|
-
from typing import TYPE_CHECKING, List, Optional, Union
|
14
|
+
from typing import TYPE_CHECKING, Callable, List, Optional, Union
|
16
15
|
|
17
16
|
import torch
|
18
17
|
import triton
|
19
|
-
import triton.language as tl
|
20
18
|
|
21
19
|
from sglang.global_config import global_config
|
22
20
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
@@ -37,7 +35,7 @@ if is_flashinfer_available():
|
|
37
35
|
BatchPrefillWithRaggedKVCacheWrapper,
|
38
36
|
)
|
39
37
|
from flashinfer.cascade import merge_state
|
40
|
-
from flashinfer.decode import
|
38
|
+
from flashinfer.decode import _get_range_buf, get_seq_lens
|
41
39
|
|
42
40
|
|
43
41
|
class WrapperDispatch(Enum):
|
@@ -73,8 +71,6 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
73
71
|
):
|
74
72
|
super().__init__()
|
75
73
|
|
76
|
-
self.is_multimodal = model_runner.model_config.is_multimodal
|
77
|
-
|
78
74
|
# Parse constants
|
79
75
|
self.decode_use_tensor_cores = should_use_tensor_core(
|
80
76
|
kv_cache_dtype=model_runner.kv_cache_dtype,
|
@@ -86,6 +82,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
86
82
|
)
|
87
83
|
self.max_context_len = model_runner.model_config.context_len
|
88
84
|
self.skip_prefill = skip_prefill
|
85
|
+
self.is_multimodal = model_runner.model_config.is_multimodal
|
89
86
|
|
90
87
|
assert not (
|
91
88
|
model_runner.sliding_window_size is not None
|
@@ -115,7 +112,6 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
115
112
|
device=model_runner.device,
|
116
113
|
)
|
117
114
|
self.workspace_buffer = global_workspace_buffer
|
118
|
-
|
119
115
|
max_bs = model_runner.req_to_token_pool.size
|
120
116
|
if kv_indptr_buf is None:
|
121
117
|
self.kv_indptr = [
|
@@ -163,9 +159,11 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
163
159
|
)
|
164
160
|
)
|
165
161
|
self.prefill_wrappers_verify.append(
|
166
|
-
BatchPrefillWithPagedKVCacheWrapper(
|
162
|
+
BatchPrefillWithPagedKVCacheWrapper(
|
163
|
+
self.workspace_buffer,
|
164
|
+
"NHD",
|
165
|
+
)
|
167
166
|
)
|
168
|
-
|
169
167
|
self.decode_wrappers.append(
|
170
168
|
BatchDecodeWithPagedKVCacheWrapper(
|
171
169
|
self.workspace_buffer,
|
@@ -178,13 +176,14 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
178
176
|
if not skip_prefill:
|
179
177
|
self.indices_updater_prefill = FlashInferIndicesUpdaterPrefill(
|
180
178
|
model_runner, self
|
181
|
-
)
|
179
|
+
) # for verify
|
182
180
|
self.indices_updater_decode = FlashInferIndicesUpdaterDecode(model_runner, self)
|
183
181
|
|
184
182
|
# Other metadata
|
185
183
|
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
|
186
184
|
self.decode_cuda_graph_metadata = {}
|
187
|
-
self.prefill_cuda_graph_metadata = {}
|
185
|
+
self.prefill_cuda_graph_metadata = {} # For verify
|
186
|
+
self.draft_extend_cuda_graph_metadata = {} # For draft extend
|
188
187
|
|
189
188
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
190
189
|
if forward_batch.forward_mode.is_decode_or_idle():
|
@@ -300,7 +299,6 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
300
299
|
],
|
301
300
|
)
|
302
301
|
)
|
303
|
-
|
304
302
|
seq_lens_sum = seq_lens.sum().item()
|
305
303
|
self.indices_updater_decode.update(
|
306
304
|
req_pool_indices,
|
@@ -312,6 +310,10 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
312
310
|
)
|
313
311
|
self.decode_cuda_graph_metadata[bs] = decode_wrappers
|
314
312
|
self.forward_metadata = DecodeMetadata(decode_wrappers)
|
313
|
+
for i in range(self.num_wrappers):
|
314
|
+
decode_wrappers[i].begin_forward = partial(
|
315
|
+
fast_decode_plan, decode_wrappers[i]
|
316
|
+
)
|
315
317
|
elif forward_mode.is_target_verify():
|
316
318
|
prefill_wrappers = []
|
317
319
|
for i in range(self.num_wrappers):
|
@@ -437,7 +439,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
437
439
|
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
438
440
|
causal=False,
|
439
441
|
sm_scale=layer.scaling,
|
440
|
-
logits_soft_cap=
|
442
|
+
logits_soft_cap=logits_soft_cap,
|
441
443
|
)
|
442
444
|
|
443
445
|
o, _ = merge_state(o1, s1, o2, s2)
|
@@ -636,9 +638,15 @@ class FlashInferIndicesUpdaterDecode:
|
|
636
638
|
bs = len(req_pool_indices)
|
637
639
|
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
638
640
|
kv_indptr = kv_indptr[: bs + 1]
|
639
|
-
|
640
|
-
|
641
|
-
|
641
|
+
|
642
|
+
if wrapper.is_cuda_graph_enabled:
|
643
|
+
# Directly write to the cuda graph input buffer
|
644
|
+
kv_indices = wrapper._paged_kv_indices_buf
|
645
|
+
else:
|
646
|
+
kv_indices = torch.empty(
|
647
|
+
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
|
648
|
+
)
|
649
|
+
|
642
650
|
create_flashinfer_kv_indices_triton[(bs,)](
|
643
651
|
self.req_to_token,
|
644
652
|
req_pool_indices,
|
@@ -649,9 +657,9 @@ class FlashInferIndicesUpdaterDecode:
|
|
649
657
|
self.req_to_token.shape[1],
|
650
658
|
)
|
651
659
|
else:
|
652
|
-
assert isinstance(spec_info, EagleDraftInput)
|
653
660
|
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
654
661
|
bs = kv_indptr.shape[0] - 1
|
662
|
+
|
655
663
|
wrapper.begin_forward(
|
656
664
|
kv_indptr,
|
657
665
|
kv_indices,
|
@@ -699,7 +707,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
699
707
|
|
700
708
|
def update(
|
701
709
|
self,
|
702
|
-
req_pool_indices: torch.
|
710
|
+
req_pool_indices: torch.Tensor,
|
703
711
|
seq_lens: torch.Tensor,
|
704
712
|
seq_lens_sum: int,
|
705
713
|
prefix_lens: torch.Tensor,
|
@@ -713,7 +721,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
713
721
|
|
714
722
|
def update_single_wrapper(
|
715
723
|
self,
|
716
|
-
req_pool_indices: torch.
|
724
|
+
req_pool_indices: torch.Tensor,
|
717
725
|
seq_lens: torch.Tensor,
|
718
726
|
seq_lens_sum: int,
|
719
727
|
prefix_lens: torch.Tensor,
|
@@ -858,7 +866,6 @@ class FlashInferIndicesUpdaterPrefill:
|
|
858
866
|
kv_indices,
|
859
867
|
self.req_to_token.shape[1],
|
860
868
|
)
|
861
|
-
|
862
869
|
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
863
870
|
qo_indptr = qo_indptr[: bs + 1]
|
864
871
|
custom_mask = None
|
@@ -897,6 +904,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
897
904
|
self.head_dim,
|
898
905
|
1,
|
899
906
|
q_data_type=self.q_data_type,
|
907
|
+
kv_data_type=self.data_type,
|
900
908
|
custom_mask=custom_mask,
|
901
909
|
non_blocking=True,
|
902
910
|
)
|
@@ -954,7 +962,10 @@ class FlashInferMultiStepDraftBackend:
|
|
954
962
|
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
|
955
963
|
|
956
964
|
def common_template(
|
957
|
-
self,
|
965
|
+
self,
|
966
|
+
forward_batch: ForwardBatch,
|
967
|
+
kv_indices_buffer: torch.Tensor,
|
968
|
+
call_fn: Callable,
|
958
969
|
):
|
959
970
|
num_seqs = forward_batch.batch_size
|
960
971
|
bs = self.topk * num_seqs
|
@@ -1042,17 +1053,15 @@ class FlashInferMultiStepDraftBackend:
|
|
1042
1053
|
forward_mode=ForwardMode.DECODE,
|
1043
1054
|
spec_info=forward_batch.spec_info,
|
1044
1055
|
)
|
1045
|
-
decode_wrapper = self.attn_backends[i].decode_cuda_graph_metadata[
|
1046
|
-
forward_batch.batch_size
|
1047
|
-
][0]
|
1048
|
-
decode_wrapper.begin_forward = partial(fast_decode_plan, decode_wrapper)
|
1049
1056
|
|
1050
1057
|
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
1051
1058
|
|
1052
|
-
def init_forward_metadata_replay_cuda_graph(
|
1059
|
+
def init_forward_metadata_replay_cuda_graph(
|
1060
|
+
self, forward_batch: ForwardBatch, bs: int
|
1061
|
+
):
|
1053
1062
|
def call_fn(i, forward_batch):
|
1054
1063
|
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
|
1055
|
-
|
1064
|
+
bs,
|
1056
1065
|
forward_batch.req_pool_indices,
|
1057
1066
|
forward_batch.seq_lens,
|
1058
1067
|
seq_lens_sum=-1,
|
@@ -1113,6 +1122,11 @@ def should_use_tensor_core(
|
|
1113
1122
|
return False
|
1114
1123
|
|
1115
1124
|
|
1125
|
+
# Use as a fast path to override the indptr in flashinfer's plan function
|
1126
|
+
# This is used to remove some host-to-device copy overhead.
|
1127
|
+
global_override_indptr_cpu = None
|
1128
|
+
|
1129
|
+
|
1116
1130
|
def fast_decode_plan(
|
1117
1131
|
self,
|
1118
1132
|
indptr: torch.Tensor,
|
@@ -1142,6 +1156,9 @@ def fast_decode_plan(
|
|
1142
1156
|
if logits_soft_cap is None:
|
1143
1157
|
logits_soft_cap = 0.0
|
1144
1158
|
|
1159
|
+
if self.use_tensor_cores:
|
1160
|
+
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
|
1161
|
+
|
1145
1162
|
if self.is_cuda_graph_enabled:
|
1146
1163
|
if batch_size != self._fixed_batch_size:
|
1147
1164
|
raise ValueError(
|
@@ -1154,7 +1171,7 @@ def fast_decode_plan(
|
|
1154
1171
|
raise ValueError(
|
1155
1172
|
"The size of indices should be less than or equal to the allocated buffer"
|
1156
1173
|
)
|
1157
|
-
# Skip these copies
|
1174
|
+
# Skip these copies because we directly write to them during prepartion
|
1158
1175
|
# self._paged_kv_indptr_buf.copy_(indptr)
|
1159
1176
|
# self._paged_kv_indices_buf[: len(indices)] = indices
|
1160
1177
|
# self._paged_kv_last_page_len_buf.copy_(last_page_len)
|
@@ -1162,6 +1179,7 @@ def fast_decode_plan(
|
|
1162
1179
|
self._paged_kv_indptr_buf = indptr
|
1163
1180
|
self._paged_kv_indices_buf = indices
|
1164
1181
|
self._paged_kv_last_page_len_buf = last_page_len
|
1182
|
+
self._qo_indptr_buf = qo_indptr_host.to(self.device, non_blocking=non_blocking)
|
1165
1183
|
|
1166
1184
|
# NOTE(Zihao): the following tensors acts as placeholder to pass dtype info
|
1167
1185
|
if not q_data_type:
|
@@ -1184,27 +1202,55 @@ def fast_decode_plan(
|
|
1184
1202
|
)
|
1185
1203
|
self.last_page_len = torch.ones(32768, dtype=torch.int32)
|
1186
1204
|
|
1187
|
-
|
1188
|
-
|
1189
|
-
|
1190
|
-
|
1191
|
-
self._float_workspace_buffer,
|
1192
|
-
self._int_workspace_buffer,
|
1193
|
-
self._pin_memory_int_workspace_buffer,
|
1194
|
-
indptr.to("cpu"),
|
1195
|
-
batch_size,
|
1196
|
-
num_qo_heads,
|
1197
|
-
num_kv_heads,
|
1198
|
-
page_size,
|
1199
|
-
self.is_cuda_graph_enabled,
|
1200
|
-
window_left,
|
1201
|
-
logits_soft_cap,
|
1202
|
-
head_dim,
|
1203
|
-
head_dim,
|
1204
|
-
empty_q_data,
|
1205
|
-
empty_kv_cache,
|
1206
|
-
stream.cuda_stream,
|
1205
|
+
indptr_host = (
|
1206
|
+
global_override_indptr_cpu
|
1207
|
+
if global_override_indptr_cpu is not None
|
1208
|
+
else indptr.cpu()
|
1207
1209
|
)
|
1210
|
+
|
1211
|
+
if self.use_tensor_cores:
|
1212
|
+
kv_lens_arr_host = get_seq_lens(
|
1213
|
+
indptr_host, self.last_page_len[:batch_size], page_size
|
1214
|
+
)
|
1215
|
+
|
1216
|
+
self._plan_info = self._cached_module.plan(
|
1217
|
+
self._float_workspace_buffer,
|
1218
|
+
self._int_workspace_buffer,
|
1219
|
+
self._pin_memory_int_workspace_buffer,
|
1220
|
+
qo_indptr_host,
|
1221
|
+
indptr_host,
|
1222
|
+
kv_lens_arr_host,
|
1223
|
+
batch_size, # total_num_rows
|
1224
|
+
batch_size,
|
1225
|
+
num_qo_heads,
|
1226
|
+
num_kv_heads,
|
1227
|
+
page_size,
|
1228
|
+
self.is_cuda_graph_enabled,
|
1229
|
+
head_dim,
|
1230
|
+
head_dim,
|
1231
|
+
False, # causal
|
1232
|
+
torch.cuda.current_stream().cuda_stream,
|
1233
|
+
)
|
1234
|
+
else:
|
1235
|
+
self._plan_info = self._cached_module.plan(
|
1236
|
+
self._float_workspace_buffer,
|
1237
|
+
self._int_workspace_buffer,
|
1238
|
+
self._pin_memory_int_workspace_buffer,
|
1239
|
+
indptr_host,
|
1240
|
+
batch_size,
|
1241
|
+
num_qo_heads,
|
1242
|
+
num_kv_heads,
|
1243
|
+
page_size,
|
1244
|
+
self.is_cuda_graph_enabled,
|
1245
|
+
window_left,
|
1246
|
+
logits_soft_cap,
|
1247
|
+
head_dim,
|
1248
|
+
head_dim,
|
1249
|
+
self.empty_q_data,
|
1250
|
+
self.empty_kv_cache,
|
1251
|
+
torch.cuda.current_stream().cuda_stream,
|
1252
|
+
)
|
1253
|
+
|
1208
1254
|
self._pos_encoding_mode = pos_encoding_mode
|
1209
1255
|
self._window_left = window_left
|
1210
1256
|
self._logits_soft_cap = logits_soft_cap
|
@@ -578,10 +578,12 @@ class TritonMultiStepDraftBackend:
|
|
578
578
|
|
579
579
|
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
580
580
|
|
581
|
-
def init_forward_metadata_replay_cuda_graph(
|
581
|
+
def init_forward_metadata_replay_cuda_graph(
|
582
|
+
self, forward_batch: ForwardBatch, bs: int
|
583
|
+
):
|
582
584
|
def call_fn(i, forward_batch):
|
583
585
|
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
|
584
|
-
|
586
|
+
bs,
|
585
587
|
forward_batch.req_pool_indices,
|
586
588
|
forward_batch.seq_lens,
|
587
589
|
seq_lens_sum=-1,
|