sglang 0.4.6.post1__py3-none-any.whl → 0.4.6.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sglang/bench_one_batch.py +2 -0
  2. sglang/check_env.py +3 -3
  3. sglang/srt/configs/__init__.py +4 -0
  4. sglang/srt/configs/kimi_vl.py +38 -0
  5. sglang/srt/configs/kimi_vl_moonvit.py +32 -0
  6. sglang/srt/configs/model_config.py +15 -0
  7. sglang/srt/conversation.py +122 -1
  8. sglang/srt/entrypoints/engine.py +44 -22
  9. sglang/srt/function_call_parser.py +97 -0
  10. sglang/srt/hf_transformers_utils.py +2 -0
  11. sglang/srt/layers/attention/cutlass_mla_backend.py +1 -1
  12. sglang/srt/layers/attention/flashinfer_backend.py +107 -82
  13. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -16
  14. sglang/srt/layers/attention/flashmla_backend.py +3 -0
  15. sglang/srt/layers/dp_attention.py +5 -2
  16. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +1 -3
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  21. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +8 -6
  22. sglang/srt/layers/quantization/__init__.py +2 -2
  23. sglang/srt/layers/quantization/deep_gemm.py +1 -1
  24. sglang/srt/layers/utils.py +35 -0
  25. sglang/srt/lora/layers.py +35 -9
  26. sglang/srt/lora/lora_manager.py +84 -35
  27. sglang/srt/managers/data_parallel_controller.py +52 -34
  28. sglang/srt/managers/multimodal_processors/kimi_vl.py +73 -0
  29. sglang/srt/managers/schedule_batch.py +25 -15
  30. sglang/srt/managers/scheduler.py +263 -59
  31. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -1
  32. sglang/srt/managers/tp_worker.py +51 -16
  33. sglang/srt/managers/tp_worker_overlap_thread.py +9 -3
  34. sglang/srt/mem_cache/memory_pool.py +70 -36
  35. sglang/srt/model_executor/cuda_graph_runner.py +82 -19
  36. sglang/srt/model_executor/forward_batch_info.py +31 -1
  37. sglang/srt/model_executor/model_runner.py +115 -57
  38. sglang/srt/models/deepseek_nextn.py +1 -257
  39. sglang/srt/models/deepseek_v2.py +78 -18
  40. sglang/srt/models/kimi_vl.py +308 -0
  41. sglang/srt/models/kimi_vl_moonvit.py +639 -0
  42. sglang/srt/models/llama.py +92 -30
  43. sglang/srt/models/llama4.py +2 -1
  44. sglang/srt/models/llama_eagle.py +4 -1
  45. sglang/srt/models/llama_eagle3.py +4 -1
  46. sglang/srt/models/qwen2_moe.py +8 -3
  47. sglang/srt/models/qwen2_vl.py +0 -12
  48. sglang/srt/models/qwen3_moe.py +8 -3
  49. sglang/srt/openai_api/adapter.py +34 -22
  50. sglang/srt/openai_api/protocol.py +11 -1
  51. sglang/srt/server_args.py +67 -22
  52. sglang/srt/speculative/eagle_worker.py +3 -2
  53. sglang/srt/utils.py +88 -9
  54. sglang/test/runners.py +4 -0
  55. sglang/test/test_utils.py +29 -0
  56. sglang/version.py +1 -1
  57. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post2.dist-info}/METADATA +5 -4
  58. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post2.dist-info}/RECORD +61 -51
  59. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post2.dist-info}/WHEEL +1 -1
  60. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post2.dist-info}/licenses/LICENSE +0 -0
  61. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post2.dist-info}/top_level.txt +0 -0
@@ -51,6 +51,7 @@ from sglang.srt.disaggregation.utils import (
51
51
  ReqToMetadataIdxAllocator,
52
52
  TransferBackend,
53
53
  )
54
+ from sglang.srt.distributed import get_pp_group, get_world_group
54
55
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
55
56
  from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
56
57
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
@@ -114,7 +115,11 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
114
115
  from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
115
116
  from sglang.srt.mem_cache.radix_cache import RadixCache
116
117
  from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
