sglang 0.1.19__py3-none-any.whl → 0.1.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -15,10 +15,16 @@ from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
15
15
 
16
16
  INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
17
17
 
18
+ # Store some global server args
19
+ global_server_args_dict = {}
20
+
18
21
 
19
22
  class ForwardMode(IntEnum):
23
+ # Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
20
24
  PREFILL = auto()
25
+ # Extend a sequence. The KV cache of the first part of the sequence is already computed (e.g., system prompt).
21
26
  EXTEND = auto()
27
+ # Decode one token.
22
28
  DECODE = auto()
23
29
 
24
30
 
@@ -66,7 +72,10 @@ class FINISH_ABORT(BaseFinishReason):
66
72
 
67
73
 
68
74
  class Req:
75
+ """Store all inforamtion of a request."""
76
+
69
77
  def __init__(self, rid, origin_input_text, origin_input_ids):
78
+ # Input and output info
70
79
  self.rid = rid
71
80
  self.origin_input_text = origin_input_text
72
81
  self.origin_input_ids_unpadded = origin_input_ids # Before image padding
@@ -74,7 +83,7 @@ class Req:
74
83
  self.output_ids = [] # Each decode stage's output ids
75
84
  self.input_ids = None # input_ids = origin_input_ids + output_ids
76
85
 
77
- # For incremental decode
86
+ # For incremental decoding
78
87
  self.decoded_text = ""
79
88
  self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
80
89
  self.read_offset = None
@@ -89,20 +98,19 @@ class Req:
89
98
  self.image_offset = 0
90
99
  self.pad_value = None
91
100
 
101
+ # Prefix info
102
+ self.extend_input_len = 0
103
+ self.prefix_indices = []
104
+ self.last_node = None
105
+
92
106
  # Sampling parameters
93
107
  self.sampling_params = None
94
108
  self.stream = False
95
109
 
96
- self.tokenizer = None
97
-
98
110
  # Check finish
111
+ self.tokenizer = None
99
112
  self.finished_reason = None
100
113
 
101
- # Prefix info
102
- self.extend_input_len = 0
103
- self.prefix_indices = []
104
- self.last_node = None
105
-
106
114
  # Logprobs
107
115
  self.return_logprob = False
108
116
  self.logprob_start_len = 0
@@ -166,9 +174,6 @@ class Req:
166
174
 
167
175
  return False, ""
168
176
 
169
- def max_new_tokens(self):
170
- return self.sampling_params.max_new_tokens
171
-
172
177
  def check_finished(self):
173
178
  if self.finished():
174
179
  return
@@ -252,35 +257,36 @@ class Req:
252
257
 
253
258
  @dataclass
254
259
  class Batch:
260
+ """Store all inforamtion of a batch."""
261
+
262
+ # Request, memory pool, and cache
255
263
  reqs: List[Req]
256
264
  req_to_token_pool: ReqToTokenPool
257
265
  token_to_kv_pool: TokenToKVPool
258
266
  tree_cache: RadixCache
259
267
 
260
- # batched arguments to model runner
268
+ # Batched arguments to model runner
261
269
  input_ids: torch.Tensor = None
262
270
  req_pool_indices: torch.Tensor = None
263
271
  seq_lens: torch.Tensor = None
264
272
  prefix_lens: torch.Tensor = None
265
273
  position_ids_offsets: torch.Tensor = None
266
274
  out_cache_loc: torch.Tensor = None
267
- out_cache_cont_start: torch.Tensor = None
268
- out_cache_cont_end: torch.Tensor = None
269
275
 
270
- # for processing logprobs
276
+ # For processing logprobs
271
277
  return_logprob: bool = False
272
278
  top_logprobs_nums: List[int] = None
273
279
 
274
- # for multimodal
280
+ # For multimodal
275
281
  pixel_values: List[torch.Tensor] = None
276
282
  image_sizes: List[List[int]] = None
277
283
  image_offsets: List[int] = None
278
284
 
279
- # other arguments for control
285
+ # Other arguments for control
280
286
  output_ids: torch.Tensor = None
281
287
  extend_num_tokens: int = None
282
288
 
283
- # batched sampling params
289
+ # Batched sampling params
284
290
  temperatures: torch.Tensor = None
285
291
  top_ps: torch.Tensor = None
286
292
  top_ks: torch.Tensor = None
@@ -303,8 +309,8 @@ class Batch:
303
309
  def is_empty(self):
304
310
  return len(self.reqs) == 0
