sglang 0.4.6.post1__py3-none-any.whl → 0.4.6.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -11
- sglang/bench_serving.py +149 -1
- sglang/check_env.py +3 -3
- sglang/lang/chat_template.py +44 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/deepseekvl2.py +3 -0
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/internvl.py +696 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/kimi_vl.py +38 -0
- sglang/srt/configs/kimi_vl_moonvit.py +32 -0
- sglang/srt/configs/model_config.py +32 -0
- sglang/srt/constrained/xgrammar_backend.py +11 -19
- sglang/srt/conversation.py +151 -3
- sglang/srt/disaggregation/decode.py +4 -1
- sglang/srt/disaggregation/mini_lb.py +74 -23
- sglang/srt/disaggregation/mooncake/conn.py +9 -18
- sglang/srt/disaggregation/nixl/conn.py +241 -71
- sglang/srt/disaggregation/utils.py +44 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
- sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
- sglang/srt/distributed/device_communicators/pynccl.py +2 -1
- sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
- sglang/srt/distributed/parallel_state.py +22 -1
- sglang/srt/entrypoints/engine.py +58 -24
- sglang/srt/entrypoints/http_server.py +28 -1
- sglang/srt/entrypoints/verl_engine.py +3 -2
- sglang/srt/function_call_parser.py +97 -0
- sglang/srt/hf_transformers_utils.py +22 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +146 -50
- sglang/srt/layers/attention/flashinfer_backend.py +129 -94
- sglang/srt/layers/attention/flashinfer_mla_backend.py +88 -30
- sglang/srt/layers/attention/flashmla_backend.py +3 -0
- sglang/srt/layers/attention/merge_state.py +46 -0
- sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
- sglang/srt/layers/attention/vision.py +290 -163
- sglang/srt/layers/dp_attention.py +5 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +342 -7
- sglang/srt/layers/moe/ep_moe/layer.py +120 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +98 -57
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +10 -5
- sglang/srt/layers/quantization/__init__.py +2 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
- sglang/srt/layers/quantization/deep_gemm.py +6 -1
- sglang/srt/layers/quantization/fp8.py +108 -95
- sglang/srt/layers/quantization/fp8_kernel.py +79 -60
- sglang/srt/layers/quantization/fp8_utils.py +71 -23
- sglang/srt/layers/quantization/kv_cache.py +3 -10
- sglang/srt/layers/quantization/utils.py +0 -5
- sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
- sglang/srt/layers/utils.py +35 -0
- sglang/srt/lora/layers.py +35 -9
- sglang/srt/lora/lora_manager.py +81 -35
- sglang/srt/managers/cache_controller.py +115 -119
- sglang/srt/managers/data_parallel_controller.py +52 -34
- sglang/srt/managers/io_struct.py +10 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
- sglang/srt/managers/multimodal_processors/internvl.py +232 -0
- sglang/srt/managers/multimodal_processors/kimi_vl.py +73 -0
- sglang/srt/managers/schedule_batch.py +44 -16
- sglang/srt/managers/schedule_policy.py +11 -5
- sglang/srt/managers/scheduler.py +291 -72
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +24 -13
- sglang/srt/managers/tp_worker.py +60 -28
- sglang/srt/managers/tp_worker_overlap_thread.py +9 -3
- sglang/srt/mem_cache/chunk_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +70 -36
- sglang/srt/model_executor/cuda_graph_runner.py +82 -19
- sglang/srt/model_executor/forward_batch_info.py +31 -1
- sglang/srt/model_executor/model_runner.py +159 -90
- sglang/srt/model_loader/loader.py +18 -11
- sglang/srt/models/clip.py +4 -4
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_nextn.py +2 -277
- sglang/srt/models/deepseek_v2.py +132 -37
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/internlm2.py +3 -0
- sglang/srt/models/internvl.py +670 -0
- sglang/srt/models/kimi_vl.py +308 -0
- sglang/srt/models/kimi_vl_moonvit.py +639 -0
- sglang/srt/models/llama.py +93 -31
- sglang/srt/models/llama4.py +54 -7
- sglang/srt/models/llama_eagle.py +4 -1
- sglang/srt/models/llama_eagle3.py +4 -1
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/phi3_small.py +16 -2
- sglang/srt/models/qwen2_5_vl.py +8 -4
- sglang/srt/models/qwen2_moe.py +8 -3
- sglang/srt/models/qwen2_vl.py +4 -16
- sglang/srt/models/qwen3_moe.py +8 -3
- sglang/srt/models/xiaomi_mimo.py +171 -0
- sglang/srt/openai_api/adapter.py +58 -62
- sglang/srt/openai_api/protocol.py +38 -16
- sglang/srt/reasoning_parser.py +2 -2
- sglang/srt/sampling/sampling_batch_info.py +54 -2
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +93 -24
- sglang/srt/speculative/eagle_worker.py +3 -2
- sglang/srt/utils.py +123 -10
- sglang/test/runners.py +4 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deepep_utils.py +219 -0
- sglang/test/test_utils.py +32 -1
- sglang/version.py +1 -1
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/METADATA +18 -9
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/RECORD +119 -99
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/WHEEL +1 -1
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -51,7 +51,12 @@ from sglang.srt.disaggregation.utils import (
|
|
51
51
|
ReqToMetadataIdxAllocator,
|
52
52
|
TransferBackend,
|
53
53
|
)
|
54
|
-
from sglang.srt.
|
54
|
+
from sglang.srt.distributed import get_pp_group, get_world_group
|
55
|
+
from sglang.srt.hf_transformers_utils import (
|
56
|
+
get_processor,
|
57
|
+
get_tokenizer,
|
58
|
+
get_tokenizer_from_processor,
|
59
|
+
)
|
55
60
|
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
56
61
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
57
62
|
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
@@ -82,6 +87,8 @@ from sglang.srt.managers.io_struct import (
|
|
82
87
|
RpcReqOutput,
|
83
88
|
SetInternalStateReq,
|
84
89
|
SetInternalStateReqOutput,
|
90
|
+
SlowDownReqInput,
|
91
|
+
SlowDownReqOutput,
|
85
92
|
TokenizedEmbeddingReqInput,
|
86
93
|
TokenizedGenerateReqInput,
|
87
94
|
UpdateWeightFromDiskReqInput,
|
@@ -114,7 +121,11 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
|
114
121
|
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
115
122
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
116
123
|
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
117
|
-
from sglang.srt.model_executor.forward_batch_info import
|
124
|
+
from sglang.srt.model_executor.forward_batch_info import (
|
125
|
+
ForwardBatch,
|
126
|
+
ForwardMode,
|
127
|
+
PPProxyTensors,
|
128
|
+
)
|
118
129
|
from sglang.srt.reasoning_parser import ReasoningParser
|
119
130
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
120
131
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
@@ -127,6 +138,7 @@ from sglang.srt.utils import (
|
|
127
138
|
get_bool_env_var,
|
128
139
|
get_zmq_socket,
|
129
140
|
kill_itself_when_parent_died,
|
141
|
+
point_to_point_pyobj,
|
130
142
|
pyspy_dump_schedulers,
|
131
143
|
set_gpu_proc_affinity,
|
132
144
|
set_random_seed,
|
@@ -145,8 +157,9 @@ RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
|
|
145
157
|
|
146
158
|
@dataclass
|
147
159
|
class GenerationBatchResult:
|
148
|
-
logits_output: LogitsProcessorOutput
|
149
|
-
|
160
|
+
logits_output: Optional[LogitsProcessorOutput]
|
161
|
+
pp_hidden_states_proxy_tensors: Optional[torch.Tensor]
|
162
|
+
next_token_ids: Optional[List[int]]
|
150
163
|
extend_input_len_per_req: List[int]
|
151
164
|
extend_logprob_start_len_per_req: List[int]
|
152
165
|
bid: int
|
@@ -171,12 +184,16 @@ class Scheduler(
|
|
171
184
|
port_args: PortArgs,
|
172
185
|
gpu_id: int,
|
173
186
|
tp_rank: int,
|
187
|
+
pp_rank: int,
|
174
188
|
dp_rank: Optional[int],
|
175
189
|
):
|
176
190
|
# Parse args
|
177
191
|
self.server_args = server_args
|
178
192
|
self.tp_rank = tp_rank
|
193
|
+
self.pp_rank = pp_rank
|
179
194
|
self.tp_size = server_args.tp_size
|
195
|
+
self.pp_size = server_args.pp_size
|
196
|
+
self.dp_size = server_args.dp_size
|
180
197
|
self.schedule_policy = server_args.schedule_policy
|
181
198
|
self.lora_paths = server_args.lora_paths
|
182
199
|
self.max_loras_per_batch = server_args.max_loras_per_batch
|
@@ -192,7 +209,6 @@ class Scheduler(
|
|
192
209
|
self.page_size = server_args.page_size
|
193
210
|
|
194
211
|
# Distributed rank info
|
195
|
-
self.dp_size = server_args.dp_size
|
196
212
|
self.attn_tp_rank, self.attn_tp_size, self.dp_rank = (
|
197
213
|
compute_dp_attention_world_info(
|
198
214
|
server_args.enable_dp_attention,
|
@@ -204,7 +220,7 @@ class Scheduler(
|
|
204
220
|
|
205
221
|
# Init inter-process communication
|
206
222
|
context = zmq.Context(2)
|
207
|
-
if self.attn_tp_rank == 0:
|
223
|
+
if self.pp_rank == 0 and self.attn_tp_rank == 0:
|
208
224
|
self.recv_from_tokenizer = get_zmq_socket(
|
209
225
|
context, zmq.PULL, port_args.scheduler_input_ipc_name, False
|
210
226
|
)
|
@@ -259,6 +275,7 @@ class Scheduler(
|
|
259
275
|
server_args=server_args,
|
260
276
|
gpu_id=gpu_id,
|
261
277
|
tp_rank=tp_rank,
|
278
|
+
pp_rank=pp_rank,
|
262
279
|
dp_rank=dp_rank,
|
263
280
|
nccl_port=port_args.nccl_port,
|
264
281
|
)
|
@@ -292,8 +309,18 @@ class Scheduler(
|
|
292
309
|
_,
|
293
310
|
_,
|
294
311
|
) = self.tp_worker.get_worker_info()
|
295
|
-
|
312
|
+
if global_server_args_dict["max_micro_batch_size"] is None:
|
313
|
+
global_server_args_dict["max_micro_batch_size"] = max(
|
314
|
+
self.max_running_requests // server_args.pp_size, 1
|
315
|
+
)
|
316
|
+
|
317
|
+
self.tp_group = self.tp_worker.get_tp_group()
|
318
|
+
self.tp_cpu_group = self.tp_group.cpu_group
|
319
|
+
self.attn_tp_group = self.tp_worker.get_attention_tp_group()
|
296
320
|
self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
|
321
|
+
self.pp_group = get_pp_group()
|
322
|
+
self.world_group = get_world_group()
|
323
|
+
|
297
324
|
self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
|
298
325
|
global_server_args_dict.update(worker_global_server_args_dict)
|
299
326
|
set_random_seed(self.random_seed)
|
@@ -392,6 +419,8 @@ class Scheduler(
|
|
392
419
|
self.profiler_id: Optional[str] = None
|
393
420
|
self.profiler_target_forward_ct: Optional[int] = None
|
394
421
|
|
422
|
+
self.forward_sleep_time = None
|
423
|
+
|
395
424
|
# Init metrics stats
|
396
425
|
self.init_metrics()
|
397
426
|
|
@@ -414,6 +443,7 @@ class Scheduler(
|
|
414
443
|
(GetWeightsByNameReqInput, self.get_weights_by_name),
|
415
444
|
(ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
|
416
445
|
(ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
|
446
|
+
(SlowDownReqInput, self.slow_down),
|
417
447
|
(ProfileReq, self.profile),
|
418
448
|
(GetInternalStateReq, self.get_internal_state),
|
419
449
|
(SetInternalStateReq, self.set_internal_state),
|
@@ -430,17 +460,7 @@ class Scheduler(
|
|
430
460
|
def init_tokenizer(self):
|
431
461
|
server_args = self.server_args
|
432
462
|
|
433
|
-
self.model_config = ModelConfig(
|
434
|
-
server_args.model_path,
|
435
|
-
trust_remote_code=server_args.trust_remote_code,
|
436
|
-
revision=server_args.revision,
|
437
|
-
context_length=server_args.context_length,
|
438
|
-
model_override_args=server_args.json_model_override_args,
|
439
|
-
is_embedding=server_args.is_embedding,
|
440
|
-
enable_multimodal=server_args.enable_multimodal,
|
441
|
-
dtype=server_args.dtype,
|
442
|
-
quantization=server_args.quantization,
|
443
|
-
)
|
463
|
+
self.model_config = ModelConfig.from_server_args(server_args)
|
444
464
|
self.is_generation = self.model_config.is_generation
|
445
465
|
|
446
466
|
if server_args.skip_tokenizer_init:
|
@@ -454,7 +474,7 @@ class Scheduler(
|
|
454
474
|
revision=server_args.revision,
|
455
475
|
use_fast=not server_args.disable_fast_image_processor,
|
456
476
|
)
|
457
|
-
self.tokenizer = self.processor
|
477
|
+
self.tokenizer = get_tokenizer_from_processor(self.processor)
|
458
478
|
else:
|
459
479
|
self.tokenizer = get_tokenizer(
|
460
480
|
server_args.tokenizer_path,
|
@@ -477,6 +497,7 @@ class Scheduler(
|
|
477
497
|
self.tree_cache = ChunkCache(
|
478
498
|
req_to_token_pool=self.req_to_token_pool,
|
479
499
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
500
|
+
page_size=self.page_size,
|
480
501
|
)
|
481
502
|
else:
|
482
503
|
if self.enable_hierarchical_cache:
|
@@ -673,26 +694,141 @@ class Scheduler(
|
|
673
694
|
|
674
695
|
self.last_batch = batch
|
675
696
|
|
697
|
+
@DynamicGradMode()
|
698
|
+
def event_loop_pp(self):
|
699
|
+
"""A non-overlap scheduler loop for pipeline parallelism."""
|
700
|
+
mbs = [None] * self.pp_size
|
701
|
+
last_mbs = [None] * self.pp_size
|
702
|
+
self.running_mbs = [
|
703
|
+
ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
|
704
|
+
]
|
705
|
+
bids = [None] * self.pp_size
|
706
|
+
pp_outputs: Optional[PPProxyTensors] = None
|
707
|
+
while True:
|
708
|
+
server_is_idle = True
|
709
|
+
for mb_id in range(self.pp_size):
|
710
|
+
self.running_batch = self.running_mbs[mb_id]
|
711
|
+
self.last_batch = last_mbs[mb_id]
|
712
|
+
|
713
|
+
recv_reqs = self.recv_requests()
|
714
|
+
self.process_input_requests(recv_reqs)
|
715
|
+
mbs[mb_id] = self.get_next_batch_to_run()
|
716
|
+
self.running_mbs[mb_id] = self.running_batch
|
717
|
+
|
718
|
+
self.cur_batch = mbs[mb_id]
|
719
|
+
if self.cur_batch:
|
720
|
+
server_is_idle = False
|
721
|
+
result = self.run_batch(self.cur_batch)
|
722
|
+
|
723
|
+
# send the outputs to the next step
|
724
|
+
if self.pp_group.is_last_rank:
|
725
|
+
if self.cur_batch:
|
726
|
+
next_token_ids, bids[mb_id] = (
|
727
|
+
result.next_token_ids,
|
728
|
+
result.bid,
|
729
|
+
)
|
730
|
+
pp_outputs = PPProxyTensors(
|
731
|
+
{
|
732
|
+
"next_token_ids": next_token_ids,
|
733
|
+
}
|
734
|
+
)
|
735
|
+
# send the output from the last round to let the next stage worker run post processing
|
736
|
+
self.pp_group.send_tensor_dict(
|
737
|
+
pp_outputs.tensors,
|
738
|
+
all_gather_group=self.attn_tp_group,
|
739
|
+
)
|
740
|
+
|
741
|
+
# receive outputs and post-process (filter finished reqs) the coming microbatch
|
742
|
+
next_mb_id = (mb_id + 1) % self.pp_size
|
743
|
+
next_pp_outputs = None
|
744
|
+
if mbs[next_mb_id] is not None:
|
745
|
+
next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
|
746
|
+
self.pp_group.recv_tensor_dict(
|
747
|
+
all_gather_group=self.attn_tp_group
|
748
|
+
)
|
749
|
+
)
|
750
|
+
mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
|
751
|
+
output_result = GenerationBatchResult(
|
752
|
+
logits_output=None,
|
753
|
+
pp_hidden_states_proxy_tensors=None,
|
754
|
+
next_token_ids=next_pp_outputs["next_token_ids"],
|
755
|
+
extend_input_len_per_req=None,
|
756
|
+
extend_logprob_start_len_per_req=None,
|
757
|
+
bid=bids[next_mb_id],
|
758
|
+
)
|
759
|
+
self.process_batch_result(mbs[next_mb_id], output_result)
|
760
|
+
last_mbs[next_mb_id] = mbs[next_mb_id]
|
761
|
+
|
762
|
+
# carry the outputs to the next stage
|
763
|
+
if not self.pp_group.is_last_rank:
|
764
|
+
if self.cur_batch:
|
765
|
+
bids[mb_id] = result.bid
|
766
|
+
if pp_outputs:
|
767
|
+
# send the outputs from the last round to let the next stage worker run post processing
|
768
|
+
self.pp_group.send_tensor_dict(
|
769
|
+
pp_outputs.tensors,
|
770
|
+
all_gather_group=self.attn_tp_group,
|
771
|
+
)
|
772
|
+
|
773
|
+
if not self.pp_group.is_last_rank:
|
774
|
+
# send out reqs to the next stage
|
775
|
+
dp_offset = self.dp_rank * self.attn_tp_size
|
776
|
+
if self.attn_tp_rank == 0:
|
777
|
+
point_to_point_pyobj(
|
778
|
+
recv_reqs,
|
779
|
+
self.pp_rank * self.tp_size + dp_offset,
|
780
|
+
self.world_group.cpu_group,
|
781
|
+
self.pp_rank * self.tp_size + dp_offset,
|
782
|
+
(self.pp_rank + 1) * self.tp_size + dp_offset,
|
783
|
+
)
|
784
|
+
|
785
|
+
# send out proxy tensors to the next stage
|
786
|
+
if self.cur_batch:
|
787
|
+
self.pp_group.send_tensor_dict(
|
788
|
+
result.pp_hidden_states_proxy_tensors,
|
789
|
+
all_gather_group=self.attn_tp_group,
|
790
|
+
)
|
791
|
+
|
792
|
+
pp_outputs = next_pp_outputs
|
793
|
+
|
794
|
+
# When the server is idle, self-check and re-init some states
|
795
|
+
if server_is_idle:
|
796
|
+
self.check_memory()
|
797
|
+
self.new_token_ratio = self.init_new_token_ratio
|
798
|
+
|
676
799
|
def recv_requests(self) -> List[Req]:
|
677
800
|
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
|
678
|
-
if self.
|
679
|
-
|
680
|
-
|
681
|
-
|
682
|
-
|
683
|
-
|
684
|
-
|
685
|
-
|
686
|
-
|
687
|
-
|
688
|
-
|
689
|
-
|
690
|
-
|
691
|
-
|
692
|
-
|
693
|
-
|
801
|
+
if self.pp_rank == 0:
|
802
|
+
if self.attn_tp_rank == 0:
|
803
|
+
recv_reqs = []
|
804
|
+
|
805
|
+
while True:
|
806
|
+
try:
|
807
|
+
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
|
808
|
+
except zmq.ZMQError:
|
809
|
+
break
|
810
|
+
recv_reqs.append(recv_req)
|
811
|
+
|
812
|
+
while True:
|
813
|
+
try:
|
814
|
+
recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK)
|
815
|
+
except zmq.ZMQError:
|
816
|
+
break
|
817
|
+
recv_reqs.append(recv_rpc)
|
818
|
+
else:
|
819
|
+
recv_reqs = None
|
694
820
|
else:
|
695
|
-
|
821
|
+
if self.attn_tp_rank == 0:
|
822
|
+
dp_offset = self.dp_rank * self.attn_tp_size
|
823
|
+
recv_reqs = point_to_point_pyobj(
|
824
|
+
[],
|
825
|
+
self.pp_rank * self.tp_size + dp_offset,
|
826
|
+
self.world_group.cpu_group,
|
827
|
+
(self.pp_rank - 1) * self.tp_size + dp_offset,
|
828
|
+
self.pp_rank * self.tp_size + dp_offset,
|
829
|
+
)
|
830
|
+
else:
|
831
|
+
recv_reqs = None
|
696
832
|
|
697
833
|
if self.server_args.enable_dp_attention:
|
698
834
|
if self.attn_tp_rank == 0:
|
@@ -715,20 +851,27 @@ class Scheduler(
|
|
715
851
|
control_reqs = None
|
716
852
|
|
717
853
|
if self.attn_tp_size != 1:
|
718
|
-
attn_tp_rank_0 = self.dp_rank * self.attn_tp_size
|
719
854
|
work_reqs = broadcast_pyobj(
|
720
855
|
work_reqs,
|
721
|
-
self.
|
856
|
+
self.attn_tp_group.rank,
|
722
857
|
self.attn_tp_cpu_group,
|
723
|
-
src=
|
858
|
+
src=self.attn_tp_group.ranks[0],
|
724
859
|
)
|
725
860
|
if self.tp_size != 1:
|
726
861
|
control_reqs = broadcast_pyobj(
|
727
|
-
control_reqs,
|
862
|
+
control_reqs,
|
863
|
+
self.tp_group.rank,
|
864
|
+
self.tp_cpu_group,
|
865
|
+
src=self.tp_group.ranks[0],
|
728
866
|
)
|
729
867
|
recv_reqs = work_reqs + control_reqs
|
730
868
|
elif self.tp_size != 1:
|
731
|
-
recv_reqs = broadcast_pyobj(
|
869
|
+
recv_reqs = broadcast_pyobj(
|
870
|
+
recv_reqs,
|
871
|
+
self.tp_group.rank,
|
872
|
+
self.tp_cpu_group,
|
873
|
+
src=self.tp_group.ranks[0],
|
874
|
+
)
|
732
875
|
return recv_reqs
|
733
876
|
|
734
877
|
def process_input_requests(self, recv_reqs: List):
|
@@ -777,6 +920,10 @@ class Scheduler(
|
|
777
920
|
)
|
778
921
|
custom_logit_processor = None
|
779
922
|
|
923
|
+
if recv_req.bootstrap_port is None:
|
924
|
+
# Use default bootstrap port
|
925
|
+
recv_req.bootstrap_port = self.server_args.disaggregation_bootstrap_port
|
926
|
+
|
780
927
|
req = Req(
|
781
928
|
recv_req.rid,
|
782
929
|
recv_req.input_text,
|
@@ -900,6 +1047,7 @@ class Scheduler(
|
|
900
1047
|
add_to_grammar_queue = True
|
901
1048
|
|
902
1049
|
if add_to_grammar_queue:
|
1050
|
+
req.queue_time_start = time.time()
|
903
1051
|
self.grammar_queue.append(req)
|
904
1052
|
else:
|
905
1053
|
self._add_request_to_queue(req)
|
@@ -1025,12 +1173,14 @@ class Scheduler(
|
|
1025
1173
|
|
1026
1174
|
self.metrics_collector.log_stats(self.stats)
|
1027
1175
|
|
1028
|
-
def log_decode_stats(self):
|
1176
|
+
def log_decode_stats(self, running_batch=None):
|
1177
|
+
batch = running_batch or self.running_batch
|
1178
|
+
|
1029
1179
|
gap_latency = time.time() - self.last_decode_stats_tic
|
1030
1180
|
self.last_decode_stats_tic = time.time()
|
1031
1181
|
self.last_gen_throughput = self.num_generated_tokens / gap_latency
|
1032
1182
|
self.num_generated_tokens = 0
|
1033
|
-
num_running_reqs = len(
|
1183
|
+
num_running_reqs = len(batch.reqs)
|
1034
1184
|
num_used = self.max_total_num_tokens - (
|
1035
1185
|
self.token_to_kv_pool_allocator.available_size()
|
1036
1186
|
+ self.tree_cache.evictable_size()
|
@@ -1130,19 +1280,25 @@ class Scheduler(
|
|
1130
1280
|
|
1131
1281
|
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
|
1132
1282
|
# Merge the prefill batch into the running batch
|
1283
|
+
chunked_req_to_exclude = set()
|
1284
|
+
if self.chunked_req:
|
1285
|
+
# Move the chunked request out of the batch so that we can merge
|
1286
|
+
# only finished requests to running_batch.
|
1287
|
+
chunked_req_to_exclude.add(self.chunked_req)
|
1288
|
+
self.tree_cache.cache_unfinished_req(self.chunked_req)
|
1289
|
+
# chunked request keeps its rid but will get a new req_pool_idx
|
1290
|
+
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
|
1133
1291
|
if self.last_batch and self.last_batch.forward_mode.is_extend():
|
1134
|
-
if self.chunked_req:
|
1135
|
-
#
|
1136
|
-
#
|
1137
|
-
self.last_batch.
|
1138
|
-
self.tree_cache.cache_unfinished_req(self.chunked_req)
|
1139
|
-
# chunked request keeps its rid but will get a new req_pool_idx
|
1140
|
-
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
|
1141
|
-
self.running_batch.batch_is_full = False
|
1292
|
+
if self.last_batch.chunked_req is not None:
|
1293
|
+
# In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req.
|
1294
|
+
# We need to discard it.
|
1295
|
+
chunked_req_to_exclude.add(self.last_batch.chunked_req)
|
1142
1296
|
|
1143
1297
|
# Filter batch
|
1144
1298
|
last_bs = self.last_batch.batch_size()
|
1145
|
-
self.last_batch.filter_batch(
|
1299
|
+
self.last_batch.filter_batch(
|
1300
|
+
chunked_req_to_exclude=list(chunked_req_to_exclude)
|
1301
|
+
)
|
1146
1302
|
if self.last_batch.batch_size() < last_bs:
|
1147
1303
|
self.running_batch.batch_is_full = False
|
1148
1304
|
|
@@ -1172,6 +1328,12 @@ class Scheduler(
|
|
1172
1328
|
|
1173
1329
|
return ret
|
1174
1330
|
|
1331
|
+
def get_num_allocatable_reqs(self, running_bs):
|
1332
|
+
res = global_server_args_dict["max_micro_batch_size"] - running_bs
|
1333
|
+
if self.pp_size > 1:
|
1334
|
+
res = min(res, self.req_to_token_pool.available_size())
|
1335
|
+
return res
|
1336
|
+
|
1175
1337
|
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
|
1176
1338
|
# Check if the grammar is ready in the grammar queue
|
1177
1339
|
if self.grammar_queue:
|
@@ -1184,7 +1346,12 @@ class Scheduler(
|
|
1184
1346
|
return None
|
1185
1347
|
|
1186
1348
|
running_bs = len(self.running_batch.reqs)
|
1187
|
-
|
1349
|
+
# Igore the check if self.chunked_req is not None.
|
1350
|
+
# In the non-PP case, when self.chunked_req is not None, num_allocatable_reqs should always be greater than 0,
|
1351
|
+
# as the space for the chunked request has just been released.
|
1352
|
+
# In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
|
1353
|
+
# Instead, we should always allow chunked request to be added, otherwise, there will be a memory leak.
|
1354
|
+
if self.get_num_allocatable_reqs(running_bs) <= 0 and not self.chunked_req:
|
1188
1355
|
self.running_batch.batch_is_full = True
|
1189
1356
|
return None
|
1190
1357
|
|
@@ -1228,7 +1395,7 @@ class Scheduler(
|
|
1228
1395
|
self.running_batch.batch_is_full = True
|
1229
1396
|
break
|
1230
1397
|
|
1231
|
-
if
|
1398
|
+
if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs):
|
1232
1399
|
self.running_batch.batch_is_full = True
|
1233
1400
|
break
|
1234
1401
|
|
@@ -1240,16 +1407,14 @@ class Scheduler(
|
|
1240
1407
|
res = adder.add_one_req(
|
1241
1408
|
req, self.chunked_req, self.enable_hierarchical_cache
|
1242
1409
|
)
|
1410
|
+
|
1243
1411
|
if res != AddReqResult.CONTINUE:
|
1244
1412
|
if res == AddReqResult.NO_TOKEN:
|
1245
1413
|
if self.enable_hierarchical_cache:
|
1246
1414
|
# Set batch_is_full after making sure there are requests that can be served
|
1247
1415
|
self.running_batch.batch_is_full = len(
|
1248
1416
|
adder.can_run_list
|
1249
|
-
) > 0 or (
|
1250
|
-
self.running_batch is not None
|
1251
|
-
and not self.running_batch.is_empty()
|
1252
|
-
)
|
1417
|
+
) > 0 or (not self.running_batch.is_empty())
|
1253
1418
|
else:
|
1254
1419
|
self.running_batch.batch_is_full = True
|
1255
1420
|
break
|
@@ -1292,6 +1457,7 @@ class Scheduler(
|
|
1292
1457
|
self.enable_overlap,
|
1293
1458
|
self.spec_algorithm,
|
1294
1459
|
self.server_args.enable_custom_logit_processor,
|
1460
|
+
chunked_req=self.chunked_req,
|
1295
1461
|
)
|
1296
1462
|
new_batch.prepare_for_extend()
|
1297
1463
|
|
@@ -1365,13 +1531,22 @@ class Scheduler(
|
|
1365
1531
|
):
|
1366
1532
|
self.stop_profile()
|
1367
1533
|
|
1534
|
+
if self.forward_sleep_time is not None:
|
1535
|
+
logger.info(f"Scheduler.run_batch sleep {self.forward_sleep_time}s")
|
1536
|
+
time.sleep(self.forward_sleep_time)
|
1537
|
+
|
1368
1538
|
# Run forward
|
1369
1539
|
if self.is_generation:
|
1370
1540
|
if self.spec_algorithm.is_none():
|
1371
1541
|
model_worker_batch = batch.get_model_worker_batch()
|
1372
|
-
|
1373
|
-
|
1374
|
-
|
1542
|
+
if self.pp_group.is_last_rank:
|
1543
|
+
logits_output, next_token_ids = (
|
1544
|
+
self.tp_worker.forward_batch_generation(model_worker_batch)
|
1545
|
+
)
|
1546
|
+
else:
|
1547
|
+
pp_hidden_states_proxy_tensors, _ = (
|
1548
|
+
self.tp_worker.forward_batch_generation(model_worker_batch)
|
1549
|
+
)
|
1375
1550
|
bid = model_worker_batch.bid
|
1376
1551
|
else:
|
1377
1552
|
(
|
@@ -1385,7 +1560,9 @@ class Scheduler(
|
|
1385
1560
|
)
|
1386
1561
|
self.spec_num_total_forward_ct += batch.batch_size()
|
1387
1562
|
self.num_generated_tokens += num_accepted_tokens
|
1388
|
-
|
1563
|
+
|
1564
|
+
if self.pp_group.is_last_rank:
|
1565
|
+
batch.output_ids = next_token_ids
|
1389
1566
|
|
1390
1567
|
# These 2 values are needed for processing the output, but the values can be
|
1391
1568
|
# modified by overlap schedule. So we have to copy them here so that
|
@@ -1400,8 +1577,13 @@ class Scheduler(
|
|
1400
1577
|
extend_logprob_start_len_per_req = None
|
1401
1578
|
|
1402
1579
|
ret = GenerationBatchResult(
|
1403
|
-
logits_output=logits_output,
|
1404
|
-
|
1580
|
+
logits_output=logits_output if self.pp_group.is_last_rank else None,
|
1581
|
+
pp_hidden_states_proxy_tensors=(
|
1582
|
+
pp_hidden_states_proxy_tensors
|
1583
|
+
if not self.pp_group.is_last_rank
|
1584
|
+
else None
|
1585
|
+
),
|
1586
|
+
next_token_ids=next_token_ids if self.pp_group.is_last_rank else None,
|
1405
1587
|
extend_input_len_per_req=extend_input_len_per_req,
|
1406
1588
|
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
|
1407
1589
|
bid=bid,
|
@@ -1552,6 +1734,7 @@ class Scheduler(
|
|
1552
1734
|
|
1553
1735
|
def move_ready_grammar_requests(self):
|
1554
1736
|
"""Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
|
1737
|
+
|
1555
1738
|
num_ready_reqs = 0
|
1556
1739
|
for req in self.grammar_queue:
|
1557
1740
|
try:
|
@@ -1618,7 +1801,11 @@ class Scheduler(
|
|
1618
1801
|
|
1619
1802
|
def flush_cache(self):
|
1620
1803
|
"""Flush the memory pool and cache."""
|
1621
|
-
if
|
1804
|
+
if (
|
1805
|
+
len(self.waiting_queue) == 0
|
1806
|
+
and self.running_batch.is_empty()
|
1807
|
+
and (self.pp_size == 1 or all(x.is_empty() for x in self.running_mbs))
|
1808
|
+
):
|
1622
1809
|
self.cur_batch = None
|
1623
1810
|
self.last_batch = None
|
1624
1811
|
self.tree_cache.reset()
|
@@ -1656,7 +1843,6 @@ class Scheduler(
|
|
1656
1843
|
ret["avg_spec_accept_length"] = (
|
1657
1844
|
self.cum_spec_accept_length / self.cum_spec_accept_count
|
1658
1845
|
)
|
1659
|
-
|
1660
1846
|
if RECORD_STEP_TIME:
|
1661
1847
|
ret["step_time_dict"] = self.step_time_dict
|
1662
1848
|
return GetInternalStateReqOutput(
|
@@ -1667,6 +1853,7 @@ class Scheduler(
|
|
1667
1853
|
server_args_dict = recv_req.server_args
|
1668
1854
|
args_allow_update = set(
|
1669
1855
|
[
|
1856
|
+
"max_micro_batch_size",
|
1670
1857
|
"speculative_accept_threshold_single",
|
1671
1858
|
"speculative_accept_threshold_acc",
|
1672
1859
|
]
|
@@ -1677,6 +1864,14 @@ class Scheduler(
|
|
1677
1864
|
logging.warning(f"Updating {k} is not supported.")
|
1678
1865
|
if_success = False
|
1679
1866
|
break
|
1867
|
+
elif k == "max_micro_batch_size" and (
|
1868
|
+
v > self.max_running_requests // self.pp_size or v < 1
|
1869
|
+
):
|
1870
|
+
logging.warning(
|
1871
|
+
f"Updating {k} to {v} is rejected because it is out of the valid range [1, {self.max_running_requests // self.pp_size}]."
|
1872
|
+
)
|
1873
|
+
if_success = False
|
1874
|
+
break
|
1680
1875
|
if if_success:
|
1681
1876
|
if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
|
1682
1877
|
avg_spec_accept_length = (
|
@@ -1815,6 +2010,13 @@ class Scheduler(
|
|
1815
2010
|
del self.stashed_model_static_state
|
1816
2011
|
return ResumeMemoryOccupationReqOutput()
|
1817
2012
|
|
2013
|
+
def slow_down(self, recv_req: SlowDownReqInput):
|
2014
|
+
t = recv_req.forward_sleep_time
|
2015
|
+
if t is not None and t <= 0:
|
2016
|
+
t = None
|
2017
|
+
self.forward_sleep_time = t
|
2018
|
+
return SlowDownReqOutput()
|
2019
|
+
|
1818
2020
|
def profile(self, recv_req: ProfileReq):
|
1819
2021
|
if recv_req.type == ProfileReqType.START_PROFILE:
|
1820
2022
|
return self.start_profile(
|
@@ -1958,6 +2160,16 @@ class Scheduler(
|
|
1958
2160
|
else:
|
1959
2161
|
del self.sessions[session_id]
|
1960
2162
|
|
2163
|
+
def get_print_prefix(self):
|
2164
|
+
prefix = ""
|
2165
|
+
if self.dp_rank is not None:
|
2166
|
+
prefix += f" DP{self.dp_rank}"
|
2167
|
+
if self.server_args.tp_size > 1:
|
2168
|
+
prefix += f" TP{self.tp_rank}"
|
2169
|
+
if self.pp_size > 1:
|
2170
|
+
prefix += f" PP{self.pp_rank}"
|
2171
|
+
return prefix
|
2172
|
+
|
1961
2173
|
|
1962
2174
|
def is_health_check_generate_req(recv_req):
|
1963
2175
|
return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
|
@@ -1982,14 +2194,18 @@ def run_scheduler_process(
|
|
1982
2194
|
port_args: PortArgs,
|
1983
2195
|
gpu_id: int,
|
1984
2196
|
tp_rank: int,
|
2197
|
+
pp_rank: int,
|
1985
2198
|
dp_rank: Optional[int],
|
1986
2199
|
pipe_writer,
|
1987
2200
|
):
|
1988
2201
|
# Generate the prefix
|
1989
|
-
|
1990
|
-
|
1991
|
-
|
1992
|
-
|
2202
|
+
prefix = ""
|
2203
|
+
if dp_rank is not None:
|
2204
|
+
prefix += f" DP{dp_rank}"
|
2205
|
+
if server_args.tp_size > 1:
|
2206
|
+
prefix += f" TP{tp_rank}"
|
2207
|
+
if server_args.pp_size > 1:
|
2208
|
+
prefix += f" PP{pp_rank}"
|
1993
2209
|
|
1994
2210
|
# Config the process
|
1995
2211
|
kill_itself_when_parent_died()
|
@@ -2011,7 +2227,7 @@ def run_scheduler_process(
|
|
2011
2227
|
|
2012
2228
|
# Create a scheduler and run the event loop
|
2013
2229
|
try:
|
2014
|
-
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
|
2230
|
+
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
|
2015
2231
|
pipe_writer.send(
|
2016
2232
|
{
|
2017
2233
|
"status": "ready",
|
@@ -2022,7 +2238,9 @@ def run_scheduler_process(
|
|
2022
2238
|
disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
|
2023
2239
|
|
2024
2240
|
if disaggregation_mode == DisaggregationMode.NULL:
|
2025
|
-
if
|
2241
|
+
if server_args.pp_size > 1:
|
2242
|
+
scheduler.event_loop_pp()
|
2243
|
+
elif scheduler.enable_overlap:
|
2026
2244
|
scheduler.event_loop_overlap()
|
2027
2245
|
else:
|
2028
2246
|
scheduler.event_loop_normal()
|
@@ -2031,6 +2249,7 @@ def run_scheduler_process(
|
|
2031
2249
|
scheduler.event_loop_overlap_disagg_prefill()
|
2032
2250
|
else:
|
2033
2251
|
scheduler.event_loop_normal_disagg_prefill()
|
2252
|
+
|
2034
2253
|
elif disaggregation_mode == DisaggregationMode.DECODE:
|
2035
2254
|
if scheduler.enable_overlap:
|
2036
2255
|
scheduler.event_loop_overlap_disagg_decode()
|
@@ -278,7 +278,7 @@ class SchedulerOutputProcessorMixin:
|
|
278
278
|
self.attn_tp_rank == 0
|
279
279
|
and self.forward_ct_decode % self.server_args.decode_log_interval == 0
|
280
280
|
):
|
281
|
-
self.log_decode_stats()
|
281
|
+
self.log_decode_stats(running_batch=batch)
|
282
282
|
|
283
283
|
def add_input_logprob_return_values(
|
284
284
|
self: Scheduler,
|