sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +4 -2
- sglang/bench_one_batch.py +3 -13
- sglang/bench_one_batch_server.py +143 -15
- sglang/bench_serving.py +158 -8
- sglang/compile_deep_gemm.py +1 -1
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +119 -75
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +5 -2
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/internvl.py +696 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +18 -0
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +71 -53
- sglang/srt/conversation.py +78 -46
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +11 -3
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +74 -23
- sglang/srt/disaggregation/mooncake/conn.py +236 -138
- sglang/srt/disaggregation/nixl/conn.py +242 -71
- sglang/srt/disaggregation/prefill.py +7 -4
- sglang/srt/disaggregation/utils.py +51 -2
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
- sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
- sglang/srt/distributed/device_communicators/pynccl.py +2 -1
- sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
- sglang/srt/distributed/parallel_state.py +22 -1
- sglang/srt/entrypoints/engine.py +31 -4
- sglang/srt/entrypoints/http_server.py +45 -3
- sglang/srt/entrypoints/verl_engine.py +3 -2
- sglang/srt/function_call_parser.py +2 -2
- sglang/srt/hf_transformers_utils.py +20 -1
- sglang/srt/layers/attention/flashattention_backend.py +147 -51
- sglang/srt/layers/attention/flashinfer_backend.py +23 -13
- sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
- sglang/srt/layers/attention/merge_state.py +46 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
- sglang/srt/layers/attention/utils.py +4 -2
- sglang/srt/layers/attention/vision.py +290 -163
- sglang/srt/layers/dp_attention.py +71 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/ep_moe/kernels.py +343 -8
- sglang/srt/layers/moe/ep_moe/layer.py +121 -2
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/topk.py +1 -1
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
- sglang/srt/layers/quantization/deep_gemm.py +77 -71
- sglang/srt/layers/quantization/fp8.py +110 -97
- sglang/srt/layers/quantization/fp8_kernel.py +81 -62
- sglang/srt/layers/quantization/fp8_utils.py +71 -23
- sglang/srt/layers/quantization/int8_kernel.py +2 -2
- sglang/srt/layers/quantization/kv_cache.py +3 -10
- sglang/srt/layers/quantization/utils.py +0 -5
- sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +11 -14
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +115 -119
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/io_struct.py +13 -1
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
- sglang/srt/managers/multimodal_processors/internvl.py +232 -0
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/schedule_batch.py +93 -23
- sglang/srt/managers/schedule_policy.py +11 -8
- sglang/srt/managers/scheduler.py +140 -100
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/tokenizer_manager.py +157 -47
- sglang/srt/managers/tp_worker.py +21 -21
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/chunk_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +4 -2
- sglang/srt/metrics/collector.py +312 -37
- sglang/srt/model_executor/cuda_graph_runner.py +10 -11
- sglang/srt/model_executor/forward_batch_info.py +1 -1
- sglang/srt/model_executor/model_runner.py +57 -41
- sglang/srt/model_loader/loader.py +18 -11
- sglang/srt/models/clip.py +4 -4
- sglang/srt/models/deepseek_janus_pro.py +3 -3
- sglang/srt/models/deepseek_nextn.py +1 -20
- sglang/srt/models/deepseek_v2.py +77 -39
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/internlm2.py +3 -0
- sglang/srt/models/internvl.py +670 -0
- sglang/srt/models/llama.py +3 -1
- sglang/srt/models/llama4.py +58 -13
- sglang/srt/models/llava.py +248 -5
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/phi3_small.py +16 -2
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/qwen2_5_vl.py +8 -4
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/xiaomi_mimo.py +171 -0
- sglang/srt/openai_api/adapter.py +52 -42
- sglang/srt/openai_api/protocol.py +20 -16
- sglang/srt/reasoning_parser.py +1 -1
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +2 -2
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +64 -10
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +7 -7
- sglang/srt/speculative/eagle_worker.py +22 -19
- sglang/srt/utils.py +41 -6
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deepep_utils.py +219 -0
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +92 -15
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +18 -9
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +150 -137
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +1 -1
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -20,7 +20,6 @@ import signal
|
|
20
20
|
import sys
|
21
21
|
import threading
|
22
22
|
import time
|
23
|
-
import warnings
|
24
23
|
from collections import defaultdict, deque
|
25
24
|
from concurrent import futures
|
26
25
|
from dataclasses import dataclass
|
@@ -52,7 +51,11 @@ from sglang.srt.disaggregation.utils import (
|
|
52
51
|
TransferBackend,
|
53
52
|
)
|
54
53
|
from sglang.srt.distributed import get_pp_group, get_world_group
|
55
|
-
from sglang.srt.hf_transformers_utils import
|
54
|
+
from sglang.srt.hf_transformers_utils import (
|
55
|
+
get_processor,
|
56
|
+
get_tokenizer,
|
57
|
+
get_tokenizer_from_processor,
|
58
|
+
)
|
56
59
|
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
57
60
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
58
61
|
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
@@ -83,6 +86,8 @@ from sglang.srt.managers.io_struct import (
|
|
83
86
|
RpcReqOutput,
|
84
87
|
SetInternalStateReq,
|
85
88
|
SetInternalStateReqOutput,
|
89
|
+
SlowDownReqInput,
|
90
|
+
SlowDownReqOutput,
|
86
91
|
TokenizedEmbeddingReqInput,
|
87
92
|
TokenizedGenerateReqInput,
|
88
93
|
UpdateWeightFromDiskReqInput,
|
@@ -115,11 +120,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
|
115
120
|
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
116
121
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
117
122
|
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
118
|
-
from sglang.srt.model_executor.forward_batch_info import
|
119
|
-
ForwardBatch,
|
120
|
-
ForwardMode,
|
121
|
-
PPProxyTensors,
|
122
|
-
)
|
123
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
|
123
124
|
from sglang.srt.reasoning_parser import ReasoningParser
|
124
125
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
125
126
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
@@ -129,6 +130,7 @@ from sglang.srt.utils import (
|
|
129
130
|
broadcast_pyobj,
|
130
131
|
configure_logger,
|
131
132
|
crash_on_warnings,
|
133
|
+
disable_request_logging,
|
132
134
|
get_bool_env_var,
|
133
135
|
get_zmq_socket,
|
134
136
|
kill_itself_when_parent_died,
|
@@ -147,6 +149,7 @@ logger = logging.getLogger(__name__)
|
|
147
149
|
# Test retract decode for debugging purposes
|
148
150
|
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
|
149
151
|
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
|
152
|
+
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
|
150
153
|
|
151
154
|
|
152
155
|
@dataclass
|
@@ -157,6 +160,7 @@ class GenerationBatchResult:
|
|
157
160
|
extend_input_len_per_req: List[int]
|
158
161
|
extend_logprob_start_len_per_req: List[int]
|
159
162
|
bid: int
|
163
|
+
can_run_cuda_graph: bool
|
160
164
|
|
161
165
|
|
162
166
|
@dataclass
|
@@ -203,7 +207,8 @@ class Scheduler(
|
|
203
207
|
self.page_size = server_args.page_size
|
204
208
|
|
205
209
|
# Distributed rank info
|
206
|
-
self.
|
210
|
+
self.dp_size = server_args.dp_size
|
211
|
+
self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
|
207
212
|
compute_dp_attention_world_info(
|
208
213
|
server_args.enable_dp_attention,
|
209
214
|
self.tp_rank,
|
@@ -320,13 +325,14 @@ class Scheduler(
|
|
320
325
|
set_random_seed(self.random_seed)
|
321
326
|
|
322
327
|
# Print debug info
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
328
|
+
if tp_rank == 0:
|
329
|
+
logger.info(
|
330
|
+
f"max_total_num_tokens={self.max_total_num_tokens}, "
|
331
|
+
f"chunked_prefill_size={server_args.chunked_prefill_size}, "
|
332
|
+
f"max_prefill_tokens={self.max_prefill_tokens}, "
|
333
|
+
f"max_running_requests={self.max_running_requests}, "
|
334
|
+
f"context_len={self.model_config.context_len}"
|
335
|
+
)
|
330
336
|
|
331
337
|
# Init memory pool and cache
|
332
338
|
self.init_memory_pool_and_cache()
|
@@ -413,6 +419,8 @@ class Scheduler(
|
|
413
419
|
self.profiler_id: Optional[str] = None
|
414
420
|
self.profiler_target_forward_ct: Optional[int] = None
|
415
421
|
|
422
|
+
self.forward_sleep_time = None
|
423
|
+
|
416
424
|
# Init metrics stats
|
417
425
|
self.init_metrics()
|
418
426
|
|
@@ -435,6 +443,7 @@ class Scheduler(
|
|
435
443
|
(GetWeightsByNameReqInput, self.get_weights_by_name),
|
436
444
|
(ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
|
437
445
|
(ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
|
446
|
+
(SlowDownReqInput, self.slow_down),
|
438
447
|
(ProfileReq, self.profile),
|
439
448
|
(GetInternalStateReq, self.get_internal_state),
|
440
449
|
(SetInternalStateReq, self.set_internal_state),
|
@@ -451,17 +460,7 @@ class Scheduler(
|
|
451
460
|
def init_tokenizer(self):
|
452
461
|
server_args = self.server_args
|
453
462
|
|
454
|
-
self.model_config = ModelConfig(
|
455
|
-
server_args.model_path,
|
456
|
-
trust_remote_code=server_args.trust_remote_code,
|
457
|
-
revision=server_args.revision,
|
458
|
-
context_length=server_args.context_length,
|
459
|
-
model_override_args=server_args.json_model_override_args,
|
460
|
-
is_embedding=server_args.is_embedding,
|
461
|
-
enable_multimodal=server_args.enable_multimodal,
|
462
|
-
dtype=server_args.dtype,
|
463
|
-
quantization=server_args.quantization,
|
464
|
-
)
|
463
|
+
self.model_config = ModelConfig.from_server_args(server_args)
|
465
464
|
self.is_generation = self.model_config.is_generation
|
466
465
|
|
467
466
|
if server_args.skip_tokenizer_init:
|
@@ -475,7 +474,7 @@ class Scheduler(
|
|
475
474
|
revision=server_args.revision,
|
476
475
|
use_fast=not server_args.disable_fast_image_processor,
|
477
476
|
)
|
478
|
-
self.tokenizer = self.processor
|
477
|
+
self.tokenizer = get_tokenizer_from_processor(self.processor)
|
479
478
|
else:
|
480
479
|
self.tokenizer = get_tokenizer(
|
481
480
|
server_args.tokenizer_path,
|
@@ -498,6 +497,7 @@ class Scheduler(
|
|
498
497
|
self.tree_cache = ChunkCache(
|
499
498
|
req_to_token_pool=self.req_to_token_pool,
|
500
499
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
500
|
+
page_size=self.page_size,
|
501
501
|
)
|
502
502
|
else:
|
503
503
|
if self.enable_hierarchical_cache:
|
@@ -531,10 +531,6 @@ class Scheduler(
|
|
531
531
|
)
|
532
532
|
|
533
533
|
def init_metrics(self):
|
534
|
-
# The largest prefill length of a single request
|
535
|
-
self._largest_prefill_len: int = 0
|
536
|
-
# The largest context length (prefill + generation) of a single request
|
537
|
-
self._largest_prefill_decode_len: int = 0
|
538
534
|
self.last_gen_throughput: float = 0.0
|
539
535
|
self.last_input_throughput: float = 0.0
|
540
536
|
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
|
@@ -720,7 +716,7 @@ class Scheduler(
|
|
720
716
|
server_is_idle = False
|
721
717
|
result = self.run_batch(self.cur_batch)
|
722
718
|
|
723
|
-
# send the outputs to the next step
|
719
|
+
# (last rank) send the outputs to the next step
|
724
720
|
if self.pp_group.is_last_rank:
|
725
721
|
if self.cur_batch:
|
726
722
|
next_token_ids, bids[mb_id] = (
|
@@ -755,24 +751,25 @@ class Scheduler(
|
|
755
751
|
extend_input_len_per_req=None,
|
756
752
|
extend_logprob_start_len_per_req=None,
|
757
753
|
bid=bids[next_mb_id],
|
754
|
+
can_run_cuda_graph=result.can_run_cuda_graph,
|
758
755
|
)
|
759
756
|
self.process_batch_result(mbs[next_mb_id], output_result)
|
760
757
|
last_mbs[next_mb_id] = mbs[next_mb_id]
|
761
758
|
|
762
|
-
#
|
759
|
+
# (not last rank)
|
763
760
|
if not self.pp_group.is_last_rank:
|
764
761
|
if self.cur_batch:
|
765
762
|
bids[mb_id] = result.bid
|
763
|
+
# carry the outputs to the next stage
|
764
|
+
# send the outputs from the last round to let the next stage worker run post processing
|
766
765
|
if pp_outputs:
|
767
|
-
# send the outputs from the last round to let the next stage worker run post processing
|
768
766
|
self.pp_group.send_tensor_dict(
|
769
767
|
pp_outputs.tensors,
|
770
768
|
all_gather_group=self.attn_tp_group,
|
771
769
|
)
|
772
770
|
|
773
|
-
if not self.pp_group.is_last_rank:
|
774
771
|
# send out reqs to the next stage
|
775
|
-
dp_offset = self.
|
772
|
+
dp_offset = self.attn_dp_rank * self.attn_tp_size
|
776
773
|
if self.attn_tp_rank == 0:
|
777
774
|
point_to_point_pyobj(
|
778
775
|
recv_reqs,
|
@@ -819,7 +816,7 @@ class Scheduler(
|
|
819
816
|
recv_reqs = None
|
820
817
|
else:
|
821
818
|
if self.attn_tp_rank == 0:
|
822
|
-
dp_offset = self.
|
819
|
+
dp_offset = self.attn_dp_rank * self.attn_tp_size
|
823
820
|
recv_reqs = point_to_point_pyobj(
|
824
821
|
[],
|
825
822
|
self.pp_rank * self.tp_size + dp_offset,
|
@@ -907,18 +904,9 @@ class Scheduler(
|
|
907
904
|
fake_input_ids = [1] * seq_length
|
908
905
|
recv_req.input_ids = fake_input_ids
|
909
906
|
|
910
|
-
|
911
|
-
|
912
|
-
|
913
|
-
not self.server_args.enable_custom_logit_processor
|
914
|
-
and custom_logit_processor is not None
|
915
|
-
):
|
916
|
-
logger.warning(
|
917
|
-
"The SGLang server is not configured to enable custom logit processor."
|
918
|
-
"The custom logit processor passed in will be ignored."
|
919
|
-
"Please set --enable-custom-logits-processor to enable this feature."
|
920
|
-
)
|
921
|
-
custom_logit_processor = None
|
907
|
+
if recv_req.bootstrap_port is None:
|
908
|
+
# Use default bootstrap port
|
909
|
+
recv_req.bootstrap_port = self.server_args.disaggregation_bootstrap_port
|
922
910
|
|
923
911
|
req = Req(
|
924
912
|
recv_req.rid,
|
@@ -931,7 +919,7 @@ class Scheduler(
|
|
931
919
|
stream=recv_req.stream,
|
932
920
|
lora_path=recv_req.lora_path,
|
933
921
|
input_embeds=recv_req.input_embeds,
|
934
|
-
custom_logit_processor=custom_logit_processor,
|
922
|
+
custom_logit_processor=recv_req.custom_logit_processor,
|
935
923
|
return_hidden_states=recv_req.return_hidden_states,
|
936
924
|
eos_token_ids=self.model_config.hf_eos_token_id,
|
937
925
|
bootstrap_host=recv_req.bootstrap_host,
|
@@ -1037,9 +1025,11 @@ class Scheduler(
|
|
1037
1025
|
elif req.sampling_params.structural_tag:
|
1038
1026
|
key = ("structural_tag", req.sampling_params.structural_tag)
|
1039
1027
|
|
1040
|
-
|
1041
|
-
|
1042
|
-
|
1028
|
+
value, cache_hit = self.grammar_backend.get_cached_or_future_value(key)
|
1029
|
+
req.grammar = value
|
1030
|
+
|
1031
|
+
if not cache_hit:
|
1032
|
+
req.grammar_key = key
|
1043
1033
|
add_to_grammar_queue = True
|
1044
1034
|
|
1045
1035
|
if add_to_grammar_queue:
|
@@ -1129,9 +1119,6 @@ class Scheduler(
|
|
1129
1119
|
self.token_to_kv_pool_allocator.available_size()
|
1130
1120
|
+ self.tree_cache.evictable_size()
|
1131
1121
|
)
|
1132
|
-
self._largest_prefill_len = max(
|
1133
|
-
self._largest_prefill_len, adder.log_input_tokens
|
1134
|
-
)
|
1135
1122
|
|
1136
1123
|
num_new_seq = len(can_run_list)
|
1137
1124
|
f = (
|
@@ -1169,7 +1156,9 @@ class Scheduler(
|
|
1169
1156
|
|
1170
1157
|
self.metrics_collector.log_stats(self.stats)
|
1171
1158
|
|
1172
|
-
def log_decode_stats(
|
1159
|
+
def log_decode_stats(
|
1160
|
+
self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
|
1161
|
+
):
|
1173
1162
|
batch = running_batch or self.running_batch
|
1174
1163
|
|
1175
1164
|
gap_latency = time.time() - self.last_decode_stats_tic
|
@@ -1209,6 +1198,7 @@ class Scheduler(
|
|
1209
1198
|
msg += f"pre-allocated usage: {self.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
|
1210
1199
|
|
1211
1200
|
msg += (
|
1201
|
+
f"cuda graph: {can_run_cuda_graph}, "
|
1212
1202
|
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
|
1213
1203
|
f"#queue-req: {len(self.waiting_queue)}"
|
1214
1204
|
)
|
@@ -1221,6 +1211,7 @@ class Scheduler(
|
|
1221
1211
|
self.stats.cache_hit_rate = 0.0
|
1222
1212
|
self.stats.gen_throughput = self.last_gen_throughput
|
1223
1213
|
self.stats.num_queue_reqs = len(self.waiting_queue)
|
1214
|
+
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
1224
1215
|
self.stats.spec_accept_length = spec_accept_length
|
1225
1216
|
self.metrics_collector.log_stats(self.stats)
|
1226
1217
|
|
@@ -1242,9 +1233,7 @@ class Scheduler(
|
|
1242
1233
|
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
|
1243
1234
|
f"{self.tree_cache.evictable_size()=}\n"
|
1244
1235
|
)
|
1245
|
-
|
1246
|
-
if crash_on_warnings():
|
1247
|
-
raise ValueError(msg)
|
1236
|
+
raise ValueError(msg)
|
1248
1237
|
|
1249
1238
|
if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
|
1250
1239
|
msg = (
|
@@ -1252,9 +1241,7 @@ class Scheduler(
|
|
1252
1241
|
f"available_size={len(self.req_to_token_pool.free_slots)}, "
|
1253
1242
|
f"total_size={self.req_to_token_pool.size}\n"
|
1254
1243
|
)
|
1255
|
-
|
1256
|
-
if crash_on_warnings():
|
1257
|
-
raise ValueError(msg)
|
1244
|
+
raise ValueError(msg)
|
1258
1245
|
|
1259
1246
|
if (
|
1260
1247
|
self.enable_metrics
|
@@ -1272,6 +1259,7 @@ class Scheduler(
|
|
1272
1259
|
self.stats.token_usage = num_used / self.max_total_num_tokens
|
1273
1260
|
self.stats.gen_throughput = 0
|
1274
1261
|
self.stats.num_queue_reqs = len(self.waiting_queue)
|
1262
|
+
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
1275
1263
|
self.metrics_collector.log_stats(self.stats)
|
1276
1264
|
|
1277
1265
|
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
|
@@ -1342,7 +1330,7 @@ class Scheduler(
|
|
1342
1330
|
return None
|
1343
1331
|
|
1344
1332
|
running_bs = len(self.running_batch.reqs)
|
1345
|
-
#
|
1333
|
+
# Ignore the check if self.chunked_req is not None.
|
1346
1334
|
# In the non-PP case, when self.chunked_req is not None, num_allocatable_reqs should always be greater than 0,
|
1347
1335
|
# as the space for the chunked request has just been released.
|
1348
1336
|
# In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
|
@@ -1527,16 +1515,20 @@ class Scheduler(
|
|
1527
1515
|
):
|
1528
1516
|
self.stop_profile()
|
1529
1517
|
|
1518
|
+
if self.forward_sleep_time is not None:
|
1519
|
+
logger.info(f"Scheduler.run_batch sleep {self.forward_sleep_time}s")
|
1520
|
+
time.sleep(self.forward_sleep_time)
|
1521
|
+
|
1530
1522
|
# Run forward
|
1531
1523
|
if self.is_generation:
|
1532
1524
|
if self.spec_algorithm.is_none():
|
1533
1525
|
model_worker_batch = batch.get_model_worker_batch()
|
1534
1526
|
if self.pp_group.is_last_rank:
|
1535
|
-
logits_output, next_token_ids = (
|
1527
|
+
logits_output, next_token_ids, can_run_cuda_graph = (
|
1536
1528
|
self.tp_worker.forward_batch_generation(model_worker_batch)
|
1537
1529
|
)
|
1538
1530
|
else:
|
1539
|
-
pp_hidden_states_proxy_tensors, _ = (
|
1531
|
+
pp_hidden_states_proxy_tensors, _, can_run_cuda_graph = (
|
1540
1532
|
self.tp_worker.forward_batch_generation(model_worker_batch)
|
1541
1533
|
)
|
1542
1534
|
bid = model_worker_batch.bid
|
@@ -1546,6 +1538,7 @@ class Scheduler(
|
|
1546
1538
|
next_token_ids,
|
1547
1539
|
bid,
|
1548
1540
|
num_accepted_tokens,
|
1541
|
+
can_run_cuda_graph,
|
1549
1542
|
) = self.draft_worker.forward_batch_speculative_generation(batch)
|
1550
1543
|
self.spec_num_total_accepted_tokens += (
|
1551
1544
|
num_accepted_tokens + batch.batch_size()
|
@@ -1579,6 +1572,7 @@ class Scheduler(
|
|
1579
1572
|
extend_input_len_per_req=extend_input_len_per_req,
|
1580
1573
|
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
|
1581
1574
|
bid=bid,
|
1575
|
+
can_run_cuda_graph=can_run_cuda_graph,
|
1582
1576
|
)
|
1583
1577
|
else: # embedding or reward model
|
1584
1578
|
model_worker_batch = batch.get_model_worker_batch()
|
@@ -1601,14 +1595,9 @@ class Scheduler(
|
|
1601
1595
|
elif batch.forward_mode.is_idle():
|
1602
1596
|
if self.enable_overlap:
|
1603
1597
|
self.tp_worker.resolve_last_batch_result(launch_done)
|
1604
|
-
|
1605
|
-
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
1606
|
-
self.current_stream.synchronize()
|
1607
|
-
batch.next_batch_sampling_info.sampling_info_done.set()
|
1598
|
+
self.set_next_batch_sampling_info_done(batch)
|
1608
1599
|
elif batch.forward_mode.is_dummy_first():
|
1609
|
-
|
1610
|
-
self.current_stream.synchronize()
|
1611
|
-
batch.next_batch_sampling_info.sampling_info_done.set()
|
1600
|
+
self.set_next_batch_sampling_info_done(batch)
|
1612
1601
|
|
1613
1602
|
if self.return_health_check_ct:
|
1614
1603
|
# Return some signal for the health check.
|
@@ -1622,6 +1611,7 @@ class Scheduler(
|
|
1622
1611
|
local_batch,
|
1623
1612
|
dp_size=self.server_args.dp_size,
|
1624
1613
|
attn_tp_size=self.attn_tp_size,
|
1614
|
+
moe_dense_tp_size=self.server_args.moe_dense_tp_size,
|
1625
1615
|
tp_cpu_group=self.tp_cpu_group,
|
1626
1616
|
get_idle_batch=self.get_idle_batch,
|
1627
1617
|
disable_cuda_graph=self.server_args.disable_cuda_graph,
|
@@ -1634,6 +1624,7 @@ class Scheduler(
|
|
1634
1624
|
local_batch: ScheduleBatch,
|
1635
1625
|
dp_size,
|
1636
1626
|
attn_tp_size: int,
|
1627
|
+
moe_dense_tp_size: Optional[int],
|
1637
1628
|
tp_cpu_group,
|
1638
1629
|
get_idle_batch,
|
1639
1630
|
disable_cuda_graph: bool,
|
@@ -1643,15 +1634,15 @@ class Scheduler(
|
|
1643
1634
|
# Check if other DP workers have running batches
|
1644
1635
|
if local_batch is None:
|
1645
1636
|
num_tokens = 0
|
1646
|
-
|
1637
|
+
num_tokens_for_logprob = 0
|
1647
1638
|
elif local_batch.forward_mode.is_decode():
|
1648
1639
|
num_tokens = local_batch.batch_size()
|
1649
1640
|
if not spec_algorithm.is_none() and spec_algorithm.is_eagle():
|
1650
1641
|
num_tokens = num_tokens * speculative_num_draft_tokens
|
1651
|
-
|
1642
|
+
num_tokens_for_logprob = num_tokens
|
1652
1643
|
else:
|
1653
1644
|
num_tokens = local_batch.extend_num_tokens
|
1654
|
-
|
1645
|
+
num_tokens_for_logprob = sum(
|
1655
1646
|
[
|
1656
1647
|
# We should have at least 1 token for sample in every case.
|
1657
1648
|
max(extend_len - logprob_start_len, 1)
|
@@ -1678,7 +1669,7 @@ class Scheduler(
|
|
1678
1669
|
[
|
1679
1670
|
num_tokens,
|
1680
1671
|
can_cuda_graph,
|
1681
|
-
|
1672
|
+
num_tokens_for_logprob,
|
1682
1673
|
is_extend_in_batch,
|
1683
1674
|
],
|
1684
1675
|
dtype=torch.int64,
|
@@ -1701,8 +1692,15 @@ class Scheduler(
|
|
1701
1692
|
local_batch = get_idle_batch()
|
1702
1693
|
|
1703
1694
|
if local_batch is not None:
|
1704
|
-
|
1705
|
-
|
1695
|
+
# TODO: handle the case when moe_dense_tp_size != 1
|
1696
|
+
if moe_dense_tp_size == 1 and global_server_args_dict["enable_dp_lm_head"]:
|
1697
|
+
local_batch.global_num_tokens = [num_tokens]
|
1698
|
+
local_batch.global_num_tokens_for_logprob = [num_tokens_for_logprob]
|
1699
|
+
else:
|
1700
|
+
local_batch.global_num_tokens = global_num_tokens
|
1701
|
+
local_batch.global_num_tokens_for_logprob = (
|
1702
|
+
global_num_tokens_for_logprob
|
1703
|
+
)
|
1706
1704
|
|
1707
1705
|
# Check forward mode for cuda graph
|
1708
1706
|
if not disable_cuda_graph:
|
@@ -1728,11 +1726,17 @@ class Scheduler(
|
|
1728
1726
|
"""Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
|
1729
1727
|
|
1730
1728
|
num_ready_reqs = 0
|
1729
|
+
num_abort_reqs = 0
|
1731
1730
|
for req in self.grammar_queue:
|
1732
1731
|
try:
|
1733
|
-
req.grammar = req.grammar.result(timeout=0.
|
1732
|
+
req.grammar = req.grammar.result(timeout=0.03)
|
1733
|
+
if req.grammar:
|
1734
|
+
self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
|
1734
1735
|
num_ready_reqs += 1
|
1735
1736
|
except futures._base.TimeoutError:
|
1737
|
+
req.grammar_wait_ct += 1
|
1738
|
+
if req.grammar_wait_ct > GRAMMAR_TIMEOUT / 0.03:
|
1739
|
+
num_abort_reqs = 1
|
1736
1740
|
break
|
1737
1741
|
|
1738
1742
|
if self.server_args.enable_dp_attention:
|
@@ -1744,18 +1748,39 @@ class Scheduler(
|
|
1744
1748
|
|
1745
1749
|
if tp_size > 1:
|
1746
1750
|
# Sync across TP ranks to make sure they have the same number of ready requests
|
1747
|
-
tensor = torch.tensor(num_ready_reqs, dtype=torch.int32)
|
1751
|
+
tensor = torch.tensor([num_ready_reqs, num_abort_reqs], dtype=torch.int32)
|
1748
1752
|
torch.distributed.all_reduce(
|
1749
1753
|
tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group
|
1750
1754
|
)
|
1751
|
-
num_ready_reqs_max = tensor.
|
1755
|
+
num_ready_reqs_max, num_abort_reqs_max = tensor.tolist()
|
1756
|
+
|
1752
1757
|
for i in range(num_ready_reqs, num_ready_reqs_max):
|
1753
|
-
|
1754
|
-
|
1758
|
+
req = self.grammar_queue[i]
|
1759
|
+
req.grammar = req.grammar.result()
|
1760
|
+
if req.grammar:
|
1761
|
+
self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
|
1762
|
+
|
1763
|
+
for i in range(num_ready_reqs, num_ready_reqs + num_abort_reqs_max):
|
1764
|
+
req = self.grammar_queue[i]
|
1765
|
+
req.grammar.cancel()
|
1766
|
+
req.grammar = None
|
1767
|
+
error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}"
|
1768
|
+
logger.error(error_msg)
|
1769
|
+
req.finished_reason = FINISH_ABORT(
|
1770
|
+
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
|
1771
|
+
)
|
1772
|
+
num_ready_reqs = num_ready_reqs_max + num_abort_reqs_max
|
1755
1773
|
|
1756
1774
|
self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
|
1757
1775
|
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
|
1758
1776
|
|
1777
|
+
def set_next_batch_sampling_info_done(self, batch: ScheduleBatch):
|
1778
|
+
if batch.next_batch_sampling_info:
|
1779
|
+
if batch.next_batch_sampling_info.grammars is not None:
|
1780
|
+
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
1781
|
+
self.current_stream.synchronize()
|
1782
|
+
batch.next_batch_sampling_info.sampling_info_done.set()
|
1783
|
+
|
1759
1784
|
def watchdog_thread(self):
|
1760
1785
|
"""A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
|
1761
1786
|
self.watchdog_last_forward_ct = 0
|
@@ -1766,24 +1791,27 @@ class Scheduler(
|
|
1766
1791
|
if self.cur_batch is not None:
|
1767
1792
|
if self.watchdog_last_forward_ct == self.forward_ct:
|
1768
1793
|
if current > self.watchdog_last_time + self.watchdog_timeout:
|
1769
|
-
logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
|
1770
1794
|
break
|
1771
1795
|
else:
|
1772
1796
|
self.watchdog_last_forward_ct = self.forward_ct
|
1773
1797
|
self.watchdog_last_time = current
|
1774
1798
|
time.sleep(self.watchdog_timeout // 2)
|
1775
1799
|
|
1776
|
-
|
1777
|
-
|
1778
|
-
|
1779
|
-
|
1780
|
-
|
1781
|
-
|
1782
|
-
|
1783
|
-
|
1800
|
+
if not disable_request_logging():
|
1801
|
+
# Print batch size and memory pool info to check whether there are de-sync issues.
|
1802
|
+
logger.error(
|
1803
|
+
f"{self.cur_batch.batch_size()=}, "
|
1804
|
+
f"{self.cur_batch.reqs=}, "
|
1805
|
+
f"{self.token_to_kv_pool_allocator.available_size()=}, "
|
1806
|
+
f"{self.tree_cache.evictable_size()=}, "
|
1807
|
+
)
|
1808
|
+
|
1784
1809
|
pyspy_dump_schedulers()
|
1810
|
+
logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
|
1785
1811
|
print(file=sys.stderr, flush=True)
|
1786
1812
|
print(file=sys.stdout, flush=True)
|
1813
|
+
|
1814
|
+
# Wait for some time so that the parent process can print the error.
|
1787
1815
|
time.sleep(5)
|
1788
1816
|
self.parent_process.send_signal(signal.SIGQUIT)
|
1789
1817
|
|
@@ -1915,25 +1943,30 @@ class Scheduler(
|
|
1915
1943
|
)
|
1916
1944
|
|
1917
1945
|
def abort_request(self, recv_req: AbortReq):
|
1946
|
+
# TODO(lmzheng): abort the requests in the grammar queue.
|
1947
|
+
|
1918
1948
|
# Delete requests in the waiting queue
|
1919
1949
|
to_del = []
|
1920
1950
|
for i, req in enumerate(self.waiting_queue):
|
1921
1951
|
if req.rid.startswith(recv_req.rid):
|
1922
1952
|
to_del.append(i)
|
1923
|
-
break
|
1924
1953
|
|
1925
1954
|
# Sort in reverse order to avoid index issues when deleting
|
1926
|
-
for i in
|
1955
|
+
for i in reversed(to_del):
|
1927
1956
|
req = self.waiting_queue.pop(i)
|
1957
|
+
self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
|
1928
1958
|
logger.debug(f"Abort queued request. {req.rid=}")
|
1929
|
-
return
|
1930
1959
|
|
1931
1960
|
# Delete requests in the running batch
|
1932
|
-
|
1961
|
+
if self.cur_batch is self.running_batch or self.cur_batch is None:
|
1962
|
+
reqs = self.running_batch.reqs
|
1963
|
+
else:
|
1964
|
+
reqs = self.running_batch.reqs + self.cur_batch.reqs
|
1965
|
+
|
1966
|
+
for req in reqs:
|
1933
1967
|
if req.rid.startswith(recv_req.rid) and not req.finished():
|
1934
1968
|
logger.debug(f"Abort running request. {req.rid=}")
|
1935
1969
|
req.to_abort = True
|
1936
|
-
return
|
1937
1970
|
|
1938
1971
|
def _pause_engine(self) -> Tuple[List[Req], int]:
|
1939
1972
|
raise NotImplementedError()
|
@@ -2002,6 +2035,13 @@ class Scheduler(
|
|
2002
2035
|
del self.stashed_model_static_state
|
2003
2036
|
return ResumeMemoryOccupationReqOutput()
|
2004
2037
|
|
2038
|
+
def slow_down(self, recv_req: SlowDownReqInput):
|
2039
|
+
t = recv_req.forward_sleep_time
|
2040
|
+
if t is not None and t <= 0:
|
2041
|
+
t = None
|
2042
|
+
self.forward_sleep_time = t
|
2043
|
+
return SlowDownReqOutput()
|
2044
|
+
|
2005
2045
|
def profile(self, recv_req: ProfileReq):
|
2006
2046
|
if recv_req.type == ProfileReqType.START_PROFILE:
|
2007
2047
|
return self.start_profile(
|
@@ -2147,8 +2187,8 @@ class Scheduler(
|
|
2147
2187
|
|
2148
2188
|
def get_print_prefix(self):
|
2149
2189
|
prefix = ""
|
2150
|
-
if self.
|
2151
|
-
prefix += f" DP{self.
|
2190
|
+
if self.attn_dp_rank is not None:
|
2191
|
+
prefix += f" DP{self.attn_dp_rank}"
|
2152
2192
|
if self.server_args.tp_size > 1:
|
2153
2193
|
prefix += f" TP{self.tp_rank}"
|
2154
2194
|
if self.pp_size > 1:
|