305
311
 
306
- # whether batch has at least 1 streaming request
307
312
  def has_stream(self) -> bool:
313
+ # Return whether batch has at least 1 streaming request
308
314
  return any(r.stream for r in self.reqs)
309
315
 
310
316
  def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
@@ -338,12 +344,12 @@ class Batch:
338
344
 
339
345
  position_ids_offsets = torch.zeros((bs,), dtype=torch.int32, device=device)
340
346
 
341
- # Alloc mem
347
+ # Allocate memory
342
348
  seq_lens, prefix_lens = np.array(seq_lens), np.array(prefix_lens)
343
349
  extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
344
350
  out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
345
351
  if out_cache_loc is None:
346
- self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.dec_refs)
352
+ self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.free)
347
353
  out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
348
354
 
349
355
  if out_cache_loc is None:
@@ -413,7 +419,7 @@ class Batch:
413
419
  if self.token_to_kv_pool.available_size() >= bs:
414
420
  return True
415
421
 
416
- self.tree_cache.evict(bs, self.token_to_kv_pool.dec_refs)
422
+ self.tree_cache.evict(bs, self.token_to_kv_pool.free)
417
423
 
418
424
  if self.token_to_kv_pool.available_size() >= bs:
419
425
  return True
@@ -444,7 +450,7 @@ class Batch:
444
450
  token_indices = self.req_to_token_pool.req_to_token[
445
451
  req_pool_indices_cpu[idx]
446
452
  ][last_uncached_pos : seq_lens_cpu[idx]]
447
- self.token_to_kv_pool.dec_refs(token_indices)
453
+ self.token_to_kv_pool.free(token_indices)
448
454
 
449
455
  # release the last node
450
456
  self.tree_cache.dec_lock_ref(req.last_node)
@@ -555,21 +561,12 @@ class Batch:
555
561
 
556
562
  # Alloc mem
557
563
  bs = len(self.reqs)
558
- alloc_res = self.token_to_kv_pool.alloc_contiguous(bs)
559
- if alloc_res is None:
560
- self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
564
+ self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
561
565
 
562
- if self.out_cache_loc is None:
563
- print("Decode out of memory. This should never happen.")
564
- self.tree_cache.pretty_print()
565
- exit()
566
-
567
- self.out_cache_cont_start = None
568
- self.out_cache_cont_end = None
569
- else:
570
- self.out_cache_loc = alloc_res[0]
571
- self.out_cache_cont_start = alloc_res[1]
572
- self.out_cache_cont_end = alloc_res[2]
566
+ if self.out_cache_loc is None:
567
+ print("Decode out of memory. This should never happen.")
568
+ self.tree_cache.pretty_print()
569
+ exit()
573
570
 
574
571
  self.req_to_token_pool.req_to_token[
575
572
  self.req_pool_indices, self.seq_lens - 1
@@ -583,7 +580,7 @@ class Batch:
583
580
  self.req_pool_indices = self.req_pool_indices[new_indices]
584
581
  self.prefix_lens = None
585
582
  self.position_ids_offsets = self.position_ids_offsets[new_indices]
586
- self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
583
+ self.out_cache_loc = None
587
584
  self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
588
585
  self.return_logprob = any(req.return_logprob for req in self.reqs)
589
586
 
@@ -596,8 +593,7 @@ class Batch:
596
593
  "logit_bias",
597
594
  ]:
598
595
  self_val = getattr(self, item, None)
599
- # logit_bias can be None
600
- if self_val is not None:
596
+ if self_val is not None: # logit_bias can be None
601
597
  setattr(self, item, self_val[new_indices])
602
598
 
603
599
  def merge(self, other: "Batch"):
@@ -611,7 +607,7 @@ class Batch:
611
607
  self.position_ids_offsets = torch.concat(
612
608
  [self.position_ids_offsets, other.position_ids_offsets]
613
609
  )
614
- self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
610
+ self.out_cache_loc = None
615
611
  self.top_logprobs_nums.extend(other.top_logprobs_nums)
616
612
  self.return_logprob = any(req.return_logprob for req in self.reqs)
617
613
 
@@ -664,7 +660,13 @@ class Batch:
664
660
  # TODO(lmzheng): apply penalty
665
661
  probs = torch.softmax(logits, dim=-1)
666
662
  probs_sort, probs_idx = _top_p_top_k(probs, self.top_ps, self.top_ks)
