sglang 0.1.19__py3-none-any.whl → 0.1.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +1 -1
- sglang/backend/runtime_endpoint.py +14 -4
- sglang/bench_latency.py +6 -3
- sglang/global_config.py +22 -16
- sglang/lang/chat_template.py +2 -2
- sglang/lang/ir.py +3 -3
- sglang/srt/layers/radix_attention.py +14 -37
- sglang/srt/layers/token_attention.py +2 -9
- sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
- sglang/srt/managers/controller/infer_batch.py +256 -42
- sglang/srt/managers/controller/manager_multi.py +6 -2
- sglang/srt/managers/controller/manager_single.py +125 -50
- sglang/srt/managers/controller/model_runner.py +69 -284
- sglang/srt/managers/controller/radix_cache.py +4 -3
- sglang/srt/managers/controller/schedule_heuristic.py +4 -0
- sglang/srt/managers/controller/tp_worker.py +44 -44
- sglang/srt/memory_pool.py +52 -50
- sglang/srt/models/minicpm.py +1 -8
- sglang/srt/models/qwen2_moe.py +126 -107
- sglang/srt/server.py +11 -15
- sglang/srt/server_args.py +12 -4
- sglang/srt/utils.py +1 -1
- {sglang-0.1.19.dist-info → sglang-0.1.21.dist-info}/METADATA +9 -1
- {sglang-0.1.19.dist-info → sglang-0.1.21.dist-info}/RECORD +27 -26
- {sglang-0.1.19.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
- {sglang-0.1.19.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
- {sglang-0.1.19.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
@@ -15,10 +15,16 @@ from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
|
15
15
|
|
16
16
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
17
17
|
|
18
|
+
# Store some global server args
|
19
|
+
global_server_args_dict = {}
|
20
|
+
|
18
21
|
|
19
22
|
class ForwardMode(IntEnum):
|
23
|
+
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
|
20
24
|
PREFILL = auto()
|
25
|
+
# Extend a sequence. The KV cache of the first part of the sequence is already computed (e.g., system prompt).
|
21
26
|
EXTEND = auto()
|
27
|
+
# Decode one token.
|
22
28
|
DECODE = auto()
|
23
29
|
|
24
30
|
|
@@ -66,7 +72,10 @@ class FINISH_ABORT(BaseFinishReason):
|
|
66
72
|
|
67
73
|
|
68
74
|
class Req:
|
75
|
+
"""Store all inforamtion of a request."""
|
76
|
+
|
69
77
|
def __init__(self, rid, origin_input_text, origin_input_ids):
|
78
|
+
# Input and output info
|
70
79
|
self.rid = rid
|
71
80
|
self.origin_input_text = origin_input_text
|
72
81
|
self.origin_input_ids_unpadded = origin_input_ids # Before image padding
|
@@ -74,7 +83,7 @@ class Req:
|
|
74
83
|
self.output_ids = [] # Each decode stage's output ids
|
75
84
|
self.input_ids = None # input_ids = origin_input_ids + output_ids
|
76
85
|
|
77
|
-
# For incremental
|
86
|
+
# For incremental decoding
|
78
87
|
self.decoded_text = ""
|
79
88
|
self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
|
80
89
|
self.read_offset = None
|
@@ -89,20 +98,19 @@ class Req:
|
|
89
98
|
self.image_offset = 0
|
90
99
|
self.pad_value = None
|
91
100
|
|
101
|
+
# Prefix info
|
102
|
+
self.extend_input_len = 0
|
103
|
+
self.prefix_indices = []
|
104
|
+
self.last_node = None
|
105
|
+
|
92
106
|
# Sampling parameters
|
93
107
|
self.sampling_params = None
|
94
108
|
self.stream = False
|
95
109
|
|
96
|
-
self.tokenizer = None
|
97
|
-
|
98
110
|
# Check finish
|
111
|
+
self.tokenizer = None
|
99
112
|
self.finished_reason = None
|
100
113
|
|
101
|
-
# Prefix info
|
102
|
-
self.extend_input_len = 0
|
103
|
-
self.prefix_indices = []
|
104
|
-
self.last_node = None
|
105
|
-
|
106
114
|
# Logprobs
|
107
115
|
self.return_logprob = False
|
108
116
|
self.logprob_start_len = 0
|
@@ -166,9 +174,6 @@ class Req:
|
|
166
174
|
|
167
175
|
return False, ""
|
168
176
|
|
169
|
-
def max_new_tokens(self):
|
170
|
-
return self.sampling_params.max_new_tokens
|
171
|
-
|
172
177
|
def check_finished(self):
|
173
178
|
if self.finished():
|
174
179
|
return
|
@@ -252,35 +257,36 @@ class Req:
|
|
252
257
|
|
253
258
|
@dataclass
|
254
259
|
class Batch:
|
260
|
+
"""Store all inforamtion of a batch."""
|
261
|
+
|
262
|
+
# Request, memory pool, and cache
|
255
263
|
reqs: List[Req]
|
256
264
|
req_to_token_pool: ReqToTokenPool
|
257
265
|
token_to_kv_pool: TokenToKVPool
|
258
266
|
tree_cache: RadixCache
|
259
267
|
|
260
|
-
#
|
268
|
+
# Batched arguments to model runner
|
261
269
|
input_ids: torch.Tensor = None
|
262
270
|
req_pool_indices: torch.Tensor = None
|
263
271
|
seq_lens: torch.Tensor = None
|
264
272
|
prefix_lens: torch.Tensor = None
|
265
273
|
position_ids_offsets: torch.Tensor = None
|
266
274
|
out_cache_loc: torch.Tensor = None
|
267
|
-
out_cache_cont_start: torch.Tensor = None
|
268
|
-
out_cache_cont_end: torch.Tensor = None
|
269
275
|
|
270
|
-
#
|
276
|
+
# For processing logprobs
|
271
277
|
return_logprob: bool = False
|
272
278
|
top_logprobs_nums: List[int] = None
|
273
279
|
|
274
|
-
#
|
280
|
+
# For multimodal
|
275
281
|
pixel_values: List[torch.Tensor] = None
|
276
282
|
image_sizes: List[List[int]] = None
|
277
283
|
image_offsets: List[int] = None
|
278
284
|
|
279
|
-
#
|
285
|
+
# Other arguments for control
|
280
286
|
output_ids: torch.Tensor = None
|
281
287
|
extend_num_tokens: int = None
|
282
288
|
|
283
|
-
#
|
289
|
+
# Batched sampling params
|
284
290
|
temperatures: torch.Tensor = None
|
285
291
|
top_ps: torch.Tensor = None
|
286
292
|
top_ks: torch.Tensor = None
|
@@ -303,8 +309,8 @@ class Batch:
|
|
303
309
|
def is_empty(self):
|
304
310
|
return len(self.reqs) == 0
|
305
311
|
|
306
|
-
# whether batch has at least 1 streaming request
|
307
312
|
def has_stream(self) -> bool:
|
313
|
+
# Return whether batch has at least 1 streaming request
|
308
314
|
return any(r.stream for r in self.reqs)
|
309
315
|
|
310
316
|
def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
|
@@ -338,12 +344,12 @@ class Batch:
|
|
338
344
|
|
339
345
|
position_ids_offsets = torch.zeros((bs,), dtype=torch.int32, device=device)
|
340
346
|
|
341
|
-
#
|
347
|
+
# Allocate memory
|
342
348
|
seq_lens, prefix_lens = np.array(seq_lens), np.array(prefix_lens)
|
343
349
|
extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
|
344
350
|
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
345
351
|
if out_cache_loc is None:
|
346
|
-
self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.
|
352
|
+
self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.free)
|
347
353
|
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
348
354
|
|
349
355
|
if out_cache_loc is None:
|
@@ -413,7 +419,7 @@ class Batch:
|
|
413
419
|
if self.token_to_kv_pool.available_size() >= bs:
|
414
420
|
return True
|
415
421
|
|
416
|
-
self.tree_cache.evict(bs, self.token_to_kv_pool.
|
422
|
+
self.tree_cache.evict(bs, self.token_to_kv_pool.free)
|
417
423
|
|
418
424
|
if self.token_to_kv_pool.available_size() >= bs:
|
419
425
|
return True
|
@@ -444,7 +450,7 @@ class Batch:
|
|
444
450
|
token_indices = self.req_to_token_pool.req_to_token[
|
445
451
|
req_pool_indices_cpu[idx]
|
446
452
|
][last_uncached_pos : seq_lens_cpu[idx]]
|
447
|
-
self.token_to_kv_pool.
|
453
|
+
self.token_to_kv_pool.free(token_indices)
|
448
454
|
|
449
455
|
# release the last node
|
450
456
|
self.tree_cache.dec_lock_ref(req.last_node)
|
@@ -555,21 +561,12 @@ class Batch:
|
|
555
561
|
|
556
562
|
# Alloc mem
|
557
563
|
bs = len(self.reqs)
|
558
|
-
|
559
|
-
if alloc_res is None:
|
560
|
-
self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
|
564
|
+
self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
|
561
565
|
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
self.out_cache_cont_start = None
|
568
|
-
self.out_cache_cont_end = None
|
569
|
-
else:
|
570
|
-
self.out_cache_loc = alloc_res[0]
|
571
|
-
self.out_cache_cont_start = alloc_res[1]
|
572
|
-
self.out_cache_cont_end = alloc_res[2]
|
566
|
+
if self.out_cache_loc is None:
|
567
|
+
print("Decode out of memory. This should never happen.")
|
568
|
+
self.tree_cache.pretty_print()
|
569
|
+
exit()
|
573
570
|
|
574
571
|
self.req_to_token_pool.req_to_token[
|
575
572
|
self.req_pool_indices, self.seq_lens - 1
|
@@ -583,7 +580,7 @@ class Batch:
|
|
583
580
|
self.req_pool_indices = self.req_pool_indices[new_indices]
|
584
581
|
self.prefix_lens = None
|
585
582
|
self.position_ids_offsets = self.position_ids_offsets[new_indices]
|
586
|
-
self.out_cache_loc =
|
583
|
+
self.out_cache_loc = None
|
587
584
|
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
|
588
585
|
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
589
586
|
|
@@ -596,8 +593,7 @@ class Batch:
|
|
596
593
|
"logit_bias",
|
597
594
|
]:
|
598
595
|
self_val = getattr(self, item, None)
|
599
|
-
# logit_bias can be None
|
600
|
-
if self_val is not None:
|
596
|
+
if self_val is not None: # logit_bias can be None
|
601
597
|
setattr(self, item, self_val[new_indices])
|
602
598
|
|
603
599
|
def merge(self, other: "Batch"):
|
@@ -611,7 +607,7 @@ class Batch:
|
|
611
607
|
self.position_ids_offsets = torch.concat(
|
612
608
|
[self.position_ids_offsets, other.position_ids_offsets]
|
613
609
|
)
|
614
|
-
self.out_cache_loc =
|
610
|
+
self.out_cache_loc = None
|
615
611
|
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
616
612
|
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
617
613
|
|
@@ -664,7 +660,13 @@ class Batch:
|
|
664
660
|
# TODO(lmzheng): apply penalty
|
665
661
|
probs = torch.softmax(logits, dim=-1)
|
666
662
|
probs_sort, probs_idx = _top_p_top_k(probs, self.top_ps, self.top_ks)
|
667
|
-
|
663
|
+
try:
|
664
|
+
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
665
|
+
except RuntimeError as e:
|
666
|
+
warnings.warn(f"Ignore errors in sampling: {e}")
|
667
|
+
sampled_index = torch.ones(
|
668
|
+
probs_sort.shape[:-1] + (1,), dtype=torch.int64, device=probs.device
|
669
|
+
)
|
668
670
|
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(
|
669
671
|
-1
|
670
672
|
)
|
@@ -692,3 +694,215 @@ def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor
|
|
692
694
|
] = 0.0
|
693
695
|
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
|
694
696
|
return probs_sort, probs_idx
|
697
|
+
|
698
|
+
|
699
|
+
@dataclass
|
700
|
+
class InputMetadata:
|
701
|
+
"""Store all inforamtion of a forward pass."""
|
702
|
+
|
703
|
+
forward_mode: ForwardMode
|
704
|
+
batch_size: int
|
705
|
+
total_num_tokens: int
|
706
|
+
req_pool_indices: torch.Tensor
|
707
|
+
seq_lens: torch.Tensor
|
708
|
+
positions: torch.Tensor
|
709
|
+
req_to_token_pool: ReqToTokenPool
|
710
|
+
token_to_kv_pool: TokenToKVPool
|
711
|
+
|
712
|
+
# For extend
|
713
|
+
extend_seq_lens: torch.Tensor
|
714
|
+
extend_start_loc: torch.Tensor
|
715
|
+
extend_no_prefix: bool
|
716
|
+
|
717
|
+
# Output location of the KV cache
|
718
|
+
out_cache_loc: torch.Tensor = None
|
719
|
+
|
720
|
+
# Output options
|
721
|
+
return_logprob: bool = False
|
722
|
+
top_logprobs_nums: List[int] = None
|
723
|
+
|
724
|
+
# Trition attention backend
|
725
|
+
triton_max_seq_len: int = 0
|
726
|
+
triton_max_extend_len: int = 0
|
727
|
+
triton_start_loc: torch.Tensor = None
|
728
|
+
triton_prefix_lens: torch.Tensor = None
|
729
|
+
|
730
|
+
# FlashInfer attention backend
|
731
|
+
flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
|
732
|
+
flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
|
733
|
+
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
|
734
|
+
|
735
|
+
@classmethod
|
736
|
+
def create(
|
737
|
+
cls,
|
738
|
+
model_runner,
|
739
|
+
forward_mode,
|
740
|
+
req_pool_indices,
|
741
|
+
seq_lens,
|
742
|
+
prefix_lens,
|
743
|
+
position_ids_offsets,
|
744
|
+
out_cache_loc,
|
745
|
+
top_logprobs_nums=None,
|
746
|
+
return_logprob=False,
|
747
|
+
skip_flashinfer_init=False,
|
748
|
+
):
|
749
|
+
if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
|
750
|
+
init_flashinfer_args(
|
751
|
+
forward_mode,
|
752
|
+
model_runner,
|
753
|
+
req_pool_indices,
|
754
|
+
seq_lens,
|
755
|
+
prefix_lens,
|
756
|
+
model_runner.flashinfer_decode_wrapper,
|
757
|
+
)
|
758
|
+
|
759
|
+
batch_size = len(req_pool_indices)
|
760
|
+
|
761
|
+
if forward_mode == ForwardMode.DECODE:
|
762
|
+
positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
|
763
|
+
extend_seq_lens = extend_start_loc = extend_no_prefix = None
|
764
|
+
if not model_runner.server_args.disable_flashinfer:
|
765
|
+
# This variable is not needed in this case,
|
766
|
+
# we do not compute it to make it compatbile with cuda graph.
|
767
|
+
total_num_tokens = None
|
768
|
+
else:
|
769
|
+
total_num_tokens = int(torch.sum(seq_lens))
|
770
|
+
else:
|
771
|
+
seq_lens_cpu = seq_lens.cpu().numpy()
|
772
|
+
prefix_lens_cpu = prefix_lens.cpu().numpy()
|
773
|
+
position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
|
774
|
+
positions = torch.tensor(
|
775
|
+
np.concatenate(
|
776
|
+
[
|
777
|
+
np.arange(
|
778
|
+
prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
|
779
|
+
seq_lens_cpu[i] + position_ids_offsets_cpu[i],
|
780
|
+
)
|
781
|
+
for i in range(batch_size)
|
782
|
+
],
|
783
|
+
axis=0,
|
784
|
+
),
|
785
|
+
device="cuda",
|
786
|
+
)
|
787
|
+
extend_seq_lens = seq_lens - prefix_lens
|
788
|
+
extend_start_loc = torch.zeros_like(seq_lens)
|
789
|
+
extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
|
790
|
+
extend_no_prefix = torch.all(prefix_lens == 0)
|
791
|
+
total_num_tokens = int(torch.sum(seq_lens))
|
792
|
+
|
793
|
+
ret = cls(
|
794
|
+
forward_mode=forward_mode,
|
795
|
+
batch_size=batch_size,
|
796
|
+
total_num_tokens=total_num_tokens,
|
797
|
+
req_pool_indices=req_pool_indices,
|
798
|
+
seq_lens=seq_lens,
|
799
|
+
positions=positions,
|
800
|
+
req_to_token_pool=model_runner.req_to_token_pool,
|
801
|
+
token_to_kv_pool=model_runner.token_to_kv_pool,
|
802
|
+
out_cache_loc=out_cache_loc,
|
803
|
+
extend_seq_lens=extend_seq_lens,
|
804
|
+
extend_start_loc=extend_start_loc,
|
805
|
+
extend_no_prefix=extend_no_prefix,
|
806
|
+
return_logprob=return_logprob,
|
807
|
+
top_logprobs_nums=top_logprobs_nums,
|
808
|
+
flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged,
|
809
|
+
flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged,
|
810
|
+
flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
|
811
|
+
)
|
812
|
+
|
813
|
+
if model_runner.server_args.disable_flashinfer:
|
814
|
+
(
|
815
|
+
ret.triton_max_seq_len,
|
816
|
+
ret.triton_max_extend_len,
|
817
|
+
ret.triton_start_loc,
|
818
|
+
ret.triton_prefix_lens,
|
819
|
+
) = init_triton_args(forward_mode, seq_lens, prefix_lens)
|
820
|
+
|
821
|
+
return ret
|
822
|
+
|
823
|
+
|
824
|
+
def init_flashinfer_args(
|
825
|
+
forward_mode,
|
826
|
+
model_runner,
|
827
|
+
req_pool_indices,
|
828
|
+
seq_lens,
|
829
|
+
prefix_lens,
|
830
|
+
flashinfer_decode_wrapper,
|
831
|
+
):
|
832
|
+
num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
|
833
|
+
num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
|
834
|
+
head_dim = model_runner.model_config.head_dim
|
835
|
+
batch_size = len(req_pool_indices)
|
836
|
+
|
837
|
+
if forward_mode == ForwardMode.DECODE:
|
838
|
+
paged_kernel_lens = seq_lens
|
839
|
+
else:
|
840
|
+
paged_kernel_lens = prefix_lens
|
841
|
+
|
842
|
+
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
843
|
+
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
844
|
+
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
845
|
+
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
846
|
+
kv_indices = torch.cat(
|
847
|
+
[
|
848
|
+
model_runner.req_to_token_pool.req_to_token[
|
849
|
+
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
|
850
|
+
]
|
851
|
+
for i in range(batch_size)
|
852
|
+
],
|
853
|
+
dim=0,
|
854
|
+
).contiguous()
|
855
|
+
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
|
856
|
+
|
857
|
+
if forward_mode == ForwardMode.DECODE:
|
858
|
+
flashinfer_decode_wrapper.end_forward()
|
859
|
+
flashinfer_decode_wrapper.begin_forward(
|
860
|
+
kv_indptr,
|
861
|
+
kv_indices,
|
862
|
+
kv_last_page_len,
|
863
|
+
num_qo_heads,
|
864
|
+
num_kv_heads,
|
865
|
+
head_dim,
|
866
|
+
1,
|
867
|
+
)
|
868
|
+
else:
|
869
|
+
# extend part
|
870
|
+
qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
871
|
+
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
872
|
+
|
873
|
+
model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
|
874
|
+
model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
|
875
|
+
qo_indptr,
|
876
|
+
qo_indptr,
|
877
|
+
num_qo_heads,
|
878
|
+
num_kv_heads,
|
879
|
+
head_dim,
|
880
|
+
)
|
881
|
+
|
882
|
+
# cached part
|
883
|
+
model_runner.flashinfer_prefill_wrapper_paged.end_forward()
|
884
|
+
model_runner.flashinfer_prefill_wrapper_paged.begin_forward(
|
885
|
+
qo_indptr,
|
886
|
+
kv_indptr,
|
887
|
+
kv_indices,
|
888
|
+
kv_last_page_len,
|
889
|
+
num_qo_heads,
|
890
|
+
num_kv_heads,
|
891
|
+
head_dim,
|
892
|
+
1,
|
893
|
+
)
|
894
|
+
|
895
|
+
|
896
|
+
def init_triton_args(forward_mode, seq_lens, prefix_lens):
|
897
|
+
batch_size = len(seq_lens)
|
898
|
+
max_seq_len = int(torch.max(seq_lens))
|
899
|
+
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
900
|
+
start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
|
901
|
+
|
902
|
+
if forward_mode == ForwardMode.DECODE:
|
903
|
+
max_extend_len = None
|
904
|
+
else:
|
905
|
+
extend_seq_lens = seq_lens - prefix_lens
|
906
|
+
max_extend_len = int(torch.max(extend_seq_lens))
|
907
|
+
|
908
|
+
return max_seq_len, max_extend_len, start_loc, prefix_lens
|
@@ -42,6 +42,8 @@ class LoadBalanceMethod(Enum):
|
|
42
42
|
|
43
43
|
|
44
44
|
class Controller:
|
45
|
+
"""A controller that manages multiple data parallel workers."""
|
46
|
+
|
45
47
|
def __init__(
|
46
48
|
self,
|
47
49
|
load_balance_method: str,
|
@@ -183,9 +185,11 @@ def start_controller_process(
|
|
183
185
|
except Exception:
|
184
186
|
pipe_writer.send(get_exception_traceback())
|
185
187
|
raise
|
186
|
-
|
187
188
|
pipe_writer.send("init ok")
|
188
|
-
|
189
|
+
|
190
|
+
loop = asyncio.new_event_loop()
|
191
|
+
loop.set_default_executor(ThreadPoolExecutor(max_workers=256))
|
192
|
+
|
189
193
|
asyncio.set_event_loop(loop)
|
190
194
|
loop.create_task(controller.loop_for_recv_requests())
|
191
195
|
loop.run_until_complete(controller.loop_for_forward())
|
@@ -1,28 +1,104 @@
|
|
1
1
|
"""A controller that manages a group of tensor parallel workers."""
|
2
2
|
|
3
|
-
import
|
3
|
+
import multiprocessing
|
4
4
|
import logging
|
5
|
-
|
5
|
+
import os
|
6
|
+
import pickle
|
6
7
|
|
7
|
-
import
|
8
|
+
import torch
|
9
|
+
import torch.distributed as dist
|
8
10
|
import zmq
|
9
11
|
import zmq.asyncio
|
10
12
|
|
11
|
-
from sglang.
|
12
|
-
from sglang.srt.
|
13
|
-
from sglang.srt.server_args import PortArgs, ServerArgs
|
13
|
+
from sglang.srt.managers.controller.tp_worker import ModelTpServer
|
14
|
+
from sglang.srt.server_args import PortArgs, ServerArgs, ModelPortArgs
|
14
15
|
from sglang.srt.utils import kill_parent_process
|
15
16
|
from sglang.utils import get_exception_traceback
|
16
17
|
|
17
|
-
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
18
18
|
|
19
19
|
logger = logging.getLogger("srt.controller")
|
20
20
|
|
21
21
|
|
22
|
+
def run_tp_server(
|
23
|
+
gpu_id: int,
|
24
|
+
tp_rank: int,
|
25
|
+
server_args: ServerArgs,
|
26
|
+
model_port_args: ModelPortArgs,
|
27
|
+
model_overide_args: dict,
|
28
|
+
):
|
29
|
+
"""Run a tp server."""
|
30
|
+
try:
|
31
|
+
model_server = ModelTpServer(
|
32
|
+
gpu_id,
|
33
|
+
tp_rank,
|
34
|
+
server_args,
|
35
|
+
model_port_args,
|
36
|
+
model_overide_args,
|
37
|
+
)
|
38
|
+
tp_cpu_group = model_server.model_runner.tp_group.cpu_group
|
39
|
+
|
40
|
+
while True:
|
41
|
+
recv_reqs = broadcast_recv_input(None, tp_rank, tp_cpu_group)
|
42
|
+
model_server.exposed_step(recv_reqs)
|
43
|
+
except Exception:
|
44
|
+
logger.error("Exception in run_tp_server:\n" + get_exception_traceback())
|
45
|
+
raise
|
46
|
+
|
47
|
+
|
48
|
+
def launch_tp_servers(gpu_ids, tp_rank_range, server_args,
|
49
|
+
model_port_args, model_overide_args):
|
50
|
+
"""Launch multiple tp servers."""
|
51
|
+
procs = []
|
52
|
+
for i in tp_rank_range:
|
53
|
+
proc = multiprocessing.Process(target=run_tp_server, args=(
|
54
|
+
gpu_ids[i], i, server_args, model_port_args, model_overide_args
|
55
|
+
))
|
56
|
+
proc.start()
|
57
|
+
procs.append(proc)
|
58
|
+
|
59
|
+
return procs
|
60
|
+
|
61
|
+
|
62
|
+
def broadcast_recv_input(data, rank, dist_group):
|
63
|
+
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
|
64
|
+
|
65
|
+
if rank == 0:
|
66
|
+
if len(data) == 0:
|
67
|
+
tensor_size = torch.tensor([0], dtype=torch.long)
|
68
|
+
dist.broadcast(tensor_size, src=0, group=dist_group)
|
69
|
+
else:
|
70
|
+
serialized_data = pickle.dumps(data)
|
71
|
+
size = len(serialized_data)
|
72
|
+
tensor_data = torch.ByteTensor(list(serialized_data))
|
73
|
+
tensor_size = torch.tensor([size], dtype=torch.long)
|
74
|
+
|
75
|
+
dist.broadcast(tensor_size, src=0, group=dist_group)
|
76
|
+
dist.broadcast(tensor_data, src=0, group=dist_group)
|
77
|
+
else:
|
78
|
+
tensor_size = torch.tensor([0], dtype=torch.long)
|
79
|
+
dist.broadcast(tensor_size, src=0, group=dist_group)
|
80
|
+
size = tensor_size.item()
|
81
|
+
|
82
|
+
if size == 0:
|
83
|
+
return []
|
84
|
+
|
85
|
+
tensor_data = torch.empty(size, dtype=torch.uint8)
|
86
|
+
dist.broadcast(tensor_data, src=0, group=dist_group)
|
87
|
+
|
88
|
+
serialized_data = bytes(tensor_data.tolist())
|
89
|
+
data = pickle.loads(serialized_data)
|
90
|
+
return data
|
91
|
+
|
92
|
+
|
22
93
|
class ControllerSingle:
|
23
|
-
|
94
|
+
"""A controller that manages a group of tensor parallel workers."""
|
95
|
+
|
96
|
+
def __init__(self, server_args: ServerArgs, port_args: PortArgs, model_overide_args: dict):
|
97
|
+
# Parse args
|
98
|
+
self.server_args = server_args
|
99
|
+
|
24
100
|
# Init communication
|
25
|
-
context = zmq.
|
101
|
+
context = zmq.Context(2)
|
26
102
|
self.recv_from_tokenizer = context.socket(zmq.PULL)
|
27
103
|
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}")
|
28
104
|
|
@@ -31,44 +107,52 @@ class ControllerSingle:
|
|
31
107
|
f"tcp://127.0.0.1:{port_args.detokenizer_port}"
|
32
108
|
)
|
33
109
|
|
34
|
-
# Init
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
#
|
39
|
-
|
110
|
+
# Init model server
|
111
|
+
tp_size_local = server_args.tp_size // server_args.nnodes
|
112
|
+
gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
|
113
|
+
|
114
|
+
# Launch other tp ranks
|
115
|
+
if tp_size_local > 1:
|
116
|
+
tp_rank_range = range(1, tp_size_local)
|
117
|
+
self.tp_procs = launch_tp_servers(
|
118
|
+
gpu_ids, tp_rank_range, server_args,
|
119
|
+
port_args.model_port_args[0], model_overide_args)
|
120
|
+
|
121
|
+
# Launch tp rank 0
|
122
|
+
self.tp_server = ModelTpServer(
|
123
|
+
gpu_ids[0],
|
124
|
+
0,
|
125
|
+
server_args,
|
126
|
+
port_args.model_port_args[0],
|
127
|
+
model_overide_args,
|
128
|
+
)
|
129
|
+
self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group
|
40
130
|
|
41
|
-
|
131
|
+
def loop_for_forward(self):
|
42
132
|
while True:
|
43
|
-
|
44
|
-
|
45
|
-
|
133
|
+
recv_reqs = self.recv_requests()
|
134
|
+
|
135
|
+
if self.server_args.tp_size > 1:
|
136
|
+
broadcast_recv_input(recv_reqs, 0, self.tp_cpu_group)
|
137
|
+
|
138
|
+
out_pyobjs = self.tp_server.exposed_step(recv_reqs)
|
46
139
|
|
47
140
|
for obj in out_pyobjs:
|
48
141
|
self.send_to_detokenizer.send_pyobj(obj)
|
49
142
|
|
50
|
-
|
51
|
-
|
52
|
-
if len(out_pyobjs) != 0:
|
53
|
-
has_finished = any(
|
54
|
-
[obj.finished_reason is not None for obj in out_pyobjs]
|
55
|
-
)
|
56
|
-
if has_finished:
|
57
|
-
if self.request_dependency_delay > 0:
|
58
|
-
slept = True
|
59
|
-
await asyncio.sleep(self.request_dependency_delay)
|
60
|
-
|
61
|
-
if not slept:
|
62
|
-
await asyncio.sleep(global_config.wait_for_new_request_delay)
|
63
|
-
|
64
|
-
async def loop_for_recv_requests(self):
|
143
|
+
def recv_requests(self):
|
144
|
+
recv_reqs = []
|
65
145
|
while True:
|
66
|
-
|
67
|
-
|
146
|
+
try:
|
147
|
+
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
|
148
|
+
recv_reqs.append(recv_req)
|
149
|
+
except zmq.ZMQError:
|
150
|
+
break
|
151
|
+
return recv_reqs
|
68
152
|
|
69
153
|
|
70
154
|
def start_controller_process(
|
71
|
-
server_args: ServerArgs, port_args: PortArgs, pipe_writer, model_overide_args
|
155
|
+
server_args: ServerArgs, port_args: PortArgs, pipe_writer, model_overide_args: dict
|
72
156
|
):
|
73
157
|
logging.basicConfig(
|
74
158
|
level=getattr(logging, server_args.log_level.upper()),
|
@@ -76,27 +160,18 @@ def start_controller_process(
|
|
76
160
|
)
|
77
161
|
|
78
162
|
try:
|
79
|
-
|
80
|
-
model_client = ModelTpClient(
|
81
|
-
[i for _ in range(server_args.nnodes) for i in range(tp_size_local)],
|
82
|
-
server_args,
|
83
|
-
port_args.model_port_args[0],
|
84
|
-
model_overide_args,
|
85
|
-
)
|
86
|
-
controller = ControllerSingle(model_client, port_args)
|
163
|
+
controller = ControllerSingle(server_args, port_args, model_overide_args)
|
87
164
|
except Exception:
|
88
165
|
pipe_writer.send(get_exception_traceback())
|
89
166
|
raise
|
90
167
|
|
91
168
|
pipe_writer.send("init ok")
|
92
169
|
|
93
|
-
loop = asyncio.new_event_loop()
|
94
|
-
loop.set_default_executor(ThreadPoolExecutor(max_workers=256))
|
95
|
-
asyncio.set_event_loop(loop)
|
96
|
-
loop.create_task(controller.loop_for_recv_requests())
|
97
170
|
try:
|
98
|
-
|
171
|
+
controller.loop_for_forward()
|
99
172
|
except Exception:
|
100
173
|
logger.error("Exception in ControllerSingle:\n" + get_exception_traceback())
|
101
174
|
finally:
|
175
|
+
for t in controller.tp_procs:
|
176
|
+
os.kill(t.pid, 9)
|
102
177
|
kill_parent_process()
|