sglang 0.1.18__py3-none-any.whl → 0.1.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +1 -1
- sglang/api.py +26 -0
- sglang/backend/runtime_endpoint.py +18 -14
- sglang/bench_latency.py +40 -18
- sglang/global_config.py +21 -16
- sglang/lang/chat_template.py +41 -6
- sglang/lang/interpreter.py +5 -1
- sglang/lang/ir.py +61 -25
- sglang/srt/constrained/__init__.py +3 -2
- sglang/srt/hf_transformers_utils.py +7 -3
- sglang/srt/layers/extend_attention.py +2 -1
- sglang/srt/layers/fused_moe.py +181 -167
- sglang/srt/layers/logits_processor.py +55 -19
- sglang/srt/layers/radix_attention.py +33 -59
- sglang/srt/layers/token_attention.py +4 -8
- sglang/srt/managers/controller/cuda_graph_runner.py +172 -0
- sglang/srt/managers/controller/infer_batch.py +244 -36
- sglang/srt/managers/controller/manager_single.py +1 -1
- sglang/srt/managers/controller/model_runner.py +69 -284
- sglang/srt/managers/controller/tp_worker.py +39 -20
- sglang/srt/managers/detokenizer_manager.py +4 -2
- sglang/srt/managers/io_struct.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +14 -13
- sglang/srt/memory_pool.py +33 -6
- sglang/srt/model_config.py +6 -0
- sglang/srt/models/gemma2.py +436 -0
- sglang/srt/models/llama2.py +3 -3
- sglang/srt/models/llama_classification.py +10 -7
- sglang/srt/models/minicpm.py +373 -0
- sglang/srt/models/qwen2_moe.py +454 -0
- sglang/srt/openai_api_adapter.py +2 -2
- sglang/srt/openai_protocol.py +1 -1
- sglang/srt/server.py +18 -8
- sglang/srt/server_args.py +24 -20
- sglang/srt/utils.py +68 -35
- {sglang-0.1.18.dist-info → sglang-0.1.20.dist-info}/METADATA +19 -13
- {sglang-0.1.18.dist-info → sglang-0.1.20.dist-info}/RECORD +40 -36
- {sglang-0.1.18.dist-info → sglang-0.1.20.dist-info}/WHEEL +1 -1
- {sglang-0.1.18.dist-info → sglang-0.1.20.dist-info}/LICENSE +0 -0
- {sglang-0.1.18.dist-info → sglang-0.1.20.dist-info}/top_level.txt +0 -0
@@ -3,7 +3,7 @@
|
|
3
3
|
import warnings
|
4
4
|
from dataclasses import dataclass
|
5
5
|
from enum import IntEnum, auto
|
6
|
-
from typing import List
|
6
|
+
from typing import List, Union
|
7
7
|
|
8
8
|
import numpy as np
|
9
9
|
import torch
|
@@ -15,10 +15,16 @@ from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
|
15
15
|
|
16
16
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
17
17
|
|
18
|
+
# Store some global server args
|
19
|
+
global_server_args_dict = {}
|
20
|
+
|
18
21
|
|
19
22
|
class ForwardMode(IntEnum):
|
23
|
+
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
|
20
24
|
PREFILL = auto()
|
25
|
+
# Extend a sequence. The KV cache of the first part of the sequence is already computed (e.g., system prompt).
|
21
26
|
EXTEND = auto()
|
27
|
+
# Decode one token.
|
22
28
|
DECODE = auto()
|
23
29
|
|
24
30
|
|
@@ -31,7 +37,7 @@ class BaseFinishReason:
|
|
31
37
|
|
32
38
|
|
33
39
|
class FINISH_MATCHED_TOKEN(BaseFinishReason):
|
34
|
-
def __init__(self, matched: int
|
40
|
+
def __init__(self, matched: Union[int, List[int]]):
|
35
41
|
super().__init__()
|
36
42
|
self.matched = matched
|
37
43
|
|
@@ -66,7 +72,10 @@ class FINISH_ABORT(BaseFinishReason):
|
|
66
72
|
|
67
73
|
|
68
74
|
class Req:
|
75
|
+
"""Store all inforamtion of a request."""
|
76
|
+
|
69
77
|
def __init__(self, rid, origin_input_text, origin_input_ids):
|
78
|
+
# Input and output info
|
70
79
|
self.rid = rid
|
71
80
|
self.origin_input_text = origin_input_text
|
72
81
|
self.origin_input_ids_unpadded = origin_input_ids # Before image padding
|
@@ -74,7 +83,7 @@ class Req:
|
|
74
83
|
self.output_ids = [] # Each decode stage's output ids
|
75
84
|
self.input_ids = None # input_ids = origin_input_ids + output_ids
|
76
85
|
|
77
|
-
# For incremental
|
86
|
+
# For incremental decoding
|
78
87
|
self.decoded_text = ""
|
79
88
|
self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
|
80
89
|
self.read_offset = None
|
@@ -89,20 +98,19 @@ class Req:
|
|
89
98
|
self.image_offset = 0
|
90
99
|
self.pad_value = None
|
91
100
|
|
101
|
+
# Prefix info
|
102
|
+
self.extend_input_len = 0
|
103
|
+
self.prefix_indices = []
|
104
|
+
self.last_node = None
|
105
|
+
|
92
106
|
# Sampling parameters
|
93
107
|
self.sampling_params = None
|
94
108
|
self.stream = False
|
95
109
|
|
96
|
-
self.tokenizer = None
|
97
|
-
|
98
110
|
# Check finish
|
111
|
+
self.tokenizer = None
|
99
112
|
self.finished_reason = None
|
100
113
|
|
101
|
-
# Prefix info
|
102
|
-
self.extend_input_len = 0
|
103
|
-
self.prefix_indices = []
|
104
|
-
self.last_node = None
|
105
|
-
|
106
114
|
# Logprobs
|
107
115
|
self.return_logprob = False
|
108
116
|
self.logprob_start_len = 0
|
@@ -252,35 +260,36 @@ class Req:
|
|
252
260
|
|
253
261
|
@dataclass
|
254
262
|
class Batch:
|
263
|
+
"""Store all inforamtion of a batch."""
|
264
|
+
|
265
|
+
# Request, memory pool, and cache
|
255
266
|
reqs: List[Req]
|
256
267
|
req_to_token_pool: ReqToTokenPool
|
257
268
|
token_to_kv_pool: TokenToKVPool
|
258
269
|
tree_cache: RadixCache
|
259
270
|
|
260
|
-
#
|
271
|
+
# Batched arguments to model runner
|
261
272
|
input_ids: torch.Tensor = None
|
262
273
|
req_pool_indices: torch.Tensor = None
|
263
274
|
seq_lens: torch.Tensor = None
|
264
275
|
prefix_lens: torch.Tensor = None
|
265
276
|
position_ids_offsets: torch.Tensor = None
|
266
277
|
out_cache_loc: torch.Tensor = None
|
267
|
-
out_cache_cont_start: torch.Tensor = None
|
268
|
-
out_cache_cont_end: torch.Tensor = None
|
269
278
|
|
270
|
-
#
|
279
|
+
# For processing logprobs
|
271
280
|
return_logprob: bool = False
|
272
281
|
top_logprobs_nums: List[int] = None
|
273
282
|
|
274
|
-
#
|
283
|
+
# For multimodal
|
275
284
|
pixel_values: List[torch.Tensor] = None
|
276
285
|
image_sizes: List[List[int]] = None
|
277
286
|
image_offsets: List[int] = None
|
278
287
|
|
279
|
-
#
|
288
|
+
# Other arguments for control
|
280
289
|
output_ids: torch.Tensor = None
|
281
290
|
extend_num_tokens: int = None
|
282
291
|
|
283
|
-
#
|
292
|
+
# Batched sampling params
|
284
293
|
temperatures: torch.Tensor = None
|
285
294
|
top_ps: torch.Tensor = None
|
286
295
|
top_ks: torch.Tensor = None
|
@@ -303,8 +312,8 @@ class Batch:
|
|
303
312
|
def is_empty(self):
|
304
313
|
return len(self.reqs) == 0
|
305
314
|
|
306
|
-
# whether batch has at least 1 streaming request
|
307
315
|
def has_stream(self) -> bool:
|
316
|
+
# Return whether batch has at least 1 streaming request
|
308
317
|
return any(r.stream for r in self.reqs)
|
309
318
|
|
310
319
|
def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
|
@@ -338,7 +347,7 @@ class Batch:
|
|
338
347
|
|
339
348
|
position_ids_offsets = torch.zeros((bs,), dtype=torch.int32, device=device)
|
340
349
|
|
341
|
-
#
|
350
|
+
# Allocate memory
|
342
351
|
seq_lens, prefix_lens = np.array(seq_lens), np.array(prefix_lens)
|
343
352
|
extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
|
344
353
|
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
@@ -555,21 +564,12 @@ class Batch:
|
|
555
564
|
|
556
565
|
# Alloc mem
|
557
566
|
bs = len(self.reqs)
|
558
|
-
|
559
|
-
if alloc_res is None:
|
560
|
-
self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
|
567
|
+
self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
|
561
568
|
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
self.out_cache_cont_start = None
|
568
|
-
self.out_cache_cont_end = None
|
569
|
-
else:
|
570
|
-
self.out_cache_loc = alloc_res[0]
|
571
|
-
self.out_cache_cont_start = alloc_res[1]
|
572
|
-
self.out_cache_cont_end = alloc_res[2]
|
569
|
+
if self.out_cache_loc is None:
|
570
|
+
print("Decode out of memory. This should never happen.")
|
571
|
+
self.tree_cache.pretty_print()
|
572
|
+
exit()
|
573
573
|
|
574
574
|
self.req_to_token_pool.req_to_token[
|
575
575
|
self.req_pool_indices, self.seq_lens - 1
|
@@ -583,7 +583,7 @@ class Batch:
|
|
583
583
|
self.req_pool_indices = self.req_pool_indices[new_indices]
|
584
584
|
self.prefix_lens = None
|
585
585
|
self.position_ids_offsets = self.position_ids_offsets[new_indices]
|
586
|
-
self.out_cache_loc =
|
586
|
+
self.out_cache_loc = None
|
587
587
|
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
|
588
588
|
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
589
589
|
|
@@ -611,7 +611,7 @@ class Batch:
|
|
611
611
|
self.position_ids_offsets = torch.concat(
|
612
612
|
[self.position_ids_offsets, other.position_ids_offsets]
|
613
613
|
)
|
614
|
-
self.out_cache_loc =
|
614
|
+
self.out_cache_loc = None
|
615
615
|
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
616
616
|
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
617
617
|
|
@@ -664,7 +664,11 @@ class Batch:
|
|
664
664
|
# TODO(lmzheng): apply penalty
|
665
665
|
probs = torch.softmax(logits, dim=-1)
|
666
666
|
probs_sort, probs_idx = _top_p_top_k(probs, self.top_ps, self.top_ks)
|
667
|
-
|
667
|
+
try:
|
668
|
+
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
669
|
+
except RuntimeError as e:
|
670
|
+
warnings.warn(f"Ignore errors in sampling: {e}")
|
671
|
+
sampled_index = torch.ones(probs_sort.shape[:-1] + (1,), dtype=torch.int64, device=probs.device)
|
668
672
|
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(
|
669
673
|
-1
|
670
674
|
)
|
@@ -692,3 +696,207 @@ def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor
|
|
692
696
|
] = 0.0
|
693
697
|
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
|
694
698
|
return probs_sort, probs_idx
|
699
|
+
|
700
|
+
|
701
|
+
@dataclass
|
702
|
+
class InputMetadata:
|
703
|
+
"""Store all inforamtion of a forward pass."""
|
704
|
+
|
705
|
+
forward_mode: ForwardMode
|
706
|
+
batch_size: int
|
707
|
+
total_num_tokens: int
|
708
|
+
req_pool_indices: torch.Tensor
|
709
|
+
seq_lens: torch.Tensor
|
710
|
+
positions: torch.Tensor
|
711
|
+
req_to_token_pool: ReqToTokenPool
|
712
|
+
token_to_kv_pool: TokenToKVPool
|
713
|
+
|
714
|
+
# For extend
|
715
|
+
extend_seq_lens: torch.Tensor
|
716
|
+
extend_start_loc: torch.Tensor
|
717
|
+
extend_no_prefix: bool
|
718
|
+
|
719
|
+
# Output location of the KV cache
|
720
|
+
out_cache_loc: torch.Tensor = None
|
721
|
+
|
722
|
+
# Output options
|
723
|
+
return_logprob: bool = False
|
724
|
+
top_logprobs_nums: List[int] = None
|
725
|
+
|
726
|
+
# Trition attention backend
|
727
|
+
triton_max_seq_len: int = 0
|
728
|
+
triton_max_extend_len: int = 0
|
729
|
+
triton_start_loc: torch.Tensor = None
|
730
|
+
triton_prefix_lens: torch.Tensor = None
|
731
|
+
|
732
|
+
# FlashInfer attention backend
|
733
|
+
flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
|
734
|
+
flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
|
735
|
+
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
|
736
|
+
|
737
|
+
@classmethod
|
738
|
+
def create(
|
739
|
+
cls,
|
740
|
+
model_runner,
|
741
|
+
forward_mode,
|
742
|
+
req_pool_indices,
|
743
|
+
seq_lens,
|
744
|
+
prefix_lens,
|
745
|
+
position_ids_offsets,
|
746
|
+
out_cache_loc,
|
747
|
+
top_logprobs_nums=None,
|
748
|
+
return_logprob=False,
|
749
|
+
skip_flashinfer_init=False,
|
750
|
+
):
|
751
|
+
if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
|
752
|
+
init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens, prefix_lens,
|
753
|
+
model_runner.flashinfer_decode_wrapper)
|
754
|
+
|
755
|
+
batch_size = len(req_pool_indices)
|
756
|
+
|
757
|
+
if forward_mode == ForwardMode.DECODE:
|
758
|
+
positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
|
759
|
+
extend_seq_lens = extend_start_loc = extend_no_prefix = None
|
760
|
+
if not model_runner.server_args.disable_flashinfer:
|
761
|
+
# This variable is not needed in this case,
|
762
|
+
# we do not compute it to make it compatbile with cuda graph.
|
763
|
+
total_num_tokens = None
|
764
|
+
else:
|
765
|
+
total_num_tokens = int(torch.sum(seq_lens))
|
766
|
+
else:
|
767
|
+
seq_lens_cpu = seq_lens.cpu().numpy()
|
768
|
+
prefix_lens_cpu = prefix_lens.cpu().numpy()
|
769
|
+
position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
|
770
|
+
positions = torch.tensor(
|
771
|
+
np.concatenate(
|
772
|
+
[
|
773
|
+
np.arange(
|
774
|
+
prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
|
775
|
+
seq_lens_cpu[i] + position_ids_offsets_cpu[i],
|
776
|
+
)
|
777
|
+
for i in range(batch_size)
|
778
|
+
],
|
779
|
+
axis=0,
|
780
|
+
),
|
781
|
+
device="cuda",
|
782
|
+
)
|
783
|
+
extend_seq_lens = seq_lens - prefix_lens
|
784
|
+
extend_start_loc = torch.zeros_like(seq_lens)
|
785
|
+
extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
|
786
|
+
extend_no_prefix = torch.all(prefix_lens == 0)
|
787
|
+
total_num_tokens = int(torch.sum(seq_lens))
|
788
|
+
|
789
|
+
ret = cls(
|
790
|
+
forward_mode=forward_mode,
|
791
|
+
batch_size=batch_size,
|
792
|
+
total_num_tokens=total_num_tokens,
|
793
|
+
req_pool_indices=req_pool_indices,
|
794
|
+
seq_lens=seq_lens,
|
795
|
+
positions=positions,
|
796
|
+
req_to_token_pool=model_runner.req_to_token_pool,
|
797
|
+
token_to_kv_pool=model_runner.token_to_kv_pool,
|
798
|
+
out_cache_loc=out_cache_loc,
|
799
|
+
extend_seq_lens=extend_seq_lens,
|
800
|
+
extend_start_loc=extend_start_loc,
|
801
|
+
extend_no_prefix=extend_no_prefix,
|
802
|
+
return_logprob=return_logprob,
|
803
|
+
top_logprobs_nums=top_logprobs_nums,
|
804
|
+
flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged,
|
805
|
+
flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged,
|
806
|
+
flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
|
807
|
+
)
|
808
|
+
|
809
|
+
if model_runner.server_args.disable_flashinfer:
|
810
|
+
(ret.triton_max_seq_len,
|
811
|
+
ret.triton_max_extend_len,
|
812
|
+
ret.triton_start_loc,
|
813
|
+
ret.triton_prefix_lens) = init_triton_args(forward_mode, seq_lens, prefix_lens)
|
814
|
+
|
815
|
+
return ret
|
816
|
+
|
817
|
+
|
818
|
+
def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens, prefix_lens,
|
819
|
+
flashinfer_decode_wrapper):
|
820
|
+
num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
|
821
|
+
num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
|
822
|
+
head_dim = model_runner.model_config.head_dim
|
823
|
+
batch_size = len(req_pool_indices)
|
824
|
+
|
825
|
+
if forward_mode == ForwardMode.DECODE:
|
826
|
+
paged_kernel_lens = seq_lens
|
827
|
+
else:
|
828
|
+
paged_kernel_lens = prefix_lens
|
829
|
+
|
830
|
+
kv_indptr = torch.zeros(
|
831
|
+
(batch_size + 1,), dtype=torch.int32, device="cuda"
|
832
|
+
)
|
833
|
+
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
834
|
+
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
835
|
+
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
836
|
+
kv_indices = torch.cat(
|
837
|
+
[
|
838
|
+
model_runner.req_to_token_pool.req_to_token[
|
839
|
+
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
|
840
|
+
]
|
841
|
+
for i in range(batch_size)
|
842
|
+
],
|
843
|
+
dim=0,
|
844
|
+
).contiguous()
|
845
|
+
kv_last_page_len = torch.ones(
|
846
|
+
(batch_size,), dtype=torch.int32, device="cuda"
|
847
|
+
)
|
848
|
+
|
849
|
+
if forward_mode == ForwardMode.DECODE:
|
850
|
+
flashinfer_decode_wrapper.end_forward()
|
851
|
+
flashinfer_decode_wrapper.begin_forward(
|
852
|
+
kv_indptr,
|
853
|
+
kv_indices,
|
854
|
+
kv_last_page_len,
|
855
|
+
num_qo_heads,
|
856
|
+
num_kv_heads,
|
857
|
+
head_dim,
|
858
|
+
1,
|
859
|
+
)
|
860
|
+
else:
|
861
|
+
# extend part
|
862
|
+
qo_indptr = torch.zeros(
|
863
|
+
(batch_size + 1,), dtype=torch.int32, device="cuda"
|
864
|
+
)
|
865
|
+
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
866
|
+
|
867
|
+
model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
|
868
|
+
model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
|
869
|
+
qo_indptr,
|
870
|
+
qo_indptr,
|
871
|
+
num_qo_heads,
|
872
|
+
num_kv_heads,
|
873
|
+
head_dim,
|
874
|
+
)
|
875
|
+
|
876
|
+
# cached part
|
877
|
+
model_runner.flashinfer_prefill_wrapper_paged.end_forward()
|
878
|
+
model_runner.flashinfer_prefill_wrapper_paged.begin_forward(
|
879
|
+
qo_indptr,
|
880
|
+
kv_indptr,
|
881
|
+
kv_indices,
|
882
|
+
kv_last_page_len,
|
883
|
+
num_qo_heads,
|
884
|
+
num_kv_heads,
|
885
|
+
head_dim,
|
886
|
+
1,
|
887
|
+
)
|
888
|
+
|
889
|
+
|
890
|
+
def init_triton_args(forward_mode, seq_lens, prefix_lens):
|
891
|
+
batch_size = len(seq_lens)
|
892
|
+
max_seq_len = int(torch.max(seq_lens))
|
893
|
+
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
894
|
+
start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
|
895
|
+
|
896
|
+
if forward_mode == ForwardMode.DECODE:
|
897
|
+
max_extend_len = None
|
898
|
+
else:
|
899
|
+
extend_seq_lens = seq_lens - prefix_lens
|
900
|
+
max_extend_len = int(torch.max(extend_seq_lens))
|
901
|
+
|
902
|
+
return max_seq_len, max_extend_len, start_loc, prefix_lens
|