667
- sampled_index = torch.multinomial(probs_sort, num_samples=1)
663
+ try:
664
+ sampled_index = torch.multinomial(probs_sort, num_samples=1)
665
+ except RuntimeError as e:
666
+ warnings.warn(f"Ignore errors in sampling: {e}")
667
+ sampled_index = torch.ones(
668
+ probs_sort.shape[:-1] + (1,), dtype=torch.int64, device=probs.device
669
+ )
668
670
  batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(
669
671
  -1
670
672
  )
@@ -692,3 +694,215 @@ def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor
692
694
  ] = 0.0
693
695
  probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
694
696
  return probs_sort, probs_idx
697
+
698
+
699
+ @dataclass
700
+ class InputMetadata:
701
+ """Store all inforamtion of a forward pass."""
702
+
703
+ forward_mode: ForwardMode
704
+ batch_size: int
705
+ total_num_tokens: int
706
+ req_pool_indices: torch.Tensor
707
+ seq_lens: torch.Tensor
708
+ positions: torch.Tensor
709
+ req_to_token_pool: ReqToTokenPool
710
+ token_to_kv_pool: TokenToKVPool
711
+
712
+ # For extend
713
+ extend_seq_lens: torch.Tensor
714
+ extend_start_loc: torch.Tensor
715
+ extend_no_prefix: bool
716
+
717
+ # Output location of the KV cache
718
+ out_cache_loc: torch.Tensor = None
719
+
720
+ # Output options
721
+ return_logprob: bool = False
722
+ top_logprobs_nums: List[int] = None
723
+
724
+ # Trition attention backend
725
+ triton_max_seq_len: int = 0
726
+ triton_max_extend_len: int = 0
727
+ triton_start_loc: torch.Tensor = None
728
+ triton_prefix_lens: torch.Tensor = None
729
+
730
+ # FlashInfer attention backend
731
+ flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
732
+ flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
733
+ flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
734
+
735
+ @classmethod
736
+ def create(
737
+ cls,
738
+ model_runner,
739
+ forward_mode,
740
+ req_pool_indices,
741
+ seq_lens,
742
+ prefix_lens,
743
+ position_ids_offsets,
744
+ out_cache_loc,
745
+ top_logprobs_nums=None,
746
+ return_logprob=False,
747
+ skip_flashinfer_init=False,
748
+ ):
749
+ if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
750
+ init_flashinfer_args(
751
+ forward_mode,
752
+ model_runner,
753
+ req_pool_indices,
754
+ seq_lens,
755
+ prefix_lens,
756
+ model_runner.flashinfer_decode_wrapper,
757
+ )
758
+
759
+ batch_size = len(req_pool_indices)
760
+
761
+ if forward_mode == ForwardMode.DECODE:
762
+ positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
763
+ extend_seq_lens = extend_start_loc = extend_no_prefix = None
764
+ if not model_runner.server_args.disable_flashinfer:
765
+ # This variable is not needed in this case,
766
+ # we do not compute it to make it compatbile with cuda graph.
767
+ total_num_tokens = None
768
+ else:
769
+ total_num_tokens = int(torch.sum(seq_lens))
770
+ else:
771
+ seq_lens_cpu = seq_lens.cpu().numpy()
772
+ prefix_lens_cpu = prefix_lens.cpu().numpy()
773
+ position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
774
+ positions = torch.tensor(
775
+ np.concatenate(
776
+ [
777
+ np.arange(
778
+ prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
779
+ seq_lens_cpu[i] + position_ids_offsets_cpu[i],
780
+ )
781
+ for i in range(batch_size)
782
+ ],
783
+ axis=0,
784
+ ),
785
+ device="cuda",
786
+ )
787
+ extend_seq_lens = seq_lens - prefix_lens
788
+ extend_start_loc = torch.zeros_like(seq_lens)
789
+ extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
790
+ extend_no_prefix = torch.all(prefix_lens == 0)
791
+ total_num_tokens = int(torch.sum(seq_lens))
792
+
793
+ ret = cls(
794
+ forward_mode=forward_mode,
795
+ batch_size=batch_size,
796
+ total_num_tokens=total_num_tokens,
797
+ req_pool_indices=req_pool_indices,
798
+ seq_lens=seq_lens,
799
+ positions=positions,
800
+ req_to_token_pool=model_runner.req_to_token_pool,
801
+ token_to_kv_pool=model_runner.token_to_kv_pool,
802
+ out_cache_loc=out_cache_loc,
803
+ extend_seq_lens=extend_seq_lens,
804
+ extend_start_loc=extend_start_loc,
805
+ extend_no_prefix=extend_no_prefix,
806
+ return_logprob=return_logprob,
807
+ top_logprobs_nums=top_logprobs_nums,
808
+ flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged,
809
+ flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged,
810
+ flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
811
+ )
812
+
813
+ if model_runner.server_args.disable_flashinfer:
814
+ (
815
+ ret.triton_max_seq_len,
816
+ ret.triton_max_extend_len,
817
+ ret.triton_start_loc,
818
+ ret.triton_prefix_lens,
819
+ ) = init_triton_args(forward_mode, seq_lens, prefix_lens)
820
+
821
+ return ret
822
+
823
+
824
+ def init_flashinfer_args(
825
+ forward_mode,
826
+ model_runner,
827
+ req_pool_indices,
828
+ seq_lens,
829
+ prefix_lens,
830
+ flashinfer_decode_wrapper,
831
+ ):
832
+ num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
833
+ num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
834
+ head_dim = model_runner.model_config.head_dim
835
+ batch_size = len(req_pool_indices)
836
+
837
+ if forward_mode == ForwardMode.DECODE:
838
+ paged_kernel_lens = seq_lens
839
+ else:
840
+ paged_kernel_lens = prefix_lens
841
+
842
+ kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
843
+ kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
844
+ req_pool_indices_cpu = req_pool_indices.cpu().numpy()
845
+ paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
846
+ kv_indices = torch.cat(
847
+ [
848
+ model_runner.req_to_token_pool.req_to_token[
849
+ req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
850
+ ]
851
+ for i in range(batch_size)
852
+ ],
853
+ dim=0,
854
+ ).contiguous()
855
+ kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
856
+
857
+ if forward_mode == ForwardMode.DECODE:
858
+ flashinfer_decode_wrapper.end_forward()
859
+ flashinfer_decode_wrapper.begin_forward(
860
+ kv_indptr,
861
+ kv_indices,
862
+ kv_last_page_len,
863
+ num_qo_heads,
864
+ num_kv_heads,
865
+ head_dim,
866
+ 1,
867
+ )
868
+ else:
869
+ # extend part
870
+ qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
871
+ qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
872
+
873
+ model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
874
+ model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
875
+ qo_indptr,
876
+ qo_indptr,
877
+ num_qo_heads,
878
+ num_kv_heads,
879
+ head_dim,
880
+ )
881
+
882
+ # cached part
883
+ model_runner.flashinfer_prefill_wrapper_paged.end_forward()
884
+ model_runner.flashinfer_prefill_wrapper_paged.begin_forward(
885
+ qo_indptr,
886
+ kv_indptr,
887
+ kv_indices,
888
+ kv_last_page_len,
889
+ num_qo_heads,
890
+ num_kv_heads,
891
+ head_dim,
892
+ 1,
893
+ )
894
+
895
+
896
+ def init_triton_args(forward_mode, seq_lens, prefix_lens):
897
+ batch_size = len(seq_lens)
898
+ max_seq_len = int(torch.max(seq_lens))
899
+ start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
900
+ start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
901
+
902
+ if forward_mode == ForwardMode.DECODE:
903
+ max_extend_len = None
904
+ else:
905
+ extend_seq_lens = seq_lens - prefix_lens
906
+ max_extend_len = int(torch.max(extend_seq_lens))
907
+
908
+ return max_seq_len, max_extend_len, start_loc, prefix_lens
@@ -42,6 +42,8 @@ class LoadBalanceMethod(Enum):
42
42
 
