sglang 0.4.6.post4__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +6 -6
- sglang/bench_one_batch.py +5 -4
- sglang/bench_one_batch_server.py +23 -15
- sglang/bench_serving.py +133 -57
- sglang/compile_deep_gemm.py +4 -4
- sglang/srt/configs/model_config.py +39 -28
- sglang/srt/conversation.py +1 -1
- sglang/srt/disaggregation/decode.py +122 -133
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
- sglang/srt/disaggregation/fake/conn.py +3 -13
- sglang/srt/disaggregation/kv_events.py +357 -0
- sglang/srt/disaggregation/mini_lb.py +57 -24
- sglang/srt/disaggregation/mooncake/conn.py +11 -2
- sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
- sglang/srt/disaggregation/nixl/conn.py +9 -19
- sglang/srt/disaggregation/prefill.py +126 -44
- sglang/srt/disaggregation/utils.py +116 -5
- sglang/srt/distributed/utils.py +3 -3
- sglang/srt/entrypoints/EngineBase.py +5 -0
- sglang/srt/entrypoints/engine.py +28 -8
- sglang/srt/entrypoints/http_server.py +6 -4
- sglang/srt/entrypoints/http_server_engine.py +5 -2
- sglang/srt/function_call/base_format_detector.py +250 -0
- sglang/srt/function_call/core_types.py +34 -0
- sglang/srt/function_call/deepseekv3_detector.py +157 -0
- sglang/srt/function_call/ebnf_composer.py +234 -0
- sglang/srt/function_call/function_call_parser.py +175 -0
- sglang/srt/function_call/llama32_detector.py +74 -0
- sglang/srt/function_call/mistral_detector.py +84 -0
- sglang/srt/function_call/pythonic_detector.py +163 -0
- sglang/srt/function_call/qwen25_detector.py +67 -0
- sglang/srt/function_call/utils.py +35 -0
- sglang/srt/hf_transformers_utils.py +46 -7
- sglang/srt/layers/attention/aiter_backend.py +513 -0
- sglang/srt/layers/attention/flashattention_backend.py +63 -17
- sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
- sglang/srt/layers/attention/flashmla_backend.py +340 -78
- sglang/srt/layers/attention/triton_backend.py +3 -0
- sglang/srt/layers/attention/utils.py +2 -2
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/communicator.py +451 -0
- sglang/srt/layers/dp_attention.py +0 -10
- sglang/srt/layers/moe/cutlass_moe.py +207 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +33 -11
- sglang/srt/layers/moe/ep_moe/layer.py +104 -50
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
- sglang/srt/layers/moe/topk.py +66 -9
- sglang/srt/layers/multimodal.py +70 -0
- sglang/srt/layers/quantization/__init__.py +7 -2
- sglang/srt/layers/quantization/deep_gemm.py +5 -3
- sglang/srt/layers/quantization/fp8.py +90 -0
- sglang/srt/layers/quantization/fp8_utils.py +6 -0
- sglang/srt/layers/quantization/gptq.py +298 -6
- sglang/srt/layers/quantization/int8_kernel.py +18 -5
- sglang/srt/layers/quantization/qoq.py +244 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/deepseek_eplb.py +278 -0
- sglang/srt/managers/eplb_manager.py +55 -0
- sglang/srt/managers/expert_distribution.py +704 -56
- sglang/srt/managers/expert_location.py +394 -0
- sglang/srt/managers/expert_location_dispatch.py +91 -0
- sglang/srt/managers/io_struct.py +16 -3
- sglang/srt/managers/mm_utils.py +293 -139
- sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
- sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
- sglang/srt/managers/multimodal_processors/internvl.py +14 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
- sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
- sglang/srt/managers/multimodal_processors/llava.py +3 -3
- sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
- sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +9 -9
- sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
- sglang/srt/managers/schedule_batch.py +49 -21
- sglang/srt/managers/schedule_policy.py +4 -5
- sglang/srt/managers/scheduler.py +92 -50
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +99 -24
- sglang/srt/mem_cache/base_prefix_cache.py +3 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -1
- sglang/srt/mem_cache/hiradix_cache.py +4 -4
- sglang/srt/mem_cache/memory_pool.py +74 -52
- sglang/srt/mem_cache/multimodal_cache.py +45 -0
- sglang/srt/mem_cache/radix_cache.py +58 -5
- sglang/srt/metrics/collector.py +2 -2
- sglang/srt/mm_utils.py +10 -0
- sglang/srt/model_executor/cuda_graph_runner.py +20 -9
- sglang/srt/model_executor/expert_location_updater.py +422 -0
- sglang/srt/model_executor/forward_batch_info.py +4 -0
- sglang/srt/model_executor/model_runner.py +144 -54
- sglang/srt/model_loader/loader.py +10 -6
- sglang/srt/models/clip.py +5 -1
- sglang/srt/models/deepseek_v2.py +297 -343
- sglang/srt/models/exaone.py +8 -3
- sglang/srt/models/gemma3_mm.py +70 -33
- sglang/srt/models/llama4.py +10 -2
- sglang/srt/models/llava.py +26 -18
- sglang/srt/models/mimo_mtp.py +220 -0
- sglang/srt/models/minicpmo.py +5 -12
- sglang/srt/models/mistral.py +71 -1
- sglang/srt/models/mllama.py +3 -3
- sglang/srt/models/qwen2.py +95 -26
- sglang/srt/models/qwen2_5_vl.py +8 -0
- sglang/srt/models/qwen2_moe.py +330 -60
- sglang/srt/models/qwen2_vl.py +6 -0
- sglang/srt/models/qwen3.py +52 -10
- sglang/srt/models/qwen3_moe.py +411 -48
- sglang/srt/models/siglip.py +294 -0
- sglang/srt/openai_api/adapter.py +28 -16
- sglang/srt/openai_api/protocol.py +6 -0
- sglang/srt/operations.py +154 -0
- sglang/srt/operations_strategy.py +31 -0
- sglang/srt/server_args.py +134 -24
- sglang/srt/speculative/eagle_utils.py +131 -0
- sglang/srt/speculative/eagle_worker.py +47 -2
- sglang/srt/utils.py +68 -12
- sglang/test/test_cutlass_moe.py +278 -0
- sglang/test/test_utils.py +2 -36
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +20 -11
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +128 -102
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
- sglang/srt/function_call_parser.py +0 -858
- sglang/srt/platforms/interface.py +0 -371
- /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -41,14 +41,17 @@ from sglang.srt.disaggregation.decode import (
|
|
41
41
|
DecodeTransferQueue,
|
42
42
|
SchedulerDisaggregationDecodeMixin,
|
43
43
|
)
|
44
|
+
from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
|
44
45
|
from sglang.srt.disaggregation.prefill import (
|
45
46
|
PrefillBootstrapQueue,
|
46
47
|
SchedulerDisaggregationPrefillMixin,
|
47
48
|
)
|
48
49
|
from sglang.srt.disaggregation.utils import (
|
49
50
|
DisaggregationMode,
|
51
|
+
MetadataBuffers,
|
50
52
|
ReqToMetadataIdxAllocator,
|
51
53
|
TransferBackend,
|
54
|
+
prepare_abort,
|
52
55
|
)
|
53
56
|
from sglang.srt.distributed import get_pp_group, get_world_group
|
54
57
|
from sglang.srt.hf_transformers_utils import (
|
@@ -58,7 +61,10 @@ from sglang.srt.hf_transformers_utils import (
|
|
58
61
|
)
|
59
62
|
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
60
63
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
61
|
-
from sglang.srt.managers.expert_distribution import
|
64
|
+
from sglang.srt.managers.expert_distribution import (
|
65
|
+
ExpertDistributionRecorder,
|
66
|
+
get_global_expert_distribution_recorder,
|
67
|
+
)
|
62
68
|
from sglang.srt.managers.io_struct import (
|
63
69
|
AbortReq,
|
64
70
|
CloseSessionReqInput,
|
@@ -97,6 +103,7 @@ from sglang.srt.managers.io_struct import (
|
|
97
103
|
UpdateWeightsFromTensorReqInput,
|
98
104
|
UpdateWeightsFromTensorReqOutput,
|
99
105
|
)
|
106
|
+
from sglang.srt.managers.mm_utils import init_embedding_cache
|
100
107
|
from sglang.srt.managers.schedule_batch import (
|
101
108
|
FINISH_ABORT,
|
102
109
|
MultimodalInputs,
|
@@ -129,7 +136,6 @@ from sglang.srt.utils import (
|
|
129
136
|
DynamicGradMode,
|
130
137
|
broadcast_pyobj,
|
131
138
|
configure_logger,
|
132
|
-
crash_on_warnings,
|
133
139
|
disable_request_logging,
|
134
140
|
get_bool_env_var,
|
135
141
|
get_zmq_socket,
|
@@ -142,8 +148,6 @@ from sglang.srt.utils import (
|
|
142
148
|
)
|
143
149
|
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
|
144
150
|
|
145
|
-
expert_distribution_recorder = ExpertDistributionRecorder()
|
146
|
-
|
147
151
|
logger = logging.getLogger(__name__)
|
148
152
|
|
149
153
|
# Test retract decode for debugging purposes
|
@@ -198,6 +202,7 @@ class Scheduler(
|
|
198
202
|
self.enable_overlap = not server_args.disable_overlap_schedule
|
199
203
|
self.skip_tokenizer_init = server_args.skip_tokenizer_init
|
200
204
|
self.enable_metrics = server_args.enable_metrics
|
205
|
+
self.enable_kv_cache_events = server_args.kv_events_config is not None
|
201
206
|
self.stream_interval = server_args.stream_interval
|
202
207
|
self.spec_algorithm = SpeculativeAlgorithm.from_string(
|
203
208
|
server_args.speculative_algorithm
|
@@ -205,7 +210,6 @@ class Scheduler(
|
|
205
210
|
self.gpu_id = gpu_id
|
206
211
|
self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
|
207
212
|
self.page_size = server_args.page_size
|
208
|
-
|
209
213
|
# Distributed rank info
|
210
214
|
self.dp_size = server_args.dp_size
|
211
215
|
self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
|
@@ -349,8 +353,8 @@ class Scheduler(
|
|
349
353
|
self.forward_ct_decode = 0
|
350
354
|
self.num_generated_tokens = 0
|
351
355
|
self.num_prefill_tokens = 0
|
352
|
-
self.last_decode_stats_tic = time.
|
353
|
-
self.last_prefill_stats_tic = time.
|
356
|
+
self.last_decode_stats_tic = time.perf_counter()
|
357
|
+
self.last_prefill_stats_tic = time.perf_counter()
|
354
358
|
self.return_health_check_ct = 0
|
355
359
|
self.current_stream = torch.get_device_module(self.device).current_stream()
|
356
360
|
if self.device == "cpu":
|
@@ -423,6 +427,7 @@ class Scheduler(
|
|
423
427
|
|
424
428
|
# Init metrics stats
|
425
429
|
self.init_metrics()
|
430
|
+
self.init_kv_events(server_args.kv_events_config)
|
426
431
|
|
427
432
|
# Init request dispatcher
|
428
433
|
self._request_dispatcher = TypeBasedDispatcher(
|
@@ -516,6 +521,7 @@ class Scheduler(
|
|
516
521
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
517
522
|
page_size=self.page_size,
|
518
523
|
disable=server_args.disable_radix_cache,
|
524
|
+
enable_kv_cache_events=self.enable_kv_cache_events,
|
519
525
|
)
|
520
526
|
|
521
527
|
self.decode_mem_cache_buf_multiplier = (
|
@@ -548,6 +554,10 @@ class Scheduler(
|
|
548
554
|
},
|
549
555
|
)
|
550
556
|
|
557
|
+
def init_kv_events(self, kv_events_config: Optional[str]):
|
558
|
+
if self.enable_kv_cache_events:
|
559
|
+
self.kv_event_publisher = EventPublisherFactory.create(kv_events_config)
|
560
|
+
|
551
561
|
def init_disaggregation(self):
|
552
562
|
self.transfer_backend = TransferBackend(
|
553
563
|
self.server_args.disaggregation_transfer_backend
|
@@ -560,29 +570,28 @@ class Scheduler(
|
|
560
570
|
req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
|
561
571
|
buffer_size
|
562
572
|
)
|
563
|
-
|
564
|
-
# A list of metadata buffers. The shape is (b, metadata_size) where
|
565
|
-
# b corresponds to a max running requests. The last shape * dtype.itemsize
|
566
|
-
# should be larger than 64 bytes to work with RDMA, so we pad it.
|
567
|
-
output_id_buffer = torch.zeros(
|
568
|
-
(buffer_size, 16), dtype=aux_dtype, device="cpu"
|
569
|
-
)
|
570
|
-
metadata_buffers = [output_id_buffer]
|
573
|
+
self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
|
571
574
|
|
572
575
|
# The decode requests polling kv cache
|
573
576
|
self.disagg_decode_transfer_queue = DecodeTransferQueue(
|
574
577
|
gloo_group=self.attn_tp_cpu_group,
|
575
578
|
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
576
|
-
metadata_buffers=
|
579
|
+
metadata_buffers=self.disagg_metadata_buffers,
|
580
|
+
scheduler=self,
|
581
|
+
tree_cache=self.tree_cache,
|
577
582
|
)
|
578
583
|
|
579
584
|
# The decode requests pending for pre-allocation
|
580
585
|
self.disagg_decode_prealloc_queue = DecodePreallocQueue(
|
581
586
|
req_to_token_pool=self.req_to_token_pool,
|
582
587
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
588
|
+
draft_token_to_kv_pool=(
|
589
|
+
None
|
590
|
+
if self.draft_worker is None
|
591
|
+
else self.draft_worker.model_runner.token_to_kv_pool
|
592
|
+
),
|
583
593
|
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
584
|
-
metadata_buffers=
|
585
|
-
aux_dtype=aux_dtype,
|
594
|
+
metadata_buffers=self.disagg_metadata_buffers,
|
586
595
|
scheduler=self,
|
587
596
|
transfer_queue=self.disagg_decode_transfer_queue,
|
588
597
|
tree_cache=self.tree_cache,
|
@@ -602,20 +611,17 @@ class Scheduler(
|
|
602
611
|
req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
|
603
612
|
buffer_size
|
604
613
|
)
|
605
|
-
|
606
|
-
# A list of metadata buffers. The shape is (b, metadata_size) where
|
607
|
-
# b corresponds to a max running requests. The last shape * dtype.itemsize
|
608
|
-
# should be larger than 64 bytes to work with RDMA, so we pad it.
|
609
|
-
output_id_buffer = torch.zeros(
|
610
|
-
(buffer_size, 16), dtype=aux_dtype, device="cpu"
|
611
|
-
)
|
612
|
-
metadata_buffers = [output_id_buffer]
|
614
|
+
self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
|
613
615
|
|
614
616
|
self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
|
615
617
|
token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
|
618
|
+
draft_token_to_kv_pool=(
|
619
|
+
None
|
620
|
+
if self.draft_worker is None
|
621
|
+
else self.draft_worker.model_runner.token_to_kv_pool
|
622
|
+
),
|
616
623
|
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
617
|
-
metadata_buffers=
|
618
|
-
aux_dtype=aux_dtype,
|
624
|
+
metadata_buffers=self.disagg_metadata_buffers,
|
619
625
|
tp_rank=self.tp_rank,
|
620
626
|
tp_size=self.tp_size,
|
621
627
|
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
|
@@ -928,6 +934,18 @@ class Scheduler(
|
|
928
934
|
)
|
929
935
|
req.tokenizer = self.tokenizer
|
930
936
|
|
937
|
+
if self.disaggregation_mode != DisaggregationMode.NULL:
|
938
|
+
# Invalid request for disaggregated mode
|
939
|
+
if recv_req.bootstrap_room is None:
|
940
|
+
error_message = (
|
941
|
+
f"Invalid request: Disaggregated request received without "
|
942
|
+
f"boostrap room id. {req.rid=}"
|
943
|
+
)
|
944
|
+
logger.error(error_message)
|
945
|
+
prepare_abort(req, error_message)
|
946
|
+
self.stream_output([req], req.return_logprob)
|
947
|
+
return
|
948
|
+
|
931
949
|
if (
|
932
950
|
recv_req.session_params is not None
|
933
951
|
and recv_req.session_params.id is not None
|
@@ -1033,13 +1051,13 @@ class Scheduler(
|
|
1033
1051
|
add_to_grammar_queue = True
|
1034
1052
|
|
1035
1053
|
if add_to_grammar_queue:
|
1036
|
-
req.queue_time_start = time.
|
1054
|
+
req.queue_time_start = time.perf_counter()
|
1037
1055
|
self.grammar_queue.append(req)
|
1038
1056
|
else:
|
1039
1057
|
self._add_request_to_queue(req)
|
1040
1058
|
|
1041
1059
|
def _add_request_to_queue(self, req: Req):
|
1042
|
-
req.queue_time_start = time.
|
1060
|
+
req.queue_time_start = time.perf_counter()
|
1043
1061
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
1044
1062
|
self.disagg_prefill_bootstrap_queue.add(req)
|
1045
1063
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
@@ -1047,8 +1065,11 @@ class Scheduler(
|
|
1047
1065
|
else:
|
1048
1066
|
self.waiting_queue.append(req)
|
1049
1067
|
|
1050
|
-
def _extend_requests_to_queue(self, reqs: List[Req]
|
1051
|
-
if self.disaggregation_mode == DisaggregationMode.
|
1068
|
+
def _extend_requests_to_queue(self, reqs: List[Req]):
|
1069
|
+
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
1070
|
+
self.disagg_prefill_bootstrap_queue.extend(reqs)
|
1071
|
+
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
1072
|
+
# If this is a decode server, we put the request to the decode pending prealloc queue
|
1052
1073
|
self.disagg_decode_prealloc_queue.extend(reqs)
|
1053
1074
|
else:
|
1054
1075
|
self.waiting_queue.extend(reqs)
|
@@ -1086,7 +1107,7 @@ class Scheduler(
|
|
1086
1107
|
req.finished_reason = FINISH_ABORT(
|
1087
1108
|
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
|
1088
1109
|
)
|
1089
|
-
req.queue_time_start = time.
|
1110
|
+
req.queue_time_start = time.perf_counter()
|
1090
1111
|
self.waiting_queue.append(req)
|
1091
1112
|
return
|
1092
1113
|
|
@@ -1110,8 +1131,8 @@ class Scheduler(
|
|
1110
1131
|
can_run_list: List[Req],
|
1111
1132
|
running_bs: int,
|
1112
1133
|
):
|
1113
|
-
gap_latency = time.
|
1114
|
-
self.last_prefill_stats_tic = time.
|
1134
|
+
gap_latency = time.perf_counter() - self.last_prefill_stats_tic
|
1135
|
+
self.last_prefill_stats_tic = time.perf_counter()
|
1115
1136
|
self.last_input_throughput = self.num_prefill_tokens / gap_latency
|
1116
1137
|
self.num_prefill_tokens = 0
|
1117
1138
|
|
@@ -1155,14 +1176,15 @@ class Scheduler(
|
|
1155
1176
|
self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq
|
1156
1177
|
|
1157
1178
|
self.metrics_collector.log_stats(self.stats)
|
1179
|
+
self._publish_kv_events()
|
1158
1180
|
|
1159
1181
|
def log_decode_stats(
|
1160
1182
|
self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
|
1161
1183
|
):
|
1162
1184
|
batch = running_batch or self.running_batch
|
1163
1185
|
|
1164
|
-
gap_latency = time.
|
1165
|
-
self.last_decode_stats_tic = time.
|
1186
|
+
gap_latency = time.perf_counter() - self.last_decode_stats_tic
|
1187
|
+
self.last_decode_stats_tic = time.perf_counter()
|
1166
1188
|
self.last_gen_throughput = self.num_generated_tokens / gap_latency
|
1167
1189
|
self.num_generated_tokens = 0
|
1168
1190
|
num_running_reqs = len(batch.reqs)
|
@@ -1214,6 +1236,7 @@ class Scheduler(
|
|
1214
1236
|
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
1215
1237
|
self.stats.spec_accept_length = spec_accept_length
|
1216
1238
|
self.metrics_collector.log_stats(self.stats)
|
1239
|
+
self._publish_kv_events()
|
1217
1240
|
|
1218
1241
|
def check_memory(self):
|
1219
1242
|
available_size = (
|
@@ -1246,7 +1269,7 @@ class Scheduler(
|
|
1246
1269
|
if (
|
1247
1270
|
self.enable_metrics
|
1248
1271
|
and self.attn_tp_rank == 0
|
1249
|
-
and time.
|
1272
|
+
and time.perf_counter() > self.metrics_collector.last_log_time + 30
|
1250
1273
|
):
|
1251
1274
|
# During idle time, also collect metrics every 30 seconds.
|
1252
1275
|
num_used = self.max_total_num_tokens - (
|
@@ -1261,6 +1284,7 @@ class Scheduler(
|
|
1261
1284
|
self.stats.num_queue_reqs = len(self.waiting_queue)
|
1262
1285
|
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
1263
1286
|
self.metrics_collector.log_stats(self.stats)
|
1287
|
+
self._publish_kv_events()
|
1264
1288
|
|
1265
1289
|
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
|
1266
1290
|
# Merge the prefill batch into the running batch
|
@@ -1383,6 +1407,13 @@ class Scheduler(
|
|
1383
1407
|
self.running_batch.batch_is_full = True
|
1384
1408
|
break
|
1385
1409
|
|
1410
|
+
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
1411
|
+
# In prefill mode, prealloc queue and transfer queue can also take memory,
|
1412
|
+
# so we need to check if the available size for the actual available size.
|
1413
|
+
if len(adder.can_run_list) >= self.req_to_token_pool.available_size():
|
1414
|
+
self.running_batch.batch_is_full = True
|
1415
|
+
break
|
1416
|
+
|
1386
1417
|
req.init_next_round_input(
|
1387
1418
|
None if prefix_computed else self.tree_cache,
|
1388
1419
|
self.enable_hierarchical_cache,
|
@@ -1411,7 +1442,7 @@ class Scheduler(
|
|
1411
1442
|
if self.enable_metrics:
|
1412
1443
|
# only record queue time when enable_metrics is True to avoid overhead
|
1413
1444
|
for req in can_run_list:
|
1414
|
-
req.queue_time_end = time.
|
1445
|
+
req.queue_time_end = time.perf_counter()
|
1415
1446
|
|
1416
1447
|
self.waiting_queue = [
|
1417
1448
|
x for x in self.waiting_queue if x not in set(can_run_list)
|
@@ -1513,7 +1544,7 @@ class Scheduler(
|
|
1513
1544
|
self.profiler_target_forward_ct
|
1514
1545
|
and self.profiler_target_forward_ct <= self.forward_ct
|
1515
1546
|
):
|
1516
|
-
self.stop_profile()
|
1547
|
+
self.send_to_tokenizer.send_pyobj(self.stop_profile())
|
1517
1548
|
|
1518
1549
|
if self.forward_sleep_time is not None:
|
1519
1550
|
logger.info(f"Scheduler.run_batch sleep {self.forward_sleep_time}s")
|
@@ -1784,10 +1815,10 @@ class Scheduler(
|
|
1784
1815
|
def watchdog_thread(self):
|
1785
1816
|
"""A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
|
1786
1817
|
self.watchdog_last_forward_ct = 0
|
1787
|
-
self.watchdog_last_time = time.
|
1818
|
+
self.watchdog_last_time = time.perf_counter()
|
1788
1819
|
|
1789
1820
|
while True:
|
1790
|
-
current = time.
|
1821
|
+
current = time.perf_counter()
|
1791
1822
|
if self.cur_batch is not None:
|
1792
1823
|
if self.watchdog_last_forward_ct == self.forward_ct:
|
1793
1824
|
if current > self.watchdog_last_time + self.watchdog_timeout:
|
@@ -2115,7 +2146,10 @@ class Scheduler(
|
|
2115
2146
|
|
2116
2147
|
def stop_profile(self) -> None:
|
2117
2148
|
if self.profiler_activities is None:
|
2118
|
-
return
|
2149
|
+
return ProfileReqOutput(
|
2150
|
+
success=False,
|
2151
|
+
message="Profiling is not in progress. Call /start_profile first.",
|
2152
|
+
)
|
2119
2153
|
|
2120
2154
|
logger.info("Stop profiling...")
|
2121
2155
|
if self.torch_profiler is not None:
|
@@ -2146,18 +2180,15 @@ class Scheduler(
|
|
2146
2180
|
self.torch_profiler_output_dir = None
|
2147
2181
|
self.profiler_activities = None
|
2148
2182
|
|
2149
|
-
|
2150
|
-
self.send_to_tokenizer.send_pyobj(
|
2151
|
-
ProfileReqOutput(success=True, message="Succeeded.")
|
2152
|
-
)
|
2183
|
+
return ProfileReqOutput(success=True, message="Succeeded")
|
2153
2184
|
|
2154
2185
|
def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
|
2155
2186
|
if recv_req == ExpertDistributionReq.START_RECORD:
|
2156
|
-
|
2187
|
+
get_global_expert_distribution_recorder().start_record()
|
2157
2188
|
elif recv_req == ExpertDistributionReq.STOP_RECORD:
|
2158
|
-
|
2189
|
+
get_global_expert_distribution_recorder().stop_record()
|
2159
2190
|
elif recv_req == ExpertDistributionReq.DUMP_RECORD:
|
2160
|
-
|
2191
|
+
get_global_expert_distribution_recorder().dump_record()
|
2161
2192
|
else:
|
2162
2193
|
raise ValueError("Unrecognized ExpertDistributionReq value")
|
2163
2194
|
return ExpertDistributionReqOutput()
|
@@ -2195,6 +2226,13 @@ class Scheduler(
|
|
2195
2226
|
prefix += f" PP{self.pp_rank}"
|
2196
2227
|
return prefix
|
2197
2228
|
|
2229
|
+
def _publish_kv_events(self):
|
2230
|
+
if self.enable_kv_cache_events:
|
2231
|
+
events = self.tree_cache.take_events()
|
2232
|
+
if events:
|
2233
|
+
batch = KVEventBatch(ts=time.time(), events=events)
|
2234
|
+
self.kv_event_publisher.publish(batch)
|
2235
|
+
|
2198
2236
|
|
2199
2237
|
def is_health_check_generate_req(recv_req):
|
2200
2238
|
return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
|
@@ -2250,6 +2288,10 @@ def run_scheduler_process(
|
|
2250
2288
|
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
|
2251
2289
|
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
|
2252
2290
|
|
2291
|
+
embedding_cache_size = 100
|
2292
|
+
if "SGLANG_VLM_CACHE_SIZE_MB" in os.environ:
|
2293
|
+
embedding_cache_size = int(os.environ["SGLANG_VLM_CACHE_SIZE_MB"])
|
2294
|
+
init_embedding_cache(embedding_cache_size * 1024 * 1024)
|
2253
2295
|
# Create a scheduler and run the event loop
|
2254
2296
|
try:
|
2255
2297
|
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
|
@@ -54,7 +54,7 @@ class SessionReqNode:
|
|
54
54
|
prefix += " -- " + self.childs[0].req.rid
|
55
55
|
ret = self.childs[0]._str_helper(prefix)
|
56
56
|
for child in self.childs[1:]:
|
57
|
-
prefix = " " * len(origin_prefix) + " \- " + child.req.rid
|
57
|
+
prefix = " " * len(origin_prefix) + r" \- " + child.req.rid
|
58
58
|
ret += child._str_helper(prefix)
|
59
59
|
return ret
|
60
60
|
|
@@ -16,6 +16,7 @@
|
|
16
16
|
import asyncio
|
17
17
|
import copy
|
18
18
|
import dataclasses
|
19
|
+
import json
|
19
20
|
import logging
|
20
21
|
import os
|
21
22
|
import pickle
|
@@ -90,6 +91,8 @@ from sglang.srt.managers.io_struct import (
|
|
90
91
|
ResumeMemoryOccupationReqInput,
|
91
92
|
ResumeMemoryOccupationReqOutput,
|
92
93
|
SessionParams,
|
94
|
+
SetInternalStateReq,
|
95
|
+
SetInternalStateReqOutput,
|
93
96
|
SlowDownReqInput,
|
94
97
|
SlowDownReqOutput,
|
95
98
|
TokenizedEmbeddingReqInput,
|
@@ -169,6 +172,11 @@ class TokenizerManager:
|
|
169
172
|
self.enable_metrics = server_args.enable_metrics
|
170
173
|
self.log_requests = server_args.log_requests
|
171
174
|
self.log_requests_level = server_args.log_requests_level
|
175
|
+
self.preferred_sampling_params = (
|
176
|
+
json.loads(server_args.preferred_sampling_params)
|
177
|
+
if server_args.preferred_sampling_params
|
178
|
+
else None
|
179
|
+
)
|
172
180
|
|
173
181
|
# Init inter-process communication
|
174
182
|
context = zmq.asyncio.Context(2)
|
@@ -228,6 +236,7 @@ class TokenizerManager:
|
|
228
236
|
# Store states
|
229
237
|
self.no_create_loop = False
|
230
238
|
self.rid_to_state: Dict[str, ReqState] = {}
|
239
|
+
self.health_check_failed = False
|
231
240
|
self.gracefully_exit = False
|
232
241
|
self.last_receive_tstamp = 0
|
233
242
|
self.dump_requests_folder = "" # By default do not dump
|
@@ -255,6 +264,10 @@ class TokenizerManager:
|
|
255
264
|
"model_name": self.server_args.served_model_name,
|
256
265
|
# TODO: Add lora name/path in the future,
|
257
266
|
},
|
267
|
+
bucket_time_to_first_token=self.server_args.bucket_time_to_first_token,
|
268
|
+
bucket_e2e_request_latency=self.server_args.bucket_e2e_request_latency,
|
269
|
+
bucket_inter_token_latency=self.server_args.bucket_inter_token_latency,
|
270
|
+
collect_tokens_histogram=self.server_args.collect_tokens_histogram,
|
258
271
|
)
|
259
272
|
|
260
273
|
# Communicators
|
@@ -282,12 +295,16 @@ class TokenizerManager:
|
|
282
295
|
self.flush_cache_communicator = _Communicator(
|
283
296
|
self.send_to_scheduler, server_args.dp_size
|
284
297
|
)
|
285
|
-
self.
|
298
|
+
self.profile_communicator = _Communicator(
|
286
299
|
self.send_to_scheduler, server_args.dp_size
|
287
300
|
)
|
301
|
+
self.health_check_communitcator = _Communicator(self.send_to_scheduler, 1)
|
288
302
|
self.get_internal_state_communicator = _Communicator(
|
289
303
|
self.send_to_scheduler, server_args.dp_size
|
290
304
|
)
|
305
|
+
self.set_internal_state_communicator = _Communicator(
|
306
|
+
self.send_to_scheduler, server_args.dp_size
|
307
|
+
)
|
291
308
|
self.expert_distribution_communicator = _Communicator(
|
292
309
|
self.send_to_scheduler, server_args.dp_size
|
293
310
|
)
|
@@ -343,12 +360,16 @@ class TokenizerManager:
|
|
343
360
|
),
|
344
361
|
(
|
345
362
|
ProfileReqOutput,
|
346
|
-
self.
|
363
|
+
self.profile_communicator.handle_recv,
|
347
364
|
),
|
348
365
|
(
|
349
366
|
GetInternalStateReqOutput,
|
350
367
|
self.get_internal_state_communicator.handle_recv,
|
351
368
|
),
|
369
|
+
(
|
370
|
+
SetInternalStateReqOutput,
|
371
|
+
self.set_internal_state_communicator.handle_recv,
|
372
|
+
),
|
352
373
|
(
|
353
374
|
ExpertDistributionReqOutput,
|
354
375
|
self.expert_distribution_communicator.handle_recv,
|
@@ -438,14 +459,16 @@ class TokenizerManager:
|
|
438
459
|
)
|
439
460
|
input_ids = self.tokenizer.encode(input_text)
|
440
461
|
|
441
|
-
image_inputs: Dict =
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
462
|
+
image_inputs: Optional[Dict] = None
|
463
|
+
if obj.contains_mm_input():
|
464
|
+
image_inputs = await self.mm_processor.process_mm_data_async(
|
465
|
+
image_data=obj.image_data,
|
466
|
+
input_text=input_text or input_ids,
|
467
|
+
request_obj=obj,
|
468
|
+
max_req_input_len=self.max_req_input_len,
|
469
|
+
)
|
470
|
+
if image_inputs and "input_ids" in image_inputs:
|
471
|
+
input_ids = image_inputs["input_ids"]
|
449
472
|
|
450
473
|
self._validate_token_len(obj, input_ids)
|
451
474
|
return self._create_tokenized_object(
|
@@ -508,7 +531,14 @@ class TokenizerManager:
|
|
508
531
|
"Please set `--enable-custom-logits-processor` to enable this feature."
|
509
532
|
)
|
510
533
|
|
511
|
-
|
534
|
+
# Parse sampling parameters
|
535
|
+
# Note: if there are preferred sampling params, we use them if they are not
|
536
|
+
# explicitly passed in sampling_params
|
537
|
+
if self.preferred_sampling_params:
|
538
|
+
sampling_kwargs = {**self.preferred_sampling_params, **obj.sampling_params}
|
539
|
+
else:
|
540
|
+
sampling_kwargs = obj.sampling_params
|
541
|
+
sampling_params = SamplingParams(**sampling_kwargs)
|
512
542
|
sampling_params.normalize(self.tokenizer)
|
513
543
|
sampling_params.verify()
|
514
544
|
|
@@ -667,7 +697,6 @@ class TokenizerManager:
|
|
667
697
|
|
668
698
|
generators = []
|
669
699
|
rids = []
|
670
|
-
|
671
700
|
if getattr(obj, "parallel_sample_num", 1) == 1:
|
672
701
|
if self.server_args.enable_tokenizer_batch_encode:
|
673
702
|
# Validate batch tokenization constraints
|
@@ -765,6 +794,7 @@ class TokenizerManager:
|
|
765
794
|
with_stack: Optional[bool] = None,
|
766
795
|
record_shapes: Optional[bool] = None,
|
767
796
|
):
|
797
|
+
self.auto_create_handle_loop()
|
768
798
|
req = ProfileReq(
|
769
799
|
type=ProfileReqType.START_PROFILE,
|
770
800
|
output_dir=output_dir,
|
@@ -774,22 +804,29 @@ class TokenizerManager:
|
|
774
804
|
record_shapes=record_shapes,
|
775
805
|
profile_id=str(time.time()),
|
776
806
|
)
|
777
|
-
|
807
|
+
return await self._execute_profile(req)
|
808
|
+
|
809
|
+
async def stop_profile(self):
|
810
|
+
self.auto_create_handle_loop()
|
811
|
+
req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
|
812
|
+
return await self._execute_profile(req)
|
813
|
+
|
814
|
+
async def _execute_profile(self, req: ProfileReq):
|
815
|
+
result = (await self.profile_communicator(req))[0]
|
778
816
|
if not result.success:
|
779
817
|
raise RuntimeError(result.message)
|
780
818
|
return result
|
781
819
|
|
782
|
-
def stop_profile(self):
|
783
|
-
req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
|
784
|
-
self.send_to_scheduler.send_pyobj(req)
|
785
|
-
|
786
820
|
async def start_expert_distribution_record(self):
|
821
|
+
self.auto_create_handle_loop()
|
787
822
|
await self.expert_distribution_communicator(ExpertDistributionReq.START_RECORD)
|
788
823
|
|
789
824
|
async def stop_expert_distribution_record(self):
|
825
|
+
self.auto_create_handle_loop()
|
790
826
|
await self.expert_distribution_communicator(ExpertDistributionReq.STOP_RECORD)
|
791
827
|
|
792
828
|
async def dump_expert_distribution_record(self):
|
829
|
+
self.auto_create_handle_loop()
|
793
830
|
await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
|
794
831
|
|
795
832
|
async def update_weights_from_disk(
|
@@ -856,8 +893,8 @@ class TokenizerManager:
|
|
856
893
|
) -> Tuple[bool, str]:
|
857
894
|
self.auto_create_handle_loop()
|
858
895
|
assert (
|
859
|
-
self.server_args.dp_size == 1
|
860
|
-
), "dp_size must be for update weights from distributed"
|
896
|
+
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
|
897
|
+
), "dp_size must be 1 or dp attention must be enabled for update weights from distributed"
|
861
898
|
|
862
899
|
# This means that weight sync
|
863
900
|
# cannot run while requests are in progress.
|
@@ -872,8 +909,8 @@ class TokenizerManager:
|
|
872
909
|
) -> Tuple[bool, str]:
|
873
910
|
self.auto_create_handle_loop()
|
874
911
|
assert (
|
875
|
-
self.server_args.dp_size == 1
|
876
|
-
), "dp_size must be 1 for update weights from
|
912
|
+
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
|
913
|
+
), "dp_size must be 1 or dp attention must be enabled for update weights from tensor"
|
877
914
|
|
878
915
|
# This means that weight sync
|
879
916
|
# cannot run while requests are in progress.
|
@@ -946,6 +983,14 @@ class TokenizerManager:
|
|
946
983
|
# Many DP ranks
|
947
984
|
return [res.internal_state for res in responses]
|
948
985
|
|
986
|
+
async def set_internal_state(
|
987
|
+
self, obj: SetInternalStateReq
|
988
|
+
) -> SetInternalStateReqOutput:
|
989
|
+
responses: List[SetInternalStateReqOutput] = (
|
990
|
+
await self.set_internal_state_communicator(obj)
|
991
|
+
)
|
992
|
+
return [res.internal_state for res in responses]
|
993
|
+
|
949
994
|
def get_log_request_metadata(self):
|
950
995
|
max_length = None
|
951
996
|
skip_names = None
|
@@ -1015,11 +1060,17 @@ class TokenizerManager:
|
|
1015
1060
|
loop.create_task(print_exception_wrapper(self.handle_loop))
|
1016
1061
|
)
|
1017
1062
|
|
1063
|
+
self.event_loop = loop
|
1064
|
+
|
1018
1065
|
# We cannot add signal handler when the tokenizer manager is not in
|
1019
1066
|
# the main thread due to the CPython limitation.
|
1020
1067
|
if threading.current_thread() is threading.main_thread():
|
1021
1068
|
signal_handler = SignalHandler(self)
|
1022
|
-
loop.add_signal_handler(signal.SIGTERM, signal_handler.
|
1069
|
+
loop.add_signal_handler(signal.SIGTERM, signal_handler.sigterm_handler)
|
1070
|
+
# Update the signal handler for the process. It overrides the sigquit handler in the launch phase.
|
1071
|
+
loop.add_signal_handler(
|
1072
|
+
signal.SIGQUIT, signal_handler.running_phase_sigquit_handler
|
1073
|
+
)
|
1023
1074
|
else:
|
1024
1075
|
logger.warning(
|
1025
1076
|
"Signal handler is not added because the tokenizer manager is "
|
@@ -1037,6 +1088,15 @@ class TokenizerManager:
|
|
1037
1088
|
# Drain requests
|
1038
1089
|
while True:
|
1039
1090
|
remain_num_req = len(self.rid_to_state)
|
1091
|
+
|
1092
|
+
if self.health_check_failed:
|
1093
|
+
# if health check failed, we should exit immediately
|
1094
|
+
logger.error(
|
1095
|
+
"Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
|
1096
|
+
remain_num_req,
|
1097
|
+
)
|
1098
|
+
break
|
1099
|
+
|
1040
1100
|
logger.info(
|
1041
1101
|
f"Gracefully exiting... remaining number of requests {remain_num_req}"
|
1042
1102
|
)
|
@@ -1120,7 +1180,16 @@ class TokenizerManager:
|
|
1120
1180
|
"meta_info": meta_info,
|
1121
1181
|
}
|
1122
1182
|
elif isinstance(recv_obj, BatchMultimodalOut):
|
1123
|
-
|
1183
|
+
if isinstance(recv_obj.outputs[i], str):
|
1184
|
+
out_dict = {
|
1185
|
+
"text": recv_obj.outputs[i],
|
1186
|
+
"meta_info": meta_info,
|
1187
|
+
}
|
1188
|
+
else:
|
1189
|
+
out_dict = {
|
1190
|
+
"outputs": json.dumps(recv_obj.outputs[i]),
|
1191
|
+
"meta_info": meta_info,
|
1192
|
+
}
|
1124
1193
|
else:
|
1125
1194
|
assert isinstance(recv_obj, BatchEmbeddingOut)
|
1126
1195
|
out_dict = {
|
@@ -1366,12 +1435,18 @@ class SignalHandler:
|
|
1366
1435
|
def __init__(self, tokenizer_manager: TokenizerManager):
|
1367
1436
|
self.tokenizer_manager = tokenizer_manager
|
1368
1437
|
|
1369
|
-
def
|
1438
|
+
def sigterm_handler(self, signum=None, frame=None):
|
1370
1439
|
logger.warning(
|
1371
1440
|
f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
|
1372
1441
|
)
|
1373
1442
|
self.tokenizer_manager.gracefully_exit = True
|
1374
1443
|
|
1444
|
+
def running_phase_sigquit_handler(self, signum=None, frame=None):
|
1445
|
+
logger.error(
|
1446
|
+
"Received sigquit from a child process. It usually means the child failed."
|
1447
|
+
)
|
1448
|
+
kill_process_tree(os.getpid())
|
1449
|
+
|
1375
1450
|
|
1376
1451
|
T = TypeVar("T")
|
1377
1452
|
|