sglang 0.4.6.post1__py3-none-any.whl → 0.4.6.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -0
- sglang/check_env.py +3 -3
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/kimi_vl.py +38 -0
- sglang/srt/configs/kimi_vl_moonvit.py +32 -0
- sglang/srt/configs/model_config.py +15 -0
- sglang/srt/conversation.py +122 -1
- sglang/srt/entrypoints/engine.py +44 -22
- sglang/srt/function_call_parser.py +97 -0
- sglang/srt/hf_transformers_utils.py +2 -0
- sglang/srt/layers/attention/cutlass_mla_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +107 -82
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -16
- sglang/srt/layers/attention/flashmla_backend.py +3 -0
- sglang/srt/layers/dp_attention.py +5 -2
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +1 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +8 -6
- sglang/srt/layers/quantization/__init__.py +2 -2
- sglang/srt/layers/quantization/deep_gemm.py +1 -1
- sglang/srt/layers/utils.py +35 -0
- sglang/srt/lora/layers.py +35 -9
- sglang/srt/lora/lora_manager.py +84 -35
- sglang/srt/managers/data_parallel_controller.py +52 -34
- sglang/srt/managers/multimodal_processors/kimi_vl.py +73 -0
- sglang/srt/managers/schedule_batch.py +25 -15
- sglang/srt/managers/scheduler.py +263 -59
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -1
- sglang/srt/managers/tp_worker.py +51 -16
- sglang/srt/managers/tp_worker_overlap_thread.py +9 -3
- sglang/srt/mem_cache/memory_pool.py +70 -36
- sglang/srt/model_executor/cuda_graph_runner.py +82 -19
- sglang/srt/model_executor/forward_batch_info.py +31 -1
- sglang/srt/model_executor/model_runner.py +115 -57
- sglang/srt/models/deepseek_nextn.py +1 -257
- sglang/srt/models/deepseek_v2.py +78 -18
- sglang/srt/models/kimi_vl.py +308 -0
- sglang/srt/models/kimi_vl_moonvit.py +639 -0
- sglang/srt/models/llama.py +92 -30
- sglang/srt/models/llama4.py +2 -1
- sglang/srt/models/llama_eagle.py +4 -1
- sglang/srt/models/llama_eagle3.py +4 -1
- sglang/srt/models/qwen2_moe.py +8 -3
- sglang/srt/models/qwen2_vl.py +0 -12
- sglang/srt/models/qwen3_moe.py +8 -3
- sglang/srt/openai_api/adapter.py +34 -22
- sglang/srt/openai_api/protocol.py +11 -1
- sglang/srt/server_args.py +67 -22
- sglang/srt/speculative/eagle_worker.py +3 -2
- sglang/srt/utils.py +88 -9
- sglang/test/runners.py +4 -0
- sglang/test/test_utils.py +29 -0
- sglang/version.py +1 -1
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post2.dist-info}/METADATA +5 -4
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post2.dist-info}/RECORD +61 -51
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post2.dist-info}/WHEEL +1 -1
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post2.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -51,6 +51,7 @@ from sglang.srt.disaggregation.utils import (
|
|
51
51
|
ReqToMetadataIdxAllocator,
|
52
52
|
TransferBackend,
|
53
53
|
)
|
54
|
+
from sglang.srt.distributed import get_pp_group, get_world_group
|
54
55
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
55
56
|
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
56
57
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
@@ -114,7 +115,11 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
|
114
115
|
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
115
116
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
116
117
|
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
117
|
-
from sglang.srt.model_executor.forward_batch_info import
|
118
|
+
from sglang.srt.model_executor.forward_batch_info import (
|
119
|
+
ForwardBatch,
|
120
|
+
ForwardMode,
|
121
|
+
PPProxyTensors,
|
122
|
+
)
|
118
123
|
from sglang.srt.reasoning_parser import ReasoningParser
|
119
124
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
120
125
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
@@ -127,6 +132,7 @@ from sglang.srt.utils import (
|
|
127
132
|
get_bool_env_var,
|
128
133
|
get_zmq_socket,
|
129
134
|
kill_itself_when_parent_died,
|
135
|
+
point_to_point_pyobj,
|
130
136
|
pyspy_dump_schedulers,
|
131
137
|
set_gpu_proc_affinity,
|
132
138
|
set_random_seed,
|
@@ -145,8 +151,9 @@ RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
|
|
145
151
|
|
146
152
|
@dataclass
|
147
153
|
class GenerationBatchResult:
|
148
|
-
logits_output: LogitsProcessorOutput
|
149
|
-
|
154
|
+
logits_output: Optional[LogitsProcessorOutput]
|
155
|
+
pp_hidden_states_proxy_tensors: Optional[torch.Tensor]
|
156
|
+
next_token_ids: Optional[List[int]]
|
150
157
|
extend_input_len_per_req: List[int]
|
151
158
|
extend_logprob_start_len_per_req: List[int]
|
152
159
|
bid: int
|
@@ -171,12 +178,16 @@ class Scheduler(
|
|
171
178
|
port_args: PortArgs,
|
172
179
|
gpu_id: int,
|
173
180
|
tp_rank: int,
|
181
|
+
pp_rank: int,
|
174
182
|
dp_rank: Optional[int],
|
175
183
|
):
|
176
184
|
# Parse args
|
177
185
|
self.server_args = server_args
|
178
186
|
self.tp_rank = tp_rank
|
187
|
+
self.pp_rank = pp_rank
|
179
188
|
self.tp_size = server_args.tp_size
|
189
|
+
self.pp_size = server_args.pp_size
|
190
|
+
self.dp_size = server_args.dp_size
|
180
191
|
self.schedule_policy = server_args.schedule_policy
|
181
192
|
self.lora_paths = server_args.lora_paths
|
182
193
|
self.max_loras_per_batch = server_args.max_loras_per_batch
|
@@ -192,7 +203,6 @@ class Scheduler(
|
|
192
203
|
self.page_size = server_args.page_size
|
193
204
|
|
194
205
|
# Distributed rank info
|
195
|
-
self.dp_size = server_args.dp_size
|
196
206
|
self.attn_tp_rank, self.attn_tp_size, self.dp_rank = (
|
197
207
|
compute_dp_attention_world_info(
|
198
208
|
server_args.enable_dp_attention,
|
@@ -204,7 +214,7 @@ class Scheduler(
|
|
204
214
|
|
205
215
|
# Init inter-process communication
|
206
216
|
context = zmq.Context(2)
|
207
|
-
if self.attn_tp_rank == 0:
|
217
|
+
if self.pp_rank == 0 and self.attn_tp_rank == 0:
|
208
218
|
self.recv_from_tokenizer = get_zmq_socket(
|
209
219
|
context, zmq.PULL, port_args.scheduler_input_ipc_name, False
|
210
220
|
)
|
@@ -259,6 +269,7 @@ class Scheduler(
|
|
259
269
|
server_args=server_args,
|
260
270
|
gpu_id=gpu_id,
|
261
271
|
tp_rank=tp_rank,
|
272
|
+
pp_rank=pp_rank,
|
262
273
|
dp_rank=dp_rank,
|
263
274
|
nccl_port=port_args.nccl_port,
|
264
275
|
)
|
@@ -292,8 +303,18 @@ class Scheduler(
|
|
292
303
|
_,
|
293
304
|
_,
|
294
305
|
) = self.tp_worker.get_worker_info()
|
295
|
-
|
306
|
+
if global_server_args_dict["max_micro_batch_size"] is None:
|
307
|
+
global_server_args_dict["max_micro_batch_size"] = max(
|
308
|
+
self.max_running_requests // server_args.pp_size, 1
|
309
|
+
)
|
310
|
+
|
311
|
+
self.tp_group = self.tp_worker.get_tp_group()
|
312
|
+
self.tp_cpu_group = self.tp_group.cpu_group
|
313
|
+
self.attn_tp_group = self.tp_worker.get_attention_tp_group()
|
296
314
|
self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
|
315
|
+
self.pp_group = get_pp_group()
|
316
|
+
self.world_group = get_world_group()
|
317
|
+
|
297
318
|
self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
|
298
319
|
global_server_args_dict.update(worker_global_server_args_dict)
|
299
320
|
set_random_seed(self.random_seed)
|
@@ -673,26 +694,141 @@ class Scheduler(
|
|
673
694
|
|
674
695
|
self.last_batch = batch
|
675
696
|
|
697
|
+
@DynamicGradMode()
|
698
|
+
def event_loop_pp(self):
|
699
|
+
"""A non-overlap scheduler loop for pipeline parallelism."""
|
700
|
+
mbs = [None] * self.pp_size
|
701
|
+
last_mbs = [None] * self.pp_size
|
702
|
+
self.running_mbs = [
|
703
|
+
ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
|
704
|
+
]
|
705
|
+
bids = [None] * self.pp_size
|
706
|
+
pp_outputs: Optional[PPProxyTensors] = None
|
707
|
+
while True:
|
708
|
+
server_is_idle = True
|
709
|
+
for mb_id in range(self.pp_size):
|
710
|
+
self.running_batch = self.running_mbs[mb_id]
|
711
|
+
self.last_batch = last_mbs[mb_id]
|
712
|
+
|
713
|
+
recv_reqs = self.recv_requests()
|
714
|
+
self.process_input_requests(recv_reqs)
|
715
|
+
mbs[mb_id] = self.get_next_batch_to_run()
|
716
|
+
self.running_mbs[mb_id] = self.running_batch
|
717
|
+
|
718
|
+
self.cur_batch = mbs[mb_id]
|
719
|
+
if self.cur_batch:
|
720
|
+
server_is_idle = False
|
721
|
+
result = self.run_batch(self.cur_batch)
|
722
|
+
|
723
|
+
# send the outputs to the next step
|
724
|
+
if self.pp_group.is_last_rank:
|
725
|
+
if self.cur_batch:
|
726
|
+
next_token_ids, bids[mb_id] = (
|
727
|
+
result.next_token_ids,
|
728
|
+
result.bid,
|
729
|
+
)
|
730
|
+
pp_outputs = PPProxyTensors(
|
731
|
+
{
|
732
|
+
"next_token_ids": next_token_ids,
|
733
|
+
}
|
734
|
+
)
|
735
|
+
# send the output from the last round to let the next stage worker run post processing
|
736
|
+
self.pp_group.send_tensor_dict(
|
737
|
+
pp_outputs.tensors,
|
738
|
+
all_gather_group=self.attn_tp_group,
|
739
|
+
)
|
740
|
+
|
741
|
+
# receive outputs and post-process (filter finished reqs) the coming microbatch
|
742
|
+
next_mb_id = (mb_id + 1) % self.pp_size
|
743
|
+
next_pp_outputs = None
|
744
|
+
if mbs[next_mb_id] is not None:
|
745
|
+
next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
|
746
|
+
self.pp_group.recv_tensor_dict(
|
747
|
+
all_gather_group=self.attn_tp_group
|
748
|
+
)
|
749
|
+
)
|
750
|
+
mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
|
751
|
+
output_result = GenerationBatchResult(
|
752
|
+
logits_output=None,
|
753
|
+
pp_hidden_states_proxy_tensors=None,
|
754
|
+
next_token_ids=next_pp_outputs["next_token_ids"],
|
755
|
+
extend_input_len_per_req=None,
|
756
|
+
extend_logprob_start_len_per_req=None,
|
757
|
+
bid=bids[next_mb_id],
|
758
|
+
)
|
759
|
+
self.process_batch_result(mbs[next_mb_id], output_result)
|
760
|
+
last_mbs[next_mb_id] = mbs[next_mb_id]
|
761
|
+
|
762
|
+
# carry the outputs to the next stage
|
763
|
+
if not self.pp_group.is_last_rank:
|
764
|
+
if self.cur_batch:
|
765
|
+
bids[mb_id] = result.bid
|
766
|
+
if pp_outputs:
|
767
|
+
# send the outputs from the last round to let the next stage worker run post processing
|
768
|
+
self.pp_group.send_tensor_dict(
|
769
|
+
pp_outputs.tensors,
|
770
|
+
all_gather_group=self.attn_tp_group,
|
771
|
+
)
|
772
|
+
|
773
|
+
if not self.pp_group.is_last_rank:
|
774
|
+
# send out reqs to the next stage
|
775
|
+
dp_offset = self.dp_rank * self.attn_tp_size
|
776
|
+
if self.attn_tp_rank == 0:
|
777
|
+
point_to_point_pyobj(
|
778
|
+
recv_reqs,
|
779
|
+
self.pp_rank * self.tp_size + dp_offset,
|
780
|
+
self.world_group.cpu_group,
|
781
|
+
self.pp_rank * self.tp_size + dp_offset,
|
782
|
+
(self.pp_rank + 1) * self.tp_size + dp_offset,
|
783
|
+
)
|
784
|
+
|
785
|
+
# send out proxy tensors to the next stage
|
786
|
+
if self.cur_batch:
|
787
|
+
self.pp_group.send_tensor_dict(
|
788
|
+
result.pp_hidden_states_proxy_tensors,
|
789
|
+
all_gather_group=self.attn_tp_group,
|
790
|
+
)
|
791
|
+
|
792
|
+
pp_outputs = next_pp_outputs
|
793
|
+
|
794
|
+
# When the server is idle, self-check and re-init some states
|
795
|
+
if server_is_idle:
|
796
|
+
self.check_memory()
|
797
|
+
self.new_token_ratio = self.init_new_token_ratio
|
798
|
+
|
676
799
|
def recv_requests(self) -> List[Req]:
|
677
800
|
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
|
678
|
-
if self.
|
679
|
-
|
680
|
-
|
681
|
-
|
682
|
-
|
683
|
-
|
684
|
-
|
685
|
-
|
686
|
-
|
687
|
-
|
688
|
-
|
689
|
-
|
690
|
-
|
691
|
-
|
692
|
-
|
693
|
-
|
801
|
+
if self.pp_rank == 0:
|
802
|
+
if self.attn_tp_rank == 0:
|
803
|
+
recv_reqs = []
|
804
|
+
|
805
|
+
while True:
|
806
|
+
try:
|
807
|
+
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
|
808
|
+
except zmq.ZMQError:
|
809
|
+
break
|
810
|
+
recv_reqs.append(recv_req)
|
811
|
+
|
812
|
+
while True:
|
813
|
+
try:
|
814
|
+
recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK)
|
815
|
+
except zmq.ZMQError:
|
816
|
+
break
|
817
|
+
recv_reqs.append(recv_rpc)
|
818
|
+
else:
|
819
|
+
recv_reqs = None
|
694
820
|
else:
|
695
|
-
|
821
|
+
if self.attn_tp_rank == 0:
|
822
|
+
dp_offset = self.dp_rank * self.attn_tp_size
|
823
|
+
recv_reqs = point_to_point_pyobj(
|
824
|
+
[],
|
825
|
+
self.pp_rank * self.tp_size + dp_offset,
|
826
|
+
self.world_group.cpu_group,
|
827
|
+
(self.pp_rank - 1) * self.tp_size + dp_offset,
|
828
|
+
self.pp_rank * self.tp_size + dp_offset,
|
829
|
+
)
|
830
|
+
else:
|
831
|
+
recv_reqs = None
|
696
832
|
|
697
833
|
if self.server_args.enable_dp_attention:
|
698
834
|
if self.attn_tp_rank == 0:
|
@@ -715,20 +851,27 @@ class Scheduler(
|
|
715
851
|
control_reqs = None
|
716
852
|
|
717
853
|
if self.attn_tp_size != 1:
|
718
|
-
attn_tp_rank_0 = self.dp_rank * self.attn_tp_size
|
719
854
|
work_reqs = broadcast_pyobj(
|
720
855
|
work_reqs,
|
721
|
-
self.
|
856
|
+
self.attn_tp_group.rank,
|
722
857
|
self.attn_tp_cpu_group,
|
723
|
-
src=
|
858
|
+
src=self.attn_tp_group.ranks[0],
|
724
859
|
)
|
725
860
|
if self.tp_size != 1:
|
726
861
|
control_reqs = broadcast_pyobj(
|
727
|
-
control_reqs,
|
862
|
+
control_reqs,
|
863
|
+
self.tp_group.rank,
|
864
|
+
self.tp_cpu_group,
|
865
|
+
src=self.tp_group.ranks[0],
|
728
866
|
)
|
729
867
|
recv_reqs = work_reqs + control_reqs
|
730
868
|
elif self.tp_size != 1:
|
731
|
-
recv_reqs = broadcast_pyobj(
|
869
|
+
recv_reqs = broadcast_pyobj(
|
870
|
+
recv_reqs,
|
871
|
+
self.tp_group.rank,
|
872
|
+
self.tp_cpu_group,
|
873
|
+
src=self.tp_group.ranks[0],
|
874
|
+
)
|
732
875
|
return recv_reqs
|
733
876
|
|
734
877
|
def process_input_requests(self, recv_reqs: List):
|
@@ -900,6 +1043,7 @@ class Scheduler(
|
|
900
1043
|
add_to_grammar_queue = True
|
901
1044
|
|
902
1045
|
if add_to_grammar_queue:
|
1046
|
+
req.queue_time_start = time.time()
|
903
1047
|
self.grammar_queue.append(req)
|
904
1048
|
else:
|
905
1049
|
self._add_request_to_queue(req)
|
@@ -1025,12 +1169,14 @@ class Scheduler(
|
|
1025
1169
|
|
1026
1170
|
self.metrics_collector.log_stats(self.stats)
|
1027
1171
|
|
1028
|
-
def log_decode_stats(self):
|
1172
|
+
def log_decode_stats(self, running_batch=None):
|
1173
|
+
batch = running_batch or self.running_batch
|
1174
|
+
|
1029
1175
|
gap_latency = time.time() - self.last_decode_stats_tic
|
1030
1176
|
self.last_decode_stats_tic = time.time()
|
1031
1177
|
self.last_gen_throughput = self.num_generated_tokens / gap_latency
|
1032
1178
|
self.num_generated_tokens = 0
|
1033
|
-
num_running_reqs = len(
|
1179
|
+
num_running_reqs = len(batch.reqs)
|
1034
1180
|
num_used = self.max_total_num_tokens - (
|
1035
1181
|
self.token_to_kv_pool_allocator.available_size()
|
1036
1182
|
+ self.tree_cache.evictable_size()
|
@@ -1130,19 +1276,25 @@ class Scheduler(
|
|
1130
1276
|
|
1131
1277
|
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
|
1132
1278
|
# Merge the prefill batch into the running batch
|
1279
|
+
chunked_req_to_exclude = set()
|
1280
|
+
if self.chunked_req:
|
1281
|
+
# Move the chunked request out of the batch so that we can merge
|
1282
|
+
# only finished requests to running_batch.
|
1283
|
+
chunked_req_to_exclude.add(self.chunked_req)
|
1284
|
+
self.tree_cache.cache_unfinished_req(self.chunked_req)
|
1285
|
+
# chunked request keeps its rid but will get a new req_pool_idx
|
1286
|
+
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
|
1133
1287
|
if self.last_batch and self.last_batch.forward_mode.is_extend():
|
1134
|
-
if self.chunked_req:
|
1135
|
-
#
|
1136
|
-
#
|
1137
|
-
self.last_batch.
|
1138
|
-
self.tree_cache.cache_unfinished_req(self.chunked_req)
|
1139
|
-
# chunked request keeps its rid but will get a new req_pool_idx
|
1140
|
-
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
|
1141
|
-
self.running_batch.batch_is_full = False
|
1288
|
+
if self.last_batch.chunked_req is not None:
|
1289
|
+
# In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req.
|
1290
|
+
# We need to discard it.
|
1291
|
+
chunked_req_to_exclude.add(self.last_batch.chunked_req)
|
1142
1292
|
|
1143
1293
|
# Filter batch
|
1144
1294
|
last_bs = self.last_batch.batch_size()
|
1145
|
-
self.last_batch.filter_batch(
|
1295
|
+
self.last_batch.filter_batch(
|
1296
|
+
chunked_req_to_exclude=list(chunked_req_to_exclude)
|
1297
|
+
)
|
1146
1298
|
if self.last_batch.batch_size() < last_bs:
|
1147
1299
|
self.running_batch.batch_is_full = False
|
1148
1300
|
|
@@ -1172,6 +1324,12 @@ class Scheduler(
|
|
1172
1324
|
|
1173
1325
|
return ret
|
1174
1326
|
|
1327
|
+
def get_num_allocatable_reqs(self, running_bs):
|
1328
|
+
res = global_server_args_dict["max_micro_batch_size"] - running_bs
|
1329
|
+
if self.pp_size > 1:
|
1330
|
+
res = min(res, self.req_to_token_pool.available_size())
|
1331
|
+
return res
|
1332
|
+
|
1175
1333
|
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
|
1176
1334
|
# Check if the grammar is ready in the grammar queue
|
1177
1335
|
if self.grammar_queue:
|
@@ -1184,7 +1342,12 @@ class Scheduler(
|
|
1184
1342
|
return None
|
1185
1343
|
|
1186
1344
|
running_bs = len(self.running_batch.reqs)
|
1187
|
-
|
1345
|
+
# Igore the check if self.chunked_req is not None.
|
1346
|
+
# In the non-PP case, when self.chunked_req is not None, num_allocatable_reqs should always be greater than 0,
|
1347
|
+
# as the space for the chunked request has just been released.
|
1348
|
+
# In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
|
1349
|
+
# Instead, we should always allow chunked request to be added, otherwise, there will be a memory leak.
|
1350
|
+
if self.get_num_allocatable_reqs(running_bs) <= 0 and not self.chunked_req:
|
1188
1351
|
self.running_batch.batch_is_full = True
|
1189
1352
|
return None
|
1190
1353
|
|
@@ -1228,7 +1391,7 @@ class Scheduler(
|
|
1228
1391
|
self.running_batch.batch_is_full = True
|
1229
1392
|
break
|
1230
1393
|
|
1231
|
-
if
|
1394
|
+
if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs):
|
1232
1395
|
self.running_batch.batch_is_full = True
|
1233
1396
|
break
|
1234
1397
|
|
@@ -1240,16 +1403,14 @@ class Scheduler(
|
|
1240
1403
|
res = adder.add_one_req(
|
1241
1404
|
req, self.chunked_req, self.enable_hierarchical_cache
|
1242
1405
|
)
|
1406
|
+
|
1243
1407
|
if res != AddReqResult.CONTINUE:
|
1244
1408
|
if res == AddReqResult.NO_TOKEN:
|
1245
1409
|
if self.enable_hierarchical_cache:
|
1246
1410
|
# Set batch_is_full after making sure there are requests that can be served
|
1247
1411
|
self.running_batch.batch_is_full = len(
|
1248
1412
|
adder.can_run_list
|
1249
|
-
) > 0 or (
|
1250
|
-
self.running_batch is not None
|
1251
|
-
and not self.running_batch.is_empty()
|
1252
|
-
)
|
1413
|
+
) > 0 or (not self.running_batch.is_empty())
|
1253
1414
|
else:
|
1254
1415
|
self.running_batch.batch_is_full = True
|
1255
1416
|
break
|
@@ -1292,6 +1453,7 @@ class Scheduler(
|
|
1292
1453
|
self.enable_overlap,
|
1293
1454
|
self.spec_algorithm,
|
1294
1455
|
self.server_args.enable_custom_logit_processor,
|
1456
|
+
chunked_req=self.chunked_req,
|
1295
1457
|
)
|
1296
1458
|
new_batch.prepare_for_extend()
|
1297
1459
|
|
@@ -1369,9 +1531,14 @@ class Scheduler(
|
|
1369
1531
|
if self.is_generation:
|
1370
1532
|
if self.spec_algorithm.is_none():
|
1371
1533
|
model_worker_batch = batch.get_model_worker_batch()
|
1372
|
-
|
1373
|
-
|
1374
|
-
|
1534
|
+
if self.pp_group.is_last_rank:
|
1535
|
+
logits_output, next_token_ids = (
|
1536
|
+
self.tp_worker.forward_batch_generation(model_worker_batch)
|
1537
|
+
)
|
1538
|
+
else:
|
1539
|
+
pp_hidden_states_proxy_tensors, _ = (
|
1540
|
+
self.tp_worker.forward_batch_generation(model_worker_batch)
|
1541
|
+
)
|
1375
1542
|
bid = model_worker_batch.bid
|
1376
1543
|
else:
|
1377
1544
|
(
|
@@ -1385,7 +1552,9 @@ class Scheduler(
|
|
1385
1552
|
)
|
1386
1553
|
self.spec_num_total_forward_ct += batch.batch_size()
|
1387
1554
|
self.num_generated_tokens += num_accepted_tokens
|
1388
|
-
|
1555
|
+
|
1556
|
+
if self.pp_group.is_last_rank:
|
1557
|
+
batch.output_ids = next_token_ids
|
1389
1558
|
|
1390
1559
|
# These 2 values are needed for processing the output, but the values can be
|
1391
1560
|
# modified by overlap schedule. So we have to copy them here so that
|
@@ -1400,8 +1569,13 @@ class Scheduler(
|
|
1400
1569
|
extend_logprob_start_len_per_req = None
|
1401
1570
|
|
1402
1571
|
ret = GenerationBatchResult(
|
1403
|
-
logits_output=logits_output,
|
1404
|
-
|
1572
|
+
logits_output=logits_output if self.pp_group.is_last_rank else None,
|
1573
|
+
pp_hidden_states_proxy_tensors=(
|
1574
|
+
pp_hidden_states_proxy_tensors
|
1575
|
+
if not self.pp_group.is_last_rank
|
1576
|
+
else None
|
1577
|
+
),
|
1578
|
+
next_token_ids=next_token_ids if self.pp_group.is_last_rank else None,
|
1405
1579
|
extend_input_len_per_req=extend_input_len_per_req,
|
1406
1580
|
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
|
1407
1581
|
bid=bid,
|
@@ -1552,6 +1726,7 @@ class Scheduler(
|
|
1552
1726
|
|
1553
1727
|
def move_ready_grammar_requests(self):
|
1554
1728
|
"""Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
|
1729
|
+
|
1555
1730
|
num_ready_reqs = 0
|
1556
1731
|
for req in self.grammar_queue:
|
1557
1732
|
try:
|
@@ -1618,7 +1793,11 @@ class Scheduler(
|
|
1618
1793
|
|
1619
1794
|
def flush_cache(self):
|
1620
1795
|
"""Flush the memory pool and cache."""
|
1621
|
-
if
|
1796
|
+
if (
|
1797
|
+
len(self.waiting_queue) == 0
|
1798
|
+
and self.running_batch.is_empty()
|
1799
|
+
and (self.pp_size == 1 or all(x.is_empty() for x in self.running_mbs))
|
1800
|
+
):
|
1622
1801
|
self.cur_batch = None
|
1623
1802
|
self.last_batch = None
|
1624
1803
|
self.tree_cache.reset()
|
@@ -1656,7 +1835,6 @@ class Scheduler(
|
|
1656
1835
|
ret["avg_spec_accept_length"] = (
|
1657
1836
|
self.cum_spec_accept_length / self.cum_spec_accept_count
|
1658
1837
|
)
|
1659
|
-
|
1660
1838
|
if RECORD_STEP_TIME:
|
1661
1839
|
ret["step_time_dict"] = self.step_time_dict
|
1662
1840
|
return GetInternalStateReqOutput(
|
@@ -1667,6 +1845,7 @@ class Scheduler(
|
|
1667
1845
|
server_args_dict = recv_req.server_args
|
1668
1846
|
args_allow_update = set(
|
1669
1847
|
[
|
1848
|
+
"max_micro_batch_size",
|
1670
1849
|
"speculative_accept_threshold_single",
|
1671
1850
|
"speculative_accept_threshold_acc",
|
1672
1851
|
]
|
@@ -1677,6 +1856,14 @@ class Scheduler(
|
|
1677
1856
|
logging.warning(f"Updating {k} is not supported.")
|
1678
1857
|
if_success = False
|
1679
1858
|
break
|
1859
|
+
elif k == "max_micro_batch_size" and (
|
1860
|
+
v > self.max_running_requests // self.pp_size or v < 1
|
1861
|
+
):
|
1862
|
+
logging.warning(
|
1863
|
+
f"Updating {k} to {v} is rejected because it is out of the valid range [1, {self.max_running_requests // self.pp_size}]."
|
1864
|
+
)
|
1865
|
+
if_success = False
|
1866
|
+
break
|
1680
1867
|
if if_success:
|
1681
1868
|
if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
|
1682
1869
|
avg_spec_accept_length = (
|
@@ -1958,6 +2145,16 @@ class Scheduler(
|
|
1958
2145
|
else:
|
1959
2146
|
del self.sessions[session_id]
|
1960
2147
|
|
2148
|
+
def get_print_prefix(self):
|
2149
|
+
prefix = ""
|
2150
|
+
if self.dp_rank is not None:
|
2151
|
+
prefix += f" DP{self.dp_rank}"
|
2152
|
+
if self.server_args.tp_size > 1:
|
2153
|
+
prefix += f" TP{self.tp_rank}"
|
2154
|
+
if self.pp_size > 1:
|
2155
|
+
prefix += f" PP{self.pp_rank}"
|
2156
|
+
return prefix
|
2157
|
+
|
1961
2158
|
|
1962
2159
|
def is_health_check_generate_req(recv_req):
|
1963
2160
|
return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
|
@@ -1982,14 +2179,18 @@ def run_scheduler_process(
|
|
1982
2179
|
port_args: PortArgs,
|
1983
2180
|
gpu_id: int,
|
1984
2181
|
tp_rank: int,
|
2182
|
+
pp_rank: int,
|
1985
2183
|
dp_rank: Optional[int],
|
1986
2184
|
pipe_writer,
|
1987
2185
|
):
|
1988
2186
|
# Generate the prefix
|
1989
|
-
|
1990
|
-
|
1991
|
-
|
1992
|
-
|
2187
|
+
prefix = ""
|
2188
|
+
if dp_rank is not None:
|
2189
|
+
prefix += f" DP{dp_rank}"
|
2190
|
+
if server_args.tp_size > 1:
|
2191
|
+
prefix += f" TP{tp_rank}"
|
2192
|
+
if server_args.pp_size > 1:
|
2193
|
+
prefix += f" PP{pp_rank}"
|
1993
2194
|
|
1994
2195
|
# Config the process
|
1995
2196
|
kill_itself_when_parent_died()
|
@@ -2011,7 +2212,7 @@ def run_scheduler_process(
|
|
2011
2212
|
|
2012
2213
|
# Create a scheduler and run the event loop
|
2013
2214
|
try:
|
2014
|
-
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
|
2215
|
+
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
|
2015
2216
|
pipe_writer.send(
|
2016
2217
|
{
|
2017
2218
|
"status": "ready",
|
@@ -2022,7 +2223,9 @@ def run_scheduler_process(
|
|
2022
2223
|
disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
|
2023
2224
|
|
2024
2225
|
if disaggregation_mode == DisaggregationMode.NULL:
|
2025
|
-
if
|
2226
|
+
if server_args.pp_size > 1:
|
2227
|
+
scheduler.event_loop_pp()
|
2228
|
+
elif scheduler.enable_overlap:
|
2026
2229
|
scheduler.event_loop_overlap()
|
2027
2230
|
else:
|
2028
2231
|
scheduler.event_loop_normal()
|
@@ -2031,6 +2234,7 @@ def run_scheduler_process(
|
|
2031
2234
|
scheduler.event_loop_overlap_disagg_prefill()
|
2032
2235
|
else:
|
2033
2236
|
scheduler.event_loop_normal_disagg_prefill()
|
2237
|
+
|
2034
2238
|
elif disaggregation_mode == DisaggregationMode.DECODE:
|
2035
2239
|
if scheduler.enable_overlap:
|
2036
2240
|
scheduler.event_loop_overlap_disagg_decode()
|
@@ -278,7 +278,7 @@ class SchedulerOutputProcessorMixin:
|
|
278
278
|
self.attn_tp_rank == 0
|
279
279
|
and self.forward_ct_decode % self.server_args.decode_log_interval == 0
|
280
280
|
):
|
281
|
-
self.log_decode_stats()
|
281
|
+
self.log_decode_stats(running_batch=batch)
|
282
282
|
|
283
283
|
def add_input_logprob_return_values(
|
284
284
|
self: Scheduler,
|