43
43
 
44
44
  class Controller:
45
+ """A controller that manages multiple data parallel workers."""
46
+
45
47
  def __init__(
46
48
  self,
47
49
  load_balance_method: str,
@@ -183,9 +185,11 @@ def start_controller_process(
183
185
  except Exception:
184
186
  pipe_writer.send(get_exception_traceback())
185
187
  raise
186
-
187
188
  pipe_writer.send("init ok")
188
- loop = asyncio.get_event_loop()
189
+
190
+ loop = asyncio.new_event_loop()
191
+ loop.set_default_executor(ThreadPoolExecutor(max_workers=256))
192
+
189
193
  asyncio.set_event_loop(loop)
190
194
  loop.create_task(controller.loop_for_recv_requests())
191
195
  loop.run_until_complete(controller.loop_for_forward())
@@ -1,28 +1,104 @@
1
1
  """A controller that manages a group of tensor parallel workers."""
2
2
 
3
- import asyncio
3
+ import multiprocessing
4
4
  import logging
5
- from concurrent.futures import ThreadPoolExecutor
5
+ import os
6
+ import pickle
6
7
 
7
- import uvloop
8
+ import torch
9
+ import torch.distributed as dist
8
10
  import zmq
9
11
  import zmq.asyncio
10
12
 
11
- from sglang.global_config import global_config
12
- from sglang.srt.managers.controller.tp_worker import ModelTpClient
13
- from sglang.srt.server_args import PortArgs, ServerArgs
13
+ from sglang.srt.managers.controller.tp_worker import ModelTpServer
14
+ from sglang.srt.server_args import PortArgs, ServerArgs, ModelPortArgs
14
15
  from sglang.srt.utils import kill_parent_process
15
16
  from sglang.utils import get_exception_traceback
16
17
 
17
- asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
18
18
 
19
19
  logger = logging.getLogger("srt.controller")
20
20
 
21
21
 
22
+ def run_tp_server(
23
+ gpu_id: int,
24
+ tp_rank: int,
25
+ server_args: ServerArgs,
26
+ model_port_args: ModelPortArgs,
27
+ model_overide_args: dict,
28
+ ):
29
+ """Run a tp server."""
30
+ try:
31
+ model_server = ModelTpServer(
32
+ gpu_id,
33
+ tp_rank,
34
+ server_args,
35
+ model_port_args,
36
+ model_overide_args,
37
+ )
38
+ tp_cpu_group = model_server.model_runner.tp_group.cpu_group
39
+
40
+ while True:
41
+ recv_reqs = broadcast_recv_input(None, tp_rank, tp_cpu_group)
42
+ model_server.exposed_step(recv_reqs)
43
+ except Exception:
44
+ logger.error("Exception in run_tp_server:\n" + get_exception_traceback())
45
+ raise
46
+
47
+
48
+ def launch_tp_servers(gpu_ids, tp_rank_range, server_args,
49
+ model_port_args, model_overide_args):
50
+ """Launch multiple tp servers."""
51
+ procs = []
52
+ for i in tp_rank_range:
53
+ proc = multiprocessing.Process(target=run_tp_server, args=(
54
+ gpu_ids[i], i, server_args, model_port_args, model_overide_args
55
+ ))
56
+ proc.start()
57
+ procs.append(proc)
58
+
59
+ return procs
60
+
61
+
62
+ def broadcast_recv_input(data, rank, dist_group):
63
+ """Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
64
+
65
+ if rank == 0:
66
+ if len(data) == 0:
67
+ tensor_size = torch.tensor([0], dtype=torch.long)
68
+ dist.broadcast(tensor_size, src=0, group=dist_group)
69
+ else:
70
+ serialized_data = pickle.dumps(data)
71
+ size = len(serialized_data)
72
+ tensor_data = torch.ByteTensor(list(serialized_data))
73
+ tensor_size = torch.tensor([size], dtype=torch.long)
74
+
75
+ dist.broadcast(tensor_size, src=0, group=dist_group)
76
+ dist.broadcast(tensor_data, src=0, group=dist_group)
77
+ else:
78
+ tensor_size = torch.tensor([0], dtype=torch.long)
79
+ dist.broadcast(tensor_size, src=0, group=dist_group)
80
+ size = tensor_size.item()
81
+
82
+ if size == 0:
83
+ return []
84
+
85
+ tensor_data = torch.empty(size, dtype=torch.uint8)
86
+ dist.broadcast(tensor_data, src=0, group=dist_group)
87
+
88
+ serialized_data = bytes(tensor_data.tolist())
89
+ data = pickle.loads(serialized_data)
90
+ return data
91
+
92
+
22
93
  class ControllerSingle:
23
- def __init__(self, model_client: ModelTpClient, port_args: PortArgs):
94
+ """A controller that manages a group of tensor parallel workers."""
95
+
96
+ def __init__(self, server_args: ServerArgs, port_args: PortArgs, model_overide_args: dict):
97
+ # Parse args
98
+ self.server_args = server_args
99
+
24
100
  # Init communication
25
- context = zmq.asyncio.Context(2)
101
+ context = zmq.Context(2)
26
102
  self.recv_from_tokenizer = context.socket(zmq.PULL)
27
103
  self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}")
28
104
 
@@ -31,44 +107,52 @@ class ControllerSingle:
31
107
  f"tcp://127.0.0.1:{port_args.detokenizer_port}"
32
108
  )
33
109
 
34
- # Init status
35
- self.model_client = model_client
36
- self.recv_reqs = []
37
-
38
- # Init some configs
39
- self.request_dependency_delay = global_config.request_dependency_delay
110
+ # Init model server
111
+ tp_size_local = server_args.tp_size // server_args.nnodes
112
+ gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
113
+
114
+ # Launch other tp ranks
115
+ if tp_size_local > 1:
116
+ tp_rank_range = range(1, tp_size_local)
117
+ self.tp_procs = launch_tp_servers(
118
+ gpu_ids, tp_rank_range, server_args,
119
+ port_args.model_port_args[0], model_overide_args)
120
+
121
+ # Launch tp rank 0
122
+ self.tp_server = ModelTpServer(
123
+ gpu_ids[0],
124
+ 0,
125
+ server_args,
126
+ port_args.model_port_args[0],
127
+ model_overide_args,
128
+ )
129
+ self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group
40
130
 
41
- async def loop_for_forward(self):
131
+ def loop_for_forward(self):
42
132
  while True:
43
- next_step_input = list(self.recv_reqs)
44
- self.recv_reqs = []
45
- out_pyobjs = await self.model_client.step(next_step_input)
133
+ recv_reqs = self.recv_requests()
134
+
135
+ if self.server_args.tp_size > 1:
136
+ broadcast_recv_input(recv_reqs, 0, self.tp_cpu_group)
137
+
138
+ out_pyobjs = self.tp_server.exposed_step(recv_reqs)
46
139
 
47
140
  for obj in out_pyobjs:
48
141
  self.send_to_detokenizer.send_pyobj(obj)
49
142
 
50
- # async sleep for receiving the subsequent request and avoiding cache miss
51
- slept = False
52
- if len(out_pyobjs) != 0:
53
- has_finished = any(
54
- [obj.finished_reason is not None for obj in out_pyobjs]
55
- )
56
- if has_finished:
57
- if self.request_dependency_delay > 0:
58
- slept = True
59
- await asyncio.sleep(self.request_dependency_delay)
60
-
61
- if not slept:
62
- await asyncio.sleep(global_config.wait_for_new_request_delay)
63
-
64
- async def loop_for_recv_requests(self):
143
+ def recv_requests(self):
144
+ recv_reqs = []
65
145
  while True:
66
- recv_req = await self.recv_from_tokenizer.recv_pyobj()
67
- self.recv_reqs.append(recv_req)
146
+ try:
147
+ recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
148
+ recv_reqs.append(recv_req)
149
+ except zmq.ZMQError:
150
+ break
151
+ return recv_reqs
68
152
 
69
153
 
70
154
  def start_controller_process(
71
- server_args: ServerArgs, port_args: PortArgs, pipe_writer, model_overide_args
155
+ server_args: ServerArgs, port_args: PortArgs, pipe_writer, model_overide_args: dict
72
156
  ):
73
157
  logging.basicConfig(
74
158
  level=getattr(logging, server_args.log_level.upper()),
@@ -76,27 +160,18 @@ def start_controller_process(
76
160
  )
77
161
 
78
162
  try:
79
- tp_size_local = server_args.tp_size // server_args.nnodes
80
- model_client = ModelTpClient(
81
- [i for _ in range(server_args.nnodes) for i in range(tp_size_local)],
82
- server_args,
83
- port_args.model_port_args[0],
84
- model_overide_args,
85
- )
86
- controller = ControllerSingle(model_client, port_args)
163
+ controller = ControllerSingle(server_args, port_args, model_overide_args)
87
164
  except Exception:
88
165
  pipe_writer.send(get_exception_traceback())
89
166
  raise
90
167
 
91
168
  pipe_writer.send("init ok")
92
169
 
93
- loop = asyncio.new_event_loop()
94
- loop.set_default_executor(ThreadPoolExecutor(max_workers=256))
95
- asyncio.set_event_loop(loop)
96
- loop.create_task(controller.loop_for_recv_requests())
97
170
  try:
98
- loop.run_until_complete(controller.loop_for_forward())
171
+ controller.loop_for_forward()
99
172
  except Exception:
100
173
  logger.error("Exception in ControllerSingle:\n" + get_exception_traceback())
101
174
  finally:
175
+ for t in controller.tp_procs:
176
+ os.kill(t.pid, 9)
102
177
  kill_parent_process()