117
- from sglang.srt.model_executor.forward_batch_info import ForwardMode
118
+ from sglang.srt.model_executor.forward_batch_info import (
119
+ ForwardBatch,
120
+ ForwardMode,
121
+ PPProxyTensors,
122
+ )
118
123
  from sglang.srt.reasoning_parser import ReasoningParser
119
124
  from sglang.srt.server_args import PortArgs, ServerArgs
120
125
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
@@ -127,6 +132,7 @@ from sglang.srt.utils import (
127
132
  get_bool_env_var,
128
133
  get_zmq_socket,
129
134
  kill_itself_when_parent_died,
135
+ point_to_point_pyobj,
130
136
  pyspy_dump_schedulers,
131
137
  set_gpu_proc_affinity,
132
138
  set_random_seed,
@@ -145,8 +151,9 @@ RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
145
151
 
146
152
  @dataclass
147
153
  class GenerationBatchResult:
148
- logits_output: LogitsProcessorOutput
149
- next_token_ids: List[int]
154
+ logits_output: Optional[LogitsProcessorOutput]
155
+ pp_hidden_states_proxy_tensors: Optional[torch.Tensor]
156
+ next_token_ids: Optional[List[int]]
150
157
  extend_input_len_per_req: List[int]
151
158
  extend_logprob_start_len_per_req: List[int]
152
159
  bid: int
@@ -171,12 +178,16 @@ class Scheduler(
171
178
  port_args: PortArgs,
172
179
  gpu_id: int,
173
180
  tp_rank: int,
181
+ pp_rank: int,
174
182
  dp_rank: Optional[int],
175
183
  ):
176
184
  # Parse args
177
185
  self.server_args = server_args
178
186
  self.tp_rank = tp_rank
187
+ self.pp_rank = pp_rank
179
188
  self.tp_size = server_args.tp_size
189
+ self.pp_size = server_args.pp_size
190
+ self.dp_size = server_args.dp_size
180
191
  self.schedule_policy = server_args.schedule_policy
181
192
  self.lora_paths = server_args.lora_paths
182
193
  self.max_loras_per_batch = server_args.max_loras_per_batch
@@ -192,7 +203,6 @@ class Scheduler(
192
203
  self.page_size = server_args.page_size
193
204
 
194
205
  # Distributed rank info
195
- self.dp_size = server_args.dp_size
196
206
  self.attn_tp_rank, self.attn_tp_size, self.dp_rank = (
197
207
  compute_dp_attention_world_info(
198
208
  server_args.enable_dp_attention,
@@ -204,7 +214,7 @@ class Scheduler(
204
214
 
205
215
  # Init inter-process communication
206
216
  context = zmq.Context(2)
207
- if self.attn_tp_rank == 0:
217
+ if self.pp_rank == 0 and self.attn_tp_rank == 0:
208
218
  self.recv_from_tokenizer = get_zmq_socket(
209
219
  context, zmq.PULL, port_args.scheduler_input_ipc_name, False
210
220
  )
@@ -259,6 +269,7 @@ class Scheduler(
259
269
  server_args=server_args,
260
270
  gpu_id=gpu_id,
261
271
  tp_rank=tp_rank,
272
+ pp_rank=pp_rank,
262
273
  dp_rank=dp_rank,
263
274
  nccl_port=port_args.nccl_port,
264
275
  )
@@ -292,8 +303,18 @@ class Scheduler(
292
303
  _,
293
304
  _,
294
305
  ) = self.tp_worker.get_worker_info()
295
- self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
306
+ if global_server_args_dict["max_micro_batch_size"] is None:
307
+ global_server_args_dict["max_micro_batch_size"] = max(
308
+ self.max_running_requests // server_args.pp_size, 1
309
+ )
310
+
311
+ self.tp_group = self.tp_worker.get_tp_group()
312
+ self.tp_cpu_group = self.tp_group.cpu_group
313
+ self.attn_tp_group = self.tp_worker.get_attention_tp_group()
296
314
  self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
315
+ self.pp_group = get_pp_group()
316
+ self.world_group = get_world_group()
317
+
297
318
  self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
298
319
  global_server_args_dict.update(worker_global_server_args_dict)
299
320
  set_random_seed(self.random_seed)
@@ -673,26 +694,141 @@ class Scheduler(
673
694
 
674
695
  self.last_batch = batch
675
696
 
697
+ @DynamicGradMode()
698
+ def event_loop_pp(self):
699
+ """A non-overlap scheduler loop for pipeline parallelism."""
700
+ mbs = [None] * self.pp_size
701
+ last_mbs = [None] * self.pp_size
702
+ self.running_mbs = [
703
+ ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
704
+ ]
705
+ bids = [None] * self.pp_size
706
+ pp_outputs: Optional[PPProxyTensors] = None
707
+ while True:
708
+ server_is_idle = True
709
+ for mb_id in range(self.pp_size):
710
+ self.running_batch = self.running_mbs[mb_id]
711
+ self.last_batch = last_mbs[mb_id]
712
+
713
+ recv_reqs = self.recv_requests()
714
+ self.process_input_requests(recv_reqs)
715
+ mbs[mb_id] = self.get_next_batch_to_run()
716
+ self.running_mbs[mb_id] = self.running_batch
717
+
718
+ self.cur_batch = mbs[mb_id]
719
+ if self.cur_batch:
720
+ server_is_idle = False
721
+ result = self.run_batch(self.cur_batch)
722
+
723
+ # send the outputs to the next step
724
+ if self.pp_group.is_last_rank:
725
+ if self.cur_batch:
726
+ next_token_ids, bids[mb_id] = (
727
+ result.next_token_ids,
728
+ result.bid,
729
+ )
730
+ pp_outputs = PPProxyTensors(
731
+ {
732
+ "next_token_ids": next_token_ids,
733
+ }
734
+ )
735
+ # send the output from the last round to let the next stage worker run post processing
736
+ self.pp_group.send_tensor_dict(
737
+ pp_outputs.tensors,
738
+ all_gather_group=self.attn_tp_group,
739
+ )
740
+
741
+ # receive outputs and post-process (filter finished reqs) the coming microbatch
742
+ next_mb_id = (mb_id + 1) % self.pp_size
743
+ next_pp_outputs = None
744
+ if mbs[next_mb_id] is not None:
745
+ next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
746
+ self.pp_group.recv_tensor_dict(
747
+ all_gather_group=self.attn_tp_group
748
+ )
749
+ )
750
+ mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
751
+ output_result = GenerationBatchResult(
752
+ logits_output=None,
753
+ pp_hidden_states_proxy_tensors=None,
754
+ next_token_ids=next_pp_outputs["next_token_ids"],
755
+ extend_input_len_per_req=None,
756
+ extend_logprob_start_len_per_req=None,
757
+ bid=bids[next_mb_id],
758
+ )
759
+ self.process_batch_result(mbs[next_mb_id], output_result)
760
+ last_mbs[next_mb_id] = mbs[next_mb_id]
761
+
762
+ # carry the outputs to the next stage
763
+ if not self.pp_group.is_last_rank:
764
+ if self.cur_batch:
765
+ bids[mb_id] = result.bid
766
+ if pp_outputs:
767
+ # send the outputs from the last round to let the next stage worker run post processing
768
+ self.pp_group.send_tensor_dict(
769
+ pp_outputs.tensors,
770
+ all_gather_group=self.attn_tp_group,
771
+ )
772
+
773
+ if not self.pp_group.is_last_rank:
774
+ # send out reqs to the next stage
775
+ dp_offset = self.dp_rank * self.attn_tp_size
776
+ if self.attn_tp_rank == 0:
777
+ point_to_point_pyobj(
778
+ recv_reqs,
779
+ self.pp_rank * self.tp_size + dp_offset,
780
+ self.world_group.cpu_group,
781
+ self.pp_rank * self.tp_size + dp_offset,
782
+ (self.pp_rank + 1) * self.tp_size + dp_offset,
783
+ )
784
+
785
+ # send out proxy tensors to the next stage
786
+ if self.cur_batch:
787
+ self.pp_group.send_tensor_dict(
788
+ result.pp_hidden_states_proxy_tensors,
789
+ all_gather_group=self.attn_tp_group,
790
+ )
791
+
792
+ pp_outputs = next_pp_outputs
793
+
794
+ # When the server is idle, self-check and re-init some states
795
+ if server_is_idle:
796
+ self.check_memory()
797
+ self.new_token_ratio = self.init_new_token_ratio
798
+
676
799
  def recv_requests(self) -> List[Req]:
677
800
  """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
678
- if self.attn_tp_rank == 0:
679
- recv_reqs = []
680
-
681
- while True:
682
- try:
683
- recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
684
- except zmq.ZMQError:
685
- break
686
- recv_reqs.append(recv_req)
687
-
688
- while True:
689
- try:
690
- recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK)
691
- except zmq.ZMQError:
692
- break
693
- recv_reqs.append(recv_rpc)
801
+ if self.pp_rank == 0:
802
+ if self.attn_tp_rank == 0:
803
+ recv_reqs = []
804
+
805
+ while True:
806
+ try:
807
+ recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
808
+ except zmq.ZMQError:
809
+ break
810
+ recv_reqs.append(recv_req)
811
+
812
+ while True:
813
+ try:
814
+ recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK)
815
+ except zmq.ZMQError:
816
+ break
817
+ recv_reqs.append(recv_rpc)
818
+ else:
819
+ recv_reqs = None
694
820
  else:
695
- recv_reqs = None
821
+ if self.attn_tp_rank == 0:
822
+ dp_offset = self.dp_rank * self.attn_tp_size
823
+ recv_reqs = point_to_point_pyobj(
824
+ [],
825
+ self.pp_rank * self.tp_size + dp_offset,
826
+ self.world_group.cpu_group,
827
+ (self.pp_rank - 1) * self.tp_size + dp_offset,
828
+ self.pp_rank * self.tp_size + dp_offset,
829
+ )
830
+ else:
831
+ recv_reqs = None
696
832
 
697
833
  if self.server_args.enable_dp_attention:
698
834
  if self.attn_tp_rank == 0:
@@ -715,20 +851,27 @@ class Scheduler(
715
851
  control_reqs = None
716
852
 
717
853
  if self.attn_tp_size != 1:
718
- attn_tp_rank_0 = self.dp_rank * self.attn_tp_size
719
854
  work_reqs = broadcast_pyobj(
720
855
  work_reqs,
721
- self.attn_tp_rank,
856
+ self.attn_tp_group.rank,
722
857
  self.attn_tp_cpu_group,
723
- src=attn_tp_rank_0,
858
+ src=self.attn_tp_group.ranks[0],
724
859
  )
725
860
  if self.tp_size != 1:
726
861
  control_reqs = broadcast_pyobj(
727
- control_reqs, self.tp_rank, self.tp_cpu_group
862
+ control_reqs,
863
+ self.tp_group.rank,
864
+ self.tp_cpu_group,
865
+ src=self.tp_group.ranks[0],
728
866
  )
729
867
  recv_reqs = work_reqs + control_reqs
730
868
  elif self.tp_size != 1:
731
- recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
869
+ recv_reqs = broadcast_pyobj(
870
+ recv_reqs,
871
+ self.tp_group.rank,
872
+ self.tp_cpu_group,
873
+ src=self.tp_group.ranks[0],
874
+ )
732
875
  return recv_reqs
733
876
 
734
877
  def process_input_requests(self, recv_reqs: List):
@@ -900,6 +1043,7 @@ class Scheduler(
900
1043
  add_to_grammar_queue = True
901
1044
 
902
1045
  if add_to_grammar_queue:
1046
+ req.queue_time_start = time.time()
903
1047
  self.grammar_queue.append(req)
904
1048
  else:
905
1049
  self._add_request_to_queue(req)
@@ -1025,12 +1169,14 @@ class Scheduler(
1025
1169
 
1026
1170
  self.metrics_collector.log_stats(self.stats)
1027
1171
 
1028
- def log_decode_stats(self):
1172
+ def log_decode_stats(self, running_batch=None):
1173
+ batch = running_batch or self.running_batch
1174
+
1029
1175
  gap_latency = time.time() - self.last_decode_stats_tic
1030
1176
  self.last_decode_stats_tic = time.time()
1031
1177
  self.last_gen_throughput = self.num_generated_tokens / gap_latency
1032
1178
  self.num_generated_tokens = 0
1033
- num_running_reqs = len(self.running_batch.reqs)
1179
+ num_running_reqs = len(batch.reqs)
1034
1180
  num_used = self.max_total_num_tokens - (
1035
1181
  self.token_to_kv_pool_allocator.available_size()
1036
1182
  + self.tree_cache.evictable_size()
@@ -1130,19 +1276,25 @@ class Scheduler(
1130
1276
 
1131
1277
  def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
1132
1278
  # Merge the prefill batch into the running batch
1279
+ chunked_req_to_exclude = set()
1280
+ if self.chunked_req:
1281
+ # Move the chunked request out of the batch so that we can merge
1282
+ # only finished requests to running_batch.
1283
+ chunked_req_to_exclude.add(self.chunked_req)
1284
+ self.tree_cache.cache_unfinished_req(self.chunked_req)
1285
+ # chunked request keeps its rid but will get a new req_pool_idx
1286
+ self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
1133
1287
  if self.last_batch and self.last_batch.forward_mode.is_extend():
1134
- if self.chunked_req:
1135
- # Move the chunked request out of the batch so that we can merge
1136
- # only finished requests to running_batch.
1137
- self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
1138
- self.tree_cache.cache_unfinished_req(self.chunked_req)
1139
- # chunked request keeps its rid but will get a new req_pool_idx
1140
- self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
1141
- self.running_batch.batch_is_full = False
1288
+ if self.last_batch.chunked_req is not None:
1289
+ # In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req.
1290
+ # We need to discard it.
1291
+ chunked_req_to_exclude.add(self.last_batch.chunked_req)
1142
1292
 
1143
1293
  # Filter batch
1144
1294
  last_bs = self.last_batch.batch_size()
1145
- self.last_batch.filter_batch()
1295
+ self.last_batch.filter_batch(
1296
+ chunked_req_to_exclude=list(chunked_req_to_exclude)
1297
+ )
1146
1298
  if self.last_batch.batch_size() < last_bs:
1147
1299
  self.running_batch.batch_is_full = False
1148
1300
 
@@ -1172,6 +1324,12 @@ class Scheduler(
1172
1324
 
1173
1325
  return ret
1174
1326
 
1327
+ def get_num_allocatable_reqs(self, running_bs):
1328
+ res = global_server_args_dict["max_micro_batch_size"] - running_bs
1329
+ if self.pp_size > 1:
1330
+ res = min(res, self.req_to_token_pool.available_size())
1331
+ return res
1332
+
1175
1333
  def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
1176
1334
  # Check if the grammar is ready in the grammar queue
1177
1335
  if self.grammar_queue:
@@ -1184,7 +1342,12 @@ class Scheduler(
1184
1342
  return None
1185
1343
 
1186
1344
  running_bs = len(self.running_batch.reqs)
1187
- if running_bs >= self.max_running_requests:
1345
+ # Igore the check if self.chunked_req is not None.
1346
+ # In the non-PP case, when self.chunked_req is not None, num_allocatable_reqs should always be greater than 0,
1347
+ # as the space for the chunked request has just been released.
1348
+ # In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
1349
+ # Instead, we should always allow chunked request to be added, otherwise, there will be a memory leak.
1350
+ if self.get_num_allocatable_reqs(running_bs) <= 0 and not self.chunked_req:
1188
1351
  self.running_batch.batch_is_full = True
1189
1352
  return None
1190
1353
 
@@ -1228,7 +1391,7 @@ class Scheduler(
1228
1391
  self.running_batch.batch_is_full = True
1229
1392
  break
1230
1393
 
1231
- if running_bs + len(adder.can_run_list) >= self.max_running_requests:
1394
+ if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs):
1232
1395
  self.running_batch.batch_is_full = True
1233
1396
  break
1234
1397
 
@@ -1240,16 +1403,14 @@ class Scheduler(
1240
1403
  res = adder.add_one_req(
1241
1404
  req, self.chunked_req, self.enable_hierarchical_cache
1242
1405
  )
1406
+
1243
1407
  if res != AddReqResult.CONTINUE:
1244
1408
  if res == AddReqResult.NO_TOKEN:
1245
1409
  if self.enable_hierarchical_cache:
1246
1410
  # Set batch_is_full after making sure there are requests that can be served
1247
1411
  self.running_batch.batch_is_full = len(
1248
1412
  adder.can_run_list
1249
- ) > 0 or (
1250
- self.running_batch is not None
1251
- and not self.running_batch.is_empty()
1252
- )
1413
+ ) > 0 or (not self.running_batch.is_empty())
1253
1414
  else:
1254
1415
  self.running_batch.batch_is_full = True
1255
1416
  break
@@ -1292,6 +1453,7 @@ class Scheduler(
1292
1453
  self.enable_overlap,
1293
1454
  self.spec_algorithm,
1294
1455
  self.server_args.enable_custom_logit_processor,
1456
+ chunked_req=self.chunked_req,
1295
1457
  )
1296
1458
  new_batch.prepare_for_extend()
1297
1459
 
@@ -1369,9 +1531,14 @@ class Scheduler(
1369
1531
  if self.is_generation:
1370
1532
  if self.spec_algorithm.is_none():
1371
1533
  model_worker_batch = batch.get_model_worker_batch()
1372
- logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
1373
- model_worker_batch
1374
- )
1534
+ if self.pp_group.is_last_rank:
1535
+ logits_output, next_token_ids = (
1536
+ self.tp_worker.forward_batch_generation(model_worker_batch)
1537
+ )
1538
+ else:
1539
+ pp_hidden_states_proxy_tensors, _ = (
1540
+ self.tp_worker.forward_batch_generation(model_worker_batch)
1541
+ )
1375
1542
  bid = model_worker_batch.bid
1376
1543
  else:
1377
1544
  (
@@ -1385,7 +1552,9 @@ class Scheduler(
1385
1552
  )
1386
1553
  self.spec_num_total_forward_ct += batch.batch_size()
1387
1554
  self.num_generated_tokens += num_accepted_tokens
1388
- batch.output_ids = next_token_ids
1555
+
1556
+ if self.pp_group.is_last_rank:
1557
+ batch.output_ids = next_token_ids
1389
1558
 
1390
1559
  # These 2 values are needed for processing the output, but the values can be
1391
1560
  # modified by overlap schedule. So we have to copy them here so that
@@ -1400,8 +1569,13 @@ class Scheduler(
1400
1569
  extend_logprob_start_len_per_req = None
1401
1570
 
1402
1571
  ret = GenerationBatchResult(
1403
- logits_output=logits_output,
1404
- next_token_ids=next_token_ids,
1572
+ logits_output=logits_output if self.pp_group.is_last_rank else None,
1573
+ pp_hidden_states_proxy_tensors=(
1574
+ pp_hidden_states_proxy_tensors
1575
+ if not self.pp_group.is_last_rank
1576
+ else None
1577
+ ),
1578
+ next_token_ids=next_token_ids if self.pp_group.is_last_rank else None,
1405
1579
  extend_input_len_per_req=extend_input_len_per_req,
1406
1580
  extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
1407
1581
  bid=bid,
@@ -1552,6 +1726,7 @@ class Scheduler(
1552
1726
 
1553
1727
  def move_ready_grammar_requests(self):
1554
1728
  """Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
1729
+
1555
1730
  num_ready_reqs = 0
1556
1731
  for req in self.grammar_queue:
1557
1732
  try:
@@ -1618,7 +1793,11 @@ class Scheduler(
1618
1793
 
1619
1794
  def flush_cache(self):
1620
1795
  """Flush the memory pool and cache."""
1621
- if len(self.waiting_queue) == 0 and self.running_batch.is_empty():
1796
+ if (
1797
+ len(self.waiting_queue) == 0
1798
+ and self.running_batch.is_empty()
1799
+ and (self.pp_size == 1 or all(x.is_empty() for x in self.running_mbs))
1800
+ ):
1622
1801
  self.cur_batch = None
1623
1802
  self.last_batch = None
1624
1803
  self.tree_cache.reset()
@@ -1656,7 +1835,6 @@ class Scheduler(
1656
1835
  ret["avg_spec_accept_length"] = (
1657
1836
  self.cum_spec_accept_length / self.cum_spec_accept_count
1658
1837
  )
1659
-
1660
1838
  if RECORD_STEP_TIME:
1661
1839
  ret["step_time_dict"] = self.step_time_dict
1662
1840
  return GetInternalStateReqOutput(
@@ -1667,6 +1845,7 @@ class Scheduler(
1667
1845
  server_args_dict = recv_req.server_args
1668
1846
  args_allow_update = set(
1669
1847
  [
1848
+ "max_micro_batch_size",
1670
1849
  "speculative_accept_threshold_single",
1671
1850
  "speculative_accept_threshold_acc",
1672
1851
  ]
@@ -1677,6 +1856,14 @@ class Scheduler(
1677
1856
  logging.warning(f"Updating {k} is not supported.")
1678
1857
  if_success = False
1679
1858
  break
1859
+ elif k == "max_micro_batch_size" and (
1860
+ v > self.max_running_requests // self.pp_size or v < 1
1861
+ ):
1862
+ logging.warning(
1863
+ f"Updating {k} to {v} is rejected because it is out of the valid range [1, {self.max_running_requests // self.pp_size}]."
1864
+ )
1865
+ if_success = False
1866
+ break
1680
1867
  if if_success:
1681
1868
  if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
1682
1869
  avg_spec_accept_length = (
@@ -1958,6 +2145,16 @@ class Scheduler(
1958
2145
  else:
1959
2146
  del self.sessions[session_id]
1960
2147
 
2148
+ def get_print_prefix(self):
2149
+ prefix = ""
2150
+ if self.dp_rank is not None:
2151
+ prefix += f" DP{self.dp_rank}"
2152
+ if self.server_args.tp_size > 1:
2153
+ prefix += f" TP{self.tp_rank}"
2154
+ if self.pp_size > 1:
2155
+ prefix += f" PP{self.pp_rank}"
2156
+ return prefix
2157
+
1961
2158
 
1962
2159
  def is_health_check_generate_req(recv_req):
1963
2160
  return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
@@ -1982,14 +2179,18 @@ def run_scheduler_process(
1982
2179
  port_args: PortArgs,
1983
2180
  gpu_id: int,
1984
2181
  tp_rank: int,
2182
+ pp_rank: int,
1985
2183
  dp_rank: Optional[int],
1986
2184
  pipe_writer,
1987
2185
  ):
1988
2186
  # Generate the prefix
1989
- if dp_rank is None:
1990
- prefix = f" TP{tp_rank}"
1991
- else:
1992
- prefix = f" DP{dp_rank} TP{tp_rank}"
2187
+ prefix = ""
2188
+ if dp_rank is not None:
2189
+ prefix += f" DP{dp_rank}"
2190
+ if server_args.tp_size > 1:
2191
+ prefix += f" TP{tp_rank}"
2192
+ if server_args.pp_size > 1:
2193
+ prefix += f" PP{pp_rank}"
1993
2194
 
1994
2195
  # Config the process
1995
2196
  kill_itself_when_parent_died()
@@ -2011,7 +2212,7 @@ def run_scheduler_process(
2011
2212
 
2012
2213
  # Create a scheduler and run the event loop
2013
2214
  try:
2014
- scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
2215
+ scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
2015
2216
  pipe_writer.send(
2016
2217
  {
2017
2218
  "status": "ready",
@@ -2022,7 +2223,9 @@ def run_scheduler_process(
2022
2223
  disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
2023
2224
 
2024
2225
  if disaggregation_mode == DisaggregationMode.NULL:
2025
- if scheduler.enable_overlap:
2226
+ if server_args.pp_size > 1:
2227
+ scheduler.event_loop_pp()
2228
+ elif scheduler.enable_overlap:
2026
2229
  scheduler.event_loop_overlap()
2027
2230
  else:
2028
2231
  scheduler.event_loop_normal()
@@ -2031,6 +2234,7 @@ def run_scheduler_process(
2031
2234
  scheduler.event_loop_overlap_disagg_prefill()
2032
2235
  else:
2033
2236
  scheduler.event_loop_normal_disagg_prefill()
2237
+
2034
2238
  elif disaggregation_mode == DisaggregationMode.DECODE:
2035
2239
  if scheduler.enable_overlap:
2036
2240
  scheduler.event_loop_overlap_disagg_decode()
@@ -278,7 +278,7 @@ class SchedulerOutputProcessorMixin:
278
278
  self.attn_tp_rank == 0
279
279
  and self.forward_ct_decode % self.server_args.decode_log_interval == 0
280
280
  ):
281
- self.log_decode_stats()
281
+ self.log_decode_stats(running_batch=batch)
282
282
 
283
283
  def add_input_logprob_return_values(
284
284
  self: Scheduler,