sglang 0.4.6__py3-none-any.whl → 0.4.6.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -0
- sglang/check_env.py +3 -3
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/kimi_vl.py +38 -0
- sglang/srt/configs/kimi_vl_moonvit.py +32 -0
- sglang/srt/configs/model_config.py +15 -0
- sglang/srt/conversation.py +122 -1
- sglang/srt/disaggregation/decode.py +8 -2
- sglang/srt/disaggregation/fake/__init__.py +1 -0
- sglang/srt/disaggregation/fake/conn.py +88 -0
- sglang/srt/disaggregation/prefill.py +12 -3
- sglang/srt/disaggregation/utils.py +16 -2
- sglang/srt/entrypoints/engine.py +52 -21
- sglang/srt/entrypoints/http_server.py +27 -2
- sglang/srt/function_call_parser.py +97 -0
- sglang/srt/hf_transformers_utils.py +2 -0
- sglang/srt/layers/attention/cutlass_mla_backend.py +278 -0
- sglang/srt/layers/attention/flashinfer_backend.py +107 -82
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -16
- sglang/srt/layers/attention/flashmla_backend.py +3 -0
- sglang/srt/layers/attention/utils.py +1 -1
- sglang/srt/layers/dp_attention.py +5 -2
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +1 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=96,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +10 -8
- sglang/srt/layers/moe/fused_moe_triton/layer.py +15 -17
- sglang/srt/layers/quantization/__init__.py +2 -2
- sglang/srt/layers/quantization/deep_gemm.py +1 -1
- sglang/srt/layers/quantization/fp8.py +20 -22
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/utils.py +35 -0
- sglang/srt/lora/layers.py +35 -9
- sglang/srt/lora/lora_manager.py +84 -35
- sglang/srt/managers/data_parallel_controller.py +52 -34
- sglang/srt/managers/multimodal_processors/kimi_vl.py +73 -0
- sglang/srt/managers/schedule_batch.py +34 -15
- sglang/srt/managers/scheduler.py +273 -67
- sglang/srt/managers/scheduler_output_processor_mixin.py +26 -10
- sglang/srt/managers/tp_worker.py +52 -17
- sglang/srt/managers/tp_worker_overlap_thread.py +18 -7
- sglang/srt/mem_cache/memory_pool.py +70 -36
- sglang/srt/model_executor/cuda_graph_runner.py +82 -19
- sglang/srt/model_executor/forward_batch_info.py +31 -1
- sglang/srt/model_executor/model_runner.py +123 -58
- sglang/srt/models/deepseek_nextn.py +1 -257
- sglang/srt/models/deepseek_v2.py +78 -18
- sglang/srt/models/kimi_vl.py +308 -0
- sglang/srt/models/kimi_vl_moonvit.py +639 -0
- sglang/srt/models/llama.py +92 -30
- sglang/srt/models/llama4.py +2 -1
- sglang/srt/models/llama_eagle.py +4 -1
- sglang/srt/models/llama_eagle3.py +4 -1
- sglang/srt/models/qwen2_moe.py +8 -3
- sglang/srt/models/qwen2_vl.py +0 -12
- sglang/srt/models/qwen3_moe.py +8 -3
- sglang/srt/openai_api/adapter.py +49 -8
- sglang/srt/openai_api/protocol.py +13 -1
- sglang/srt/reasoning_parser.py +25 -1
- sglang/srt/server_args.py +83 -24
- sglang/srt/speculative/eagle_worker.py +3 -2
- sglang/srt/utils.py +91 -9
- sglang/test/runners.py +4 -0
- sglang/test/send_one.py +84 -28
- sglang/test/test_utils.py +67 -0
- sglang/version.py +1 -1
- {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/METADATA +5 -4
- {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/RECORD +85 -60
- {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/WHEEL +1 -1
- {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -51,6 +51,7 @@ from sglang.srt.disaggregation.utils import (
|
|
51
51
|
ReqToMetadataIdxAllocator,
|
52
52
|
TransferBackend,
|
53
53
|
)
|
54
|
+
from sglang.srt.distributed import get_pp_group, get_world_group
|
54
55
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
55
56
|
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
56
57
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
@@ -114,7 +115,11 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
|
114
115
|
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
115
116
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
116
117
|
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
117
|
-
from sglang.srt.model_executor.forward_batch_info import
|
118
|
+
from sglang.srt.model_executor.forward_batch_info import (
|
119
|
+
ForwardBatch,
|
120
|
+
ForwardMode,
|
121
|
+
PPProxyTensors,
|
122
|
+
)
|
118
123
|
from sglang.srt.reasoning_parser import ReasoningParser
|
119
124
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
120
125
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
@@ -127,6 +132,7 @@ from sglang.srt.utils import (
|
|
127
132
|
get_bool_env_var,
|
128
133
|
get_zmq_socket,
|
129
134
|
kill_itself_when_parent_died,
|
135
|
+
point_to_point_pyobj,
|
130
136
|
pyspy_dump_schedulers,
|
131
137
|
set_gpu_proc_affinity,
|
132
138
|
set_random_seed,
|
@@ -145,8 +151,9 @@ RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
|
|
145
151
|
|
146
152
|
@dataclass
|
147
153
|
class GenerationBatchResult:
|
148
|
-
logits_output: LogitsProcessorOutput
|
149
|
-
|
154
|
+
logits_output: Optional[LogitsProcessorOutput]
|
155
|
+
pp_hidden_states_proxy_tensors: Optional[torch.Tensor]
|
156
|
+
next_token_ids: Optional[List[int]]
|
150
157
|
extend_input_len_per_req: List[int]
|
151
158
|
extend_logprob_start_len_per_req: List[int]
|
152
159
|
bid: int
|
@@ -171,12 +178,16 @@ class Scheduler(
|
|
171
178
|
port_args: PortArgs,
|
172
179
|
gpu_id: int,
|
173
180
|
tp_rank: int,
|
181
|
+
pp_rank: int,
|
174
182
|
dp_rank: Optional[int],
|
175
183
|
):
|
176
184
|
# Parse args
|
177
185
|
self.server_args = server_args
|
178
186
|
self.tp_rank = tp_rank
|
187
|
+
self.pp_rank = pp_rank
|
179
188
|
self.tp_size = server_args.tp_size
|
189
|
+
self.pp_size = server_args.pp_size
|
190
|
+
self.dp_size = server_args.dp_size
|
180
191
|
self.schedule_policy = server_args.schedule_policy
|
181
192
|
self.lora_paths = server_args.lora_paths
|
182
193
|
self.max_loras_per_batch = server_args.max_loras_per_batch
|
@@ -192,7 +203,6 @@ class Scheduler(
|
|
192
203
|
self.page_size = server_args.page_size
|
193
204
|
|
194
205
|
# Distributed rank info
|
195
|
-
self.dp_size = server_args.dp_size
|
196
206
|
self.attn_tp_rank, self.attn_tp_size, self.dp_rank = (
|
197
207
|
compute_dp_attention_world_info(
|
198
208
|
server_args.enable_dp_attention,
|
@@ -204,7 +214,7 @@ class Scheduler(
|
|
204
214
|
|
205
215
|
# Init inter-process communication
|
206
216
|
context = zmq.Context(2)
|
207
|
-
if self.attn_tp_rank == 0:
|
217
|
+
if self.pp_rank == 0 and self.attn_tp_rank == 0:
|
208
218
|
self.recv_from_tokenizer = get_zmq_socket(
|
209
219
|
context, zmq.PULL, port_args.scheduler_input_ipc_name, False
|
210
220
|
)
|
@@ -248,9 +258,6 @@ class Scheduler(
|
|
248
258
|
if not self.is_generation:
|
249
259
|
self.enable_overlap = False
|
250
260
|
logger.info("Overlap scheduler is disabled for embedding models.")
|
251
|
-
if self.model_config.is_multimodal:
|
252
|
-
self.enable_overlap = False
|
253
|
-
logger.info("Overlap scheduler is disabled for multimodal models.")
|
254
261
|
|
255
262
|
# Launch a tensor parallel worker
|
256
263
|
if self.enable_overlap:
|
@@ -262,6 +269,7 @@ class Scheduler(
|
|
262
269
|
server_args=server_args,
|
263
270
|
gpu_id=gpu_id,
|
264
271
|
tp_rank=tp_rank,
|
272
|
+
pp_rank=pp_rank,
|
265
273
|
dp_rank=dp_rank,
|
266
274
|
nccl_port=port_args.nccl_port,
|
267
275
|
)
|
@@ -295,8 +303,18 @@ class Scheduler(
|
|
295
303
|
_,
|
296
304
|
_,
|
297
305
|
) = self.tp_worker.get_worker_info()
|
298
|
-
|
306
|
+
if global_server_args_dict["max_micro_batch_size"] is None:
|
307
|
+
global_server_args_dict["max_micro_batch_size"] = max(
|
308
|
+
self.max_running_requests // server_args.pp_size, 1
|
309
|
+
)
|
310
|
+
|
311
|
+
self.tp_group = self.tp_worker.get_tp_group()
|
312
|
+
self.tp_cpu_group = self.tp_group.cpu_group
|
313
|
+
self.attn_tp_group = self.tp_worker.get_attention_tp_group()
|
299
314
|
self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
|
315
|
+
self.pp_group = get_pp_group()
|
316
|
+
self.world_group = get_world_group()
|
317
|
+
|
300
318
|
self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
|
301
319
|
global_server_args_dict.update(worker_global_server_args_dict)
|
302
320
|
set_random_seed(self.random_seed)
|
@@ -645,6 +663,7 @@ class Scheduler(
|
|
645
663
|
self.cur_batch = batch
|
646
664
|
|
647
665
|
if batch:
|
666
|
+
batch.launch_done = threading.Event()
|
648
667
|
result = self.run_batch(batch)
|
649
668
|
self.result_queue.append((batch.copy(), result))
|
650
669
|
|
@@ -656,7 +675,7 @@ class Scheduler(
|
|
656
675
|
forward_mode=ForwardMode.DUMMY_FIRST,
|
657
676
|
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
|
658
677
|
)
|
659
|
-
self.process_batch_result(tmp_batch, None)
|
678
|
+
self.process_batch_result(tmp_batch, None, batch.launch_done)
|
660
679
|
|
661
680
|
if self.last_batch:
|
662
681
|
# Process the results of the last batch
|
@@ -664,7 +683,10 @@ class Scheduler(
|
|
664
683
|
tmp_batch.next_batch_sampling_info = (
|
665
684
|
self.tp_worker.cur_sampling_info if batch else None
|
666
685
|
)
|
667
|
-
|
686
|
+
# NOTE: we should use current launched batch's launch_done event Instead of the last batch's
|
687
|
+
self.process_batch_result(
|
688
|
+
tmp_batch, tmp_result, batch.launch_done if batch else None
|
689
|
+
)
|
668
690
|
elif batch is None:
|
669
691
|
# When the server is idle, do self-check and re-init some states
|
670
692
|
self.check_memory()
|
@@ -672,26 +694,141 @@ class Scheduler(
|
|
672
694
|
|
673
695
|
self.last_batch = batch
|
674
696
|
|
697
|
+
@DynamicGradMode()
|
698
|
+
def event_loop_pp(self):
|
699
|
+
"""A non-overlap scheduler loop for pipeline parallelism."""
|
700
|
+
mbs = [None] * self.pp_size
|
701
|
+
last_mbs = [None] * self.pp_size
|
702
|
+
self.running_mbs = [
|
703
|
+
ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
|
704
|
+
]
|
705
|
+
bids = [None] * self.pp_size
|
706
|
+
pp_outputs: Optional[PPProxyTensors] = None
|
707
|
+
while True:
|
708
|
+
server_is_idle = True
|
709
|
+
for mb_id in range(self.pp_size):
|
710
|
+
self.running_batch = self.running_mbs[mb_id]
|
711
|
+
self.last_batch = last_mbs[mb_id]
|
712
|
+
|
713
|
+
recv_reqs = self.recv_requests()
|
714
|
+
self.process_input_requests(recv_reqs)
|
715
|
+
mbs[mb_id] = self.get_next_batch_to_run()
|
716
|
+
self.running_mbs[mb_id] = self.running_batch
|
717
|
+
|
718
|
+
self.cur_batch = mbs[mb_id]
|
719
|
+
if self.cur_batch:
|
720
|
+
server_is_idle = False
|
721
|
+
result = self.run_batch(self.cur_batch)
|
722
|
+
|
723
|
+
# send the outputs to the next step
|
724
|
+
if self.pp_group.is_last_rank:
|
725
|
+
if self.cur_batch:
|
726
|
+
next_token_ids, bids[mb_id] = (
|
727
|
+
result.next_token_ids,
|
728
|
+
result.bid,
|
729
|
+
)
|
730
|
+
pp_outputs = PPProxyTensors(
|
731
|
+
{
|
732
|
+
"next_token_ids": next_token_ids,
|
733
|
+
}
|
734
|
+
)
|
735
|
+
# send the output from the last round to let the next stage worker run post processing
|
736
|
+
self.pp_group.send_tensor_dict(
|
737
|
+
pp_outputs.tensors,
|
738
|
+
all_gather_group=self.attn_tp_group,
|
739
|
+
)
|
740
|
+
|
741
|
+
# receive outputs and post-process (filter finished reqs) the coming microbatch
|
742
|
+
next_mb_id = (mb_id + 1) % self.pp_size
|
743
|
+
next_pp_outputs = None
|
744
|
+
if mbs[next_mb_id] is not None:
|
745
|
+
next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
|
746
|
+
self.pp_group.recv_tensor_dict(
|
747
|
+
all_gather_group=self.attn_tp_group
|
748
|
+
)
|
749
|
+
)
|
750
|
+
mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
|
751
|
+
output_result = GenerationBatchResult(
|
752
|
+
logits_output=None,
|
753
|
+
pp_hidden_states_proxy_tensors=None,
|
754
|
+
next_token_ids=next_pp_outputs["next_token_ids"],
|
755
|
+
extend_input_len_per_req=None,
|
756
|
+
extend_logprob_start_len_per_req=None,
|
757
|
+
bid=bids[next_mb_id],
|
758
|
+
)
|
759
|
+
self.process_batch_result(mbs[next_mb_id], output_result)
|
760
|
+
last_mbs[next_mb_id] = mbs[next_mb_id]
|
761
|
+
|
762
|
+
# carry the outputs to the next stage
|
763
|
+
if not self.pp_group.is_last_rank:
|
764
|
+
if self.cur_batch:
|
765
|
+
bids[mb_id] = result.bid
|
766
|
+
if pp_outputs:
|
767
|
+
# send the outputs from the last round to let the next stage worker run post processing
|
768
|
+
self.pp_group.send_tensor_dict(
|
769
|
+
pp_outputs.tensors,
|
770
|
+
all_gather_group=self.attn_tp_group,
|
771
|
+
)
|
772
|
+
|
773
|
+
if not self.pp_group.is_last_rank:
|
774
|
+
# send out reqs to the next stage
|
775
|
+
dp_offset = self.dp_rank * self.attn_tp_size
|
776
|
+
if self.attn_tp_rank == 0:
|
777
|
+
point_to_point_pyobj(
|
778
|
+
recv_reqs,
|
779
|
+
self.pp_rank * self.tp_size + dp_offset,
|
780
|
+
self.world_group.cpu_group,
|
781
|
+
self.pp_rank * self.tp_size + dp_offset,
|
782
|
+
(self.pp_rank + 1) * self.tp_size + dp_offset,
|
783
|
+
)
|
784
|
+
|
785
|
+
# send out proxy tensors to the next stage
|
786
|
+
if self.cur_batch:
|
787
|
+
self.pp_group.send_tensor_dict(
|
788
|
+
result.pp_hidden_states_proxy_tensors,
|
789
|
+
all_gather_group=self.attn_tp_group,
|
790
|
+
)
|
791
|
+
|
792
|
+
pp_outputs = next_pp_outputs
|
793
|
+
|
794
|
+
# When the server is idle, self-check and re-init some states
|
795
|
+
if server_is_idle:
|
796
|
+
self.check_memory()
|
797
|
+
self.new_token_ratio = self.init_new_token_ratio
|
798
|
+
|
675
799
|
def recv_requests(self) -> List[Req]:
|
676
800
|
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
|
677
|
-
if self.
|
678
|
-
|
679
|
-
|
680
|
-
|
681
|
-
|
682
|
-
|
683
|
-
|
684
|
-
|
685
|
-
|
686
|
-
|
687
|
-
|
688
|
-
|
689
|
-
|
690
|
-
|
691
|
-
|
692
|
-
|
801
|
+
if self.pp_rank == 0:
|
802
|
+
if self.attn_tp_rank == 0:
|
803
|
+
recv_reqs = []
|
804
|
+
|
805
|
+
while True:
|
806
|
+
try:
|
807
|
+
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
|
808
|
+
except zmq.ZMQError:
|
809
|
+
break
|
810
|
+
recv_reqs.append(recv_req)
|
811
|
+
|
812
|
+
while True:
|
813
|
+
try:
|
814
|
+
recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK)
|
815
|
+
except zmq.ZMQError:
|
816
|
+
break
|
817
|
+
recv_reqs.append(recv_rpc)
|
818
|
+
else:
|
819
|
+
recv_reqs = None
|
693
820
|
else:
|
694
|
-
|
821
|
+
if self.attn_tp_rank == 0:
|
822
|
+
dp_offset = self.dp_rank * self.attn_tp_size
|
823
|
+
recv_reqs = point_to_point_pyobj(
|
824
|
+
[],
|
825
|
+
self.pp_rank * self.tp_size + dp_offset,
|
826
|
+
self.world_group.cpu_group,
|
827
|
+
(self.pp_rank - 1) * self.tp_size + dp_offset,
|
828
|
+
self.pp_rank * self.tp_size + dp_offset,
|
829
|
+
)
|
830
|
+
else:
|
831
|
+
recv_reqs = None
|
695
832
|
|
696
833
|
if self.server_args.enable_dp_attention:
|
697
834
|
if self.attn_tp_rank == 0:
|
@@ -714,20 +851,27 @@ class Scheduler(
|
|
714
851
|
control_reqs = None
|
715
852
|
|
716
853
|
if self.attn_tp_size != 1:
|
717
|
-
attn_tp_rank_0 = self.dp_rank * self.attn_tp_size
|
718
854
|
work_reqs = broadcast_pyobj(
|
719
855
|
work_reqs,
|
720
|
-
self.
|
856
|
+
self.attn_tp_group.rank,
|
721
857
|
self.attn_tp_cpu_group,
|
722
|
-
src=
|
858
|
+
src=self.attn_tp_group.ranks[0],
|
723
859
|
)
|
724
860
|
if self.tp_size != 1:
|
725
861
|
control_reqs = broadcast_pyobj(
|
726
|
-
control_reqs,
|
862
|
+
control_reqs,
|
863
|
+
self.tp_group.rank,
|
864
|
+
self.tp_cpu_group,
|
865
|
+
src=self.tp_group.ranks[0],
|
727
866
|
)
|
728
867
|
recv_reqs = work_reqs + control_reqs
|
729
868
|
elif self.tp_size != 1:
|
730
|
-
recv_reqs = broadcast_pyobj(
|
869
|
+
recv_reqs = broadcast_pyobj(
|
870
|
+
recv_reqs,
|
871
|
+
self.tp_group.rank,
|
872
|
+
self.tp_cpu_group,
|
873
|
+
src=self.tp_group.ranks[0],
|
874
|
+
)
|
731
875
|
return recv_reqs
|
732
876
|
|
733
877
|
def process_input_requests(self, recv_reqs: List):
|
@@ -899,6 +1043,7 @@ class Scheduler(
|
|
899
1043
|
add_to_grammar_queue = True
|
900
1044
|
|
901
1045
|
if add_to_grammar_queue:
|
1046
|
+
req.queue_time_start = time.time()
|
902
1047
|
self.grammar_queue.append(req)
|
903
1048
|
else:
|
904
1049
|
self._add_request_to_queue(req)
|
@@ -1024,12 +1169,14 @@ class Scheduler(
|
|
1024
1169
|
|
1025
1170
|
self.metrics_collector.log_stats(self.stats)
|
1026
1171
|
|
1027
|
-
def log_decode_stats(self):
|
1172
|
+
def log_decode_stats(self, running_batch=None):
|
1173
|
+
batch = running_batch or self.running_batch
|
1174
|
+
|
1028
1175
|
gap_latency = time.time() - self.last_decode_stats_tic
|
1029
1176
|
self.last_decode_stats_tic = time.time()
|
1030
1177
|
self.last_gen_throughput = self.num_generated_tokens / gap_latency
|
1031
1178
|
self.num_generated_tokens = 0
|
1032
|
-
num_running_reqs = len(
|
1179
|
+
num_running_reqs = len(batch.reqs)
|
1033
1180
|
num_used = self.max_total_num_tokens - (
|
1034
1181
|
self.token_to_kv_pool_allocator.available_size()
|
1035
1182
|
+ self.tree_cache.evictable_size()
|
@@ -1129,19 +1276,25 @@ class Scheduler(
|
|
1129
1276
|
|
1130
1277
|
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
|
1131
1278
|
# Merge the prefill batch into the running batch
|
1279
|
+
chunked_req_to_exclude = set()
|
1280
|
+
if self.chunked_req:
|
1281
|
+
# Move the chunked request out of the batch so that we can merge
|
1282
|
+
# only finished requests to running_batch.
|
1283
|
+
chunked_req_to_exclude.add(self.chunked_req)
|
1284
|
+
self.tree_cache.cache_unfinished_req(self.chunked_req)
|
1285
|
+
# chunked request keeps its rid but will get a new req_pool_idx
|
1286
|
+
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
|
1132
1287
|
if self.last_batch and self.last_batch.forward_mode.is_extend():
|
1133
|
-
if self.chunked_req:
|
1134
|
-
#
|
1135
|
-
#
|
1136
|
-
self.last_batch.
|
1137
|
-
self.tree_cache.cache_unfinished_req(self.chunked_req)
|
1138
|
-
# chunked request keeps its rid but will get a new req_pool_idx
|
1139
|
-
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
|
1140
|
-
self.running_batch.batch_is_full = False
|
1288
|
+
if self.last_batch.chunked_req is not None:
|
1289
|
+
# In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req.
|
1290
|
+
# We need to discard it.
|
1291
|
+
chunked_req_to_exclude.add(self.last_batch.chunked_req)
|
1141
1292
|
|
1142
1293
|
# Filter batch
|
1143
1294
|
last_bs = self.last_batch.batch_size()
|
1144
|
-
self.last_batch.filter_batch(
|
1295
|
+
self.last_batch.filter_batch(
|
1296
|
+
chunked_req_to_exclude=list(chunked_req_to_exclude)
|
1297
|
+
)
|
1145
1298
|
if self.last_batch.batch_size() < last_bs:
|
1146
1299
|
self.running_batch.batch_is_full = False
|
1147
1300
|
|
@@ -1171,6 +1324,12 @@ class Scheduler(
|
|
1171
1324
|
|
1172
1325
|
return ret
|
1173
1326
|
|
1327
|
+
def get_num_allocatable_reqs(self, running_bs):
|
1328
|
+
res = global_server_args_dict["max_micro_batch_size"] - running_bs
|
1329
|
+
if self.pp_size > 1:
|
1330
|
+
res = min(res, self.req_to_token_pool.available_size())
|
1331
|
+
return res
|
1332
|
+
|
1174
1333
|
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
|
1175
1334
|
# Check if the grammar is ready in the grammar queue
|
1176
1335
|
if self.grammar_queue:
|
@@ -1183,7 +1342,12 @@ class Scheduler(
|
|
1183
1342
|
return None
|
1184
1343
|
|
1185
1344
|
running_bs = len(self.running_batch.reqs)
|
1186
|
-
|
1345
|
+
# Igore the check if self.chunked_req is not None.
|
1346
|
+
# In the non-PP case, when self.chunked_req is not None, num_allocatable_reqs should always be greater than 0,
|
1347
|
+
# as the space for the chunked request has just been released.
|
1348
|
+
# In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
|
1349
|
+
# Instead, we should always allow chunked request to be added, otherwise, there will be a memory leak.
|
1350
|
+
if self.get_num_allocatable_reqs(running_bs) <= 0 and not self.chunked_req:
|
1187
1351
|
self.running_batch.batch_is_full = True
|
1188
1352
|
return None
|
1189
1353
|
|
@@ -1227,7 +1391,7 @@ class Scheduler(
|
|
1227
1391
|
self.running_batch.batch_is_full = True
|
1228
1392
|
break
|
1229
1393
|
|
1230
|
-
if
|
1394
|
+
if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs):
|
1231
1395
|
self.running_batch.batch_is_full = True
|
1232
1396
|
break
|
1233
1397
|
|
@@ -1239,16 +1403,14 @@ class Scheduler(
|
|
1239
1403
|
res = adder.add_one_req(
|
1240
1404
|
req, self.chunked_req, self.enable_hierarchical_cache
|
1241
1405
|
)
|
1406
|
+
|
1242
1407
|
if res != AddReqResult.CONTINUE:
|
1243
1408
|
if res == AddReqResult.NO_TOKEN:
|
1244
1409
|
if self.enable_hierarchical_cache:
|
1245
1410
|
# Set batch_is_full after making sure there are requests that can be served
|
1246
1411
|
self.running_batch.batch_is_full = len(
|
1247
1412
|
adder.can_run_list
|
1248
|
-
) > 0 or (
|
1249
|
-
self.running_batch is not None
|
1250
|
-
and not self.running_batch.is_empty()
|
1251
|
-
)
|
1413
|
+
) > 0 or (not self.running_batch.is_empty())
|
1252
1414
|
else:
|
1253
1415
|
self.running_batch.batch_is_full = True
|
1254
1416
|
break
|
@@ -1291,6 +1453,7 @@ class Scheduler(
|
|
1291
1453
|
self.enable_overlap,
|
1292
1454
|
self.spec_algorithm,
|
1293
1455
|
self.server_args.enable_custom_logit_processor,
|
1456
|
+
chunked_req=self.chunked_req,
|
1294
1457
|
)
|
1295
1458
|
new_batch.prepare_for_extend()
|
1296
1459
|
|
@@ -1368,9 +1531,14 @@ class Scheduler(
|
|
1368
1531
|
if self.is_generation:
|
1369
1532
|
if self.spec_algorithm.is_none():
|
1370
1533
|
model_worker_batch = batch.get_model_worker_batch()
|
1371
|
-
|
1372
|
-
|
1373
|
-
|
1534
|
+
if self.pp_group.is_last_rank:
|
1535
|
+
logits_output, next_token_ids = (
|
1536
|
+
self.tp_worker.forward_batch_generation(model_worker_batch)
|
1537
|
+
)
|
1538
|
+
else:
|
1539
|
+
pp_hidden_states_proxy_tensors, _ = (
|
1540
|
+
self.tp_worker.forward_batch_generation(model_worker_batch)
|
1541
|
+
)
|
1374
1542
|
bid = model_worker_batch.bid
|
1375
1543
|
else:
|
1376
1544
|
(
|
@@ -1384,7 +1552,9 @@ class Scheduler(
|
|
1384
1552
|
)
|
1385
1553
|
self.spec_num_total_forward_ct += batch.batch_size()
|
1386
1554
|
self.num_generated_tokens += num_accepted_tokens
|
1387
|
-
|
1555
|
+
|
1556
|
+
if self.pp_group.is_last_rank:
|
1557
|
+
batch.output_ids = next_token_ids
|
1388
1558
|
|
1389
1559
|
# These 2 values are needed for processing the output, but the values can be
|
1390
1560
|
# modified by overlap schedule. So we have to copy them here so that
|
@@ -1399,8 +1569,13 @@ class Scheduler(
|
|
1399
1569
|
extend_logprob_start_len_per_req = None
|
1400
1570
|
|
1401
1571
|
ret = GenerationBatchResult(
|
1402
|
-
logits_output=logits_output,
|
1403
|
-
|
1572
|
+
logits_output=logits_output if self.pp_group.is_last_rank else None,
|
1573
|
+
pp_hidden_states_proxy_tensors=(
|
1574
|
+
pp_hidden_states_proxy_tensors
|
1575
|
+
if not self.pp_group.is_last_rank
|
1576
|
+
else None
|
1577
|
+
),
|
1578
|
+
next_token_ids=next_token_ids if self.pp_group.is_last_rank else None,
|
1404
1579
|
extend_input_len_per_req=extend_input_len_per_req,
|
1405
1580
|
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
|
1406
1581
|
bid=bid,
|
@@ -1417,14 +1592,15 @@ class Scheduler(
|
|
1417
1592
|
self,
|
1418
1593
|
batch: ScheduleBatch,
|
1419
1594
|
result: Union[GenerationBatchResult, EmbeddingBatchResult],
|
1595
|
+
launch_done: Optional[threading.Event] = None,
|
1420
1596
|
):
|
1421
1597
|
if batch.forward_mode.is_decode():
|
1422
|
-
self.process_batch_result_decode(batch, result)
|
1598
|
+
self.process_batch_result_decode(batch, result, launch_done)
|
1423
1599
|
elif batch.forward_mode.is_extend():
|
1424
|
-
self.process_batch_result_prefill(batch, result)
|
1600
|
+
self.process_batch_result_prefill(batch, result, launch_done)
|
1425
1601
|
elif batch.forward_mode.is_idle():
|
1426
1602
|
if self.enable_overlap:
|
1427
|
-
self.tp_worker.
|
1603
|
+
self.tp_worker.resolve_last_batch_result(launch_done)
|
1428
1604
|
if batch.next_batch_sampling_info:
|
1429
1605
|
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
1430
1606
|
self.current_stream.synchronize()
|
@@ -1550,6 +1726,7 @@ class Scheduler(
|
|
1550
1726
|
|
1551
1727
|
def move_ready_grammar_requests(self):
|
1552
1728
|
"""Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
|
1729
|
+
|
1553
1730
|
num_ready_reqs = 0
|
1554
1731
|
for req in self.grammar_queue:
|
1555
1732
|
try:
|
@@ -1616,7 +1793,11 @@ class Scheduler(
|
|
1616
1793
|
|
1617
1794
|
def flush_cache(self):
|
1618
1795
|
"""Flush the memory pool and cache."""
|
1619
|
-
if
|
1796
|
+
if (
|
1797
|
+
len(self.waiting_queue) == 0
|
1798
|
+
and self.running_batch.is_empty()
|
1799
|
+
and (self.pp_size == 1 or all(x.is_empty() for x in self.running_mbs))
|
1800
|
+
):
|
1620
1801
|
self.cur_batch = None
|
1621
1802
|
self.last_batch = None
|
1622
1803
|
self.tree_cache.reset()
|
@@ -1654,7 +1835,6 @@ class Scheduler(
|
|
1654
1835
|
ret["avg_spec_accept_length"] = (
|
1655
1836
|
self.cum_spec_accept_length / self.cum_spec_accept_count
|
1656
1837
|
)
|
1657
|
-
|
1658
1838
|
if RECORD_STEP_TIME:
|
1659
1839
|
ret["step_time_dict"] = self.step_time_dict
|
1660
1840
|
return GetInternalStateReqOutput(
|
@@ -1665,6 +1845,7 @@ class Scheduler(
|
|
1665
1845
|
server_args_dict = recv_req.server_args
|
1666
1846
|
args_allow_update = set(
|
1667
1847
|
[
|
1848
|
+
"max_micro_batch_size",
|
1668
1849
|
"speculative_accept_threshold_single",
|
1669
1850
|
"speculative_accept_threshold_acc",
|
1670
1851
|
]
|
@@ -1675,6 +1856,14 @@ class Scheduler(
|
|
1675
1856
|
logging.warning(f"Updating {k} is not supported.")
|
1676
1857
|
if_success = False
|
1677
1858
|
break
|
1859
|
+
elif k == "max_micro_batch_size" and (
|
1860
|
+
v > self.max_running_requests // self.pp_size or v < 1
|
1861
|
+
):
|
1862
|
+
logging.warning(
|
1863
|
+
f"Updating {k} to {v} is rejected because it is out of the valid range [1, {self.max_running_requests // self.pp_size}]."
|
1864
|
+
)
|
1865
|
+
if_success = False
|
1866
|
+
break
|
1678
1867
|
if if_success:
|
1679
1868
|
if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
|
1680
1869
|
avg_spec_accept_length = (
|
@@ -1956,6 +2145,16 @@ class Scheduler(
|
|
1956
2145
|
else:
|
1957
2146
|
del self.sessions[session_id]
|
1958
2147
|
|
2148
|
+
def get_print_prefix(self):
|
2149
|
+
prefix = ""
|
2150
|
+
if self.dp_rank is not None:
|
2151
|
+
prefix += f" DP{self.dp_rank}"
|
2152
|
+
if self.server_args.tp_size > 1:
|
2153
|
+
prefix += f" TP{self.tp_rank}"
|
2154
|
+
if self.pp_size > 1:
|
2155
|
+
prefix += f" PP{self.pp_rank}"
|
2156
|
+
return prefix
|
2157
|
+
|
1959
2158
|
|
1960
2159
|
def is_health_check_generate_req(recv_req):
|
1961
2160
|
return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
|
@@ -1980,14 +2179,18 @@ def run_scheduler_process(
|
|
1980
2179
|
port_args: PortArgs,
|
1981
2180
|
gpu_id: int,
|
1982
2181
|
tp_rank: int,
|
2182
|
+
pp_rank: int,
|
1983
2183
|
dp_rank: Optional[int],
|
1984
2184
|
pipe_writer,
|
1985
2185
|
):
|
1986
2186
|
# Generate the prefix
|
1987
|
-
|
1988
|
-
|
1989
|
-
|
1990
|
-
|
2187
|
+
prefix = ""
|
2188
|
+
if dp_rank is not None:
|
2189
|
+
prefix += f" DP{dp_rank}"
|
2190
|
+
if server_args.tp_size > 1:
|
2191
|
+
prefix += f" TP{tp_rank}"
|
2192
|
+
if server_args.pp_size > 1:
|
2193
|
+
prefix += f" PP{pp_rank}"
|
1991
2194
|
|
1992
2195
|
# Config the process
|
1993
2196
|
kill_itself_when_parent_died()
|
@@ -2009,7 +2212,7 @@ def run_scheduler_process(
|
|
2009
2212
|
|
2010
2213
|
# Create a scheduler and run the event loop
|
2011
2214
|
try:
|
2012
|
-
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
|
2215
|
+
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
|
2013
2216
|
pipe_writer.send(
|
2014
2217
|
{
|
2015
2218
|
"status": "ready",
|
@@ -2020,7 +2223,9 @@ def run_scheduler_process(
|
|
2020
2223
|
disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
|
2021
2224
|
|
2022
2225
|
if disaggregation_mode == DisaggregationMode.NULL:
|
2023
|
-
if
|
2226
|
+
if server_args.pp_size > 1:
|
2227
|
+
scheduler.event_loop_pp()
|
2228
|
+
elif scheduler.enable_overlap:
|
2024
2229
|
scheduler.event_loop_overlap()
|
2025
2230
|
else:
|
2026
2231
|
scheduler.event_loop_normal()
|
@@ -2029,6 +2234,7 @@ def run_scheduler_process(
|
|
2029
2234
|
scheduler.event_loop_overlap_disagg_prefill()
|
2030
2235
|
else:
|
2031
2236
|
scheduler.event_loop_normal_disagg_prefill()
|
2237
|
+
|
2032
2238
|
elif disaggregation_mode == DisaggregationMode.DECODE:
|
2033
2239
|
if scheduler.enable_overlap:
|
2034
2240
|
scheduler.event_loop_overlap_disagg_decode()
|