sglang 0.4.6__py3-none-any.whl → 0.4.6.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. sglang/bench_one_batch.py +2 -0
  2. sglang/check_env.py +3 -3
  3. sglang/srt/configs/__init__.py +4 -0
  4. sglang/srt/configs/kimi_vl.py +38 -0
  5. sglang/srt/configs/kimi_vl_moonvit.py +32 -0
  6. sglang/srt/configs/model_config.py +15 -0
  7. sglang/srt/conversation.py +122 -1
  8. sglang/srt/disaggregation/decode.py +8 -2
  9. sglang/srt/disaggregation/fake/__init__.py +1 -0
  10. sglang/srt/disaggregation/fake/conn.py +88 -0
  11. sglang/srt/disaggregation/prefill.py +12 -3
  12. sglang/srt/disaggregation/utils.py +16 -2
  13. sglang/srt/entrypoints/engine.py +52 -21
  14. sglang/srt/entrypoints/http_server.py +27 -2
  15. sglang/srt/function_call_parser.py +97 -0
  16. sglang/srt/hf_transformers_utils.py +2 -0
  17. sglang/srt/layers/attention/cutlass_mla_backend.py +278 -0
  18. sglang/srt/layers/attention/flashinfer_backend.py +107 -82
  19. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -16
  20. sglang/srt/layers/attention/flashmla_backend.py +3 -0
  21. sglang/srt/layers/attention/utils.py +1 -1
  22. sglang/srt/layers/dp_attention.py +5 -2
  23. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +1 -3
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H20.json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H200.json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  38. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200.json +146 -0
  39. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=96,device_name=NVIDIA_H20.json +146 -0
  40. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +10 -8
  41. sglang/srt/layers/moe/fused_moe_triton/layer.py +15 -17
  42. sglang/srt/layers/quantization/__init__.py +2 -2
  43. sglang/srt/layers/quantization/deep_gemm.py +1 -1
  44. sglang/srt/layers/quantization/fp8.py +20 -22
  45. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  46. sglang/srt/layers/utils.py +35 -0
  47. sglang/srt/lora/layers.py +35 -9
  48. sglang/srt/lora/lora_manager.py +84 -35
  49. sglang/srt/managers/data_parallel_controller.py +52 -34
  50. sglang/srt/managers/multimodal_processors/kimi_vl.py +73 -0
  51. sglang/srt/managers/schedule_batch.py +34 -15
  52. sglang/srt/managers/scheduler.py +273 -67
  53. sglang/srt/managers/scheduler_output_processor_mixin.py +26 -10
  54. sglang/srt/managers/tp_worker.py +52 -17
  55. sglang/srt/managers/tp_worker_overlap_thread.py +18 -7
  56. sglang/srt/mem_cache/memory_pool.py +70 -36
  57. sglang/srt/model_executor/cuda_graph_runner.py +82 -19
  58. sglang/srt/model_executor/forward_batch_info.py +31 -1
  59. sglang/srt/model_executor/model_runner.py +123 -58
  60. sglang/srt/models/deepseek_nextn.py +1 -257
  61. sglang/srt/models/deepseek_v2.py +78 -18
  62. sglang/srt/models/kimi_vl.py +308 -0
  63. sglang/srt/models/kimi_vl_moonvit.py +639 -0
  64. sglang/srt/models/llama.py +92 -30
  65. sglang/srt/models/llama4.py +2 -1
  66. sglang/srt/models/llama_eagle.py +4 -1
  67. sglang/srt/models/llama_eagle3.py +4 -1
  68. sglang/srt/models/qwen2_moe.py +8 -3
  69. sglang/srt/models/qwen2_vl.py +0 -12
  70. sglang/srt/models/qwen3_moe.py +8 -3
  71. sglang/srt/openai_api/adapter.py +49 -8
  72. sglang/srt/openai_api/protocol.py +13 -1
  73. sglang/srt/reasoning_parser.py +25 -1
  74. sglang/srt/server_args.py +83 -24
  75. sglang/srt/speculative/eagle_worker.py +3 -2
  76. sglang/srt/utils.py +91 -9
  77. sglang/test/runners.py +4 -0
  78. sglang/test/send_one.py +84 -28
  79. sglang/test/test_utils.py +67 -0
  80. sglang/version.py +1 -1
  81. {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/METADATA +5 -4
  82. {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/RECORD +85 -60
  83. {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/WHEEL +1 -1
  84. {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/licenses/LICENSE +0 -0
  85. {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/top_level.txt +0 -0
@@ -51,6 +51,7 @@ from sglang.srt.disaggregation.utils import (
51
51
  ReqToMetadataIdxAllocator,
52
52
  TransferBackend,
53
53
  )
54
+ from sglang.srt.distributed import get_pp_group, get_world_group
54
55
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
55
56
  from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
56
57
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
@@ -114,7 +115,11 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
114
115
  from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
115
116
  from sglang.srt.mem_cache.radix_cache import RadixCache
116
117
  from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
117
- from sglang.srt.model_executor.forward_batch_info import ForwardMode
118
+ from sglang.srt.model_executor.forward_batch_info import (
119
+ ForwardBatch,
120
+ ForwardMode,
121
+ PPProxyTensors,
122
+ )
118
123
  from sglang.srt.reasoning_parser import ReasoningParser
119
124
  from sglang.srt.server_args import PortArgs, ServerArgs
120
125
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
@@ -127,6 +132,7 @@ from sglang.srt.utils import (
127
132
  get_bool_env_var,
128
133
  get_zmq_socket,
129
134
  kill_itself_when_parent_died,
135
+ point_to_point_pyobj,
130
136
  pyspy_dump_schedulers,
131
137
  set_gpu_proc_affinity,
132
138
  set_random_seed,
@@ -145,8 +151,9 @@ RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
145
151
 
146
152
  @dataclass
147
153
  class GenerationBatchResult:
148
- logits_output: LogitsProcessorOutput
149
- next_token_ids: List[int]
154
+ logits_output: Optional[LogitsProcessorOutput]
155
+ pp_hidden_states_proxy_tensors: Optional[torch.Tensor]
156
+ next_token_ids: Optional[List[int]]
150
157
  extend_input_len_per_req: List[int]
151
158
  extend_logprob_start_len_per_req: List[int]
152
159
  bid: int
@@ -171,12 +178,16 @@ class Scheduler(
171
178
  port_args: PortArgs,
172
179
  gpu_id: int,
173
180
  tp_rank: int,
181
+ pp_rank: int,
174
182
  dp_rank: Optional[int],
175
183
  ):
176
184
  # Parse args
177
185
  self.server_args = server_args
178
186
  self.tp_rank = tp_rank
187
+ self.pp_rank = pp_rank
179
188
  self.tp_size = server_args.tp_size
189
+ self.pp_size = server_args.pp_size
190
+ self.dp_size = server_args.dp_size
180
191
  self.schedule_policy = server_args.schedule_policy
181
192
  self.lora_paths = server_args.lora_paths
182
193
  self.max_loras_per_batch = server_args.max_loras_per_batch
@@ -192,7 +203,6 @@ class Scheduler(
192
203
  self.page_size = server_args.page_size
193
204
 
194
205
  # Distributed rank info
195
- self.dp_size = server_args.dp_size
196
206
  self.attn_tp_rank, self.attn_tp_size, self.dp_rank = (
197
207
  compute_dp_attention_world_info(
198
208
  server_args.enable_dp_attention,
@@ -204,7 +214,7 @@ class Scheduler(
204
214
 
205
215
  # Init inter-process communication
206
216
  context = zmq.Context(2)
207
- if self.attn_tp_rank == 0:
217
+ if self.pp_rank == 0 and self.attn_tp_rank == 0:
208
218
  self.recv_from_tokenizer = get_zmq_socket(
209
219
  context, zmq.PULL, port_args.scheduler_input_ipc_name, False
210
220
  )
@@ -248,9 +258,6 @@ class Scheduler(
248
258
  if not self.is_generation:
249
259
  self.enable_overlap = False
250
260
  logger.info("Overlap scheduler is disabled for embedding models.")
251
- if self.model_config.is_multimodal:
252
- self.enable_overlap = False
253
- logger.info("Overlap scheduler is disabled for multimodal models.")
254
261
 
255
262
  # Launch a tensor parallel worker
256
263
  if self.enable_overlap:
@@ -262,6 +269,7 @@ class Scheduler(
262
269
  server_args=server_args,
263
270
  gpu_id=gpu_id,
264
271
  tp_rank=tp_rank,
272
+ pp_rank=pp_rank,
265
273
  dp_rank=dp_rank,
266
274
  nccl_port=port_args.nccl_port,
267
275
  )
@@ -295,8 +303,18 @@ class Scheduler(
295
303
  _,
296
304
  _,
297
305
  ) = self.tp_worker.get_worker_info()
298
- self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
306
+ if global_server_args_dict["max_micro_batch_size"] is None:
307
+ global_server_args_dict["max_micro_batch_size"] = max(
308
+ self.max_running_requests // server_args.pp_size, 1
309
+ )
310
+
311
+ self.tp_group = self.tp_worker.get_tp_group()
312
+ self.tp_cpu_group = self.tp_group.cpu_group
313
+ self.attn_tp_group = self.tp_worker.get_attention_tp_group()
299
314
  self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
315
+ self.pp_group = get_pp_group()
316
+ self.world_group = get_world_group()
317
+
300
318
  self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
301
319
  global_server_args_dict.update(worker_global_server_args_dict)
302
320
  set_random_seed(self.random_seed)
@@ -645,6 +663,7 @@ class Scheduler(
645
663
  self.cur_batch = batch
646
664
 
647
665
  if batch:
666
+ batch.launch_done = threading.Event()
648
667
  result = self.run_batch(batch)
649
668
  self.result_queue.append((batch.copy(), result))
650
669
 
@@ -656,7 +675,7 @@ class Scheduler(
656
675
  forward_mode=ForwardMode.DUMMY_FIRST,
657
676
  next_batch_sampling_info=self.tp_worker.cur_sampling_info,
658
677
  )
659
- self.process_batch_result(tmp_batch, None)
678
+ self.process_batch_result(tmp_batch, None, batch.launch_done)
660
679
 
661
680
  if self.last_batch:
662
681
  # Process the results of the last batch
@@ -664,7 +683,10 @@ class Scheduler(
664
683
  tmp_batch.next_batch_sampling_info = (
665
684
  self.tp_worker.cur_sampling_info if batch else None
666
685
  )
667
- self.process_batch_result(tmp_batch, tmp_result)
686
+ # NOTE: we should use current launched batch's launch_done event Instead of the last batch's
687
+ self.process_batch_result(
688
+ tmp_batch, tmp_result, batch.launch_done if batch else None
689
+ )
668
690
  elif batch is None:
669
691
  # When the server is idle, do self-check and re-init some states
670
692
  self.check_memory()
@@ -672,26 +694,141 @@ class Scheduler(
672
694
 
673
695
  self.last_batch = batch
674
696
 
697
+ @DynamicGradMode()
698
+ def event_loop_pp(self):
699
+ """A non-overlap scheduler loop for pipeline parallelism."""
700
+ mbs = [None] * self.pp_size
701
+ last_mbs = [None] * self.pp_size
702
+ self.running_mbs = [
703
+ ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
704
+ ]
705
+ bids = [None] * self.pp_size
706
+ pp_outputs: Optional[PPProxyTensors] = None
707
+ while True:
708
+ server_is_idle = True
709
+ for mb_id in range(self.pp_size):
710
+ self.running_batch = self.running_mbs[mb_id]
711
+ self.last_batch = last_mbs[mb_id]
712
+
713
+ recv_reqs = self.recv_requests()
714
+ self.process_input_requests(recv_reqs)
715
+ mbs[mb_id] = self.get_next_batch_to_run()
716
+ self.running_mbs[mb_id] = self.running_batch
717
+
718
+ self.cur_batch = mbs[mb_id]
719
+ if self.cur_batch:
720
+ server_is_idle = False
721
+ result = self.run_batch(self.cur_batch)
722
+
723
+ # send the outputs to the next step
724
+ if self.pp_group.is_last_rank:
725
+ if self.cur_batch:
726
+ next_token_ids, bids[mb_id] = (
727
+ result.next_token_ids,
728
+ result.bid,
729
+ )
730
+ pp_outputs = PPProxyTensors(
731
+ {
732
+ "next_token_ids": next_token_ids,
733
+ }
734
+ )
735
+ # send the output from the last round to let the next stage worker run post processing
736
+ self.pp_group.send_tensor_dict(
737
+ pp_outputs.tensors,
738
+ all_gather_group=self.attn_tp_group,
739
+ )
740
+
741
+ # receive outputs and post-process (filter finished reqs) the coming microbatch
742
+ next_mb_id = (mb_id + 1) % self.pp_size
743
+ next_pp_outputs = None
744
+ if mbs[next_mb_id] is not None:
745
+ next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
746
+ self.pp_group.recv_tensor_dict(
747
+ all_gather_group=self.attn_tp_group
748
+ )
749
+ )
750
+ mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
751
+ output_result = GenerationBatchResult(
752
+ logits_output=None,
753
+ pp_hidden_states_proxy_tensors=None,
754
+ next_token_ids=next_pp_outputs["next_token_ids"],
755
+ extend_input_len_per_req=None,
756
+ extend_logprob_start_len_per_req=None,
757
+ bid=bids[next_mb_id],
758
+ )
759
+ self.process_batch_result(mbs[next_mb_id], output_result)
760
+ last_mbs[next_mb_id] = mbs[next_mb_id]
761
+
762
+ # carry the outputs to the next stage
763
+ if not self.pp_group.is_last_rank:
764
+ if self.cur_batch:
765
+ bids[mb_id] = result.bid
766
+ if pp_outputs:
767
+ # send the outputs from the last round to let the next stage worker run post processing
768
+ self.pp_group.send_tensor_dict(
769
+ pp_outputs.tensors,
770
+ all_gather_group=self.attn_tp_group,
771
+ )
772
+
773
+ if not self.pp_group.is_last_rank:
774
+ # send out reqs to the next stage
775
+ dp_offset = self.dp_rank * self.attn_tp_size
776
+ if self.attn_tp_rank == 0:
777
+ point_to_point_pyobj(
778
+ recv_reqs,
779
+ self.pp_rank * self.tp_size + dp_offset,
780
+ self.world_group.cpu_group,
781
+ self.pp_rank * self.tp_size + dp_offset,
782
+ (self.pp_rank + 1) * self.tp_size + dp_offset,
783
+ )
784
+
785
+ # send out proxy tensors to the next stage
786
+ if self.cur_batch:
787
+ self.pp_group.send_tensor_dict(
788
+ result.pp_hidden_states_proxy_tensors,
789
+ all_gather_group=self.attn_tp_group,
790
+ )
791
+
792
+ pp_outputs = next_pp_outputs
793
+
794
+ # When the server is idle, self-check and re-init some states
795
+ if server_is_idle:
796
+ self.check_memory()
797
+ self.new_token_ratio = self.init_new_token_ratio
798
+
675
799
  def recv_requests(self) -> List[Req]:
676
800
  """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
677
- if self.attn_tp_rank == 0:
678
- recv_reqs = []
679
-
680
- while True:
681
- try:
682
- recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
683
- except zmq.ZMQError:
684
- break
685
- recv_reqs.append(recv_req)
686
-
687
- while True:
688
- try:
689
- recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK)
690
- except zmq.ZMQError:
691
- break
692
- recv_reqs.append(recv_rpc)
801
+ if self.pp_rank == 0:
802
+ if self.attn_tp_rank == 0:
803
+ recv_reqs = []
804
+
805
+ while True:
806
+ try:
807
+ recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
808
+ except zmq.ZMQError:
809
+ break
810
+ recv_reqs.append(recv_req)
811
+
812
+ while True:
813
+ try:
814
+ recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK)
815
+ except zmq.ZMQError:
816
+ break
817
+ recv_reqs.append(recv_rpc)
818
+ else:
819
+ recv_reqs = None
693
820
  else:
694
- recv_reqs = None
821
+ if self.attn_tp_rank == 0:
822
+ dp_offset = self.dp_rank * self.attn_tp_size
823
+ recv_reqs = point_to_point_pyobj(
824
+ [],
825
+ self.pp_rank * self.tp_size + dp_offset,
826
+ self.world_group.cpu_group,
827
+ (self.pp_rank - 1) * self.tp_size + dp_offset,
828
+ self.pp_rank * self.tp_size + dp_offset,
829
+ )
830
+ else:
831
+ recv_reqs = None
695
832
 
696
833
  if self.server_args.enable_dp_attention:
697
834
  if self.attn_tp_rank == 0:
@@ -714,20 +851,27 @@ class Scheduler(
714
851
  control_reqs = None
715
852
 
716
853
  if self.attn_tp_size != 1:
717
- attn_tp_rank_0 = self.dp_rank * self.attn_tp_size
718
854
  work_reqs = broadcast_pyobj(
719
855
  work_reqs,
720
- self.attn_tp_rank,
856
+ self.attn_tp_group.rank,
721
857
  self.attn_tp_cpu_group,
722
- src=attn_tp_rank_0,
858
+ src=self.attn_tp_group.ranks[0],
723
859
  )
724
860
  if self.tp_size != 1:
725
861
  control_reqs = broadcast_pyobj(
726
- control_reqs, self.tp_rank, self.tp_cpu_group
862
+ control_reqs,
863
+ self.tp_group.rank,
864
+ self.tp_cpu_group,
865
+ src=self.tp_group.ranks[0],
727
866
  )
728
867
  recv_reqs = work_reqs + control_reqs
729
868
  elif self.tp_size != 1:
730
- recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
869
+ recv_reqs = broadcast_pyobj(
870
+ recv_reqs,
871
+ self.tp_group.rank,
872
+ self.tp_cpu_group,
873
+ src=self.tp_group.ranks[0],
874
+ )
731
875
  return recv_reqs
732
876
 
733
877
  def process_input_requests(self, recv_reqs: List):
@@ -899,6 +1043,7 @@ class Scheduler(
899
1043
  add_to_grammar_queue = True
900
1044
 
901
1045
  if add_to_grammar_queue:
1046
+ req.queue_time_start = time.time()
902
1047
  self.grammar_queue.append(req)
903
1048
  else:
904
1049
  self._add_request_to_queue(req)
@@ -1024,12 +1169,14 @@ class Scheduler(
1024
1169
 
1025
1170
  self.metrics_collector.log_stats(self.stats)
1026
1171
 
1027
- def log_decode_stats(self):
1172
+ def log_decode_stats(self, running_batch=None):
1173
+ batch = running_batch or self.running_batch
1174
+
1028
1175
  gap_latency = time.time() - self.last_decode_stats_tic
1029
1176
  self.last_decode_stats_tic = time.time()
1030
1177
  self.last_gen_throughput = self.num_generated_tokens / gap_latency
1031
1178
  self.num_generated_tokens = 0
1032
- num_running_reqs = len(self.running_batch.reqs)
1179
+ num_running_reqs = len(batch.reqs)
1033
1180
  num_used = self.max_total_num_tokens - (
1034
1181
  self.token_to_kv_pool_allocator.available_size()
1035
1182
  + self.tree_cache.evictable_size()
@@ -1129,19 +1276,25 @@ class Scheduler(
1129
1276
 
1130
1277
  def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
1131
1278
  # Merge the prefill batch into the running batch
1279
+ chunked_req_to_exclude = set()
1280
+ if self.chunked_req:
1281
+ # Move the chunked request out of the batch so that we can merge
1282
+ # only finished requests to running_batch.
1283
+ chunked_req_to_exclude.add(self.chunked_req)
1284
+ self.tree_cache.cache_unfinished_req(self.chunked_req)
1285
+ # chunked request keeps its rid but will get a new req_pool_idx
1286
+ self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
1132
1287
  if self.last_batch and self.last_batch.forward_mode.is_extend():
1133
- if self.chunked_req:
1134
- # Move the chunked request out of the batch so that we can merge
1135
- # only finished requests to running_batch.
1136
- self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
1137
- self.tree_cache.cache_unfinished_req(self.chunked_req)
1138
- # chunked request keeps its rid but will get a new req_pool_idx
1139
- self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
1140
- self.running_batch.batch_is_full = False
1288
+ if self.last_batch.chunked_req is not None:
1289
+ # In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req.
1290
+ # We need to discard it.
1291
+ chunked_req_to_exclude.add(self.last_batch.chunked_req)
1141
1292
 
1142
1293
  # Filter batch
1143
1294
  last_bs = self.last_batch.batch_size()
1144
- self.last_batch.filter_batch()
1295
+ self.last_batch.filter_batch(
1296
+ chunked_req_to_exclude=list(chunked_req_to_exclude)
1297
+ )
1145
1298
  if self.last_batch.batch_size() < last_bs:
1146
1299
  self.running_batch.batch_is_full = False
1147
1300
 
@@ -1171,6 +1324,12 @@ class Scheduler(
1171
1324
 
1172
1325
  return ret
1173
1326
 
1327
+ def get_num_allocatable_reqs(self, running_bs):
1328
+ res = global_server_args_dict["max_micro_batch_size"] - running_bs
1329
+ if self.pp_size > 1:
1330
+ res = min(res, self.req_to_token_pool.available_size())
1331
+ return res
1332
+
1174
1333
  def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
1175
1334
  # Check if the grammar is ready in the grammar queue
1176
1335
  if self.grammar_queue:
@@ -1183,7 +1342,12 @@ class Scheduler(
1183
1342
  return None
1184
1343
 
1185
1344
  running_bs = len(self.running_batch.reqs)
1186
- if running_bs >= self.max_running_requests:
1345
+ # Igore the check if self.chunked_req is not None.
1346
+ # In the non-PP case, when self.chunked_req is not None, num_allocatable_reqs should always be greater than 0,
1347
+ # as the space for the chunked request has just been released.
1348
+ # In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
1349
+ # Instead, we should always allow chunked request to be added, otherwise, there will be a memory leak.
1350
+ if self.get_num_allocatable_reqs(running_bs) <= 0 and not self.chunked_req:
1187
1351
  self.running_batch.batch_is_full = True
1188
1352
  return None
1189
1353
 
@@ -1227,7 +1391,7 @@ class Scheduler(
1227
1391
  self.running_batch.batch_is_full = True
1228
1392
  break
1229
1393
 
1230
- if running_bs + len(adder.can_run_list) >= self.max_running_requests:
1394
+ if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs):
1231
1395
  self.running_batch.batch_is_full = True
1232
1396
  break
1233
1397
 
@@ -1239,16 +1403,14 @@ class Scheduler(
1239
1403
  res = adder.add_one_req(
1240
1404
  req, self.chunked_req, self.enable_hierarchical_cache
1241
1405
  )
1406
+
1242
1407
  if res != AddReqResult.CONTINUE:
1243
1408
  if res == AddReqResult.NO_TOKEN:
1244
1409
  if self.enable_hierarchical_cache:
1245
1410
  # Set batch_is_full after making sure there are requests that can be served
1246
1411
  self.running_batch.batch_is_full = len(
1247
1412
  adder.can_run_list
1248
- ) > 0 or (
1249
- self.running_batch is not None
1250
- and not self.running_batch.is_empty()
1251
- )
1413
+ ) > 0 or (not self.running_batch.is_empty())
1252
1414
  else:
1253
1415
  self.running_batch.batch_is_full = True
1254
1416
  break
@@ -1291,6 +1453,7 @@ class Scheduler(
1291
1453
  self.enable_overlap,
1292
1454
  self.spec_algorithm,
1293
1455
  self.server_args.enable_custom_logit_processor,
1456
+ chunked_req=self.chunked_req,
1294
1457
  )
1295
1458
  new_batch.prepare_for_extend()
1296
1459
 
@@ -1368,9 +1531,14 @@ class Scheduler(
1368
1531
  if self.is_generation:
1369
1532
  if self.spec_algorithm.is_none():
1370
1533
  model_worker_batch = batch.get_model_worker_batch()
1371
- logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
1372
- model_worker_batch
1373
- )
1534
+ if self.pp_group.is_last_rank:
1535
+ logits_output, next_token_ids = (
1536
+ self.tp_worker.forward_batch_generation(model_worker_batch)
1537
+ )
1538
+ else:
1539
+ pp_hidden_states_proxy_tensors, _ = (
1540
+ self.tp_worker.forward_batch_generation(model_worker_batch)
1541
+ )
1374
1542
  bid = model_worker_batch.bid
1375
1543
  else:
1376
1544
  (
@@ -1384,7 +1552,9 @@ class Scheduler(
1384
1552
  )
1385
1553
  self.spec_num_total_forward_ct += batch.batch_size()
1386
1554
  self.num_generated_tokens += num_accepted_tokens
1387
- batch.output_ids = next_token_ids
1555
+
1556
+ if self.pp_group.is_last_rank:
1557
+ batch.output_ids = next_token_ids
1388
1558
 
1389
1559
  # These 2 values are needed for processing the output, but the values can be
1390
1560
  # modified by overlap schedule. So we have to copy them here so that
@@ -1399,8 +1569,13 @@ class Scheduler(
1399
1569
  extend_logprob_start_len_per_req = None
1400
1570
 
1401
1571
  ret = GenerationBatchResult(
1402
- logits_output=logits_output,
1403
- next_token_ids=next_token_ids,
1572
+ logits_output=logits_output if self.pp_group.is_last_rank else None,
1573
+ pp_hidden_states_proxy_tensors=(
1574
+ pp_hidden_states_proxy_tensors
1575
+ if not self.pp_group.is_last_rank
1576
+ else None
1577
+ ),
1578
+ next_token_ids=next_token_ids if self.pp_group.is_last_rank else None,
1404
1579
  extend_input_len_per_req=extend_input_len_per_req,
1405
1580
  extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
1406
1581
  bid=bid,
@@ -1417,14 +1592,15 @@ class Scheduler(
1417
1592
  self,
1418
1593
  batch: ScheduleBatch,
1419
1594
  result: Union[GenerationBatchResult, EmbeddingBatchResult],
1595
+ launch_done: Optional[threading.Event] = None,
1420
1596
  ):
1421
1597
  if batch.forward_mode.is_decode():
1422
- self.process_batch_result_decode(batch, result)
1598
+ self.process_batch_result_decode(batch, result, launch_done)
1423
1599
  elif batch.forward_mode.is_extend():
1424
- self.process_batch_result_prefill(batch, result)
1600
+ self.process_batch_result_prefill(batch, result, launch_done)
1425
1601
  elif batch.forward_mode.is_idle():
1426
1602
  if self.enable_overlap:
1427
- self.tp_worker.resolve_batch_result(result.bid)
1603
+ self.tp_worker.resolve_last_batch_result(launch_done)
1428
1604
  if batch.next_batch_sampling_info:
1429
1605
  batch.next_batch_sampling_info.update_regex_vocab_mask()
1430
1606
  self.current_stream.synchronize()
@@ -1550,6 +1726,7 @@ class Scheduler(
1550
1726
 
1551
1727
  def move_ready_grammar_requests(self):
1552
1728
  """Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
1729
+
1553
1730
  num_ready_reqs = 0
1554
1731
  for req in self.grammar_queue:
1555
1732
  try:
@@ -1616,7 +1793,11 @@ class Scheduler(
1616
1793
 
1617
1794
  def flush_cache(self):
1618
1795
  """Flush the memory pool and cache."""
1619
- if len(self.waiting_queue) == 0 and self.running_batch.is_empty():
1796
+ if (
1797
+ len(self.waiting_queue) == 0
1798
+ and self.running_batch.is_empty()
1799
+ and (self.pp_size == 1 or all(x.is_empty() for x in self.running_mbs))
1800
+ ):
1620
1801
  self.cur_batch = None
1621
1802
  self.last_batch = None
1622
1803
  self.tree_cache.reset()
@@ -1654,7 +1835,6 @@ class Scheduler(
1654
1835
  ret["avg_spec_accept_length"] = (
1655
1836
  self.cum_spec_accept_length / self.cum_spec_accept_count
1656
1837
  )
1657
-
1658
1838
  if RECORD_STEP_TIME:
1659
1839
  ret["step_time_dict"] = self.step_time_dict
1660
1840
  return GetInternalStateReqOutput(
@@ -1665,6 +1845,7 @@ class Scheduler(
1665
1845
  server_args_dict = recv_req.server_args
1666
1846
  args_allow_update = set(
1667
1847
  [
1848
+ "max_micro_batch_size",
1668
1849
  "speculative_accept_threshold_single",
1669
1850
  "speculative_accept_threshold_acc",
1670
1851
  ]
@@ -1675,6 +1856,14 @@ class Scheduler(
1675
1856
  logging.warning(f"Updating {k} is not supported.")
1676
1857
  if_success = False
1677
1858
  break
1859
+ elif k == "max_micro_batch_size" and (
1860
+ v > self.max_running_requests // self.pp_size or v < 1
1861
+ ):
1862
+ logging.warning(
1863
+ f"Updating {k} to {v} is rejected because it is out of the valid range [1, {self.max_running_requests // self.pp_size}]."
1864
+ )
1865
+ if_success = False
1866
+ break
1678
1867
  if if_success:
1679
1868
  if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
1680
1869
  avg_spec_accept_length = (
@@ -1956,6 +2145,16 @@ class Scheduler(
1956
2145
  else:
1957
2146
  del self.sessions[session_id]
1958
2147
 
2148
+ def get_print_prefix(self):
2149
+ prefix = ""
2150
+ if self.dp_rank is not None:
2151
+ prefix += f" DP{self.dp_rank}"
2152
+ if self.server_args.tp_size > 1:
2153
+ prefix += f" TP{self.tp_rank}"
2154
+ if self.pp_size > 1:
2155
+ prefix += f" PP{self.pp_rank}"
2156
+ return prefix
2157
+
1959
2158
 
1960
2159
  def is_health_check_generate_req(recv_req):
1961
2160
  return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
@@ -1980,14 +2179,18 @@ def run_scheduler_process(
1980
2179
  port_args: PortArgs,
1981
2180
  gpu_id: int,
1982
2181
  tp_rank: int,
2182
+ pp_rank: int,
1983
2183
  dp_rank: Optional[int],
1984
2184
  pipe_writer,
1985
2185
  ):
1986
2186
  # Generate the prefix
1987
- if dp_rank is None:
1988
- prefix = f" TP{tp_rank}"
1989
- else:
1990
- prefix = f" DP{dp_rank} TP{tp_rank}"
2187
+ prefix = ""
2188
+ if dp_rank is not None:
2189
+ prefix += f" DP{dp_rank}"
2190
+ if server_args.tp_size > 1:
2191
+ prefix += f" TP{tp_rank}"
2192
+ if server_args.pp_size > 1:
2193
+ prefix += f" PP{pp_rank}"
1991
2194
 
1992
2195
  # Config the process
1993
2196
  kill_itself_when_parent_died()
@@ -2009,7 +2212,7 @@ def run_scheduler_process(
2009
2212
 
2010
2213
  # Create a scheduler and run the event loop
2011
2214
  try:
2012
- scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
2215
+ scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
2013
2216
  pipe_writer.send(
2014
2217
  {
2015
2218
  "status": "ready",
@@ -2020,7 +2223,9 @@ def run_scheduler_process(
2020
2223
  disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
2021
2224
 
2022
2225
  if disaggregation_mode == DisaggregationMode.NULL:
2023
- if scheduler.enable_overlap:
2226
+ if server_args.pp_size > 1:
2227
+ scheduler.event_loop_pp()
2228
+ elif scheduler.enable_overlap:
2024
2229
  scheduler.event_loop_overlap()
2025
2230
  else:
2026
2231
  scheduler.event_loop_normal()
@@ -2029,6 +2234,7 @@ def run_scheduler_process(
2029
2234
  scheduler.event_loop_overlap_disagg_prefill()
2030
2235
  else:
2031
2236
  scheduler.event_loop_normal_disagg_prefill()
2237
+
2032
2238
  elif disaggregation_mode == DisaggregationMode.DECODE:
2033
2239
  if scheduler.enable_overlap:
2034
2240
  scheduler.event_loop_overlap_disagg_decode()