sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +10 -8
- sglang/bench_one_batch.py +7 -6
- sglang/bench_one_batch_server.py +157 -21
- sglang/bench_serving.py +137 -59
- sglang/compile_deep_gemm.py +5 -5
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +78 -78
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +2 -2
- sglang/srt/configs/model_config.py +40 -28
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -43
- sglang/srt/conversation.py +49 -44
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +129 -135
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
- sglang/srt/disaggregation/fake/conn.py +3 -13
- sglang/srt/disaggregation/kv_events.py +357 -0
- sglang/srt/disaggregation/mini_lb.py +57 -24
- sglang/srt/disaggregation/mooncake/conn.py +238 -122
- sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
- sglang/srt/disaggregation/nixl/conn.py +10 -19
- sglang/srt/disaggregation/prefill.py +132 -47
- sglang/srt/disaggregation/utils.py +123 -6
- sglang/srt/distributed/utils.py +3 -3
- sglang/srt/entrypoints/EngineBase.py +5 -0
- sglang/srt/entrypoints/engine.py +44 -9
- sglang/srt/entrypoints/http_server.py +23 -6
- sglang/srt/entrypoints/http_server_engine.py +5 -2
- sglang/srt/function_call/base_format_detector.py +250 -0
- sglang/srt/function_call/core_types.py +34 -0
- sglang/srt/function_call/deepseekv3_detector.py +157 -0
- sglang/srt/function_call/ebnf_composer.py +234 -0
- sglang/srt/function_call/function_call_parser.py +175 -0
- sglang/srt/function_call/llama32_detector.py +74 -0
- sglang/srt/function_call/mistral_detector.py +84 -0
- sglang/srt/function_call/pythonic_detector.py +163 -0
- sglang/srt/function_call/qwen25_detector.py +67 -0
- sglang/srt/function_call/utils.py +35 -0
- sglang/srt/hf_transformers_utils.py +46 -7
- sglang/srt/layers/attention/aiter_backend.py +513 -0
- sglang/srt/layers/attention/flashattention_backend.py +64 -18
- sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
- sglang/srt/layers/attention/flashmla_backend.py +340 -78
- sglang/srt/layers/attention/triton_backend.py +3 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/utils.py +6 -4
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/communicator.py +451 -0
- sglang/srt/layers/dp_attention.py +61 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/cutlass_moe.py +207 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
- sglang/srt/layers/moe/ep_moe/layer.py +105 -51
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
- sglang/srt/layers/moe/topk.py +67 -10
- sglang/srt/layers/multimodal.py +70 -0
- sglang/srt/layers/quantization/__init__.py +8 -3
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/deep_gemm.py +77 -74
- sglang/srt/layers/quantization/fp8.py +92 -2
- sglang/srt/layers/quantization/fp8_kernel.py +3 -3
- sglang/srt/layers/quantization/fp8_utils.py +6 -0
- sglang/srt/layers/quantization/gptq.py +298 -6
- sglang/srt/layers/quantization/int8_kernel.py +20 -7
- sglang/srt/layers/quantization/qoq.py +244 -0
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +2 -4
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/deepseek_eplb.py +278 -0
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/eplb_manager.py +55 -0
- sglang/srt/managers/expert_distribution.py +704 -56
- sglang/srt/managers/expert_location.py +394 -0
- sglang/srt/managers/expert_location_dispatch.py +91 -0
- sglang/srt/managers/io_struct.py +19 -4
- sglang/srt/managers/mm_utils.py +294 -140
- sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
- sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
- sglang/srt/managers/multimodal_processors/internvl.py +14 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
- sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
- sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
- sglang/srt/managers/schedule_batch.py +122 -42
- sglang/srt/managers/schedule_policy.py +1 -5
- sglang/srt/managers/scheduler.py +205 -138
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +232 -58
- sglang/srt/managers/tp_worker.py +12 -9
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/base_prefix_cache.py +3 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -1
- sglang/srt/mem_cache/hiradix_cache.py +4 -4
- sglang/srt/mem_cache/memory_pool.py +76 -52
- sglang/srt/mem_cache/multimodal_cache.py +45 -0
- sglang/srt/mem_cache/radix_cache.py +58 -5
- sglang/srt/metrics/collector.py +314 -39
- sglang/srt/mm_utils.py +10 -0
- sglang/srt/model_executor/cuda_graph_runner.py +29 -19
- sglang/srt/model_executor/expert_location_updater.py +422 -0
- sglang/srt/model_executor/forward_batch_info.py +5 -1
- sglang/srt/model_executor/model_runner.py +163 -68
- sglang/srt/model_loader/loader.py +10 -6
- sglang/srt/models/clip.py +5 -1
- sglang/srt/models/deepseek_janus_pro.py +2 -2
- sglang/srt/models/deepseek_v2.py +308 -351
- sglang/srt/models/exaone.py +8 -3
- sglang/srt/models/gemma3_mm.py +70 -33
- sglang/srt/models/llama.py +2 -0
- sglang/srt/models/llama4.py +15 -8
- sglang/srt/models/llava.py +258 -7
- sglang/srt/models/mimo_mtp.py +220 -0
- sglang/srt/models/minicpmo.py +5 -12
- sglang/srt/models/mistral.py +71 -1
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/mllama.py +3 -3
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/qwen2.py +95 -26
- sglang/srt/models/qwen2_5_vl.py +8 -0
- sglang/srt/models/qwen2_moe.py +330 -60
- sglang/srt/models/qwen2_vl.py +6 -0
- sglang/srt/models/qwen3.py +52 -10
- sglang/srt/models/qwen3_moe.py +411 -48
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/siglip.py +294 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/openai_api/adapter.py +58 -20
- sglang/srt/openai_api/protocol.py +6 -8
- sglang/srt/operations.py +154 -0
- sglang/srt/operations_strategy.py +31 -0
- sglang/srt/reasoning_parser.py +3 -3
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +4 -56
- sglang/srt/sampling/sampling_params.py +2 -2
- sglang/srt/server_args.py +162 -22
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +138 -7
- sglang/srt/speculative/eagle_worker.py +69 -21
- sglang/srt/utils.py +74 -17
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_cutlass_moe.py +278 -0
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +55 -14
- sglang/utils.py +3 -3
- sglang/version.py +1 -1
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
- sglang/srt/function_call_parser.py +0 -858
- sglang/srt/platforms/interface.py +0 -371
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -20,7 +20,6 @@ import signal
|
|
20
20
|
import sys
|
21
21
|
import threading
|
22
22
|
import time
|
23
|
-
import warnings
|
24
23
|
from collections import defaultdict, deque
|
25
24
|
from concurrent import futures
|
26
25
|
from dataclasses import dataclass
|
@@ -42,14 +41,17 @@ from sglang.srt.disaggregation.decode import (
|
|
42
41
|
DecodeTransferQueue,
|
43
42
|
SchedulerDisaggregationDecodeMixin,
|
44
43
|
)
|
44
|
+
from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
|
45
45
|
from sglang.srt.disaggregation.prefill import (
|
46
46
|
PrefillBootstrapQueue,
|
47
47
|
SchedulerDisaggregationPrefillMixin,
|
48
48
|
)
|
49
49
|
from sglang.srt.disaggregation.utils import (
|
50
50
|
DisaggregationMode,
|
51
|
+
MetadataBuffers,
|
51
52
|
ReqToMetadataIdxAllocator,
|
52
53
|
TransferBackend,
|
54
|
+
prepare_abort,
|
53
55
|
)
|
54
56
|
from sglang.srt.distributed import get_pp_group, get_world_group
|
55
57
|
from sglang.srt.hf_transformers_utils import (
|
@@ -59,7 +61,10 @@ from sglang.srt.hf_transformers_utils import (
|
|
59
61
|
)
|
60
62
|
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
61
63
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
62
|
-
from sglang.srt.managers.expert_distribution import
|
64
|
+
from sglang.srt.managers.expert_distribution import (
|
65
|
+
ExpertDistributionRecorder,
|
66
|
+
get_global_expert_distribution_recorder,
|
67
|
+
)
|
63
68
|
from sglang.srt.managers.io_struct import (
|
64
69
|
AbortReq,
|
65
70
|
CloseSessionReqInput,
|
@@ -98,6 +103,7 @@ from sglang.srt.managers.io_struct import (
|
|
98
103
|
UpdateWeightsFromTensorReqInput,
|
99
104
|
UpdateWeightsFromTensorReqOutput,
|
100
105
|
)
|
106
|
+
from sglang.srt.managers.mm_utils import init_embedding_cache
|
101
107
|
from sglang.srt.managers.schedule_batch import (
|
102
108
|
FINISH_ABORT,
|
103
109
|
MultimodalInputs,
|
@@ -121,11 +127,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
|
121
127
|
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
122
128
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
123
129
|
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
124
|
-
from sglang.srt.model_executor.forward_batch_info import
|
125
|
-
ForwardBatch,
|
126
|
-
ForwardMode,
|
127
|
-
PPProxyTensors,
|
128
|
-
)
|
130
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
|
129
131
|
from sglang.srt.reasoning_parser import ReasoningParser
|
130
132
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
131
133
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
@@ -134,7 +136,7 @@ from sglang.srt.utils import (
|
|
134
136
|
DynamicGradMode,
|
135
137
|
broadcast_pyobj,
|
136
138
|
configure_logger,
|
137
|
-
|
139
|
+
disable_request_logging,
|
138
140
|
get_bool_env_var,
|
139
141
|
get_zmq_socket,
|
140
142
|
kill_itself_when_parent_died,
|
@@ -146,13 +148,12 @@ from sglang.srt.utils import (
|
|
146
148
|
)
|
147
149
|
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
|
148
150
|
|
149
|
-
expert_distribution_recorder = ExpertDistributionRecorder()
|
150
|
-
|
151
151
|
logger = logging.getLogger(__name__)
|
152
152
|
|
153
153
|
# Test retract decode for debugging purposes
|
154
154
|
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
|
155
155
|
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
|
156
|
+
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
|
156
157
|
|
157
158
|
|
158
159
|
@dataclass
|
@@ -163,6 +164,7 @@ class GenerationBatchResult:
|
|
163
164
|
extend_input_len_per_req: List[int]
|
164
165
|
extend_logprob_start_len_per_req: List[int]
|
165
166
|
bid: int
|
167
|
+
can_run_cuda_graph: bool
|
166
168
|
|
167
169
|
|
168
170
|
@dataclass
|
@@ -200,6 +202,7 @@ class Scheduler(
|
|
200
202
|
self.enable_overlap = not server_args.disable_overlap_schedule
|
201
203
|
self.skip_tokenizer_init = server_args.skip_tokenizer_init
|
202
204
|
self.enable_metrics = server_args.enable_metrics
|
205
|
+
self.enable_kv_cache_events = server_args.kv_events_config is not None
|
203
206
|
self.stream_interval = server_args.stream_interval
|
204
207
|
self.spec_algorithm = SpeculativeAlgorithm.from_string(
|
205
208
|
server_args.speculative_algorithm
|
@@ -207,9 +210,9 @@ class Scheduler(
|
|
207
210
|
self.gpu_id = gpu_id
|
208
211
|
self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
|
209
212
|
self.page_size = server_args.page_size
|
210
|
-
|
211
213
|
# Distributed rank info
|
212
|
-
self.
|
214
|
+
self.dp_size = server_args.dp_size
|
215
|
+
self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
|
213
216
|
compute_dp_attention_world_info(
|
214
217
|
server_args.enable_dp_attention,
|
215
218
|
self.tp_rank,
|
@@ -326,13 +329,14 @@ class Scheduler(
|
|
326
329
|
set_random_seed(self.random_seed)
|
327
330
|
|
328
331
|
# Print debug info
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
332
|
+
if tp_rank == 0:
|
333
|
+
logger.info(
|
334
|
+
f"max_total_num_tokens={self.max_total_num_tokens}, "
|
335
|
+
f"chunked_prefill_size={server_args.chunked_prefill_size}, "
|
336
|
+
f"max_prefill_tokens={self.max_prefill_tokens}, "
|
337
|
+
f"max_running_requests={self.max_running_requests}, "
|
338
|
+
f"context_len={self.model_config.context_len}"
|
339
|
+
)
|
336
340
|
|
337
341
|
# Init memory pool and cache
|
338
342
|
self.init_memory_pool_and_cache()
|
@@ -349,8 +353,8 @@ class Scheduler(
|
|
349
353
|
self.forward_ct_decode = 0
|
350
354
|
self.num_generated_tokens = 0
|
351
355
|
self.num_prefill_tokens = 0
|
352
|
-
self.last_decode_stats_tic = time.
|
353
|
-
self.last_prefill_stats_tic = time.
|
356
|
+
self.last_decode_stats_tic = time.perf_counter()
|
357
|
+
self.last_prefill_stats_tic = time.perf_counter()
|
354
358
|
self.return_health_check_ct = 0
|
355
359
|
self.current_stream = torch.get_device_module(self.device).current_stream()
|
356
360
|
if self.device == "cpu":
|
@@ -423,6 +427,7 @@ class Scheduler(
|
|
423
427
|
|
424
428
|
# Init metrics stats
|
425
429
|
self.init_metrics()
|
430
|
+
self.init_kv_events(server_args.kv_events_config)
|
426
431
|
|
427
432
|
# Init request dispatcher
|
428
433
|
self._request_dispatcher = TypeBasedDispatcher(
|
@@ -516,6 +521,7 @@ class Scheduler(
|
|
516
521
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
517
522
|
page_size=self.page_size,
|
518
523
|
disable=server_args.disable_radix_cache,
|
524
|
+
enable_kv_cache_events=self.enable_kv_cache_events,
|
519
525
|
)
|
520
526
|
|
521
527
|
self.decode_mem_cache_buf_multiplier = (
|
@@ -531,10 +537,6 @@ class Scheduler(
|
|
531
537
|
)
|
532
538
|
|
533
539
|
def init_metrics(self):
|
534
|
-
# The largest prefill length of a single request
|
535
|
-
self._largest_prefill_len: int = 0
|
536
|
-
# The largest context length (prefill + generation) of a single request
|
537
|
-
self._largest_prefill_decode_len: int = 0
|
538
540
|
self.last_gen_throughput: float = 0.0
|
539
541
|
self.last_input_throughput: float = 0.0
|
540
542
|
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
|
@@ -552,6 +554,10 @@ class Scheduler(
|
|
552
554
|
},
|
553
555
|
)
|
554
556
|
|
557
|
+
def init_kv_events(self, kv_events_config: Optional[str]):
|
558
|
+
if self.enable_kv_cache_events:
|
559
|
+
self.kv_event_publisher = EventPublisherFactory.create(kv_events_config)
|
560
|
+
|
555
561
|
def init_disaggregation(self):
|
556
562
|
self.transfer_backend = TransferBackend(
|
557
563
|
self.server_args.disaggregation_transfer_backend
|
@@ -564,29 +570,28 @@ class Scheduler(
|
|
564
570
|
req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
|
565
571
|
buffer_size
|
566
572
|
)
|
567
|
-
|
568
|
-
# A list of metadata buffers. The shape is (b, metadata_size) where
|
569
|
-
# b corresponds to a max running requests. The last shape * dtype.itemsize
|
570
|
-
# should be larger than 64 bytes to work with RDMA, so we pad it.
|
571
|
-
output_id_buffer = torch.zeros(
|
572
|
-
(buffer_size, 16), dtype=aux_dtype, device="cpu"
|
573
|
-
)
|
574
|
-
metadata_buffers = [output_id_buffer]
|
573
|
+
self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
|
575
574
|
|
576
575
|
# The decode requests polling kv cache
|
577
576
|
self.disagg_decode_transfer_queue = DecodeTransferQueue(
|
578
577
|
gloo_group=self.attn_tp_cpu_group,
|
579
578
|
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
580
|
-
metadata_buffers=
|
579
|
+
metadata_buffers=self.disagg_metadata_buffers,
|
580
|
+
scheduler=self,
|
581
|
+
tree_cache=self.tree_cache,
|
581
582
|
)
|
582
583
|
|
583
584
|
# The decode requests pending for pre-allocation
|
584
585
|
self.disagg_decode_prealloc_queue = DecodePreallocQueue(
|
585
586
|
req_to_token_pool=self.req_to_token_pool,
|
586
587
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
588
|
+
draft_token_to_kv_pool=(
|
589
|
+
None
|
590
|
+
if self.draft_worker is None
|
591
|
+
else self.draft_worker.model_runner.token_to_kv_pool
|
592
|
+
),
|
587
593
|
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
588
|
-
metadata_buffers=
|
589
|
-
aux_dtype=aux_dtype,
|
594
|
+
metadata_buffers=self.disagg_metadata_buffers,
|
590
595
|
scheduler=self,
|
591
596
|
transfer_queue=self.disagg_decode_transfer_queue,
|
592
597
|
tree_cache=self.tree_cache,
|
@@ -606,20 +611,17 @@ class Scheduler(
|
|
606
611
|
req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
|
607
612
|
buffer_size
|
608
613
|
)
|
609
|
-
|
610
|
-
# A list of metadata buffers. The shape is (b, metadata_size) where
|
611
|
-
# b corresponds to a max running requests. The last shape * dtype.itemsize
|
612
|
-
# should be larger than 64 bytes to work with RDMA, so we pad it.
|
613
|
-
output_id_buffer = torch.zeros(
|
614
|
-
(buffer_size, 16), dtype=aux_dtype, device="cpu"
|
615
|
-
)
|
616
|
-
metadata_buffers = [output_id_buffer]
|
614
|
+
self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
|
617
615
|
|
618
616
|
self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
|
619
617
|
token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
|
618
|
+
draft_token_to_kv_pool=(
|
619
|
+
None
|
620
|
+
if self.draft_worker is None
|
621
|
+
else self.draft_worker.model_runner.token_to_kv_pool
|
622
|
+
),
|
620
623
|
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
621
|
-
metadata_buffers=
|
622
|
-
aux_dtype=aux_dtype,
|
624
|
+
metadata_buffers=self.disagg_metadata_buffers,
|
623
625
|
tp_rank=self.tp_rank,
|
624
626
|
tp_size=self.tp_size,
|
625
627
|
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
|
@@ -720,7 +722,7 @@ class Scheduler(
|
|
720
722
|
server_is_idle = False
|
721
723
|
result = self.run_batch(self.cur_batch)
|
722
724
|
|
723
|
-
# send the outputs to the next step
|
725
|
+
# (last rank) send the outputs to the next step
|
724
726
|
if self.pp_group.is_last_rank:
|
725
727
|
if self.cur_batch:
|
726
728
|
next_token_ids, bids[mb_id] = (
|
@@ -755,24 +757,25 @@ class Scheduler(
|
|
755
757
|
extend_input_len_per_req=None,
|
756
758
|
extend_logprob_start_len_per_req=None,
|
757
759
|
bid=bids[next_mb_id],
|
760
|
+
can_run_cuda_graph=result.can_run_cuda_graph,
|
758
761
|
)
|
759
762
|
self.process_batch_result(mbs[next_mb_id], output_result)
|
760
763
|
last_mbs[next_mb_id] = mbs[next_mb_id]
|
761
764
|
|
762
|
-
#
|
765
|
+
# (not last rank)
|
763
766
|
if not self.pp_group.is_last_rank:
|
764
767
|
if self.cur_batch:
|
765
768
|
bids[mb_id] = result.bid
|
769
|
+
# carry the outputs to the next stage
|
770
|
+
# send the outputs from the last round to let the next stage worker run post processing
|
766
771
|
if pp_outputs:
|
767
|
-
# send the outputs from the last round to let the next stage worker run post processing
|
768
772
|
self.pp_group.send_tensor_dict(
|
769
773
|
pp_outputs.tensors,
|
770
774
|
all_gather_group=self.attn_tp_group,
|
771
775
|
)
|
772
776
|
|
773
|
-
if not self.pp_group.is_last_rank:
|
774
777
|
# send out reqs to the next stage
|
775
|
-
dp_offset = self.
|
778
|
+
dp_offset = self.attn_dp_rank * self.attn_tp_size
|
776
779
|
if self.attn_tp_rank == 0:
|
777
780
|
point_to_point_pyobj(
|
778
781
|
recv_reqs,
|
@@ -819,7 +822,7 @@ class Scheduler(
|
|
819
822
|
recv_reqs = None
|
820
823
|
else:
|
821
824
|
if self.attn_tp_rank == 0:
|
822
|
-
dp_offset = self.
|
825
|
+
dp_offset = self.attn_dp_rank * self.attn_tp_size
|
823
826
|
recv_reqs = point_to_point_pyobj(
|
824
827
|
[],
|
825
828
|
self.pp_rank * self.tp_size + dp_offset,
|
@@ -907,19 +910,6 @@ class Scheduler(
|
|
907
910
|
fake_input_ids = [1] * seq_length
|
908
911
|
recv_req.input_ids = fake_input_ids
|
909
912
|
|
910
|
-
# Handle custom logit processor passed to the request
|
911
|
-
custom_logit_processor = recv_req.custom_logit_processor
|
912
|
-
if (
|
913
|
-
not self.server_args.enable_custom_logit_processor
|
914
|
-
and custom_logit_processor is not None
|
915
|
-
):
|
916
|
-
logger.warning(
|
917
|
-
"The SGLang server is not configured to enable custom logit processor."
|
918
|
-
"The custom logit processor passed in will be ignored."
|
919
|
-
"Please set --enable-custom-logits-processor to enable this feature."
|
920
|
-
)
|
921
|
-
custom_logit_processor = None
|
922
|
-
|
923
913
|
if recv_req.bootstrap_port is None:
|
924
914
|
# Use default bootstrap port
|
925
915
|
recv_req.bootstrap_port = self.server_args.disaggregation_bootstrap_port
|
@@ -935,7 +925,7 @@ class Scheduler(
|
|
935
925
|
stream=recv_req.stream,
|
936
926
|
lora_path=recv_req.lora_path,
|
937
927
|
input_embeds=recv_req.input_embeds,
|
938
|
-
custom_logit_processor=custom_logit_processor,
|
928
|
+
custom_logit_processor=recv_req.custom_logit_processor,
|
939
929
|
return_hidden_states=recv_req.return_hidden_states,
|
940
930
|
eos_token_ids=self.model_config.hf_eos_token_id,
|
941
931
|
bootstrap_host=recv_req.bootstrap_host,
|
@@ -944,6 +934,18 @@ class Scheduler(
|
|
944
934
|
)
|
945
935
|
req.tokenizer = self.tokenizer
|
946
936
|
|
937
|
+
if self.disaggregation_mode != DisaggregationMode.NULL:
|
938
|
+
# Invalid request for disaggregated mode
|
939
|
+
if recv_req.bootstrap_room is None:
|
940
|
+
error_message = (
|
941
|
+
f"Invalid request: Disaggregated request received without "
|
942
|
+
f"boostrap room id. {req.rid=}"
|
943
|
+
)
|
944
|
+
logger.error(error_message)
|
945
|
+
prepare_abort(req, error_message)
|
946
|
+
self.stream_output([req], req.return_logprob)
|
947
|
+
return
|
948
|
+
|
947
949
|
if (
|
948
950
|
recv_req.session_params is not None
|
949
951
|
and recv_req.session_params.id is not None
|
@@ -1041,19 +1043,21 @@ class Scheduler(
|
|
1041
1043
|
elif req.sampling_params.structural_tag:
|
1042
1044
|
key = ("structural_tag", req.sampling_params.structural_tag)
|
1043
1045
|
|
1044
|
-
|
1045
|
-
|
1046
|
-
|
1046
|
+
value, cache_hit = self.grammar_backend.get_cached_or_future_value(key)
|
1047
|
+
req.grammar = value
|
1048
|
+
|
1049
|
+
if not cache_hit:
|
1050
|
+
req.grammar_key = key
|
1047
1051
|
add_to_grammar_queue = True
|
1048
1052
|
|
1049
1053
|
if add_to_grammar_queue:
|
1050
|
-
req.queue_time_start = time.
|
1054
|
+
req.queue_time_start = time.perf_counter()
|
1051
1055
|
self.grammar_queue.append(req)
|
1052
1056
|
else:
|
1053
1057
|
self._add_request_to_queue(req)
|
1054
1058
|
|
1055
1059
|
def _add_request_to_queue(self, req: Req):
|
1056
|
-
req.queue_time_start = time.
|
1060
|
+
req.queue_time_start = time.perf_counter()
|
1057
1061
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
1058
1062
|
self.disagg_prefill_bootstrap_queue.add(req)
|
1059
1063
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
@@ -1061,8 +1065,11 @@ class Scheduler(
|
|
1061
1065
|
else:
|
1062
1066
|
self.waiting_queue.append(req)
|
1063
1067
|
|
1064
|
-
def _extend_requests_to_queue(self, reqs: List[Req]
|
1065
|
-
if self.disaggregation_mode == DisaggregationMode.
|
1068
|
+
def _extend_requests_to_queue(self, reqs: List[Req]):
|
1069
|
+
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
1070
|
+
self.disagg_prefill_bootstrap_queue.extend(reqs)
|
1071
|
+
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
1072
|
+
# If this is a decode server, we put the request to the decode pending prealloc queue
|
1066
1073
|
self.disagg_decode_prealloc_queue.extend(reqs)
|
1067
1074
|
else:
|
1068
1075
|
self.waiting_queue.extend(reqs)
|
@@ -1100,7 +1107,7 @@ class Scheduler(
|
|
1100
1107
|
req.finished_reason = FINISH_ABORT(
|
1101
1108
|
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
|
1102
1109
|
)
|
1103
|
-
req.queue_time_start = time.
|
1110
|
+
req.queue_time_start = time.perf_counter()
|
1104
1111
|
self.waiting_queue.append(req)
|
1105
1112
|
return
|
1106
1113
|
|
@@ -1124,8 +1131,8 @@ class Scheduler(
|
|
1124
1131
|
can_run_list: List[Req],
|
1125
1132
|
running_bs: int,
|
1126
1133
|
):
|
1127
|
-
gap_latency = time.
|
1128
|
-
self.last_prefill_stats_tic = time.
|
1134
|
+
gap_latency = time.perf_counter() - self.last_prefill_stats_tic
|
1135
|
+
self.last_prefill_stats_tic = time.perf_counter()
|
1129
1136
|
self.last_input_throughput = self.num_prefill_tokens / gap_latency
|
1130
1137
|
self.num_prefill_tokens = 0
|
1131
1138
|
|
@@ -1133,9 +1140,6 @@ class Scheduler(
|
|
1133
1140
|
self.token_to_kv_pool_allocator.available_size()
|
1134
1141
|
+ self.tree_cache.evictable_size()
|
1135
1142
|
)
|
1136
|
-
self._largest_prefill_len = max(
|
1137
|
-
self._largest_prefill_len, adder.log_input_tokens
|
1138
|
-
)
|
1139
1143
|
|
1140
1144
|
num_new_seq = len(can_run_list)
|
1141
1145
|
f = (
|
@@ -1172,12 +1176,15 @@ class Scheduler(
|
|
1172
1176
|
self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq
|
1173
1177
|
|
1174
1178
|
self.metrics_collector.log_stats(self.stats)
|
1179
|
+
self._publish_kv_events()
|
1175
1180
|
|
1176
|
-
def log_decode_stats(
|
1181
|
+
def log_decode_stats(
|
1182
|
+
self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
|
1183
|
+
):
|
1177
1184
|
batch = running_batch or self.running_batch
|
1178
1185
|
|
1179
|
-
gap_latency = time.
|
1180
|
-
self.last_decode_stats_tic = time.
|
1186
|
+
gap_latency = time.perf_counter() - self.last_decode_stats_tic
|
1187
|
+
self.last_decode_stats_tic = time.perf_counter()
|
1181
1188
|
self.last_gen_throughput = self.num_generated_tokens / gap_latency
|
1182
1189
|
self.num_generated_tokens = 0
|
1183
1190
|
num_running_reqs = len(batch.reqs)
|
@@ -1213,6 +1220,7 @@ class Scheduler(
|
|
1213
1220
|
msg += f"pre-allocated usage: {self.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
|
1214
1221
|
|
1215
1222
|
msg += (
|
1223
|
+
f"cuda graph: {can_run_cuda_graph}, "
|
1216
1224
|
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
|
1217
1225
|
f"#queue-req: {len(self.waiting_queue)}"
|
1218
1226
|
)
|
@@ -1225,8 +1233,10 @@ class Scheduler(
|
|
1225
1233
|
self.stats.cache_hit_rate = 0.0
|
1226
1234
|
self.stats.gen_throughput = self.last_gen_throughput
|
1227
1235
|
self.stats.num_queue_reqs = len(self.waiting_queue)
|
1236
|
+
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
1228
1237
|
self.stats.spec_accept_length = spec_accept_length
|
1229
1238
|
self.metrics_collector.log_stats(self.stats)
|
1239
|
+
self._publish_kv_events()
|
1230
1240
|
|
1231
1241
|
def check_memory(self):
|
1232
1242
|
available_size = (
|
@@ -1246,9 +1256,7 @@ class Scheduler(
|
|
1246
1256
|
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
|
1247
1257
|
f"{self.tree_cache.evictable_size()=}\n"
|
1248
1258
|
)
|
1249
|
-
|
1250
|
-
if crash_on_warnings():
|
1251
|
-
raise ValueError(msg)
|
1259
|
+
raise ValueError(msg)
|
1252
1260
|
|
1253
1261
|
if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
|
1254
1262
|
msg = (
|
@@ -1256,14 +1264,12 @@ class Scheduler(
|
|
1256
1264
|
f"available_size={len(self.req_to_token_pool.free_slots)}, "
|
1257
1265
|
f"total_size={self.req_to_token_pool.size}\n"
|
1258
1266
|
)
|
1259
|
-
|
1260
|
-
if crash_on_warnings():
|
1261
|
-
raise ValueError(msg)
|
1267
|
+
raise ValueError(msg)
|
1262
1268
|
|
1263
1269
|
if (
|
1264
1270
|
self.enable_metrics
|
1265
1271
|
and self.attn_tp_rank == 0
|
1266
|
-
and time.
|
1272
|
+
and time.perf_counter() > self.metrics_collector.last_log_time + 30
|
1267
1273
|
):
|
1268
1274
|
# During idle time, also collect metrics every 30 seconds.
|
1269
1275
|
num_used = self.max_total_num_tokens - (
|
@@ -1276,7 +1282,9 @@ class Scheduler(
|
|
1276
1282
|
self.stats.token_usage = num_used / self.max_total_num_tokens
|
1277
1283
|
self.stats.gen_throughput = 0
|
1278
1284
|
self.stats.num_queue_reqs = len(self.waiting_queue)
|
1285
|
+
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
1279
1286
|
self.metrics_collector.log_stats(self.stats)
|
1287
|
+
self._publish_kv_events()
|
1280
1288
|
|
1281
1289
|
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
|
1282
1290
|
# Merge the prefill batch into the running batch
|
@@ -1346,7 +1354,7 @@ class Scheduler(
|
|
1346
1354
|
return None
|
1347
1355
|
|
1348
1356
|
running_bs = len(self.running_batch.reqs)
|
1349
|
-
#
|
1357
|
+
# Ignore the check if self.chunked_req is not None.
|
1350
1358
|
# In the non-PP case, when self.chunked_req is not None, num_allocatable_reqs should always be greater than 0,
|
1351
1359
|
# as the space for the chunked request has just been released.
|
1352
1360
|
# In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
|
@@ -1399,6 +1407,13 @@ class Scheduler(
|
|
1399
1407
|
self.running_batch.batch_is_full = True
|
1400
1408
|
break
|
1401
1409
|
|
1410
|
+
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
1411
|
+
# In prefill mode, prealloc queue and transfer queue can also take memory,
|
1412
|
+
# so we need to check if the available size for the actual available size.
|
1413
|
+
if len(adder.can_run_list) >= self.req_to_token_pool.available_size():
|
1414
|
+
self.running_batch.batch_is_full = True
|
1415
|
+
break
|
1416
|
+
|
1402
1417
|
req.init_next_round_input(
|
1403
1418
|
None if prefix_computed else self.tree_cache,
|
1404
1419
|
self.enable_hierarchical_cache,
|
@@ -1427,7 +1442,7 @@ class Scheduler(
|
|
1427
1442
|
if self.enable_metrics:
|
1428
1443
|
# only record queue time when enable_metrics is True to avoid overhead
|
1429
1444
|
for req in can_run_list:
|
1430
|
-
req.queue_time_end = time.
|
1445
|
+
req.queue_time_end = time.perf_counter()
|
1431
1446
|
|
1432
1447
|
self.waiting_queue = [
|
1433
1448
|
x for x in self.waiting_queue if x not in set(can_run_list)
|
@@ -1529,7 +1544,7 @@ class Scheduler(
|
|
1529
1544
|
self.profiler_target_forward_ct
|
1530
1545
|
and self.profiler_target_forward_ct <= self.forward_ct
|
1531
1546
|
):
|
1532
|
-
self.stop_profile()
|
1547
|
+
self.send_to_tokenizer.send_pyobj(self.stop_profile())
|
1533
1548
|
|
1534
1549
|
if self.forward_sleep_time is not None:
|
1535
1550
|
logger.info(f"Scheduler.run_batch sleep {self.forward_sleep_time}s")
|
@@ -1540,11 +1555,11 @@ class Scheduler(
|
|
1540
1555
|
if self.spec_algorithm.is_none():
|
1541
1556
|
model_worker_batch = batch.get_model_worker_batch()
|
1542
1557
|
if self.pp_group.is_last_rank:
|
1543
|
-
logits_output, next_token_ids = (
|
1558
|
+
logits_output, next_token_ids, can_run_cuda_graph = (
|
1544
1559
|
self.tp_worker.forward_batch_generation(model_worker_batch)
|
1545
1560
|
)
|
1546
1561
|
else:
|
1547
|
-
pp_hidden_states_proxy_tensors, _ = (
|
1562
|
+
pp_hidden_states_proxy_tensors, _, can_run_cuda_graph = (
|
1548
1563
|
self.tp_worker.forward_batch_generation(model_worker_batch)
|
1549
1564
|
)
|
1550
1565
|
bid = model_worker_batch.bid
|
@@ -1554,6 +1569,7 @@ class Scheduler(
|
|
1554
1569
|
next_token_ids,
|
1555
1570
|
bid,
|
1556
1571
|
num_accepted_tokens,
|
1572
|
+
can_run_cuda_graph,
|
1557
1573
|
) = self.draft_worker.forward_batch_speculative_generation(batch)
|
1558
1574
|
self.spec_num_total_accepted_tokens += (
|
1559
1575
|
num_accepted_tokens + batch.batch_size()
|
@@ -1587,6 +1603,7 @@ class Scheduler(
|
|
1587
1603
|
extend_input_len_per_req=extend_input_len_per_req,
|
1588
1604
|
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
|
1589
1605
|
bid=bid,
|
1606
|
+
can_run_cuda_graph=can_run_cuda_graph,
|
1590
1607
|
)
|
1591
1608
|
else: # embedding or reward model
|
1592
1609
|
model_worker_batch = batch.get_model_worker_batch()
|
@@ -1609,14 +1626,9 @@ class Scheduler(
|
|
1609
1626
|
elif batch.forward_mode.is_idle():
|
1610
1627
|
if self.enable_overlap:
|
1611
1628
|
self.tp_worker.resolve_last_batch_result(launch_done)
|
1612
|
-
|
1613
|
-
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
1614
|
-
self.current_stream.synchronize()
|
1615
|
-
batch.next_batch_sampling_info.sampling_info_done.set()
|
1629
|
+
self.set_next_batch_sampling_info_done(batch)
|
1616
1630
|
elif batch.forward_mode.is_dummy_first():
|
1617
|
-
|
1618
|
-
self.current_stream.synchronize()
|
1619
|
-
batch.next_batch_sampling_info.sampling_info_done.set()
|
1631
|
+
self.set_next_batch_sampling_info_done(batch)
|
1620
1632
|
|
1621
1633
|
if self.return_health_check_ct:
|
1622
1634
|
# Return some signal for the health check.
|
@@ -1630,6 +1642,7 @@ class Scheduler(
|
|
1630
1642
|
local_batch,
|
1631
1643
|
dp_size=self.server_args.dp_size,
|
1632
1644
|
attn_tp_size=self.attn_tp_size,
|
1645
|
+
moe_dense_tp_size=self.server_args.moe_dense_tp_size,
|
1633
1646
|
tp_cpu_group=self.tp_cpu_group,
|
1634
1647
|
get_idle_batch=self.get_idle_batch,
|
1635
1648
|
disable_cuda_graph=self.server_args.disable_cuda_graph,
|
@@ -1642,6 +1655,7 @@ class Scheduler(
|
|
1642
1655
|
local_batch: ScheduleBatch,
|
1643
1656
|
dp_size,
|
1644
1657
|
attn_tp_size: int,
|
1658
|
+
moe_dense_tp_size: Optional[int],
|
1645
1659
|
tp_cpu_group,
|
1646
1660
|
get_idle_batch,
|
1647
1661
|
disable_cuda_graph: bool,
|
@@ -1651,15 +1665,15 @@ class Scheduler(
|
|
1651
1665
|
# Check if other DP workers have running batches
|
1652
1666
|
if local_batch is None:
|
1653
1667
|
num_tokens = 0
|
1654
|
-
|
1668
|
+
num_tokens_for_logprob = 0
|
1655
1669
|
elif local_batch.forward_mode.is_decode():
|
1656
1670
|
num_tokens = local_batch.batch_size()
|
1657
1671
|
if not spec_algorithm.is_none() and spec_algorithm.is_eagle():
|
1658
1672
|
num_tokens = num_tokens * speculative_num_draft_tokens
|
1659
|
-
|
1673
|
+
num_tokens_for_logprob = num_tokens
|
1660
1674
|
else:
|
1661
1675
|
num_tokens = local_batch.extend_num_tokens
|
1662
|
-
|
1676
|
+
num_tokens_for_logprob = sum(
|
1663
1677
|
[
|
1664
1678
|
# We should have at least 1 token for sample in every case.
|
1665
1679
|
max(extend_len - logprob_start_len, 1)
|
@@ -1686,7 +1700,7 @@ class Scheduler(
|
|
1686
1700
|
[
|
1687
1701
|
num_tokens,
|
1688
1702
|
can_cuda_graph,
|
1689
|
-
|
1703
|
+
num_tokens_for_logprob,
|
1690
1704
|
is_extend_in_batch,
|
1691
1705
|
],
|
1692
1706
|
dtype=torch.int64,
|
@@ -1709,8 +1723,15 @@ class Scheduler(
|
|
1709
1723
|
local_batch = get_idle_batch()
|
1710
1724
|
|
1711
1725
|
if local_batch is not None:
|
1712
|
-
|
1713
|
-
|
1726
|
+
# TODO: handle the case when moe_dense_tp_size != 1
|
1727
|
+
if moe_dense_tp_size == 1 and global_server_args_dict["enable_dp_lm_head"]:
|
1728
|
+
local_batch.global_num_tokens = [num_tokens]
|
1729
|
+
local_batch.global_num_tokens_for_logprob = [num_tokens_for_logprob]
|
1730
|
+
else:
|
1731
|
+
local_batch.global_num_tokens = global_num_tokens
|
1732
|
+
local_batch.global_num_tokens_for_logprob = (
|
1733
|
+
global_num_tokens_for_logprob
|
1734
|
+
)
|
1714
1735
|
|
1715
1736
|
# Check forward mode for cuda graph
|
1716
1737
|
if not disable_cuda_graph:
|
@@ -1736,11 +1757,17 @@ class Scheduler(
|
|
1736
1757
|
"""Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
|
1737
1758
|
|
1738
1759
|
num_ready_reqs = 0
|
1760
|
+
num_abort_reqs = 0
|
1739
1761
|
for req in self.grammar_queue:
|
1740
1762
|
try:
|
1741
|
-
req.grammar = req.grammar.result(timeout=0.
|
1763
|
+
req.grammar = req.grammar.result(timeout=0.03)
|
1764
|
+
if req.grammar:
|
1765
|
+
self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
|
1742
1766
|
num_ready_reqs += 1
|
1743
1767
|
except futures._base.TimeoutError:
|
1768
|
+
req.grammar_wait_ct += 1
|
1769
|
+
if req.grammar_wait_ct > GRAMMAR_TIMEOUT / 0.03:
|
1770
|
+
num_abort_reqs = 1
|
1744
1771
|
break
|
1745
1772
|
|
1746
1773
|
if self.server_args.enable_dp_attention:
|
@@ -1752,46 +1779,70 @@ class Scheduler(
|
|
1752
1779
|
|
1753
1780
|
if tp_size > 1:
|
1754
1781
|
# Sync across TP ranks to make sure they have the same number of ready requests
|
1755
|
-
tensor = torch.tensor(num_ready_reqs, dtype=torch.int32)
|
1782
|
+
tensor = torch.tensor([num_ready_reqs, num_abort_reqs], dtype=torch.int32)
|
1756
1783
|
torch.distributed.all_reduce(
|
1757
1784
|
tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group
|
1758
1785
|
)
|
1759
|
-
num_ready_reqs_max = tensor.
|
1786
|
+
num_ready_reqs_max, num_abort_reqs_max = tensor.tolist()
|
1787
|
+
|
1760
1788
|
for i in range(num_ready_reqs, num_ready_reqs_max):
|
1761
|
-
|
1762
|
-
|
1789
|
+
req = self.grammar_queue[i]
|
1790
|
+
req.grammar = req.grammar.result()
|
1791
|
+
if req.grammar:
|
1792
|
+
self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
|
1793
|
+
|
1794
|
+
for i in range(num_ready_reqs, num_ready_reqs + num_abort_reqs_max):
|
1795
|
+
req = self.grammar_queue[i]
|
1796
|
+
req.grammar.cancel()
|
1797
|
+
req.grammar = None
|
1798
|
+
error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}"
|
1799
|
+
logger.error(error_msg)
|
1800
|
+
req.finished_reason = FINISH_ABORT(
|
1801
|
+
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
|
1802
|
+
)
|
1803
|
+
num_ready_reqs = num_ready_reqs_max + num_abort_reqs_max
|
1763
1804
|
|
1764
1805
|
self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
|
1765
1806
|
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
|
1766
1807
|
|
1808
|
+
def set_next_batch_sampling_info_done(self, batch: ScheduleBatch):
|
1809
|
+
if batch.next_batch_sampling_info:
|
1810
|
+
if batch.next_batch_sampling_info.grammars is not None:
|
1811
|
+
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
1812
|
+
self.current_stream.synchronize()
|
1813
|
+
batch.next_batch_sampling_info.sampling_info_done.set()
|
1814
|
+
|
1767
1815
|
def watchdog_thread(self):
|
1768
1816
|
"""A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
|
1769
1817
|
self.watchdog_last_forward_ct = 0
|
1770
|
-
self.watchdog_last_time = time.
|
1818
|
+
self.watchdog_last_time = time.perf_counter()
|
1771
1819
|
|
1772
1820
|
while True:
|
1773
|
-
current = time.
|
1821
|
+
current = time.perf_counter()
|
1774
1822
|
if self.cur_batch is not None:
|
1775
1823
|
if self.watchdog_last_forward_ct == self.forward_ct:
|
1776
1824
|
if current > self.watchdog_last_time + self.watchdog_timeout:
|
1777
|
-
logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
|
1778
1825
|
break
|
1779
1826
|
else:
|
1780
1827
|
self.watchdog_last_forward_ct = self.forward_ct
|
1781
1828
|
self.watchdog_last_time = current
|
1782
1829
|
time.sleep(self.watchdog_timeout // 2)
|
1783
1830
|
|
1784
|
-
|
1785
|
-
|
1786
|
-
|
1787
|
-
|
1788
|
-
|
1789
|
-
|
1790
|
-
|
1791
|
-
|
1831
|
+
if not disable_request_logging():
|
1832
|
+
# Print batch size and memory pool info to check whether there are de-sync issues.
|
1833
|
+
logger.error(
|
1834
|
+
f"{self.cur_batch.batch_size()=}, "
|
1835
|
+
f"{self.cur_batch.reqs=}, "
|
1836
|
+
f"{self.token_to_kv_pool_allocator.available_size()=}, "
|
1837
|
+
f"{self.tree_cache.evictable_size()=}, "
|
1838
|
+
)
|
1839
|
+
|
1792
1840
|
pyspy_dump_schedulers()
|
1841
|
+
logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
|
1793
1842
|
print(file=sys.stderr, flush=True)
|
1794
1843
|
print(file=sys.stdout, flush=True)
|
1844
|
+
|
1845
|
+
# Wait for some time so that the parent process can print the error.
|
1795
1846
|
time.sleep(5)
|
1796
1847
|
self.parent_process.send_signal(signal.SIGQUIT)
|
1797
1848
|
|
@@ -1923,25 +1974,30 @@ class Scheduler(
|
|
1923
1974
|
)
|
1924
1975
|
|
1925
1976
|
def abort_request(self, recv_req: AbortReq):
|
1977
|
+
# TODO(lmzheng): abort the requests in the grammar queue.
|
1978
|
+
|
1926
1979
|
# Delete requests in the waiting queue
|
1927
1980
|
to_del = []
|
1928
1981
|
for i, req in enumerate(self.waiting_queue):
|
1929
1982
|
if req.rid.startswith(recv_req.rid):
|
1930
1983
|
to_del.append(i)
|
1931
|
-
break
|
1932
1984
|
|
1933
1985
|
# Sort in reverse order to avoid index issues when deleting
|
1934
|
-
for i in
|
1986
|
+
for i in reversed(to_del):
|
1935
1987
|
req = self.waiting_queue.pop(i)
|
1988
|
+
self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
|
1936
1989
|
logger.debug(f"Abort queued request. {req.rid=}")
|
1937
|
-
return
|
1938
1990
|
|
1939
1991
|
# Delete requests in the running batch
|
1940
|
-
|
1992
|
+
if self.cur_batch is self.running_batch or self.cur_batch is None:
|
1993
|
+
reqs = self.running_batch.reqs
|
1994
|
+
else:
|
1995
|
+
reqs = self.running_batch.reqs + self.cur_batch.reqs
|
1996
|
+
|
1997
|
+
for req in reqs:
|
1941
1998
|
if req.rid.startswith(recv_req.rid) and not req.finished():
|
1942
1999
|
logger.debug(f"Abort running request. {req.rid=}")
|
1943
2000
|
req.to_abort = True
|
1944
|
-
return
|
1945
2001
|
|
1946
2002
|
def _pause_engine(self) -> Tuple[List[Req], int]:
|
1947
2003
|
raise NotImplementedError()
|
@@ -2090,7 +2146,10 @@ class Scheduler(
|
|
2090
2146
|
|
2091
2147
|
def stop_profile(self) -> None:
|
2092
2148
|
if self.profiler_activities is None:
|
2093
|
-
return
|
2149
|
+
return ProfileReqOutput(
|
2150
|
+
success=False,
|
2151
|
+
message="Profiling is not in progress. Call /start_profile first.",
|
2152
|
+
)
|
2094
2153
|
|
2095
2154
|
logger.info("Stop profiling...")
|
2096
2155
|
if self.torch_profiler is not None:
|
@@ -2121,18 +2180,15 @@ class Scheduler(
|
|
2121
2180
|
self.torch_profiler_output_dir = None
|
2122
2181
|
self.profiler_activities = None
|
2123
2182
|
|
2124
|
-
|
2125
|
-
self.send_to_tokenizer.send_pyobj(
|
2126
|
-
ProfileReqOutput(success=True, message="Succeeded.")
|
2127
|
-
)
|
2183
|
+
return ProfileReqOutput(success=True, message="Succeeded")
|
2128
2184
|
|
2129
2185
|
def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
|
2130
2186
|
if recv_req == ExpertDistributionReq.START_RECORD:
|
2131
|
-
|
2187
|
+
get_global_expert_distribution_recorder().start_record()
|
2132
2188
|
elif recv_req == ExpertDistributionReq.STOP_RECORD:
|
2133
|
-
|
2189
|
+
get_global_expert_distribution_recorder().stop_record()
|
2134
2190
|
elif recv_req == ExpertDistributionReq.DUMP_RECORD:
|
2135
|
-
|
2191
|
+
get_global_expert_distribution_recorder().dump_record()
|
2136
2192
|
else:
|
2137
2193
|
raise ValueError("Unrecognized ExpertDistributionReq value")
|
2138
2194
|
return ExpertDistributionReqOutput()
|
@@ -2162,14 +2218,21 @@ class Scheduler(
|
|
2162
2218
|
|
2163
2219
|
def get_print_prefix(self):
|
2164
2220
|
prefix = ""
|
2165
|
-
if self.
|
2166
|
-
prefix += f" DP{self.
|
2221
|
+
if self.attn_dp_rank is not None:
|
2222
|
+
prefix += f" DP{self.attn_dp_rank}"
|
2167
2223
|
if self.server_args.tp_size > 1:
|
2168
2224
|
prefix += f" TP{self.tp_rank}"
|
2169
2225
|
if self.pp_size > 1:
|
2170
2226
|
prefix += f" PP{self.pp_rank}"
|
2171
2227
|
return prefix
|
2172
2228
|
|
2229
|
+
def _publish_kv_events(self):
|
2230
|
+
if self.enable_kv_cache_events:
|
2231
|
+
events = self.tree_cache.take_events()
|
2232
|
+
if events:
|
2233
|
+
batch = KVEventBatch(ts=time.time(), events=events)
|
2234
|
+
self.kv_event_publisher.publish(batch)
|
2235
|
+
|
2173
2236
|
|
2174
2237
|
def is_health_check_generate_req(recv_req):
|
2175
2238
|
return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
|
@@ -2225,6 +2288,10 @@ def run_scheduler_process(
|
|
2225
2288
|
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
|
2226
2289
|
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
|
2227
2290
|
|
2291
|
+
embedding_cache_size = 100
|
2292
|
+
if "SGLANG_VLM_CACHE_SIZE_MB" in os.environ:
|
2293
|
+
embedding_cache_size = int(os.environ["SGLANG_VLM_CACHE_SIZE_MB"])
|
2294
|
+
init_embedding_cache(embedding_cache_size * 1024 * 1024)
|
2228
2295
|
# Create a scheduler and run the event loop
|
2229
2296
|
try:
|
2230
2297
|
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
|