sglang 0.4.6.post1__py3-none-any.whl → 0.4.6.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. sglang/bench_one_batch.py +3 -11
  2. sglang/bench_serving.py +149 -1
  3. sglang/check_env.py +3 -3
  4. sglang/lang/chat_template.py +44 -0
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/deepseekvl2.py +3 -0
  7. sglang/srt/configs/device_config.py +1 -1
  8. sglang/srt/configs/internvl.py +696 -0
  9. sglang/srt/configs/janus_pro.py +3 -0
  10. sglang/srt/configs/kimi_vl.py +38 -0
  11. sglang/srt/configs/kimi_vl_moonvit.py +32 -0
  12. sglang/srt/configs/model_config.py +32 -0
  13. sglang/srt/constrained/xgrammar_backend.py +11 -19
  14. sglang/srt/conversation.py +151 -3
  15. sglang/srt/disaggregation/decode.py +4 -1
  16. sglang/srt/disaggregation/mini_lb.py +74 -23
  17. sglang/srt/disaggregation/mooncake/conn.py +9 -18
  18. sglang/srt/disaggregation/nixl/conn.py +241 -71
  19. sglang/srt/disaggregation/utils.py +44 -1
  20. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
  21. sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
  22. sglang/srt/distributed/device_communicators/pynccl.py +2 -1
  23. sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
  24. sglang/srt/distributed/parallel_state.py +22 -1
  25. sglang/srt/entrypoints/engine.py +58 -24
  26. sglang/srt/entrypoints/http_server.py +28 -1
  27. sglang/srt/entrypoints/verl_engine.py +3 -2
  28. sglang/srt/function_call_parser.py +97 -0
  29. sglang/srt/hf_transformers_utils.py +22 -1
  30. sglang/srt/layers/attention/cutlass_mla_backend.py +1 -1
  31. sglang/srt/layers/attention/flashattention_backend.py +146 -50
  32. sglang/srt/layers/attention/flashinfer_backend.py +129 -94
  33. sglang/srt/layers/attention/flashinfer_mla_backend.py +88 -30
  34. sglang/srt/layers/attention/flashmla_backend.py +3 -0
  35. sglang/srt/layers/attention/merge_state.py +46 -0
  36. sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
  37. sglang/srt/layers/attention/vision.py +290 -163
  38. sglang/srt/layers/dp_attention.py +5 -2
  39. sglang/srt/layers/moe/ep_moe/kernels.py +342 -7
  40. sglang/srt/layers/moe/ep_moe/layer.py +120 -1
  41. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +98 -57
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  48. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +10 -5
  49. sglang/srt/layers/quantization/__init__.py +2 -2
  50. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
  52. sglang/srt/layers/quantization/deep_gemm.py +6 -1
  53. sglang/srt/layers/quantization/fp8.py +108 -95
  54. sglang/srt/layers/quantization/fp8_kernel.py +79 -60
  55. sglang/srt/layers/quantization/fp8_utils.py +71 -23
  56. sglang/srt/layers/quantization/kv_cache.py +3 -10
  57. sglang/srt/layers/quantization/utils.py +0 -5
  58. sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
  59. sglang/srt/layers/utils.py +35 -0
  60. sglang/srt/lora/layers.py +35 -9
  61. sglang/srt/lora/lora_manager.py +81 -35
  62. sglang/srt/managers/cache_controller.py +115 -119
  63. sglang/srt/managers/data_parallel_controller.py +52 -34
  64. sglang/srt/managers/io_struct.py +10 -0
  65. sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
  66. sglang/srt/managers/multimodal_processors/internvl.py +232 -0
  67. sglang/srt/managers/multimodal_processors/kimi_vl.py +73 -0
  68. sglang/srt/managers/schedule_batch.py +44 -16
  69. sglang/srt/managers/schedule_policy.py +11 -5
  70. sglang/srt/managers/scheduler.py +291 -72
  71. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -1
  72. sglang/srt/managers/tokenizer_manager.py +24 -13
  73. sglang/srt/managers/tp_worker.py +60 -28
  74. sglang/srt/managers/tp_worker_overlap_thread.py +9 -3
  75. sglang/srt/mem_cache/chunk_cache.py +2 -0
  76. sglang/srt/mem_cache/memory_pool.py +70 -36
  77. sglang/srt/model_executor/cuda_graph_runner.py +82 -19
  78. sglang/srt/model_executor/forward_batch_info.py +31 -1
  79. sglang/srt/model_executor/model_runner.py +159 -90
  80. sglang/srt/model_loader/loader.py +18 -11
  81. sglang/srt/models/clip.py +4 -4
  82. sglang/srt/models/deepseek_janus_pro.py +1 -1
  83. sglang/srt/models/deepseek_nextn.py +2 -277
  84. sglang/srt/models/deepseek_v2.py +132 -37
  85. sglang/srt/models/gemma3_mm.py +1 -1
  86. sglang/srt/models/internlm2.py +3 -0
  87. sglang/srt/models/internvl.py +670 -0
  88. sglang/srt/models/kimi_vl.py +308 -0
  89. sglang/srt/models/kimi_vl_moonvit.py +639 -0
  90. sglang/srt/models/llama.py +93 -31
  91. sglang/srt/models/llama4.py +54 -7
  92. sglang/srt/models/llama_eagle.py +4 -1
  93. sglang/srt/models/llama_eagle3.py +4 -1
  94. sglang/srt/models/minicpmv.py +1 -1
  95. sglang/srt/models/mllama.py +1 -1
  96. sglang/srt/models/phi3_small.py +16 -2
  97. sglang/srt/models/qwen2_5_vl.py +8 -4
  98. sglang/srt/models/qwen2_moe.py +8 -3
  99. sglang/srt/models/qwen2_vl.py +4 -16
  100. sglang/srt/models/qwen3_moe.py +8 -3
  101. sglang/srt/models/xiaomi_mimo.py +171 -0
  102. sglang/srt/openai_api/adapter.py +58 -62
  103. sglang/srt/openai_api/protocol.py +38 -16
  104. sglang/srt/reasoning_parser.py +2 -2
  105. sglang/srt/sampling/sampling_batch_info.py +54 -2
  106. sglang/srt/sampling/sampling_params.py +2 -0
  107. sglang/srt/server_args.py +93 -24
  108. sglang/srt/speculative/eagle_worker.py +3 -2
  109. sglang/srt/utils.py +123 -10
  110. sglang/test/runners.py +4 -0
  111. sglang/test/test_block_fp8.py +2 -2
  112. sglang/test/test_deepep_utils.py +219 -0
  113. sglang/test/test_utils.py +32 -1
  114. sglang/version.py +1 -1
  115. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/METADATA +18 -9
  116. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/RECORD +119 -99
  117. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/WHEEL +1 -1
  118. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/licenses/LICENSE +0 -0
  119. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/top_level.txt +0 -0
