sglang 0.4.9.post6__py3-none-any.whl → 0.4.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/model_config.py +3 -0
- sglang/srt/configs/step3_vl.py +172 -0
- sglang/srt/conversation.py +23 -0
- sglang/srt/disaggregation/decode.py +2 -8
- sglang/srt/disaggregation/prefill.py +2 -6
- sglang/srt/distributed/parallel_state.py +86 -1
- sglang/srt/entrypoints/engine.py +14 -18
- sglang/srt/entrypoints/http_server.py +10 -2
- sglang/srt/entrypoints/openai/serving_chat.py +2 -21
- sglang/srt/eplb/expert_distribution.py +5 -0
- sglang/srt/eplb/expert_location.py +17 -6
- sglang/srt/eplb/expert_location_dispatch.py +1 -0
- sglang/srt/eplb/expert_location_updater.py +2 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/step3_detector.py +436 -0
- sglang/srt/hf_transformers_utils.py +2 -0
- sglang/srt/jinja_template_utils.py +4 -1
- sglang/srt/layers/moe/cutlass_moe.py +2 -1
- sglang/srt/layers/moe/ep_moe/layer.py +20 -640
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +26 -13
- sglang/srt/layers/moe/fused_moe_triton/layer.py +97 -38
- sglang/srt/layers/quantization/fp8.py +0 -18
- sglang/srt/layers/quantization/unquant.py +0 -8
- sglang/srt/layers/quantization/w4afp8.py +1 -0
- sglang/srt/managers/cache_controller.py +143 -45
- sglang/srt/managers/data_parallel_controller.py +2 -0
- sglang/srt/managers/io_struct.py +0 -2
- sglang/srt/managers/scheduler.py +89 -671
- sglang/srt/managers/scheduler_metrics_mixin.py +229 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +279 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +142 -0
- sglang/srt/managers/template_manager.py +62 -19
- sglang/srt/managers/tokenizer_manager.py +123 -74
- sglang/srt/managers/tp_worker.py +4 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +2 -1
- sglang/srt/mem_cache/hicache_storage.py +45 -11
- sglang/srt/mem_cache/hiradix_cache.py +15 -4
- sglang/srt/mem_cache/memory_pool_host.py +73 -1
- sglang/srt/mem_cache/mooncake_store/mooncake_store.py +264 -0
- sglang/srt/mem_cache/mooncake_store/unit_test.py +40 -0
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +177 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +278 -0
- sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py +43 -0
- sglang/srt/model_executor/model_runner.py +5 -0
- sglang/srt/models/arcee.py +532 -0
- sglang/srt/models/deepseek_v2.py +2 -0
- sglang/srt/models/glm4_moe.py +3 -1
- sglang/srt/models/granitemoe.py +3 -0
- sglang/srt/models/grok.py +3 -0
- sglang/srt/models/hunyuan.py +1 -0
- sglang/srt/models/llama4.py +3 -0
- sglang/srt/models/mixtral.py +3 -0
- sglang/srt/models/olmoe.py +3 -0
- sglang/srt/models/phimoe.py +1 -0
- sglang/srt/models/step3_vl.py +994 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -16
- sglang/srt/multimodal/processors/step3_vl.py +515 -0
- sglang/srt/reasoning_parser.py +2 -1
- sglang/srt/server_args.py +10 -13
- sglang/srt/speculative/eagle_worker.py +2 -0
- sglang/utils.py +0 -11
- sglang/version.py +1 -1
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.dist-info}/METADATA +3 -4
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.dist-info}/RECORD +69 -56
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -13,7 +13,6 @@
|
|
13
13
|
# ==============================================================================
|
14
14
|
"""A scheduler that manages a tensor parallel GPU worker."""
|
15
15
|
|
16
|
-
import datetime
|
17
16
|
import faulthandler
|
18
17
|
import logging
|
19
18
|
import os
|
@@ -21,11 +20,10 @@ import signal
|
|
21
20
|
import sys
|
22
21
|
import threading
|
23
22
|
import time
|
24
|
-
from collections import
|
23
|
+
from collections import deque
|
25
24
|
from concurrent import futures
|
26
25
|
from dataclasses import dataclass
|
27
26
|
from http import HTTPStatus
|
28
|
-
from pathlib import Path
|
29
27
|
from types import SimpleNamespace
|
30
28
|
from typing import Dict, List, Optional, Tuple, Union
|
31
29
|
|
@@ -37,7 +35,6 @@ from torch.distributed import barrier
|
|
37
35
|
|
38
36
|
from sglang.global_config import global_config
|
39
37
|
from sglang.srt.configs.model_config import ModelConfig
|
40
|
-
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
|
41
38
|
from sglang.srt.constrained.base_grammar_backend import (
|
42
39
|
INVALID_GRAMMAR_OBJ,
|
43
40
|
create_grammar_backend,
|
@@ -47,7 +44,6 @@ from sglang.srt.disaggregation.decode import (
|
|
47
44
|
DecodeTransferQueue,
|
48
45
|
SchedulerDisaggregationDecodeMixin,
|
49
46
|
)
|
50
|
-
from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
|
51
47
|
from sglang.srt.disaggregation.prefill import (
|
52
48
|
PrefillBootstrapQueue,
|
53
49
|
SchedulerDisaggregationPrefillMixin,
|
@@ -78,21 +74,15 @@ from sglang.srt.managers.io_struct import (
|
|
78
74
|
GetInternalStateReq,
|
79
75
|
GetInternalStateReqOutput,
|
80
76
|
GetWeightsByNameReqInput,
|
81
|
-
GetWeightsByNameReqOutput,
|
82
77
|
HealthCheckOutput,
|
83
78
|
InitWeightsUpdateGroupReqInput,
|
84
|
-
InitWeightsUpdateGroupReqOutput,
|
85
79
|
LoadLoRAAdapterReqInput,
|
86
80
|
LoadLoRAAdapterReqOutput,
|
87
81
|
OpenSessionReqInput,
|
88
82
|
OpenSessionReqOutput,
|
89
83
|
ProfileReq,
|
90
|
-
ProfileReqOutput,
|
91
|
-
ProfileReqType,
|
92
84
|
ReleaseMemoryOccupationReqInput,
|
93
|
-
ReleaseMemoryOccupationReqOutput,
|
94
85
|
ResumeMemoryOccupationReqInput,
|
95
|
-
ResumeMemoryOccupationReqOutput,
|
96
86
|
RpcReqInput,
|
97
87
|
RpcReqOutput,
|
98
88
|
SetInternalStateReq,
|
@@ -104,11 +94,8 @@ from sglang.srt.managers.io_struct import (
|
|
104
94
|
UnloadLoRAAdapterReqInput,
|
105
95
|
UnloadLoRAAdapterReqOutput,
|
106
96
|
UpdateWeightFromDiskReqInput,
|
107
|
-
UpdateWeightFromDiskReqOutput,
|
108
97
|
UpdateWeightsFromDistributedReqInput,
|
109
|
-
UpdateWeightsFromDistributedReqOutput,
|
110
98
|
UpdateWeightsFromTensorReqInput,
|
111
|
-
UpdateWeightsFromTensorReqOutput,
|
112
99
|
)
|
113
100
|
from sglang.srt.managers.mm_utils import init_embedding_cache
|
114
101
|
from sglang.srt.managers.schedule_batch import (
|
@@ -124,9 +111,17 @@ from sglang.srt.managers.schedule_policy import (
|
|
124
111
|
SchedulePolicy,
|
125
112
|
)
|
126
113
|
from sglang.srt.managers.scheduler_input_blocker import SchedulerInputBlocker
|
114
|
+
from sglang.srt.managers.scheduler_metrics_mixin import (
|
115
|
+
RECORD_STEP_TIME,
|
116
|
+
SchedulerMetricsMixin,
|
117
|
+
)
|
127
118
|
from sglang.srt.managers.scheduler_output_processor_mixin import (
|
128
119
|
SchedulerOutputProcessorMixin,
|
129
120
|
)
|
121
|
+
from sglang.srt.managers.scheduler_profiler_mixin import SchedulerProfilerMixin
|
122
|
+
from sglang.srt.managers.scheduler_update_weights_mixin import (
|
123
|
+
SchedulerUpdateWeightsMixin,
|
124
|
+
)
|
130
125
|
from sglang.srt.managers.session_controller import Session
|
131
126
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
132
127
|
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
@@ -135,7 +130,6 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
|
135
130
|
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
136
131
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
137
132
|
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
138
|
-
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
139
133
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
|
140
134
|
from sglang.srt.reasoning_parser import ReasoningParser
|
141
135
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
@@ -168,7 +162,6 @@ logger = logging.getLogger(__name__)
|
|
168
162
|
|
169
163
|
# Test retract decode for debugging purposes
|
170
164
|
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
|
171
|
-
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
|
172
165
|
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
|
173
166
|
|
174
167
|
_is_cpu = is_cpu()
|
@@ -191,41 +184,11 @@ class EmbeddingBatchResult:
|
|
191
184
|
bid: int
|
192
185
|
|
193
186
|
|
194
|
-
class KvMetrics:
|
195
|
-
def __init__(self):
|
196
|
-
self.request_active_slots = None
|
197
|
-
self.request_total_slots = None
|
198
|
-
self.kv_active_blocks = None
|
199
|
-
self.kv_total_blocks = None
|
200
|
-
self.num_requests_waiting = None
|
201
|
-
self.gpu_cache_usage_perc = None
|
202
|
-
self.gpu_prefix_cache_hit_rate = None
|
203
|
-
self.data_parallel_rank = None
|
204
|
-
|
205
|
-
|
206
|
-
class IdleSleeper:
|
207
|
-
"""
|
208
|
-
In setups which have long inactivity periods it is desirable to reduce
|
209
|
-
system power consumption when sglang does nothing. This would lead not only
|
210
|
-
to power savings, but also to more CPU thermal headroom when a request
|
211
|
-
eventually comes. This is important in cases when multiple GPUs are connected
|
212
|
-
as each GPU would otherwise pin one thread at 100% CPU usage.
|
213
|
-
|
214
|
-
The simplest solution is to use zmq.Poller on all sockets that may receive
|
215
|
-
data that needs handling immediately.
|
216
|
-
"""
|
217
|
-
|
218
|
-
def __init__(self, sockets):
|
219
|
-
self.poller = zmq.Poller()
|
220
|
-
for s in sockets:
|
221
|
-
self.poller.register(s, zmq.POLLIN)
|
222
|
-
|
223
|
-
def maybe_sleep(self):
|
224
|
-
self.poller.poll(1000)
|
225
|
-
|
226
|
-
|
227
187
|
class Scheduler(
|
228
188
|
SchedulerOutputProcessorMixin,
|
189
|
+
SchedulerUpdateWeightsMixin,
|
190
|
+
SchedulerProfilerMixin,
|
191
|
+
SchedulerMetricsMixin,
|
229
192
|
SchedulerDisaggregationDecodeMixin,
|
230
193
|
SchedulerDisaggregationPrefillMixin,
|
231
194
|
):
|
@@ -237,15 +200,18 @@ class Scheduler(
|
|
237
200
|
port_args: PortArgs,
|
238
201
|
gpu_id: int,
|
239
202
|
tp_rank: int,
|
203
|
+
moe_ep_rank: int,
|
240
204
|
pp_rank: int,
|
241
205
|
dp_rank: Optional[int],
|
242
206
|
):
|
243
207
|
# Parse args
|
244
208
|
self.server_args = server_args
|
245
209
|
self.tp_rank = tp_rank
|
210
|
+
self.moe_ep_rank = moe_ep_rank
|
246
211
|
self.pp_rank = pp_rank
|
247
212
|
self.dp_rank = dp_rank
|
248
213
|
self.tp_size = server_args.tp_size
|
214
|
+
self.moe_ep_size = server_args.ep_size
|
249
215
|
self.pp_size = server_args.pp_size
|
250
216
|
self.dp_size = server_args.dp_size
|
251
217
|
self.schedule_policy = server_args.schedule_policy
|
@@ -266,7 +232,7 @@ class Scheduler(
|
|
266
232
|
self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
|
267
233
|
self.enable_hicache_storage = server_args.hicache_storage_backend is not None
|
268
234
|
self.page_size = server_args.page_size
|
269
|
-
|
235
|
+
|
270
236
|
self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
|
271
237
|
compute_dp_attention_world_info(
|
272
238
|
server_args.enable_dp_attention,
|
@@ -284,10 +250,13 @@ class Scheduler(
|
|
284
250
|
self.recv_from_tokenizer = get_zmq_socket(
|
285
251
|
context, zmq.PULL, port_args.scheduler_input_ipc_name, False
|
286
252
|
)
|
253
|
+
self.recv_from_rpc = get_zmq_socket(
|
254
|
+
context, zmq.DEALER, port_args.rpc_ipc_name, False
|
255
|
+
)
|
256
|
+
|
287
257
|
self.send_to_tokenizer = get_zmq_socket(
|
288
258
|
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
|
289
259
|
)
|
290
|
-
|
291
260
|
if server_args.skip_tokenizer_init:
|
292
261
|
# Directly send to the TokenizerManager
|
293
262
|
self.send_to_detokenizer = get_zmq_socket(
|
@@ -299,9 +268,6 @@ class Scheduler(
|
|
299
268
|
context, zmq.PUSH, port_args.detokenizer_ipc_name, False
|
300
269
|
)
|
301
270
|
|
302
|
-
self.recv_from_rpc = get_zmq_socket(
|
303
|
-
context, zmq.DEALER, port_args.rpc_ipc_name, False
|
304
|
-
)
|
305
271
|
if self.server_args.sleep_on_idle:
|
306
272
|
self.idle_sleeper = IdleSleeper(
|
307
273
|
[
|
@@ -347,6 +313,7 @@ class Scheduler(
|
|
347
313
|
server_args=server_args,
|
348
314
|
gpu_id=gpu_id,
|
349
315
|
tp_rank=tp_rank,
|
316
|
+
moe_ep_rank=moe_ep_rank,
|
350
317
|
pp_rank=pp_rank,
|
351
318
|
dp_rank=dp_rank,
|
352
319
|
nccl_port=port_args.nccl_port,
|
@@ -359,6 +326,7 @@ class Scheduler(
|
|
359
326
|
self.draft_worker = EAGLEWorker(
|
360
327
|
gpu_id=gpu_id,
|
361
328
|
tp_rank=tp_rank,
|
329
|
+
moe_ep_rank=moe_ep_rank,
|
362
330
|
server_args=server_args,
|
363
331
|
nccl_port=port_args.nccl_port,
|
364
332
|
target_worker=self.tp_worker,
|
@@ -398,7 +366,7 @@ class Scheduler(
|
|
398
366
|
global_server_args_dict.update(worker_global_server_args_dict)
|
399
367
|
set_random_seed(self.random_seed)
|
400
368
|
|
401
|
-
# Hybrid
|
369
|
+
# Hybrid memory pool
|
402
370
|
self.is_hybrid = self.tp_worker.is_hybrid
|
403
371
|
if self.is_hybrid:
|
404
372
|
self.sliding_window_size = self.tp_worker.sliding_window_size
|
@@ -515,6 +483,15 @@ class Scheduler(
|
|
515
483
|
self.init_metrics(tp_rank, pp_rank, dp_rank)
|
516
484
|
self.init_kv_events(server_args.kv_events_config)
|
517
485
|
|
486
|
+
# Init disaggregation
|
487
|
+
self.disaggregation_mode = DisaggregationMode(
|
488
|
+
self.server_args.disaggregation_mode
|
489
|
+
)
|
490
|
+
self.init_disaggregation()
|
491
|
+
|
492
|
+
if get_bool_env_var("SGLANG_GC_LOG"):
|
493
|
+
configure_gc_logger()
|
494
|
+
|
518
495
|
# Init request dispatcher
|
519
496
|
self._request_dispatcher = TypeBasedDispatcher(
|
520
497
|
[
|
@@ -545,22 +522,6 @@ class Scheduler(
|
|
545
522
|
]
|
546
523
|
)
|
547
524
|
|
548
|
-
# Init disaggregation
|
549
|
-
self.disaggregation_mode = DisaggregationMode(
|
550
|
-
self.server_args.disaggregation_mode
|
551
|
-
)
|
552
|
-
self.init_disaggregation()
|
553
|
-
|
554
|
-
if get_bool_env_var("SGLANG_GC_LOG"):
|
555
|
-
configure_gc_logger()
|
556
|
-
|
557
|
-
def current_scheduler_metrics_enabled(self):
|
558
|
-
return self.attn_tp_rank == 0 or self.enable_metrics_for_all_schedulers
|
559
|
-
|
560
|
-
def maybe_sleep_on_idle(self):
|
561
|
-
if self.idle_sleeper is not None:
|
562
|
-
self.idle_sleeper.maybe_sleep()
|
563
|
-
|
564
525
|
def init_tokenizer(self):
|
565
526
|
server_args = self.server_args
|
566
527
|
|
@@ -668,50 +629,6 @@ class Scheduler(
|
|
668
629
|
embedding_cache_size = int(os.environ.get("SGLANG_VLM_CACHE_SIZE_MB", "100"))
|
669
630
|
init_embedding_cache(embedding_cache_size * 1024 * 1024)
|
670
631
|
|
671
|
-
def init_profier(self):
|
672
|
-
self.torch_profiler = None
|
673
|
-
self.torch_profiler_output_dir: Optional[str] = None
|
674
|
-
self.profiler_activities: Optional[List[str]] = None
|
675
|
-
self.profile_id: Optional[str] = None
|
676
|
-
self.profiler_start_forward_ct: Optional[int] = None
|
677
|
-
self.profiler_target_forward_ct: Optional[int] = None
|
678
|
-
self.profiler_target_prefill_ct: Optional[int] = None
|
679
|
-
self.profiler_target_decode_ct: Optional[int] = None
|
680
|
-
self.profiler_prefill_ct: Optional[int] = None
|
681
|
-
self.profiler_decode_ct: Optional[int] = None
|
682
|
-
self.profile_by_stage: bool = False
|
683
|
-
self.profile_steps: Optional[int] = None
|
684
|
-
self.profile_in_progress: bool = False
|
685
|
-
self.rpd_profiler = None
|
686
|
-
|
687
|
-
def init_metrics(self, tp_rank: int, pp_rank: int, dp_rank: Optional[int]):
|
688
|
-
self.last_gen_throughput: float = 0.0
|
689
|
-
self.last_input_throughput: float = 0.0
|
690
|
-
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
|
691
|
-
self.spec_num_total_accepted_tokens = 0
|
692
|
-
self.spec_num_total_forward_ct = 0
|
693
|
-
self.cum_spec_accept_length = 0
|
694
|
-
self.cum_spec_accept_count = 0
|
695
|
-
self.total_retracted_reqs = 0
|
696
|
-
self.stats = SchedulerStats()
|
697
|
-
if self.enable_metrics:
|
698
|
-
engine_type = "unified"
|
699
|
-
labels = {
|
700
|
-
"model_name": self.server_args.served_model_name,
|
701
|
-
"engine_type": engine_type,
|
702
|
-
"tp_rank": tp_rank,
|
703
|
-
"pp_rank": pp_rank,
|
704
|
-
}
|
705
|
-
if dp_rank is not None:
|
706
|
-
labels["dp_rank"] = dp_rank
|
707
|
-
self.metrics_collector = SchedulerMetricsCollector(labels=labels)
|
708
|
-
|
709
|
-
def init_kv_events(self, kv_events_config: Optional[str]):
|
710
|
-
if self.enable_kv_cache_events:
|
711
|
-
self.kv_event_publisher = EventPublisherFactory.create(
|
712
|
-
kv_events_config, self.attn_dp_rank
|
713
|
-
)
|
714
|
-
|
715
632
|
def init_disaggregation(self):
|
716
633
|
self.transfer_backend = TransferBackend(
|
717
634
|
self.server_args.disaggregation_transfer_backend
|
@@ -820,10 +737,7 @@ class Scheduler(
|
|
820
737
|
self.process_batch_result(batch, result)
|
821
738
|
else:
|
822
739
|
# When the server is idle, do self-check and re-init some states
|
823
|
-
self.
|
824
|
-
self.check_tree_cache()
|
825
|
-
self.new_token_ratio = self.init_new_token_ratio
|
826
|
-
self.maybe_sleep_on_idle()
|
740
|
+
self.self_check_during_idle()
|
827
741
|
|
828
742
|
self.last_batch = batch
|
829
743
|
|
@@ -866,10 +780,7 @@ class Scheduler(
|
|
866
780
|
)
|
867
781
|
elif batch is None:
|
868
782
|
# When the server is idle, do self-check and re-init some states
|
869
|
-
self.
|
870
|
-
self.check_tree_cache()
|
871
|
-
self.new_token_ratio = self.init_new_token_ratio
|
872
|
-
self.maybe_sleep_on_idle()
|
783
|
+
self.self_check_during_idle()
|
873
784
|
|
874
785
|
self.last_batch = batch
|
875
786
|
|
@@ -1003,10 +914,8 @@ class Scheduler(
|
|
1003
914
|
|
1004
915
|
# When the server is idle, self-check and re-init some states
|
1005
916
|
if server_is_idle:
|
1006
|
-
self
|
1007
|
-
self.
|
1008
|
-
self.new_token_ratio = self.init_new_token_ratio
|
1009
|
-
self.maybe_sleep_on_idle()
|
917
|
+
# When the server is idle, do self-check and re-init some states
|
918
|
+
self.self_check_during_idle()
|
1010
919
|
|
1011
920
|
def recv_requests(self) -> List[Req]:
|
1012
921
|
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
|
@@ -1281,23 +1190,28 @@ class Scheduler(
|
|
1281
1190
|
def _add_request_to_queue(self, req: Req):
|
1282
1191
|
req.queue_time_start = time.perf_counter()
|
1283
1192
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
1193
|
+
self._prefetch_kvcache(req)
|
1284
1194
|
self.disagg_prefill_bootstrap_queue.add(
|
1285
1195
|
req, self.model_config.num_key_value_heads
|
1286
1196
|
)
|
1287
1197
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
1288
1198
|
self.disagg_decode_prealloc_queue.add(req)
|
1289
1199
|
else:
|
1290
|
-
|
1291
|
-
req.init_next_round_input(self.tree_cache)
|
1292
|
-
last_hash = req.last_host_node.get_last_hash_value()
|
1293
|
-
matched_len = len(req.prefix_indices) + req.host_hit_length
|
1294
|
-
if (matched_len > 0 and last_hash is not None) or matched_len == 0:
|
1295
|
-
new_input_tokens = req.fill_ids[matched_len:]
|
1296
|
-
self.tree_cache.prefetch_from_storage(
|
1297
|
-
req.rid, req.last_host_node, new_input_tokens, last_hash
|
1298
|
-
)
|
1200
|
+
self._prefetch_kvcache(req)
|
1299
1201
|
self.waiting_queue.append(req)
|
1300
1202
|
|
1203
|
+
def _prefetch_kvcache(self, req: Req):
|
1204
|
+
if self.enable_hicache_storage:
|
1205
|
+
req.init_next_round_input(self.tree_cache)
|
1206
|
+
last_hash = req.last_host_node.get_last_hash_value()
|
1207
|
+
matched_len = len(req.prefix_indices) + req.host_hit_length
|
1208
|
+
# todo, free-form fetching, calculating hash keys on the fly
|
1209
|
+
if (matched_len > 0 and last_hash is not None) or matched_len == 0:
|
1210
|
+
new_input_tokens = req.fill_ids[matched_len:]
|
1211
|
+
self.tree_cache.prefetch_from_storage(
|
1212
|
+
req.rid, req.last_host_node, new_input_tokens, last_hash
|
1213
|
+
)
|
1214
|
+
|
1301
1215
|
def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
|
1302
1216
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
1303
1217
|
self.disagg_prefill_bootstrap_queue.extend(
|
@@ -1355,170 +1269,11 @@ class Scheduler(
|
|
1355
1269
|
req.logprob_start_len = len(req.origin_input_ids) - 1
|
1356
1270
|
self._add_request_to_queue(req)
|
1357
1271
|
|
1358
|
-
def
|
1359
|
-
|
1360
|
-
|
1361
|
-
|
1362
|
-
|
1363
|
-
self.stats.token_usage * self.max_total_num_tokens
|
1364
|
-
)
|
1365
|
-
kv_metrics.kv_total_blocks = self.max_total_num_tokens
|
1366
|
-
kv_metrics.num_requests_waiting = self.stats.num_queue_reqs
|
1367
|
-
kv_metrics.gpu_cache_usage_perc = self.stats.token_usage
|
1368
|
-
kv_metrics.gpu_prefix_cache_hit_rate = self.stats.cache_hit_rate
|
1369
|
-
kv_metrics.data_parallel_rank = self.dp_rank if self.dp_rank is not None else 0
|
1370
|
-
|
1371
|
-
if not self.send_metrics_from_scheduler.closed:
|
1372
|
-
self.send_metrics_from_scheduler.send_pyobj(kv_metrics)
|
1373
|
-
|
1374
|
-
def log_prefill_stats(
|
1375
|
-
self,
|
1376
|
-
adder: PrefillAdder,
|
1377
|
-
can_run_list: List[Req],
|
1378
|
-
running_bs: int,
|
1379
|
-
):
|
1380
|
-
gap_latency = time.perf_counter() - self.last_prefill_stats_tic
|
1381
|
-
self.last_prefill_stats_tic = time.perf_counter()
|
1382
|
-
self.last_input_throughput = self.last_prefill_tokens / gap_latency
|
1383
|
-
self.last_prefill_tokens = adder.log_input_tokens
|
1384
|
-
|
1385
|
-
if self.is_hybrid:
|
1386
|
-
(
|
1387
|
-
full_num_used,
|
1388
|
-
swa_num_used,
|
1389
|
-
full_token_usage,
|
1390
|
-
swa_token_usage,
|
1391
|
-
_,
|
1392
|
-
_,
|
1393
|
-
_,
|
1394
|
-
_,
|
1395
|
-
) = self._get_swa_token_info()
|
1396
|
-
num_used = max(full_num_used, swa_num_used)
|
1397
|
-
token_usage = max(full_token_usage, swa_token_usage)
|
1398
|
-
token_msg = (
|
1399
|
-
f"full token usage: {full_token_usage:.2f}, "
|
1400
|
-
f"swa token usage: {swa_token_usage:.2f}, "
|
1401
|
-
)
|
1402
|
-
else:
|
1403
|
-
num_used, token_usage, _, _ = self._get_token_info()
|
1404
|
-
token_msg = f"token usage: {token_usage:.2f}, "
|
1405
|
-
|
1406
|
-
num_new_seq = len(can_run_list)
|
1407
|
-
f = (
|
1408
|
-
f"Prefill batch. "
|
1409
|
-
f"#new-seq: {num_new_seq}, "
|
1410
|
-
f"#new-token: {adder.log_input_tokens}, "
|
1411
|
-
f"#cached-token: {adder.log_hit_tokens}, "
|
1412
|
-
f"{token_msg}"
|
1413
|
-
)
|
1414
|
-
|
1415
|
-
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
1416
|
-
f += f"#unbootstrapped-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
|
1417
|
-
f += f"#queue-req: {len(self.waiting_queue)}, "
|
1418
|
-
f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)}, "
|
1419
|
-
f += f"input throughput (token/s): {self.last_input_throughput:.2f}, "
|
1420
|
-
else:
|
1421
|
-
f += f"#running-req: {running_bs}, "
|
1422
|
-
f += f"#queue-req: {len(self.waiting_queue)}, "
|
1423
|
-
|
1424
|
-
logger.info(f)
|
1425
|
-
|
1426
|
-
if self.enable_metrics:
|
1427
|
-
total_tokens = adder.log_input_tokens + adder.log_hit_tokens
|
1428
|
-
|
1429
|
-
cache_hit_rate = (
|
1430
|
-
adder.log_hit_tokens / total_tokens if total_tokens > 0 else 0.0
|
1431
|
-
)
|
1432
|
-
self.stats.num_running_reqs = running_bs
|
1433
|
-
self.stats.num_used_tokens = num_used
|
1434
|
-
self.stats.token_usage = round(token_usage, 2)
|
1435
|
-
self.stats.num_queue_reqs = len(self.waiting_queue)
|
1436
|
-
self.stats.cache_hit_rate = cache_hit_rate
|
1437
|
-
|
1438
|
-
total_queue_latency = 0
|
1439
|
-
for req in can_run_list:
|
1440
|
-
total_queue_latency += req.queue_time_end - req.queue_time_start
|
1441
|
-
self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq
|
1442
|
-
|
1443
|
-
self.metrics_collector.log_stats(self.stats)
|
1444
|
-
self._emit_kv_metrics()
|
1445
|
-
self._publish_kv_events()
|
1446
|
-
|
1447
|
-
def log_decode_stats(
|
1448
|
-
self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
|
1449
|
-
):
|
1450
|
-
batch = running_batch or self.running_batch
|
1451
|
-
|
1452
|
-
gap_latency = time.perf_counter() - self.last_decode_stats_tic
|
1453
|
-
self.last_decode_stats_tic = time.perf_counter()
|
1454
|
-
self.last_gen_throughput = self.num_generated_tokens / gap_latency
|
1455
|
-
self.num_generated_tokens = 0
|
1456
|
-
num_running_reqs = len(batch.reqs)
|
1457
|
-
if self.is_hybrid:
|
1458
|
-
(
|
1459
|
-
full_num_used,
|
1460
|
-
swa_num_used,
|
1461
|
-
full_token_usage,
|
1462
|
-
swa_token_usage,
|
1463
|
-
_,
|
1464
|
-
_,
|
1465
|
-
_,
|
1466
|
-
_,
|
1467
|
-
) = self._get_swa_token_info()
|
1468
|
-
num_used = max(full_num_used, swa_num_used)
|
1469
|
-
token_usage = max(full_token_usage, swa_token_usage)
|
1470
|
-
token_msg = (
|
1471
|
-
f"#full token: {full_num_used}, "
|
1472
|
-
f"full token usage: {full_token_usage:.2f}, "
|
1473
|
-
f"#swa token: {swa_num_used}, "
|
1474
|
-
f"swa token usage: {swa_token_usage:.2f}, "
|
1475
|
-
)
|
1476
|
-
else:
|
1477
|
-
num_used, token_usage, _, _ = self._get_token_info()
|
1478
|
-
token_msg = f"#token: {num_used}, " f"token usage: {token_usage:.2f}, "
|
1479
|
-
|
1480
|
-
if RECORD_STEP_TIME:
|
1481
|
-
self.step_time_dict[num_running_reqs].append(
|
1482
|
-
gap_latency / self.server_args.decode_log_interval
|
1483
|
-
)
|
1484
|
-
|
1485
|
-
msg = f"Decode batch. #running-req: {num_running_reqs}, {token_msg}"
|
1486
|
-
|
1487
|
-
if self.spec_algorithm.is_none():
|
1488
|
-
spec_accept_length = 0
|
1489
|
-
else:
|
1490
|
-
spec_accept_length = (
|
1491
|
-
self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
|
1492
|
-
)
|
1493
|
-
self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
|
1494
|
-
self.cum_spec_accept_count += self.spec_num_total_forward_ct
|
1495
|
-
self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
|
1496
|
-
msg += f"accept len: {spec_accept_length:.2f}, "
|
1497
|
-
|
1498
|
-
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
1499
|
-
msg += f"pre-allocated usage: {self.disagg_decode_prealloc_queue.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
|
1500
|
-
msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, "
|
1501
|
-
|
1502
|
-
msg += (
|
1503
|
-
f"cuda graph: {can_run_cuda_graph}, "
|
1504
|
-
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
|
1505
|
-
f"#queue-req: {len(self.waiting_queue)}, "
|
1506
|
-
)
|
1507
|
-
|
1508
|
-
logger.info(msg)
|
1509
|
-
if self.enable_metrics:
|
1510
|
-
self.stats.num_running_reqs = num_running_reqs
|
1511
|
-
self.stats.num_used_tokens = num_used
|
1512
|
-
self.stats.token_usage = round(token_usage, 2)
|
1513
|
-
self.stats.cache_hit_rate = 0.0
|
1514
|
-
self.stats.gen_throughput = self.last_gen_throughput
|
1515
|
-
self.stats.num_queue_reqs = len(self.waiting_queue)
|
1516
|
-
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
1517
|
-
self.stats.spec_accept_length = spec_accept_length
|
1518
|
-
self.stats.total_retracted_reqs = self.total_retracted_reqs
|
1519
|
-
self.metrics_collector.log_stats(self.stats)
|
1520
|
-
self._emit_kv_metrics()
|
1521
|
-
self._publish_kv_events()
|
1272
|
+
def self_check_during_idle(self):
|
1273
|
+
self.check_memory()
|
1274
|
+
self.check_tree_cache()
|
1275
|
+
self.new_token_ratio = self.init_new_token_ratio
|
1276
|
+
self.maybe_sleep_on_idle()
|
1522
1277
|
|
1523
1278
|
def check_memory(self):
|
1524
1279
|
if self.is_hybrid:
|
@@ -2422,22 +2177,6 @@ class Scheduler(
|
|
2422
2177
|
barrier()
|
2423
2178
|
return RpcReqOutput(success, "" if not exec else str(exec))
|
2424
2179
|
|
2425
|
-
def save_remote_model(self, params):
|
2426
|
-
url = params["url"]
|
2427
|
-
|
2428
|
-
worker = self.tp_worker.worker
|
2429
|
-
|
2430
|
-
worker.model_runner.save_remote_model(url)
|
2431
|
-
|
2432
|
-
def save_sharded_model(self, params):
|
2433
|
-
worker = self.tp_worker.worker
|
2434
|
-
|
2435
|
-
worker.model_runner.save_sharded_model(
|
2436
|
-
path=params["path"],
|
2437
|
-
pattern=params["pattern"],
|
2438
|
-
max_size=params["max_size"],
|
2439
|
-
)
|
2440
|
-
|
2441
2180
|
def abort_request(self, recv_req: AbortReq):
|
2442
2181
|
# Delete requests in the waiting queue
|
2443
2182
|
to_del = []
|
@@ -2515,16 +2254,6 @@ class Scheduler(
|
|
2515
2254
|
def _pause_engine(self) -> Tuple[List[Req], int]:
|
2516
2255
|
raise NotImplementedError()
|
2517
2256
|
|
2518
|
-
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
2519
|
-
"""In-place update of the weights from disk."""
|
2520
|
-
success, message = self.tp_worker.update_weights_from_disk(recv_req)
|
2521
|
-
if success:
|
2522
|
-
flush_cache_success = self.flush_cache()
|
2523
|
-
assert flush_cache_success, "Cache flush failed after updating weights"
|
2524
|
-
else:
|
2525
|
-
logger.error(message)
|
2526
|
-
return UpdateWeightFromDiskReqOutput(success, message, 0)
|
2527
|
-
|
2528
2257
|
def load_lora_adapter(
|
2529
2258
|
self, recv_req: LoadLoRAAdapterReqInput
|
2530
2259
|
) -> LoadLoRAAdapterReqOutput:
|
@@ -2541,81 +2270,6 @@ class Scheduler(
|
|
2541
2270
|
result = self.tp_worker.unload_lora_adapter(recv_req)
|
2542
2271
|
return result
|
2543
2272
|
|
2544
|
-
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
|
2545
|
-
"""Initialize the online model parameter update group."""
|
2546
|
-
success, message = self.tp_worker.init_weights_update_group(recv_req)
|
2547
|
-
return InitWeightsUpdateGroupReqOutput(success, message)
|
2548
|
-
|
2549
|
-
def update_weights_from_distributed(
|
2550
|
-
self,
|
2551
|
-
recv_req: UpdateWeightsFromDistributedReqInput,
|
2552
|
-
) -> Tuple[bool, str]:
|
2553
|
-
"""Update the online model parameter."""
|
2554
|
-
success, message = self.tp_worker.update_weights_from_distributed(recv_req)
|
2555
|
-
if success:
|
2556
|
-
if recv_req.flush_cache:
|
2557
|
-
flush_cache_success = self.flush_cache()
|
2558
|
-
assert flush_cache_success, "Cache flush failed after updating weights"
|
2559
|
-
else:
|
2560
|
-
logger.error(message)
|
2561
|
-
return UpdateWeightsFromDistributedReqOutput(success, message)
|
2562
|
-
|
2563
|
-
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
2564
|
-
"""Update the online model parameter from tensors."""
|
2565
|
-
success, message = self.tp_worker.update_weights_from_tensor(recv_req)
|
2566
|
-
# TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
|
2567
|
-
if success:
|
2568
|
-
if recv_req.flush_cache:
|
2569
|
-
flush_cache_success = self.flush_cache()
|
2570
|
-
assert flush_cache_success, "Cache flush failed after updating weights"
|
2571
|
-
else:
|
2572
|
-
logger.error(message)
|
2573
|
-
barrier(group=self.tp_cpu_group)
|
2574
|
-
return UpdateWeightsFromTensorReqOutput(success, message)
|
2575
|
-
|
2576
|
-
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
2577
|
-
parameter = self.tp_worker.get_weights_by_name(recv_req)
|
2578
|
-
return GetWeightsByNameReqOutput(parameter)
|
2579
|
-
|
2580
|
-
def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
|
2581
|
-
tags = recv_req.tags
|
2582
|
-
|
2583
|
-
if tags is None or len(tags) == 0:
|
2584
|
-
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
|
2585
|
-
|
2586
|
-
if GPU_MEMORY_TYPE_KV_CACHE in tags:
|
2587
|
-
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE)
|
2588
|
-
self.flush_cache()
|
2589
|
-
|
2590
|
-
if GPU_MEMORY_TYPE_WEIGHTS in tags:
|
2591
|
-
self.stashed_model_static_state = _export_static_state(
|
2592
|
-
self.tp_worker.worker.model_runner.model
|
2593
|
-
)
|
2594
|
-
torch.distributed.barrier(self.tp_cpu_group)
|
2595
|
-
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_WEIGHTS)
|
2596
|
-
|
2597
|
-
return ReleaseMemoryOccupationReqOutput()
|
2598
|
-
|
2599
|
-
def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
|
2600
|
-
tags = recv_req.tags
|
2601
|
-
|
2602
|
-
if tags is None or len(tags) == 0:
|
2603
|
-
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
|
2604
|
-
|
2605
|
-
if GPU_MEMORY_TYPE_WEIGHTS in tags:
|
2606
|
-
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
|
2607
|
-
torch.distributed.barrier(self.tp_cpu_group)
|
2608
|
-
_import_static_state(
|
2609
|
-
self.tp_worker.worker.model_runner.model,
|
2610
|
-
self.stashed_model_static_state,
|
2611
|
-
)
|
2612
|
-
del self.stashed_model_static_state
|
2613
|
-
|
2614
|
-
if GPU_MEMORY_TYPE_KV_CACHE in tags:
|
2615
|
-
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_KV_CACHE)
|
2616
|
-
|
2617
|
-
return ResumeMemoryOccupationReqOutput()
|
2618
|
-
|
2619
2273
|
def slow_down(self, recv_req: SlowDownReqInput):
|
2620
2274
|
t = recv_req.forward_sleep_time
|
2621
2275
|
if t is not None and t <= 0:
|
@@ -2623,254 +2277,6 @@ class Scheduler(
|
|
2623
2277
|
self.forward_sleep_time = t
|
2624
2278
|
return SlowDownReqOutput()
|
2625
2279
|
|
2626
|
-
def profile(self, recv_req: ProfileReq):
|
2627
|
-
if recv_req.type == ProfileReqType.START_PROFILE:
|
2628
|
-
if recv_req.profile_by_stage or recv_req.start_step:
|
2629
|
-
return self.init_profile(
|
2630
|
-
recv_req.output_dir,
|
2631
|
-
recv_req.start_step,
|
2632
|
-
recv_req.num_steps,
|
2633
|
-
recv_req.activities,
|
2634
|
-
recv_req.with_stack,
|
2635
|
-
recv_req.record_shapes,
|
2636
|
-
recv_req.profile_by_stage,
|
2637
|
-
recv_req.profile_id,
|
2638
|
-
)
|
2639
|
-
else:
|
2640
|
-
self.init_profile(
|
2641
|
-
recv_req.output_dir,
|
2642
|
-
recv_req.start_step,
|
2643
|
-
recv_req.num_steps,
|
2644
|
-
recv_req.activities,
|
2645
|
-
recv_req.with_stack,
|
2646
|
-
recv_req.record_shapes,
|
2647
|
-
recv_req.profile_by_stage,
|
2648
|
-
recv_req.profile_id,
|
2649
|
-
)
|
2650
|
-
return self.start_profile(True)
|
2651
|
-
else:
|
2652
|
-
return self.stop_profile()
|
2653
|
-
|
2654
|
-
def init_profile(
|
2655
|
-
self,
|
2656
|
-
output_dir: Optional[str],
|
2657
|
-
start_step: Optional[int],
|
2658
|
-
num_steps: Optional[int],
|
2659
|
-
activities: Optional[List[str]],
|
2660
|
-
with_stack: Optional[bool],
|
2661
|
-
record_shapes: Optional[bool],
|
2662
|
-
profile_by_stage: bool,
|
2663
|
-
profile_id: str,
|
2664
|
-
) -> ProfileReqOutput:
|
2665
|
-
if self.profile_in_progress:
|
2666
|
-
return ProfileReqOutput(
|
2667
|
-
success=False,
|
2668
|
-
message="Profiling is already in progress. Call /stop_profile first.",
|
2669
|
-
)
|
2670
|
-
|
2671
|
-
self.profile_by_stage = profile_by_stage
|
2672
|
-
|
2673
|
-
if output_dir is None:
|
2674
|
-
output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp")
|
2675
|
-
if activities is None:
|
2676
|
-
activities = ["CPU", "GPU"]
|
2677
|
-
|
2678
|
-
self.torch_profiler_output_dir = output_dir
|
2679
|
-
self.torch_profiler_with_stack = with_stack
|
2680
|
-
self.torch_profiler_record_shapes = record_shapes
|
2681
|
-
self.profiler_activities = activities
|
2682
|
-
self.profile_id = profile_id
|
2683
|
-
|
2684
|
-
if start_step:
|
2685
|
-
self.profiler_start_forward_ct = max(start_step, self.forward_ct + 1)
|
2686
|
-
|
2687
|
-
if num_steps:
|
2688
|
-
self.profile_steps = num_steps
|
2689
|
-
if self.profile_by_stage:
|
2690
|
-
self.profiler_target_prefill_ct = num_steps
|
2691
|
-
self.profiler_target_decode_ct = num_steps
|
2692
|
-
self.profiler_prefill_ct = 0
|
2693
|
-
self.profiler_decode_ct = 0
|
2694
|
-
elif start_step:
|
2695
|
-
self.profiler_target_forward_ct = (
|
2696
|
-
self.profiler_start_forward_ct + num_steps
|
2697
|
-
)
|
2698
|
-
else:
|
2699
|
-
self.profiler_target_forward_ct = self.forward_ct + num_steps
|
2700
|
-
# The caller will be notified when reaching profiler_target_forward_ct
|
2701
|
-
else:
|
2702
|
-
self.profiler_target_forward_ct = None
|
2703
|
-
|
2704
|
-
return ProfileReqOutput(success=True, message="Succeeded")
|
2705
|
-
|
2706
|
-
def start_profile(
|
2707
|
-
self, stage: Optional[ForwardMode] = None
|
2708
|
-
) -> ProfileReqOutput | None:
|
2709
|
-
stage_str = f" for {stage.__str__()}" if stage else ""
|
2710
|
-
logger.info(
|
2711
|
-
f"Profiling starts{stage_str}. Traces will be saved to: {self.torch_profiler_output_dir} (with profile id: {self.profile_id})",
|
2712
|
-
)
|
2713
|
-
|
2714
|
-
activities = self.profiler_activities
|
2715
|
-
with_stack = self.torch_profiler_with_stack
|
2716
|
-
record_shapes = self.torch_profiler_record_shapes
|
2717
|
-
|
2718
|
-
activity_map = {
|
2719
|
-
"CPU": torch.profiler.ProfilerActivity.CPU,
|
2720
|
-
"GPU": torch.profiler.ProfilerActivity.CUDA,
|
2721
|
-
}
|
2722
|
-
torchprof_activities = [
|
2723
|
-
activity_map[a] for a in activities if a in activity_map
|
2724
|
-
]
|
2725
|
-
|
2726
|
-
if "RPD" in activities:
|
2727
|
-
from rpdTracerControl import rpdTracerControl
|
2728
|
-
|
2729
|
-
rpdTracerControl.skipCreate()
|
2730
|
-
|
2731
|
-
self.rpd_profile_path = os.path.join(
|
2732
|
-
self.torch_profiler_output_dir,
|
2733
|
-
"rpd-" + str(time.time()) + f"-TP-{self.tp_rank}" + ".trace.json.gz",
|
2734
|
-
)
|
2735
|
-
|
2736
|
-
if self.tp_rank == 0:
|
2737
|
-
import sqlite3
|
2738
|
-
|
2739
|
-
from rocpd.schema import RocpdSchema
|
2740
|
-
|
2741
|
-
if os.path.exists("trace.rpd"):
|
2742
|
-
os.unlink("trace.rpd")
|
2743
|
-
schema = RocpdSchema()
|
2744
|
-
connection = sqlite3.connect("trace.rpd")
|
2745
|
-
schema.writeSchema(connection)
|
2746
|
-
connection.commit()
|
2747
|
-
del connection
|
2748
|
-
torch.distributed.barrier(self.tp_cpu_group)
|
2749
|
-
|
2750
|
-
self.rpd_profiler = rpdTracerControl()
|
2751
|
-
self.rpd_profiler.setPythonTrace(True)
|
2752
|
-
self.rpd_profiler.start()
|
2753
|
-
self.rpd_profiler.rangePush("", "rpd profile range", "")
|
2754
|
-
self.profile_in_progress = True
|
2755
|
-
elif torchprof_activities:
|
2756
|
-
self.torch_profiler = torch.profiler.profile(
|
2757
|
-
activities=torchprof_activities,
|
2758
|
-
with_stack=with_stack if with_stack is not None else True,
|
2759
|
-
record_shapes=record_shapes if record_shapes is not None else False,
|
2760
|
-
)
|
2761
|
-
self.torch_profiler.start()
|
2762
|
-
self.profile_in_progress = True
|
2763
|
-
|
2764
|
-
if "MEM" in activities:
|
2765
|
-
torch.cuda.memory._record_memory_history(max_entries=100000)
|
2766
|
-
self.profile_in_progress = True
|
2767
|
-
|
2768
|
-
if "CUDA_PROFILER" in activities:
|
2769
|
-
torch.cuda.cudart().cudaProfilerStart()
|
2770
|
-
self.profile_in_progress = True
|
2771
|
-
|
2772
|
-
return ProfileReqOutput(success=True, message="Succeeded")
|
2773
|
-
|
2774
|
-
def stop_profile(
|
2775
|
-
self, stage: Optional[ForwardMode] = None
|
2776
|
-
) -> ProfileReqOutput | None:
|
2777
|
-
if not self.profile_in_progress:
|
2778
|
-
return ProfileReqOutput(
|
2779
|
-
success=False,
|
2780
|
-
message="Profiling is not in progress. Call /start_profile first.",
|
2781
|
-
)
|
2782
|
-
|
2783
|
-
if not Path(self.torch_profiler_output_dir).exists():
|
2784
|
-
Path(self.torch_profiler_output_dir).mkdir(parents=True, exist_ok=True)
|
2785
|
-
|
2786
|
-
stage_suffix = f"-{stage.__str__()}" if stage else ""
|
2787
|
-
logger.info("Stop profiling" + stage_suffix + "...")
|
2788
|
-
if self.torch_profiler is not None:
|
2789
|
-
self.torch_profiler.stop()
|
2790
|
-
self.torch_profiler.export_chrome_trace(
|
2791
|
-
os.path.join(
|
2792
|
-
self.torch_profiler_output_dir,
|
2793
|
-
self.profile_id
|
2794
|
-
+ f"-TP-{self.tp_rank}"
|
2795
|
-
+ stage_suffix
|
2796
|
-
+ ".trace.json.gz",
|
2797
|
-
)
|
2798
|
-
)
|
2799
|
-
torch.distributed.barrier(self.tp_cpu_group)
|
2800
|
-
|
2801
|
-
if self.rpd_profiler is not None:
|
2802
|
-
self.rpd_profiler.rangePop()
|
2803
|
-
self.rpd_profiler.stop()
|
2804
|
-
self.rpd_profiler.flush()
|
2805
|
-
|
2806
|
-
torch.distributed.barrier(self.tp_cpu_group)
|
2807
|
-
if self.tp_rank == 0:
|
2808
|
-
from sglang.srt.utils import rpd_to_chrome_trace
|
2809
|
-
|
2810
|
-
rpd_to_chrome_trace("trace.rpd", self.rpd_profile_path)
|
2811
|
-
self.rpd_profiler = None
|
2812
|
-
self.rpd_profiler_path = None
|
2813
|
-
|
2814
|
-
if self.profiler_activities is not None and "MEM" in self.profiler_activities:
|
2815
|
-
memory_profile_path = os.path.join(
|
2816
|
-
self.torch_profiler_output_dir,
|
2817
|
-
str(time.time())
|
2818
|
-
+ f"-TP-{self.tp_rank}-memory"
|
2819
|
-
+ stage_suffix
|
2820
|
-
+ ".pickle",
|
2821
|
-
)
|
2822
|
-
torch.cuda.memory._dump_snapshot(memory_profile_path)
|
2823
|
-
torch.cuda.memory._record_memory_history(enabled=None)
|
2824
|
-
|
2825
|
-
if "CUDA_PROFILER" in self.profiler_activities:
|
2826
|
-
torch.cuda.cudart().cudaProfilerStop()
|
2827
|
-
|
2828
|
-
logger.info(
|
2829
|
-
"Profiling done. Traces are saved to: %s",
|
2830
|
-
self.torch_profiler_output_dir,
|
2831
|
-
)
|
2832
|
-
self.torch_profiler = None
|
2833
|
-
self.profile_in_progress = False
|
2834
|
-
self.profiler_start_forward_ct = None
|
2835
|
-
|
2836
|
-
return ProfileReqOutput(success=True, message="Succeeded.")
|
2837
|
-
|
2838
|
-
def _profile_batch_predicate(self, batch):
|
2839
|
-
if self.profile_by_stage:
|
2840
|
-
if batch.forward_mode.is_prefill():
|
2841
|
-
if self.profiler_prefill_ct == 0:
|
2842
|
-
self.start_profile(batch.forward_mode)
|
2843
|
-
self.profiler_prefill_ct += 1
|
2844
|
-
if self.profiler_prefill_ct > self.profiler_target_prefill_ct:
|
2845
|
-
if self.profile_in_progress:
|
2846
|
-
self.stop_profile(stage=ForwardMode.EXTEND)
|
2847
|
-
elif batch.forward_mode.is_decode():
|
2848
|
-
if self.profiler_decode_ct == 0:
|
2849
|
-
if self.profile_in_progress:
|
2850
|
-
# force trace flush
|
2851
|
-
self.stop_profile(ForwardMode.EXTEND)
|
2852
|
-
self.start_profile(batch.forward_mode)
|
2853
|
-
self.profiler_decode_ct += 1
|
2854
|
-
if self.profiler_decode_ct > self.profiler_target_decode_ct:
|
2855
|
-
if self.profile_in_progress:
|
2856
|
-
self.stop_profile(stage=ForwardMode.DECODE)
|
2857
|
-
elif batch.forward_mode.is_idle():
|
2858
|
-
pass
|
2859
|
-
else:
|
2860
|
-
raise RuntimeError(f"unsupported profile stage: {batch.forward_mode}")
|
2861
|
-
else:
|
2862
|
-
# Check profiler
|
2863
|
-
if (
|
2864
|
-
self.profiler_target_forward_ct
|
2865
|
-
and self.profiler_target_forward_ct <= self.forward_ct
|
2866
|
-
):
|
2867
|
-
self.stop_profile()
|
2868
|
-
if (
|
2869
|
-
self.profiler_start_forward_ct
|
2870
|
-
and self.profiler_start_forward_ct == self.forward_ct
|
2871
|
-
):
|
2872
|
-
self.start_profile()
|
2873
|
-
|
2874
2280
|
def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
|
2875
2281
|
if recv_req == ExpertDistributionReq.START_RECORD:
|
2876
2282
|
get_global_expert_distribution_recorder().start_record()
|
@@ -2879,7 +2285,7 @@ class Scheduler(
|
|
2879
2285
|
elif recv_req == ExpertDistributionReq.DUMP_RECORD:
|
2880
2286
|
get_global_expert_distribution_recorder().dump_record()
|
2881
2287
|
else:
|
2882
|
-
raise ValueError("Unrecognized ExpertDistributionReq value")
|
2288
|
+
raise ValueError(f"Unrecognized ExpertDistributionReq value: {recv_req=}")
|
2883
2289
|
return ExpertDistributionReqOutput()
|
2884
2290
|
|
2885
2291
|
def open_session(self, recv_req: OpenSessionReqInput):
|
@@ -2915,34 +2321,41 @@ class Scheduler(
|
|
2915
2321
|
prefix += f" PP{self.pp_rank}"
|
2916
2322
|
return prefix
|
2917
2323
|
|
2918
|
-
def
|
2919
|
-
|
2920
|
-
events = self.tree_cache.take_events()
|
2921
|
-
if events:
|
2922
|
-
batch = KVEventBatch(ts=time.time(), events=events)
|
2923
|
-
self.kv_event_publisher.publish(batch)
|
2324
|
+
def current_scheduler_metrics_enabled(self):
|
2325
|
+
return self.attn_tp_rank == 0 or self.enable_metrics_for_all_schedulers
|
2924
2326
|
|
2327
|
+
def maybe_sleep_on_idle(self):
|
2328
|
+
if self.idle_sleeper is not None:
|
2329
|
+
self.idle_sleeper.maybe_sleep()
|
2925
2330
|
|
2926
|
-
def is_health_check_generate_req(recv_req):
|
2927
|
-
return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
|
2928
2331
|
|
2332
|
+
class IdleSleeper:
|
2333
|
+
"""
|
2334
|
+
In setups which have long inactivity periods it is desirable to reduce
|
2335
|
+
system power consumption when sglang does nothing. This would lead not only
|
2336
|
+
to power savings, but also to more CPU thermal headroom when a request
|
2337
|
+
eventually comes. This is important in cases when multiple GPUs are connected
|
2338
|
+
as each GPU would otherwise pin one thread at 100% CPU usage.
|
2929
2339
|
|
2930
|
-
|
2931
|
-
|
2340
|
+
The simplest solution is to use zmq.Poller on all sockets that may receive
|
2341
|
+
data that needs handling immediately.
|
2342
|
+
"""
|
2932
2343
|
|
2344
|
+
def __init__(self, sockets):
|
2345
|
+
self.poller = zmq.Poller()
|
2346
|
+
for s in sockets:
|
2347
|
+
self.poller.register(s, zmq.POLLIN)
|
2933
2348
|
|
2934
|
-
def
|
2935
|
-
|
2936
|
-
buffers=[
|
2937
|
-
(name, buffer.detach().clone()) for name, buffer in model.named_buffers()
|
2938
|
-
]
|
2939
|
-
)
|
2349
|
+
def maybe_sleep(self):
|
2350
|
+
self.poller.poll(1000)
|
2940
2351
|
|
2941
2352
|
|
2942
|
-
def
|
2943
|
-
|
2944
|
-
|
2945
|
-
|
2353
|
+
def is_health_check_generate_req(recv_req):
|
2354
|
+
return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
|
2355
|
+
|
2356
|
+
|
2357
|
+
def is_work_request(recv_req):
|
2358
|
+
return isinstance(recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput))
|
2946
2359
|
|
2947
2360
|
|
2948
2361
|
def run_scheduler_process(
|
@@ -2950,6 +2363,7 @@ def run_scheduler_process(
|
|
2950
2363
|
port_args: PortArgs,
|
2951
2364
|
gpu_id: int,
|
2952
2365
|
tp_rank: int,
|
2366
|
+
moe_ep_rank: int,
|
2953
2367
|
pp_rank: int,
|
2954
2368
|
dp_rank: Optional[int],
|
2955
2369
|
pipe_writer,
|
@@ -2960,6 +2374,8 @@ def run_scheduler_process(
|
|
2960
2374
|
prefix += f" DP{dp_rank}"
|
2961
2375
|
if server_args.tp_size > 1:
|
2962
2376
|
prefix += f" TP{tp_rank}"
|
2377
|
+
if server_args.ep_size > 1:
|
2378
|
+
prefix += f" EP{moe_ep_rank}"
|
2963
2379
|
if server_args.pp_size > 1:
|
2964
2380
|
prefix += f" PP{pp_rank}"
|
2965
2381
|
|
@@ -2983,7 +2399,9 @@ def run_scheduler_process(
|
|
2983
2399
|
|
2984
2400
|
# Create a scheduler and run the event loop
|
2985
2401
|
try:
|
2986
|
-
scheduler = Scheduler(
|
2402
|
+
scheduler = Scheduler(
|
2403
|
+
server_args, port_args, gpu_id, tp_rank, moe_ep_rank, pp_rank, dp_rank
|
2404
|
+
)
|
2987
2405
|
pipe_writer.send(
|
2988
2406
|
{
|
2989
2407
|
"status": "ready",
|