@@ -51,7 +51,12 @@ from sglang.srt.disaggregation.utils import (
51
51
  ReqToMetadataIdxAllocator,
52
52
  TransferBackend,
53
53
  )
54
- from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
54
+ from sglang.srt.distributed import get_pp_group, get_world_group
55
+ from sglang.srt.hf_transformers_utils import (
56
+ get_processor,
57
+ get_tokenizer,
58
+ get_tokenizer_from_processor,
59
+ )
55
60
  from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
56
61
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
57
62
  from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
@@ -82,6 +87,8 @@ from sglang.srt.managers.io_struct import (
82
87
  RpcReqOutput,
83
88
  SetInternalStateReq,
84
89
  SetInternalStateReqOutput,
90
+ SlowDownReqInput,
91
+ SlowDownReqOutput,
85
92
  TokenizedEmbeddingReqInput,
86
93
  TokenizedGenerateReqInput,
87
94
  UpdateWeightFromDiskReqInput,
@@ -114,7 +121,11 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
114
121
  from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
115
122
  from sglang.srt.mem_cache.radix_cache import RadixCache
116
123
  from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
117
- from sglang.srt.model_executor.forward_batch_info import ForwardMode
124
+ from sglang.srt.model_executor.forward_batch_info import (
125
+ ForwardBatch,
126
+ ForwardMode,
127
+ PPProxyTensors,
128
+ )
118
129
  from sglang.srt.reasoning_parser import ReasoningParser
119
130
  from sglang.srt.server_args import PortArgs, ServerArgs
120
131
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
@@ -127,6 +138,7 @@ from sglang.srt.utils import (
127
138
  get_bool_env_var,
128
139
  get_zmq_socket,
129
140
  kill_itself_when_parent_died,
141
+ point_to_point_pyobj,
130
142
  pyspy_dump_schedulers,
131
143
  set_gpu_proc_affinity,
132
144
  set_random_seed,
@@ -145,8 +157,9 @@ RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
145
157
 
146
158
  @dataclass
147
159
  class GenerationBatchResult:
148
- logits_output: LogitsProcessorOutput
149
- next_token_ids: List[int]
160
+ logits_output: Optional[LogitsProcessorOutput]
161
+ pp_hidden_states_proxy_tensors: Optional[torch.Tensor]
162
+ next_token_ids: Optional[List[int]]
150
163
  extend_input_len_per_req: List[int]
151
164
  extend_logprob_start_len_per_req: List[int]
152
165
  bid: int
@@ -171,12 +184,16 @@ class Scheduler(
171
184
  port_args: PortArgs,
172
185
  gpu_id: int,
173
186
  tp_rank: int,
187
+ pp_rank: int,
174
188
  dp_rank: Optional[int],
175
189
  ):
176
190
  # Parse args
177
191
  self.server_args = server_args
178
192
  self.tp_rank = tp_rank
193
+ self.pp_rank = pp_rank
179
194
  self.tp_size = server_args.tp_size
195
+ self.pp_size = server_args.pp_size
196
+ self.dp_size = server_args.dp_size
180
197
  self.schedule_policy = server_args.schedule_policy
181
198
  self.lora_paths = server_args.lora_paths
182
199
  self.max_loras_per_batch = server_args.max_loras_per_batch
@@ -192,7 +209,6 @@ class Scheduler(
192
209
  self.page_size = server_args.page_size
193
210
 
194
211
  # Distributed rank info
195
- self.dp_size = server_args.dp_size
196
212
  self.attn_tp_rank, self.attn_tp_size, self.dp_rank = (
197
213
  compute_dp_attention_world_info(
198
214
  server_args.enable_dp_attention,
@@ -204,7 +220,7 @@ class Scheduler(
204
220
 
205
221
  # Init inter-process communication
206
222
  context = zmq.Context(2)
207
- if self.attn_tp_rank == 0:
223
+ if self.pp_rank == 0 and self.attn_tp_rank == 0:
208
224
  self.recv_from_tokenizer = get_zmq_socket(
209
225
  context, zmq.PULL, port_args.scheduler_input_ipc_name, False
210
226
  )
@@ -259,6 +275,7 @@ class Scheduler(
259
275
  server_args=server_args,
260
276
  gpu_id=gpu_id,
261
277
  tp_rank=tp_rank,
278
+ pp_rank=pp_rank,
262
279
  dp_rank=dp_rank,
263
280
  nccl_port=port_args.nccl_port,
264
281
  )
@@ -292,8 +309,18 @@ class Scheduler(
292
309
  _,
293
310
  _,
294
311
  ) = self.tp_worker.get_worker_info()
295
- self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
312
+ if global_server_args_dict["max_micro_batch_size"] is None:
313
+ global_server_args_dict["max_micro_batch_size"] = max(
314
+ self.max_running_requests // server_args.pp_size, 1
315
+ )
316
+
317
+ self.tp_group = self.tp_worker.get_tp_group()
318
+ self.tp_cpu_group = self.tp_group.cpu_group
319
+ self.attn_tp_group = self.tp_worker.get_attention_tp_group()
296
320
  self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
321
+ self.pp_group = get_pp_group()
322
+ self.world_group = get_world_group()
323
+
297
324
  self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
298
325
  global_server_args_dict.update(worker_global_server_args_dict)
299
326
  set_random_seed(self.random_seed)
@@ -392,6 +419,8 @@ class Scheduler(
392
419
  self.profiler_id: Optional[str] = None
393
420
  self.profiler_target_forward_ct: Optional[int] = None
394
421
 
422
+ self.forward_sleep_time = None
423
+
395
424
  # Init metrics stats
396
425
  self.init_metrics()
397
426
 
@@ -414,6 +443,7 @@ class Scheduler(
414
443
  (GetWeightsByNameReqInput, self.get_weights_by_name),
415
444
  (ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
416
445
  (ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
446
+ (SlowDownReqInput, self.slow_down),
417
447
  (ProfileReq, self.profile),
418
448
  (GetInternalStateReq, self.get_internal_state),
419
449
  (SetInternalStateReq, self.set_internal_state),
@@ -430,17 +460,7 @@ class Scheduler(
430
460
  def init_tokenizer(self):
431
461
  server_args = self.server_args
432
462
 
433
- self.model_config = ModelConfig(
434
- server_args.model_path,
435
- trust_remote_code=server_args.trust_remote_code,
436
- revision=server_args.revision,
437
- context_length=server_args.context_length,
438
- model_override_args=server_args.json_model_override_args,
439
- is_embedding=server_args.is_embedding,
440
- enable_multimodal=server_args.enable_multimodal,
441
- dtype=server_args.dtype,
442
- quantization=server_args.quantization,
443
- )
463
+ self.model_config = ModelConfig.from_server_args(server_args)
444
464
  self.is_generation = self.model_config.is_generation
445
465
 
446
466
  if server_args.skip_tokenizer_init:
@@ -454,7 +474,7 @@ class Scheduler(
454
474
  revision=server_args.revision,
455
475
  use_fast=not server_args.disable_fast_image_processor,
456
476
  )
457
- self.tokenizer = self.processor.tokenizer
477
+ self.tokenizer = get_tokenizer_from_processor(self.processor)
458
478
  else:
459
479
  self.tokenizer = get_tokenizer(
460
480
  server_args.tokenizer_path,
@@ -477,6 +497,7 @@ class Scheduler(
477
497
  self.tree_cache = ChunkCache(
478
498
  req_to_token_pool=self.req_to_token_pool,
479
499
  token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
500
+ page_size=self.page_size,
480
501
  )
481
502
  else:
482
503
  if self.enable_hierarchical_cache:
@@ -673,26 +694,141 @@ class Scheduler(
673
694
 
674
695
  self.last_batch = batch
675
696
 
697
+ @DynamicGradMode()
698
+ def event_loop_pp(self):
699
+ """A non-overlap scheduler loop for pipeline parallelism."""
700
+ mbs = [None] * self.pp_size
701
+ last_mbs = [None] * self.pp_size
702
+ self.running_mbs = [
703
+ ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
704
+ ]
705
+ bids = [None] * self.pp_size
706
+ pp_outputs: Optional[PPProxyTensors] = None
707
+ while True:
708
+ server_is_idle = True
709
+ for mb_id in range(self.pp_size):
710
+ self.running_batch = self.running_mbs[mb_id]
711
+ self.last_batch = last_mbs[mb_id]
712
+
713
+ recv_reqs = self.recv_requests()
714
+ self.process_input_requests(recv_reqs)
715
+ mbs[mb_id] = self.get_next_batch_to_run()
716
+ self.running_mbs[mb_id] = self.running_batch
717
+
718
+ self.cur_batch = mbs[mb_id]
719
+ if self.cur_batch:
720
+ server_is_idle = False
721
+ result = self.run_batch(self.cur_batch)
722
+
723
+ # send the outputs to the next step
724
+ if self.pp_group.is_last_rank:
725
+ if self.cur_batch:
726
+ next_token_ids, bids[mb_id] = (
727
+ result.next_token_ids,
728
+ result.bid,
729
+ )
730
+ pp_outputs = PPProxyTensors(
731
+ {
732
+ "next_token_ids": next_token_ids,
733
+ }
734
+ )
735
+ # send the output from the last round to let the next stage worker run post processing
736
+ self.pp_group.send_tensor_dict(
737
+ pp_outputs.tensors,
738
+ all_gather_group=self.attn_tp_group,
739
+ )
740
+
741
+ # receive outputs and post-process (filter finished reqs) the coming microbatch
742
+ next_mb_id = (mb_id + 1) % self.pp_size
743
+ next_pp_outputs = None
744
+ if mbs[next_mb_id] is not None:
745
+ next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
746
+ self.pp_group.recv_tensor_dict(
747
+ all_gather_group=self.attn_tp_group
748
+ )
749
+ )
750
+ mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
751
+ output_result = GenerationBatchResult(
752
+ logits_output=None,
753
+ pp_hidden_states_proxy_tensors=None,
754
+ next_token_ids=next_pp_outputs["next_token_ids"],
755
+ extend_input_len_per_req=None,
756
+ extend_logprob_start_len_per_req=None,
757
+ bid=bids[next_mb_id],
758
+ )
759
+ self.process_batch_result(mbs[next_mb_id], output_result)
760
+ last_mbs[next_mb_id] = mbs[next_mb_id]
761
+
762
+ # carry the outputs to the next stage
763
+ if not self.pp_group.is_last_rank:
764
+ if self.cur_batch:
765
+ bids[mb_id] = result.bid
766
+ if pp_outputs:
767
+ # send the outputs from the last round to let the next stage worker run post processing
768
+ self.pp_group.send_tensor_dict(
769
+ pp_outputs.tensors,
770
+ all_gather_group=self.attn_tp_group,
771
+ )
772
+
773
+ if not self.pp_group.is_last_rank:
774
+ # send out reqs to the next stage
775
+ dp_offset = self.dp_rank * self.attn_tp_size
776
+ if self.attn_tp_rank == 0:
777
+ point_to_point_pyobj(
778
+ recv_reqs,
779
+ self.pp_rank * self.tp_size + dp_offset,
780
+ self.world_group.cpu_group,
781
+ self.pp_rank * self.tp_size + dp_offset,
782
+ (self.pp_rank + 1) * self.tp_size + dp_offset,
783
+ )
784
+
785
+ # send out proxy tensors to the next stage
786
+ if self.cur_batch:
787
+ self.pp_group.send_tensor_dict(
788
+ result.pp_hidden_states_proxy_tensors,
789
+ all_gather_group=self.attn_tp_group,
790
+ )
791
+
792
+ pp_outputs = next_pp_outputs
793
+
794
+ # When the server is idle, self-check and re-init some states
795
+ if server_is_idle:
796
+ self.check_memory()
797
+ self.new_token_ratio = self.init_new_token_ratio
798
+
676
799
  def recv_requests(self) -> List[Req]:
677
800
  """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
678
- if self.attn_tp_rank == 0:
679
- recv_reqs = []
680
-
681
- while True:
682
- try:
683
- recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
684
- except zmq.ZMQError:
685
- break
686
- recv_reqs.append(recv_req)
687
-
688
- while True:
689
- try:
690
- recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK)
691
- except zmq.ZMQError:
692
- break
693
- recv_reqs.append(recv_rpc)
801
+ if self.pp_rank == 0:
802
+ if self.attn_tp_rank == 0:
803
+ recv_reqs = []
804
+
805
+ while True:
806
+ try:
807
+ recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
808
+ except zmq.ZMQError:
809
+ break
810
+ recv_reqs.append(recv_req)
811
+
812
+ while True:
813
+ try:
814
+ recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK)
815
+ except zmq.ZMQError:
816
+ break
817
+ recv_reqs.append(recv_rpc)
818
+ else:
819
+ recv_reqs = None
694
820
  else:
695
- recv_reqs = None
821
+ if self.attn_tp_rank == 0:
822
+ dp_offset = self.dp_rank * self.attn_tp_size
823
+ recv_reqs = point_to_point_pyobj(
824
+ [],
825
+ self.pp_rank * self.tp_size + dp_offset,
826
+ self.world_group.cpu_group,
827
+ (self.pp_rank - 1) * self.tp_size + dp_offset,
828
+ self.pp_rank * self.tp_size + dp_offset,
829
+ )
830
+ else:
831
+ recv_reqs = None
696
832
 
697
833
  if self.server_args.enable_dp_attention:
698
834
  if self.attn_tp_rank == 0:
@@ -715,20 +851,27 @@ class Scheduler(
715
851
  control_reqs = None
716
852
 
717
853
  if self.attn_tp_size != 1:
718
- attn_tp_rank_0 = self.dp_rank * self.attn_tp_size
719
854
  work_reqs = broadcast_pyobj(
720
855
  work_reqs,
721
- self.attn_tp_rank,
856
+ self.attn_tp_group.rank,
722
857
  self.attn_tp_cpu_group,
723
- src=attn_tp_rank_0,
858
+ src=self.attn_tp_group.ranks[0],
724
859
  )
725
860
  if self.tp_size != 1:
726
861
  control_reqs = broadcast_pyobj(
727
- control_reqs, self.tp_rank, self.tp_cpu_group
862
+ control_reqs,
863
+ self.tp_group.rank,
864
+ self.tp_cpu_group,
865
+ src=self.tp_group.ranks[0],
728
866
  )
729
867
  recv_reqs = work_reqs + control_reqs
730
868
  elif self.tp_size != 1:
731
- recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
869
+ recv_reqs = broadcast_pyobj(
870
+ recv_reqs,
871
+ self.tp_group.rank,
872
+ self.tp_cpu_group,
873
+ src=self.tp_group.ranks[0],
874
+ )
732
875
  return recv_reqs
733
876
 
734
877
  def process_input_requests(self, recv_reqs: List):
@@ -777,6 +920,10 @@ class Scheduler(
777
920
  )
778
921
  custom_logit_processor = None
779
922
 
923
+ if recv_req.bootstrap_port is None:
924
+ # Use default bootstrap port
925
+ recv_req.bootstrap_port = self.server_args.disaggregation_bootstrap_port
926
+
780
927
  req = Req(
781
928
  recv_req.rid,
782
929
  recv_req.input_text,
@@ -900,6 +1047,7 @@ class Scheduler(
900
1047
  add_to_grammar_queue = True
901
1048
 
902
1049
  if add_to_grammar_queue:
1050
+ req.queue_time_start = time.time()
903
1051
  self.grammar_queue.append(req)
904
1052
  else:
905
1053
  self._add_request_to_queue(req)
@@ -1025,12 +1173,14 @@ class Scheduler(
1025
1173
 
1026
1174
  self.metrics_collector.log_stats(self.stats)
1027
1175
 
1028
- def log_decode_stats(self):
1176
+ def log_decode_stats(self, running_batch=None):
1177
+ batch = running_batch or self.running_batch
1178
+
1029
1179
  gap_latency = time.time() - self.last_decode_stats_tic
1030
1180
  self.last_decode_stats_tic = time.time()
1031
1181
  self.last_gen_throughput = self.num_generated_tokens / gap_latency
1032
1182
  self.num_generated_tokens = 0
1033
- num_running_reqs = len(self.running_batch.reqs)
1183
+ num_running_reqs = len(batch.reqs)
1034
1184
  num_used = self.max_total_num_tokens - (
1035
1185
  self.token_to_kv_pool_allocator.available_size()
1036
1186
  + self.tree_cache.evictable_size()
@@ -1130,19 +1280,25 @@ class Scheduler(
1130
1280
 
1131
1281
  def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
1132
1282
  # Merge the prefill batch into the running batch
1283
+ chunked_req_to_exclude = set()
1284
+ if self.chunked_req:
1285
+ # Move the chunked request out of the batch so that we can merge
1286
+ # only finished requests to running_batch.
1287
+ chunked_req_to_exclude.add(self.chunked_req)
1288
+ self.tree_cache.cache_unfinished_req(self.chunked_req)
1289
+ # chunked request keeps its rid but will get a new req_pool_idx
1290
+ self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
1133
1291
  if self.last_batch and self.last_batch.forward_mode.is_extend():
1134
- if self.chunked_req:
1135
- # Move the chunked request out of the batch so that we can merge
1136
- # only finished requests to running_batch.
1137
- self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
1138
- self.tree_cache.cache_unfinished_req(self.chunked_req)
1139
- # chunked request keeps its rid but will get a new req_pool_idx
1140
- self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
1141
- self.running_batch.batch_is_full = False
1292
+ if self.last_batch.chunked_req is not None:
1293
+ # In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req.
1294
+ # We need to discard it.
1295
+ chunked_req_to_exclude.add(self.last_batch.chunked_req)
1142
1296
 
1143
1297
  # Filter batch
1144
1298
  last_bs = self.last_batch.batch_size()
1145
- self.last_batch.filter_batch()
1299
+ self.last_batch.filter_batch(
1300
+ chunked_req_to_exclude=list(chunked_req_to_exclude)
1301
+ )
1146
1302
  if self.last_batch.batch_size() < last_bs:
1147
1303
  self.running_batch.batch_is_full = False
1148
1304
 
@@ -1172,6 +1328,12 @@ class Scheduler(
1172
1328
 
1173
1329
  return ret
1174
1330
 
1331
+ def get_num_allocatable_reqs(self, running_bs):
1332
+ res = global_server_args_dict["max_micro_batch_size"] - running_bs
1333
+ if self.pp_size > 1:
1334
+ res = min(res, self.req_to_token_pool.available_size())
1335
+ return res
1336
+
1175
1337
  def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
1176
1338
  # Check if the grammar is ready in the grammar queue
1177
1339
  if self.grammar_queue:
@@ -1184,7 +1346,12 @@ class Scheduler(
1184
1346
  return None
1185
1347
 
1186
1348
  running_bs = len(self.running_batch.reqs)
1187
- if running_bs >= self.max_running_requests:
1349
+ # Igore the check if self.chunked_req is not None.
1350
+ # In the non-PP case, when self.chunked_req is not None, num_allocatable_reqs should always be greater than 0,
1351
+ # as the space for the chunked request has just been released.
1352
+ # In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
1353
+ # Instead, we should always allow chunked request to be added, otherwise, there will be a memory leak.
1354
+ if self.get_num_allocatable_reqs(running_bs) <= 0 and not self.chunked_req:
1188
1355
  self.running_batch.batch_is_full = True
1189
1356
  return None
1190
1357
 
@@ -1228,7 +1395,7 @@ class Scheduler(
1228
1395
  self.running_batch.batch_is_full = True
1229
1396
  break
1230
1397
 
1231
- if running_bs + len(adder.can_run_list) >= self.max_running_requests:
1398
+ if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs):
1232
1399
  self.running_batch.batch_is_full = True
1233
1400
  break
1234
1401
 
@@ -1240,16 +1407,14 @@ class Scheduler(
1240
1407
  res = adder.add_one_req(
1241
1408
  req, self.chunked_req, self.enable_hierarchical_cache
1242
1409
  )
1410
+
1243
1411
  if res != AddReqResult.CONTINUE:
1244
1412
  if res == AddReqResult.NO_TOKEN:
1245
1413
  if self.enable_hierarchical_cache:
1246
1414
  # Set batch_is_full after making sure there are requests that can be served
1247
1415
  self.running_batch.batch_is_full = len(
1248
1416
  adder.can_run_list
1249
- ) > 0 or (
1250
- self.running_batch is not None
1251
- and not self.running_batch.is_empty()
1252
- )
1417
+ ) > 0 or (not self.running_batch.is_empty())
1253
1418
  else:
1254
1419
  self.running_batch.batch_is_full = True
1255
1420
  break
@@ -1292,6 +1457,7 @@ class Scheduler(
1292
1457
  self.enable_overlap,
1293
1458
  self.spec_algorithm,
1294
1459
  self.server_args.enable_custom_logit_processor,
1460
+ chunked_req=self.chunked_req,
1295
1461
  )
1296
1462
  new_batch.prepare_for_extend()
1297
1463
 
@@ -1365,13 +1531,22 @@ class Scheduler(
1365
1531
  ):
1366
1532
  self.stop_profile()
1367
1533
 
1534
+ if self.forward_sleep_time is not None:
1535
+ logger.info(f"Scheduler.run_batch sleep {self.forward_sleep_time}s")
1536
+ time.sleep(self.forward_sleep_time)
1537
+
1368
1538
  # Run forward
1369
1539
  if self.is_generation:
1370
1540
  if self.spec_algorithm.is_none():
1371
1541
  model_worker_batch = batch.get_model_worker_batch()
1372
- logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
1373
- model_worker_batch
1374
- )
1542
+ if self.pp_group.is_last_rank:
1543
+ logits_output, next_token_ids = (
1544
+ self.tp_worker.forward_batch_generation(model_worker_batch)
1545
+ )
1546
+ else:
1547
+ pp_hidden_states_proxy_tensors, _ = (
1548
+ self.tp_worker.forward_batch_generation(model_worker_batch)
1549
+ )
1375
1550
  bid = model_worker_batch.bid
1376
1551
  else:
1377
1552
  (
@@ -1385,7 +1560,9 @@ class Scheduler(
1385
1560
  )
1386
1561
  self.spec_num_total_forward_ct += batch.batch_size()
1387
1562
  self.num_generated_tokens += num_accepted_tokens
1388
- batch.output_ids = next_token_ids
1563
+
1564
+ if self.pp_group.is_last_rank:
1565
+ batch.output_ids = next_token_ids
1389
1566
 
1390
1567
  # These 2 values are needed for processing the output, but the values can be
1391
1568
  # modified by overlap schedule. So we have to copy them here so that
@@ -1400,8 +1577,13 @@ class Scheduler(
1400
1577
  extend_logprob_start_len_per_req = None
1401
1578
 
1402
1579
  ret = GenerationBatchResult(
1403
- logits_output=logits_output,
1404
- next_token_ids=next_token_ids,
1580
+ logits_output=logits_output if self.pp_group.is_last_rank else None,
1581
+ pp_hidden_states_proxy_tensors=(
1582
+ pp_hidden_states_proxy_tensors
1583
+ if not self.pp_group.is_last_rank
1584
+ else None
1585
+ ),
1586
+ next_token_ids=next_token_ids if self.pp_group.is_last_rank else None,
1405
1587
  extend_input_len_per_req=extend_input_len_per_req,
1406
1588
  extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
1407
1589
  bid=bid,
@@ -1552,6 +1734,7 @@ class Scheduler(
1552
1734
 
1553
1735
  def move_ready_grammar_requests(self):
1554
1736
  """Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
1737
+
1555
1738
  num_ready_reqs = 0
1556
1739
  for req in self.grammar_queue:
1557
1740
  try:
@@ -1618,7 +1801,11 @@ class Scheduler(
1618
1801
 
1619
1802
  def flush_cache(self):
1620
1803
  """Flush the memory pool and cache."""
1621
- if len(self.waiting_queue) == 0 and self.running_batch.is_empty():
1804
+ if (
1805
+ len(self.waiting_queue) == 0
1806
+ and self.running_batch.is_empty()
1807
+ and (self.pp_size == 1 or all(x.is_empty() for x in self.running_mbs))
1808
+ ):
1622
1809
  self.cur_batch = None
1623
1810
  self.last_batch = None
1624
1811
  self.tree_cache.reset()
@@ -1656,7 +1843,6 @@ class Scheduler(
1656
1843
  ret["avg_spec_accept_length"] = (
1657
1844
  self.cum_spec_accept_length / self.cum_spec_accept_count
1658
1845
  )
1659
-
1660
1846
  if RECORD_STEP_TIME:
1661
1847
  ret["step_time_dict"] = self.step_time_dict
1662
1848
  return GetInternalStateReqOutput(
@@ -1667,6 +1853,7 @@ class Scheduler(
1667
1853
  server_args_dict = recv_req.server_args
1668
1854
  args_allow_update = set(
1669
1855
  [
1856
+ "max_micro_batch_size",
1670
1857
  "speculative_accept_threshold_single",
1671
1858
  "speculative_accept_threshold_acc",
1672
1859
  ]
@@ -1677,6 +1864,14 @@ class Scheduler(
1677
1864
  logging.warning(f"Updating {k} is not supported.")
1678
1865
  if_success = False
1679
1866
  break
1867
+ elif k == "max_micro_batch_size" and (
1868
+ v > self.max_running_requests // self.pp_size or v < 1
1869
+ ):
1870
+ logging.warning(
1871
+ f"Updating {k} to {v} is rejected because it is out of the valid range [1, {self.max_running_requests // self.pp_size}]."
1872
+ )
1873
+ if_success = False
1874
+ break
1680
1875
  if if_success:
1681
1876
  if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
1682
1877
  avg_spec_accept_length = (
@@ -1815,6 +2010,13 @@ class Scheduler(
1815
2010
  del self.stashed_model_static_state
1816
2011
  return ResumeMemoryOccupationReqOutput()
1817
2012
 
2013
+ def slow_down(self, recv_req: SlowDownReqInput):
2014
+ t = recv_req.forward_sleep_time
2015
+ if t is not None and t <= 0:
2016
+ t = None
2017
+ self.forward_sleep_time = t
2018
+ return SlowDownReqOutput()
2019
+
1818
2020
  def profile(self, recv_req: ProfileReq):
1819
2021
  if recv_req.type == ProfileReqType.START_PROFILE:
1820
2022
  return self.start_profile(
@@ -1958,6 +2160,16 @@ class Scheduler(
1958
2160
  else:
1959
2161
  del self.sessions[session_id]
1960
2162
 
2163
+ def get_print_prefix(self):
2164
+ prefix = ""
2165
+ if self.dp_rank is not None:
2166
+ prefix += f" DP{self.dp_rank}"
2167
+ if self.server_args.tp_size > 1:
2168
+ prefix += f" TP{self.tp_rank}"
2169
+ if self.pp_size > 1:
2170
+ prefix += f" PP{self.pp_rank}"
2171
+ return prefix
2172
+
1961
2173
 
1962
2174
  def is_health_check_generate_req(recv_req):
1963
2175
  return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
@@ -1982,14 +2194,18 @@ def run_scheduler_process(
1982
2194
  port_args: PortArgs,
1983
2195
  gpu_id: int,
1984
2196
  tp_rank: int,
2197
+ pp_rank: int,
1985
2198
  dp_rank: Optional[int],
1986
2199
  pipe_writer,
1987
2200
  ):
1988
2201
  # Generate the prefix
1989
- if dp_rank is None:
1990
- prefix = f" TP{tp_rank}"
1991
- else:
1992
- prefix = f" DP{dp_rank} TP{tp_rank}"
2202
+ prefix = ""
2203
+ if dp_rank is not None:
2204
+ prefix += f" DP{dp_rank}"
2205
+ if server_args.tp_size > 1:
2206
+ prefix += f" TP{tp_rank}"
2207
+ if server_args.pp_size > 1:
2208
+ prefix += f" PP{pp_rank}"
1993
2209
 
1994
2210
  # Config the process
1995
2211
  kill_itself_when_parent_died()
@@ -2011,7 +2227,7 @@ def run_scheduler_process(
2011
2227
 
2012
2228
  # Create a scheduler and run the event loop
2013
2229
  try:
2014
- scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
2230
+ scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
2015
2231
  pipe_writer.send(
2016
2232
  {
2017
2233
  "status": "ready",
@@ -2022,7 +2238,9 @@ def run_scheduler_process(
2022
2238
  disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
2023
2239
 
2024
2240
  if disaggregation_mode == DisaggregationMode.NULL:
2025
- if scheduler.enable_overlap:
2241
+ if server_args.pp_size > 1:
2242
+ scheduler.event_loop_pp()
2243
+ elif scheduler.enable_overlap:
2026
2244
  scheduler.event_loop_overlap()
2027
2245
  else:
2028
2246
  scheduler.event_loop_normal()
@@ -2031,6 +2249,7 @@ def run_scheduler_process(
2031
2249
  scheduler.event_loop_overlap_disagg_prefill()
2032
2250
  else:
2033
2251
  scheduler.event_loop_normal_disagg_prefill()
2252
+
2034
2253
  elif disaggregation_mode == DisaggregationMode.DECODE:
2035
2254
  if scheduler.enable_overlap:
2036
2255
  scheduler.event_loop_overlap_disagg_decode()
@@ -278,7 +278,7 @@ class SchedulerOutputProcessorMixin:
278
278
  self.attn_tp_rank == 0
279
279
  and self.forward_ct_decode % self.server_args.decode_log_interval == 0
280
280
  ):
281
- self.log_decode_stats()
281
+ self.log_decode_stats(running_batch=batch)
282
282
 
283
283
  def add_input_logprob_return_values(
284
284
  self: Scheduler,