sglang 0.4.9.post5__py3-none-any.whl → 0.4.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/model_config.py +6 -0
- sglang/srt/configs/step3_vl.py +172 -0
- sglang/srt/conversation.py +23 -0
- sglang/srt/disaggregation/decode.py +2 -8
- sglang/srt/disaggregation/prefill.py +2 -6
- sglang/srt/distributed/parallel_state.py +86 -1
- sglang/srt/entrypoints/engine.py +14 -18
- sglang/srt/entrypoints/http_server.py +23 -3
- sglang/srt/entrypoints/openai/protocol.py +3 -1
- sglang/srt/entrypoints/openai/serving_base.py +5 -2
- sglang/srt/entrypoints/openai/serving_chat.py +2 -21
- sglang/srt/eplb/expert_distribution.py +5 -0
- sglang/srt/eplb/expert_location.py +17 -6
- sglang/srt/eplb/expert_location_dispatch.py +1 -0
- sglang/srt/eplb/expert_location_updater.py +2 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/step3_detector.py +436 -0
- sglang/srt/hf_transformers_utils.py +2 -0
- sglang/srt/jinja_template_utils.py +4 -1
- sglang/srt/layers/moe/cutlass_moe.py +2 -1
- sglang/srt/layers/moe/ep_moe/layer.py +98 -603
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +83 -118
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +26 -13
- sglang/srt/layers/moe/fused_moe_triton/layer.py +97 -38
- sglang/srt/layers/moe/token_dispatcher/__init__.py +0 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +48 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +19 -0
- sglang/srt/layers/moe/topk.py +6 -2
- sglang/srt/layers/quantization/fp8.py +0 -18
- sglang/srt/layers/quantization/modelopt_quant.py +2 -0
- sglang/srt/layers/quantization/unquant.py +0 -8
- sglang/srt/layers/quantization/w4afp8.py +1 -0
- sglang/srt/managers/cache_controller.py +143 -45
- sglang/srt/managers/data_parallel_controller.py +6 -0
- sglang/srt/managers/io_struct.py +12 -2
- sglang/srt/managers/scheduler.py +116 -669
- sglang/srt/managers/scheduler_input_blocker.py +106 -0
- sglang/srt/managers/scheduler_metrics_mixin.py +229 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +279 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +142 -0
- sglang/srt/managers/template_manager.py +62 -19
- sglang/srt/managers/tokenizer_manager.py +166 -83
- sglang/srt/managers/tp_worker.py +9 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +2 -1
- sglang/srt/mem_cache/hicache_storage.py +45 -11
- sglang/srt/mem_cache/hiradix_cache.py +15 -4
- sglang/srt/mem_cache/memory_pool_host.py +73 -1
- sglang/srt/mem_cache/mooncake_store/mooncake_store.py +264 -0
- sglang/srt/mem_cache/mooncake_store/unit_test.py +40 -0
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +177 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +278 -0
- sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py +43 -0
- sglang/srt/model_executor/model_runner.py +20 -13
- sglang/srt/models/arcee.py +532 -0
- sglang/srt/models/deepseek_v2.py +15 -56
- sglang/srt/models/glm4_moe.py +3 -1
- sglang/srt/models/granitemoe.py +3 -0
- sglang/srt/models/grok.py +3 -0
- sglang/srt/models/hunyuan.py +1 -0
- sglang/srt/models/llama4.py +3 -0
- sglang/srt/models/mixtral.py +3 -0
- sglang/srt/models/olmoe.py +3 -0
- sglang/srt/models/phimoe.py +1 -0
- sglang/srt/models/qwen3_moe.py +12 -69
- sglang/srt/models/step3_vl.py +994 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -16
- sglang/srt/multimodal/processors/step3_vl.py +515 -0
- sglang/srt/poll_based_barrier.py +31 -0
- sglang/srt/reasoning_parser.py +2 -1
- sglang/srt/server_args.py +18 -13
- sglang/srt/speculative/eagle_worker.py +2 -0
- sglang/srt/two_batch_overlap.py +8 -3
- sglang/test/test_utils.py +53 -0
- sglang/utils.py +0 -11
- sglang/version.py +1 -1
- {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/METADATA +4 -4
- {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/RECORD +84 -64
- {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -13,7 +13,6 @@
|
|
13
13
|
# ==============================================================================
|
14
14
|
"""A scheduler that manages a tensor parallel GPU worker."""
|
15
15
|
|
16
|
-
import datetime
|
17
16
|
import faulthandler
|
18
17
|
import logging
|
19
18
|
import os
|
@@ -21,10 +20,10 @@ import signal
|
|
21
20
|
import sys
|
22
21
|
import threading
|
23
22
|
import time
|
24
|
-
from collections import
|
23
|
+
from collections import deque
|
25
24
|
from concurrent import futures
|
26
25
|
from dataclasses import dataclass
|
27
|
-
from
|
26
|
+
from http import HTTPStatus
|
28
27
|
from types import SimpleNamespace
|
29
28
|
from typing import Dict, List, Optional, Tuple, Union
|
30
29
|
|
@@ -36,7 +35,6 @@ from torch.distributed import barrier
|
|
36
35
|
|
37
36
|
from sglang.global_config import global_config
|
38
37
|
from sglang.srt.configs.model_config import ModelConfig
|
39
|
-
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
|
40
38
|
from sglang.srt.constrained.base_grammar_backend import (
|
41
39
|
INVALID_GRAMMAR_OBJ,
|
42
40
|
create_grammar_backend,
|
@@ -46,7 +44,6 @@ from sglang.srt.disaggregation.decode import (
|
|
46
44
|
DecodeTransferQueue,
|
47
45
|
SchedulerDisaggregationDecodeMixin,
|
48
46
|
)
|
49
|
-
from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
|
50
47
|
from sglang.srt.disaggregation.prefill import (
|
51
48
|
PrefillBootstrapQueue,
|
52
49
|
SchedulerDisaggregationPrefillMixin,
|
@@ -77,21 +74,15 @@ from sglang.srt.managers.io_struct import (
|
|
77
74
|
GetInternalStateReq,
|
78
75
|
GetInternalStateReqOutput,
|
79
76
|
GetWeightsByNameReqInput,
|
80
|
-
GetWeightsByNameReqOutput,
|
81
77
|
HealthCheckOutput,
|
82
78
|
InitWeightsUpdateGroupReqInput,
|
83
|
-
InitWeightsUpdateGroupReqOutput,
|
84
79
|
LoadLoRAAdapterReqInput,
|
85
80
|
LoadLoRAAdapterReqOutput,
|
86
81
|
OpenSessionReqInput,
|
87
82
|
OpenSessionReqOutput,
|
88
83
|
ProfileReq,
|
89
|
-
ProfileReqOutput,
|
90
|
-
ProfileReqType,
|
91
84
|
ReleaseMemoryOccupationReqInput,
|
92
|
-
ReleaseMemoryOccupationReqOutput,
|
93
85
|
ResumeMemoryOccupationReqInput,
|
94
|
-
ResumeMemoryOccupationReqOutput,
|
95
86
|
RpcReqInput,
|
96
87
|
RpcReqOutput,
|
97
88
|
SetInternalStateReq,
|
@@ -103,11 +94,8 @@ from sglang.srt.managers.io_struct import (
|
|
103
94
|
UnloadLoRAAdapterReqInput,
|
104
95
|
UnloadLoRAAdapterReqOutput,
|
105
96
|
UpdateWeightFromDiskReqInput,
|
106
|
-
UpdateWeightFromDiskReqOutput,
|
107
97
|
UpdateWeightsFromDistributedReqInput,
|
108
|
-
UpdateWeightsFromDistributedReqOutput,
|
109
98
|
UpdateWeightsFromTensorReqInput,
|
110
|
-
UpdateWeightsFromTensorReqOutput,
|
111
99
|
)
|
112
100
|
from sglang.srt.managers.mm_utils import init_embedding_cache
|
113
101
|
from sglang.srt.managers.schedule_batch import (
|
@@ -122,9 +110,18 @@ from sglang.srt.managers.schedule_policy import (
|
|
122
110
|
PrefillAdder,
|
123
111
|
SchedulePolicy,
|
124
112
|
)
|
113
|
+
from sglang.srt.managers.scheduler_input_blocker import SchedulerInputBlocker
|
114
|
+
from sglang.srt.managers.scheduler_metrics_mixin import (
|
115
|
+
RECORD_STEP_TIME,
|
116
|
+
SchedulerMetricsMixin,
|
117
|
+
)
|
125
118
|
from sglang.srt.managers.scheduler_output_processor_mixin import (
|
126
119
|
SchedulerOutputProcessorMixin,
|
127
120
|
)
|
121
|
+
from sglang.srt.managers.scheduler_profiler_mixin import SchedulerProfilerMixin
|
122
|
+
from sglang.srt.managers.scheduler_update_weights_mixin import (
|
123
|
+
SchedulerUpdateWeightsMixin,
|
124
|
+
)
|
128
125
|
from sglang.srt.managers.session_controller import Session
|
129
126
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
130
127
|
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
@@ -133,7 +130,6 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
|
133
130
|
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
134
131
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
135
132
|
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
136
|
-
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
137
133
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
|
138
134
|
from sglang.srt.reasoning_parser import ReasoningParser
|
139
135
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
@@ -166,7 +162,6 @@ logger = logging.getLogger(__name__)
|
|
166
162
|
|
167
163
|
# Test retract decode for debugging purposes
|
168
164
|
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
|
169
|
-
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
|
170
165
|
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
|
171
166
|
|
172
167
|
_is_cpu = is_cpu()
|
@@ -189,41 +184,11 @@ class EmbeddingBatchResult:
|
|
189
184
|
bid: int
|
190
185
|
|
191
186
|
|
192
|
-
class KvMetrics:
|
193
|
-
def __init__(self):
|
194
|
-
self.request_active_slots = None
|
195
|
-
self.request_total_slots = None
|
196
|
-
self.kv_active_blocks = None
|
197
|
-
self.kv_total_blocks = None
|
198
|
-
self.num_requests_waiting = None
|
199
|
-
self.gpu_cache_usage_perc = None
|
200
|
-
self.gpu_prefix_cache_hit_rate = None
|
201
|
-
self.data_parallel_rank = None
|
202
|
-
|
203
|
-
|
204
|
-
class IdleSleeper:
|
205
|
-
"""
|
206
|
-
In setups which have long inactivity periods it is desirable to reduce
|
207
|
-
system power consumption when sglang does nothing. This would lead not only
|
208
|
-
to power savings, but also to more CPU thermal headroom when a request
|
209
|
-
eventually comes. This is important in cases when multiple GPUs are connected
|
210
|
-
as each GPU would otherwise pin one thread at 100% CPU usage.
|
211
|
-
|
212
|
-
The simplest solution is to use zmq.Poller on all sockets that may receive
|
213
|
-
data that needs handling immediately.
|
214
|
-
"""
|
215
|
-
|
216
|
-
def __init__(self, sockets):
|
217
|
-
self.poller = zmq.Poller()
|
218
|
-
for s in sockets:
|
219
|
-
self.poller.register(s, zmq.POLLIN)
|
220
|
-
|
221
|
-
def maybe_sleep(self):
|
222
|
-
self.poller.poll(1000)
|
223
|
-
|
224
|
-
|
225
187
|
class Scheduler(
|
226
188
|
SchedulerOutputProcessorMixin,
|
189
|
+
SchedulerUpdateWeightsMixin,
|
190
|
+
SchedulerProfilerMixin,
|
191
|
+
SchedulerMetricsMixin,
|
227
192
|
SchedulerDisaggregationDecodeMixin,
|
228
193
|
SchedulerDisaggregationPrefillMixin,
|
229
194
|
):
|
@@ -235,15 +200,18 @@ class Scheduler(
|
|
235
200
|
port_args: PortArgs,
|
236
201
|
gpu_id: int,
|
237
202
|
tp_rank: int,
|
203
|
+
moe_ep_rank: int,
|
238
204
|
pp_rank: int,
|
239
205
|
dp_rank: Optional[int],
|
240
206
|
):
|
241
207
|
# Parse args
|
242
208
|
self.server_args = server_args
|
243
209
|
self.tp_rank = tp_rank
|
210
|
+
self.moe_ep_rank = moe_ep_rank
|
244
211
|
self.pp_rank = pp_rank
|
245
212
|
self.dp_rank = dp_rank
|
246
213
|
self.tp_size = server_args.tp_size
|
214
|
+
self.moe_ep_size = server_args.ep_size
|
247
215
|
self.pp_size = server_args.pp_size
|
248
216
|
self.dp_size = server_args.dp_size
|
249
217
|
self.schedule_policy = server_args.schedule_policy
|
@@ -264,7 +232,7 @@ class Scheduler(
|
|
264
232
|
self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
|
265
233
|
self.enable_hicache_storage = server_args.hicache_storage_backend is not None
|
266
234
|
self.page_size = server_args.page_size
|
267
|
-
|
235
|
+
|
268
236
|
self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
|
269
237
|
compute_dp_attention_world_info(
|
270
238
|
server_args.enable_dp_attention,
|
@@ -282,10 +250,13 @@ class Scheduler(
|
|
282
250
|
self.recv_from_tokenizer = get_zmq_socket(
|
283
251
|
context, zmq.PULL, port_args.scheduler_input_ipc_name, False
|
284
252
|
)
|
253
|
+
self.recv_from_rpc = get_zmq_socket(
|
254
|
+
context, zmq.DEALER, port_args.rpc_ipc_name, False
|
255
|
+
)
|
256
|
+
|
285
257
|
self.send_to_tokenizer = get_zmq_socket(
|
286
258
|
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
|
287
259
|
)
|
288
|
-
|
289
260
|
if server_args.skip_tokenizer_init:
|
290
261
|
# Directly send to the TokenizerManager
|
291
262
|
self.send_to_detokenizer = get_zmq_socket(
|
@@ -297,9 +268,6 @@ class Scheduler(
|
|
297
268
|
context, zmq.PUSH, port_args.detokenizer_ipc_name, False
|
298
269
|
)
|
299
270
|
|
300
|
-
self.recv_from_rpc = get_zmq_socket(
|
301
|
-
context, zmq.DEALER, port_args.rpc_ipc_name, False
|
302
|
-
)
|
303
271
|
if self.server_args.sleep_on_idle:
|
304
272
|
self.idle_sleeper = IdleSleeper(
|
305
273
|
[
|
@@ -345,6 +313,7 @@ class Scheduler(
|
|
345
313
|
server_args=server_args,
|
346
314
|
gpu_id=gpu_id,
|
347
315
|
tp_rank=tp_rank,
|
316
|
+
moe_ep_rank=moe_ep_rank,
|
348
317
|
pp_rank=pp_rank,
|
349
318
|
dp_rank=dp_rank,
|
350
319
|
nccl_port=port_args.nccl_port,
|
@@ -357,6 +326,7 @@ class Scheduler(
|
|
357
326
|
self.draft_worker = EAGLEWorker(
|
358
327
|
gpu_id=gpu_id,
|
359
328
|
tp_rank=tp_rank,
|
329
|
+
moe_ep_rank=moe_ep_rank,
|
360
330
|
server_args=server_args,
|
361
331
|
nccl_port=port_args.nccl_port,
|
362
332
|
target_worker=self.tp_worker,
|
@@ -370,6 +340,7 @@ class Scheduler(
|
|
370
340
|
self.max_total_num_tokens,
|
371
341
|
self.max_prefill_tokens,
|
372
342
|
self.max_running_requests,
|
343
|
+
self.max_queued_requests,
|
373
344
|
self.max_req_len,
|
374
345
|
self.max_req_input_len,
|
375
346
|
self.random_seed,
|
@@ -395,7 +366,7 @@ class Scheduler(
|
|
395
366
|
global_server_args_dict.update(worker_global_server_args_dict)
|
396
367
|
set_random_seed(self.random_seed)
|
397
368
|
|
398
|
-
# Hybrid
|
369
|
+
# Hybrid memory pool
|
399
370
|
self.is_hybrid = self.tp_worker.is_hybrid
|
400
371
|
if self.is_hybrid:
|
401
372
|
self.sliding_window_size = self.tp_worker.sliding_window_size
|
@@ -502,10 +473,25 @@ class Scheduler(
|
|
502
473
|
)
|
503
474
|
self.init_profier()
|
504
475
|
|
476
|
+
self.input_blocker = (
|
477
|
+
SchedulerInputBlocker(noop=self.attn_tp_rank != 0)
|
478
|
+
if get_bool_env_var("SGLANG_ENABLE_COLOCATED_BATCH_GEN")
|
479
|
+
else None
|
480
|
+
)
|
481
|
+
|
505
482
|
# Init metrics stats
|
506
483
|
self.init_metrics(tp_rank, pp_rank, dp_rank)
|
507
484
|
self.init_kv_events(server_args.kv_events_config)
|
508
485
|
|
486
|
+
# Init disaggregation
|
487
|
+
self.disaggregation_mode = DisaggregationMode(
|
488
|
+
self.server_args.disaggregation_mode
|
489
|
+
)
|
490
|
+
self.init_disaggregation()
|
491
|
+
|
492
|
+
if get_bool_env_var("SGLANG_GC_LOG"):
|
493
|
+
configure_gc_logger()
|
494
|
+
|
509
495
|
# Init request dispatcher
|
510
496
|
self._request_dispatcher = TypeBasedDispatcher(
|
511
497
|
[
|
@@ -536,22 +522,6 @@ class Scheduler(
|
|
536
522
|
]
|
537
523
|
)
|
538
524
|
|
539
|
-
# Init disaggregation
|
540
|
-
self.disaggregation_mode = DisaggregationMode(
|
541
|
-
self.server_args.disaggregation_mode
|
542
|
-
)
|
543
|
-
self.init_disaggregation()
|
544
|
-
|
545
|
-
if get_bool_env_var("SGLANG_GC_LOG"):
|
546
|
-
configure_gc_logger()
|
547
|
-
|
548
|
-
def current_scheduler_metrics_enabled(self):
|
549
|
-
return self.attn_tp_rank == 0 or self.enable_metrics_for_all_schedulers
|
550
|
-
|
551
|
-
def maybe_sleep_on_idle(self):
|
552
|
-
if self.idle_sleeper is not None:
|
553
|
-
self.idle_sleeper.maybe_sleep()
|
554
|
-
|
555
525
|
def init_tokenizer(self):
|
556
526
|
server_args = self.server_args
|
557
527
|
|
@@ -659,50 +629,6 @@ class Scheduler(
|
|
659
629
|
embedding_cache_size = int(os.environ.get("SGLANG_VLM_CACHE_SIZE_MB", "100"))
|
660
630
|
init_embedding_cache(embedding_cache_size * 1024 * 1024)
|
661
631
|
|
662
|
-
def init_profier(self):
|
663
|
-
self.torch_profiler = None
|
664
|
-
self.torch_profiler_output_dir: Optional[str] = None
|
665
|
-
self.profiler_activities: Optional[List[str]] = None
|
666
|
-
self.profile_id: Optional[str] = None
|
667
|
-
self.profiler_start_forward_ct: Optional[int] = None
|
668
|
-
self.profiler_target_forward_ct: Optional[int] = None
|
669
|
-
self.profiler_target_prefill_ct: Optional[int] = None
|
670
|
-
self.profiler_target_decode_ct: Optional[int] = None
|
671
|
-
self.profiler_prefill_ct: Optional[int] = None
|
672
|
-
self.profiler_decode_ct: Optional[int] = None
|
673
|
-
self.profile_by_stage: bool = False
|
674
|
-
self.profile_steps: Optional[int] = None
|
675
|
-
self.profile_in_progress: bool = False
|
676
|
-
self.rpd_profiler = None
|
677
|
-
|
678
|
-
def init_metrics(self, tp_rank: int, pp_rank: int, dp_rank: Optional[int]):
|
679
|
-
self.last_gen_throughput: float = 0.0
|
680
|
-
self.last_input_throughput: float = 0.0
|
681
|
-
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
|
682
|
-
self.spec_num_total_accepted_tokens = 0
|
683
|
-
self.spec_num_total_forward_ct = 0
|
684
|
-
self.cum_spec_accept_length = 0
|
685
|
-
self.cum_spec_accept_count = 0
|
686
|
-
self.total_retracted_reqs = 0
|
687
|
-
self.stats = SchedulerStats()
|
688
|
-
if self.enable_metrics:
|
689
|
-
engine_type = "unified"
|
690
|
-
labels = {
|
691
|
-
"model_name": self.server_args.served_model_name,
|
692
|
-
"engine_type": engine_type,
|
693
|
-
"tp_rank": tp_rank,
|
694
|
-
"pp_rank": pp_rank,
|
695
|
-
}
|
696
|
-
if dp_rank is not None:
|
697
|
-
labels["dp_rank"] = dp_rank
|
698
|
-
self.metrics_collector = SchedulerMetricsCollector(labels=labels)
|
699
|
-
|
700
|
-
def init_kv_events(self, kv_events_config: Optional[str]):
|
701
|
-
if self.enable_kv_cache_events:
|
702
|
-
self.kv_event_publisher = EventPublisherFactory.create(
|
703
|
-
kv_events_config, self.attn_dp_rank
|
704
|
-
)
|
705
|
-
|
706
632
|
def init_disaggregation(self):
|
707
633
|
self.transfer_backend = TransferBackend(
|
708
634
|
self.server_args.disaggregation_transfer_backend
|
@@ -811,10 +737,7 @@ class Scheduler(
|
|
811
737
|
self.process_batch_result(batch, result)
|
812
738
|
else:
|
813
739
|
# When the server is idle, do self-check and re-init some states
|
814
|
-
self.
|
815
|
-
self.check_tree_cache()
|
816
|
-
self.new_token_ratio = self.init_new_token_ratio
|
817
|
-
self.maybe_sleep_on_idle()
|
740
|
+
self.self_check_during_idle()
|
818
741
|
|
819
742
|
self.last_batch = batch
|
820
743
|
|
@@ -857,10 +780,7 @@ class Scheduler(
|
|
857
780
|
)
|
858
781
|
elif batch is None:
|
859
782
|
# When the server is idle, do self-check and re-init some states
|
860
|
-
self.
|
861
|
-
self.check_tree_cache()
|
862
|
-
self.new_token_ratio = self.init_new_token_ratio
|
863
|
-
self.maybe_sleep_on_idle()
|
783
|
+
self.self_check_during_idle()
|
864
784
|
|
865
785
|
self.last_batch = batch
|
866
786
|
|
@@ -994,10 +914,8 @@ class Scheduler(
|
|
994
914
|
|
995
915
|
# When the server is idle, self-check and re-init some states
|
996
916
|
if server_is_idle:
|
997
|
-
self
|
998
|
-
self.
|
999
|
-
self.new_token_ratio = self.init_new_token_ratio
|
1000
|
-
self.maybe_sleep_on_idle()
|
917
|
+
# When the server is idle, do self-check and re-init some states
|
918
|
+
self.self_check_during_idle()
|
1001
919
|
|
1002
920
|
def recv_requests(self) -> List[Req]:
|
1003
921
|
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
|
@@ -1033,6 +951,9 @@ class Scheduler(
|
|
1033
951
|
else:
|
1034
952
|
recv_reqs = None
|
1035
953
|
|
954
|
+
if self.input_blocker is not None:
|
955
|
+
recv_reqs = self.input_blocker.handle(recv_reqs)
|
956
|
+
|
1036
957
|
if self.server_args.enable_dp_attention:
|
1037
958
|
if self.attn_tp_rank == 0:
|
1038
959
|
work_reqs = [
|
@@ -1086,6 +1007,19 @@ class Scheduler(
|
|
1086
1007
|
self.return_health_check_ct += 1
|
1087
1008
|
continue
|
1088
1009
|
|
1010
|
+
# If it is a work request, accept or reject the request based on the request queue size.
|
1011
|
+
if is_work_request(recv_req):
|
1012
|
+
if len(self.waiting_queue) + 1 > self.max_queued_requests:
|
1013
|
+
abort_req = AbortReq(
|
1014
|
+
recv_req.rid,
|
1015
|
+
finished_reason={
|
1016
|
+
"type": "abort",
|
1017
|
+
"status_code": HTTPStatus.SERVICE_UNAVAILABLE,
|
1018
|
+
"message": "The request queue is full.",
|
1019
|
+
},
|
1020
|
+
)
|
1021
|
+
self.send_to_tokenizer.send_pyobj(abort_req)
|
1022
|
+
continue
|
1089
1023
|
output = self._request_dispatcher(recv_req)
|
1090
1024
|
if output is not None:
|
1091
1025
|
if isinstance(output, RpcReqOutput):
|
@@ -1256,23 +1190,28 @@ class Scheduler(
|
|
1256
1190
|
def _add_request_to_queue(self, req: Req):
|
1257
1191
|
req.queue_time_start = time.perf_counter()
|
1258
1192
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
1193
|
+
self._prefetch_kvcache(req)
|
1259
1194
|
self.disagg_prefill_bootstrap_queue.add(
|
1260
1195
|
req, self.model_config.num_key_value_heads
|
1261
1196
|
)
|
1262
1197
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
1263
1198
|
self.disagg_decode_prealloc_queue.add(req)
|
1264
1199
|
else:
|
1265
|
-
|
1266
|
-
req.init_next_round_input(self.tree_cache)
|
1267
|
-
last_hash = req.last_host_node.get_last_hash_value()
|
1268
|
-
matched_len = len(req.prefix_indices) + req.host_hit_length
|
1269
|
-
if (matched_len > 0 and last_hash is not None) or matched_len == 0:
|
1270
|
-
new_input_tokens = req.fill_ids[matched_len:]
|
1271
|
-
self.tree_cache.prefetch_from_storage(
|
1272
|
-
req.rid, req.last_host_node, new_input_tokens, last_hash
|
1273
|
-
)
|
1200
|
+
self._prefetch_kvcache(req)
|
1274
1201
|
self.waiting_queue.append(req)
|
1275
1202
|
|
1203
|
+
def _prefetch_kvcache(self, req: Req):
|
1204
|
+
if self.enable_hicache_storage:
|
1205
|
+
req.init_next_round_input(self.tree_cache)
|
1206
|
+
last_hash = req.last_host_node.get_last_hash_value()
|
1207
|
+
matched_len = len(req.prefix_indices) + req.host_hit_length
|
1208
|
+
# todo, free-form fetching, calculating hash keys on the fly
|
1209
|
+
if (matched_len > 0 and last_hash is not None) or matched_len == 0:
|
1210
|
+
new_input_tokens = req.fill_ids[matched_len:]
|
1211
|
+
self.tree_cache.prefetch_from_storage(
|
1212
|
+
req.rid, req.last_host_node, new_input_tokens, last_hash
|
1213
|
+
)
|
1214
|
+
|
1276
1215
|
def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
|
1277
1216
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
1278
1217
|
self.disagg_prefill_bootstrap_queue.extend(
|
@@ -1330,170 +1269,11 @@ class Scheduler(
|
|
1330
1269
|
req.logprob_start_len = len(req.origin_input_ids) - 1
|
1331
1270
|
self._add_request_to_queue(req)
|
1332
1271
|
|
1333
|
-
def
|
1334
|
-
|
1335
|
-
|
1336
|
-
|
1337
|
-
|
1338
|
-
self.stats.token_usage * self.max_total_num_tokens
|
1339
|
-
)
|
1340
|
-
kv_metrics.kv_total_blocks = self.max_total_num_tokens
|
1341
|
-
kv_metrics.num_requests_waiting = self.stats.num_queue_reqs
|
1342
|
-
kv_metrics.gpu_cache_usage_perc = self.stats.token_usage
|
1343
|
-
kv_metrics.gpu_prefix_cache_hit_rate = self.stats.cache_hit_rate
|
1344
|
-
kv_metrics.data_parallel_rank = self.dp_rank if self.dp_rank is not None else 0
|
1345
|
-
|
1346
|
-
if not self.send_metrics_from_scheduler.closed:
|
1347
|
-
self.send_metrics_from_scheduler.send_pyobj(kv_metrics)
|
1348
|
-
|
1349
|
-
def log_prefill_stats(
|
1350
|
-
self,
|
1351
|
-
adder: PrefillAdder,
|
1352
|
-
can_run_list: List[Req],
|
1353
|
-
running_bs: int,
|
1354
|
-
):
|
1355
|
-
gap_latency = time.perf_counter() - self.last_prefill_stats_tic
|
1356
|
-
self.last_prefill_stats_tic = time.perf_counter()
|
1357
|
-
self.last_input_throughput = self.last_prefill_tokens / gap_latency
|
1358
|
-
self.last_prefill_tokens = adder.log_input_tokens
|
1359
|
-
|
1360
|
-
if self.is_hybrid:
|
1361
|
-
(
|
1362
|
-
full_num_used,
|
1363
|
-
swa_num_used,
|
1364
|
-
full_token_usage,
|
1365
|
-
swa_token_usage,
|
1366
|
-
_,
|
1367
|
-
_,
|
1368
|
-
_,
|
1369
|
-
_,
|
1370
|
-
) = self._get_swa_token_info()
|
1371
|
-
num_used = max(full_num_used, swa_num_used)
|
1372
|
-
token_usage = max(full_token_usage, swa_token_usage)
|
1373
|
-
token_msg = (
|
1374
|
-
f"full token usage: {full_token_usage:.2f}, "
|
1375
|
-
f"swa token usage: {swa_token_usage:.2f}, "
|
1376
|
-
)
|
1377
|
-
else:
|
1378
|
-
num_used, token_usage, _, _ = self._get_token_info()
|
1379
|
-
token_msg = f"token usage: {token_usage:.2f}, "
|
1380
|
-
|
1381
|
-
num_new_seq = len(can_run_list)
|
1382
|
-
f = (
|
1383
|
-
f"Prefill batch. "
|
1384
|
-
f"#new-seq: {num_new_seq}, "
|
1385
|
-
f"#new-token: {adder.log_input_tokens}, "
|
1386
|
-
f"#cached-token: {adder.log_hit_tokens}, "
|
1387
|
-
f"{token_msg}"
|
1388
|
-
)
|
1389
|
-
|
1390
|
-
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
1391
|
-
f += f"#unbootstrapped-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
|
1392
|
-
f += f"#queue-req: {len(self.waiting_queue)}, "
|
1393
|
-
f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)}, "
|
1394
|
-
f += f"input throughput (token/s): {self.last_input_throughput:.2f}, "
|
1395
|
-
else:
|
1396
|
-
f += f"#running-req: {running_bs}, "
|
1397
|
-
f += f"#queue-req: {len(self.waiting_queue)}, "
|
1398
|
-
|
1399
|
-
logger.info(f)
|
1400
|
-
|
1401
|
-
if self.enable_metrics:
|
1402
|
-
total_tokens = adder.log_input_tokens + adder.log_hit_tokens
|
1403
|
-
|
1404
|
-
cache_hit_rate = (
|
1405
|
-
adder.log_hit_tokens / total_tokens if total_tokens > 0 else 0.0
|
1406
|
-
)
|
1407
|
-
self.stats.num_running_reqs = running_bs
|
1408
|
-
self.stats.num_used_tokens = num_used
|
1409
|
-
self.stats.token_usage = round(token_usage, 2)
|
1410
|
-
self.stats.num_queue_reqs = len(self.waiting_queue)
|
1411
|
-
self.stats.cache_hit_rate = cache_hit_rate
|
1412
|
-
|
1413
|
-
total_queue_latency = 0
|
1414
|
-
for req in can_run_list:
|
1415
|
-
total_queue_latency += req.queue_time_end - req.queue_time_start
|
1416
|
-
self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq
|
1417
|
-
|
1418
|
-
self.metrics_collector.log_stats(self.stats)
|
1419
|
-
self._emit_kv_metrics()
|
1420
|
-
self._publish_kv_events()
|
1421
|
-
|
1422
|
-
def log_decode_stats(
|
1423
|
-
self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
|
1424
|
-
):
|
1425
|
-
batch = running_batch or self.running_batch
|
1426
|
-
|
1427
|
-
gap_latency = time.perf_counter() - self.last_decode_stats_tic
|
1428
|
-
self.last_decode_stats_tic = time.perf_counter()
|
1429
|
-
self.last_gen_throughput = self.num_generated_tokens / gap_latency
|
1430
|
-
self.num_generated_tokens = 0
|
1431
|
-
num_running_reqs = len(batch.reqs)
|
1432
|
-
if self.is_hybrid:
|
1433
|
-
(
|
1434
|
-
full_num_used,
|
1435
|
-
swa_num_used,
|
1436
|
-
full_token_usage,
|
1437
|
-
swa_token_usage,
|
1438
|
-
_,
|
1439
|
-
_,
|
1440
|
-
_,
|
1441
|
-
_,
|
1442
|
-
) = self._get_swa_token_info()
|
1443
|
-
num_used = max(full_num_used, swa_num_used)
|
1444
|
-
token_usage = max(full_token_usage, swa_token_usage)
|
1445
|
-
token_msg = (
|
1446
|
-
f"#full token: {full_num_used}, "
|
1447
|
-
f"full token usage: {full_token_usage:.2f}, "
|
1448
|
-
f"#swa token: {swa_num_used}, "
|
1449
|
-
f"swa token usage: {swa_token_usage:.2f}, "
|
1450
|
-
)
|
1451
|
-
else:
|
1452
|
-
num_used, token_usage, _, _ = self._get_token_info()
|
1453
|
-
token_msg = f"#token: {num_used}, " f"token usage: {token_usage:.2f}, "
|
1454
|
-
|
1455
|
-
if RECORD_STEP_TIME:
|
1456
|
-
self.step_time_dict[num_running_reqs].append(
|
1457
|
-
gap_latency / self.server_args.decode_log_interval
|
1458
|
-
)
|
1459
|
-
|
1460
|
-
msg = f"Decode batch. #running-req: {num_running_reqs}, {token_msg}"
|
1461
|
-
|
1462
|
-
if self.spec_algorithm.is_none():
|
1463
|
-
spec_accept_length = 0
|
1464
|
-
else:
|
1465
|
-
spec_accept_length = (
|
1466
|
-
self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
|
1467
|
-
)
|
1468
|
-
self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
|
1469
|
-
self.cum_spec_accept_count += self.spec_num_total_forward_ct
|
1470
|
-
self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
|
1471
|
-
msg += f"accept len: {spec_accept_length:.2f}, "
|
1472
|
-
|
1473
|
-
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
1474
|
-
msg += f"pre-allocated usage: {self.disagg_decode_prealloc_queue.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
|
1475
|
-
msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, "
|
1476
|
-
|
1477
|
-
msg += (
|
1478
|
-
f"cuda graph: {can_run_cuda_graph}, "
|
1479
|
-
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
|
1480
|
-
f"#queue-req: {len(self.waiting_queue)}, "
|
1481
|
-
)
|
1482
|
-
|
1483
|
-
logger.info(msg)
|
1484
|
-
if self.enable_metrics:
|
1485
|
-
self.stats.num_running_reqs = num_running_reqs
|
1486
|
-
self.stats.num_used_tokens = num_used
|
1487
|
-
self.stats.token_usage = round(token_usage, 2)
|
1488
|
-
self.stats.cache_hit_rate = 0.0
|
1489
|
-
self.stats.gen_throughput = self.last_gen_throughput
|
1490
|
-
self.stats.num_queue_reqs = len(self.waiting_queue)
|
1491
|
-
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
1492
|
-
self.stats.spec_accept_length = spec_accept_length
|
1493
|
-
self.stats.total_retracted_reqs = self.total_retracted_reqs
|
1494
|
-
self.metrics_collector.log_stats(self.stats)
|
1495
|
-
self._emit_kv_metrics()
|
1496
|
-
self._publish_kv_events()
|
1272
|
+
def self_check_during_idle(self):
|
1273
|
+
self.check_memory()
|
1274
|
+
self.check_tree_cache()
|
1275
|
+
self.new_token_ratio = self.init_new_token_ratio
|
1276
|
+
self.maybe_sleep_on_idle()
|
1497
1277
|
|
1498
1278
|
def check_memory(self):
|
1499
1279
|
if self.is_hybrid:
|
@@ -2397,22 +2177,6 @@ class Scheduler(
|
|
2397
2177
|
barrier()
|
2398
2178
|
return RpcReqOutput(success, "" if not exec else str(exec))
|
2399
2179
|
|
2400
|
-
def save_remote_model(self, params):
|
2401
|
-
url = params["url"]
|
2402
|
-
|
2403
|
-
worker = self.tp_worker.worker
|
2404
|
-
|
2405
|
-
worker.model_runner.save_remote_model(url)
|
2406
|
-
|
2407
|
-
def save_sharded_model(self, params):
|
2408
|
-
worker = self.tp_worker.worker
|
2409
|
-
|
2410
|
-
worker.model_runner.save_sharded_model(
|
2411
|
-
path=params["path"],
|
2412
|
-
pattern=params["pattern"],
|
2413
|
-
max_size=params["max_size"],
|
2414
|
-
)
|
2415
|
-
|
2416
2180
|
def abort_request(self, recv_req: AbortReq):
|
2417
2181
|
# Delete requests in the waiting queue
|
2418
2182
|
to_del = []
|
@@ -2490,16 +2254,6 @@ class Scheduler(
|
|
2490
2254
|
def _pause_engine(self) -> Tuple[List[Req], int]:
|
2491
2255
|
raise NotImplementedError()
|
2492
2256
|
|
2493
|
-
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
2494
|
-
"""In-place update of the weights from disk."""
|
2495
|
-
success, message = self.tp_worker.update_weights_from_disk(recv_req)
|
2496
|
-
if success:
|
2497
|
-
flush_cache_success = self.flush_cache()
|
2498
|
-
assert flush_cache_success, "Cache flush failed after updating weights"
|
2499
|
-
else:
|
2500
|
-
logger.error(message)
|
2501
|
-
return UpdateWeightFromDiskReqOutput(success, message, 0)
|
2502
|
-
|
2503
2257
|
def load_lora_adapter(
|
2504
2258
|
self, recv_req: LoadLoRAAdapterReqInput
|
2505
2259
|
) -> LoadLoRAAdapterReqOutput:
|
@@ -2516,81 +2270,6 @@ class Scheduler(
|
|
2516
2270
|
result = self.tp_worker.unload_lora_adapter(recv_req)
|
2517
2271
|
return result
|
2518
2272
|
|
2519
|
-
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
|
2520
|
-
"""Initialize the online model parameter update group."""
|
2521
|
-
success, message = self.tp_worker.init_weights_update_group(recv_req)
|
2522
|
-
return InitWeightsUpdateGroupReqOutput(success, message)
|
2523
|
-
|
2524
|
-
def update_weights_from_distributed(
|
2525
|
-
self,
|
2526
|
-
recv_req: UpdateWeightsFromDistributedReqInput,
|
2527
|
-
) -> Tuple[bool, str]:
|
2528
|
-
"""Update the online model parameter."""
|
2529
|
-
success, message = self.tp_worker.update_weights_from_distributed(recv_req)
|
2530
|
-
if success:
|
2531
|
-
if recv_req.flush_cache:
|
2532
|
-
flush_cache_success = self.flush_cache()
|
2533
|
-
assert flush_cache_success, "Cache flush failed after updating weights"
|
2534
|
-
else:
|
2535
|
-
logger.error(message)
|
2536
|
-
return UpdateWeightsFromDistributedReqOutput(success, message)
|
2537
|
-
|
2538
|
-
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
2539
|
-
"""Update the online model parameter from tensors."""
|
2540
|
-
success, message = self.tp_worker.update_weights_from_tensor(recv_req)
|
2541
|
-
# TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
|
2542
|
-
if success:
|
2543
|
-
if recv_req.flush_cache:
|
2544
|
-
flush_cache_success = self.flush_cache()
|
2545
|
-
assert flush_cache_success, "Cache flush failed after updating weights"
|
2546
|
-
else:
|
2547
|
-
logger.error(message)
|
2548
|
-
barrier(group=self.tp_cpu_group)
|
2549
|
-
return UpdateWeightsFromTensorReqOutput(success, message)
|
2550
|
-
|
2551
|
-
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
2552
|
-
parameter = self.tp_worker.get_weights_by_name(recv_req)
|
2553
|
-
return GetWeightsByNameReqOutput(parameter)
|
2554
|
-
|
2555
|
-
def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
|
2556
|
-
tags = recv_req.tags
|
2557
|
-
|
2558
|
-
if tags is None or len(tags) == 0:
|
2559
|
-
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
|
2560
|
-
|
2561
|
-
if GPU_MEMORY_TYPE_KV_CACHE in tags:
|
2562
|
-
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE)
|
2563
|
-
self.flush_cache()
|
2564
|
-
|
2565
|
-
if GPU_MEMORY_TYPE_WEIGHTS in tags:
|
2566
|
-
self.stashed_model_static_state = _export_static_state(
|
2567
|
-
self.tp_worker.worker.model_runner.model
|
2568
|
-
)
|
2569
|
-
torch.distributed.barrier(self.tp_cpu_group)
|
2570
|
-
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_WEIGHTS)
|
2571
|
-
|
2572
|
-
return ReleaseMemoryOccupationReqOutput()
|
2573
|
-
|
2574
|
-
def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
|
2575
|
-
tags = recv_req.tags
|
2576
|
-
|
2577
|
-
if tags is None or len(tags) == 0:
|
2578
|
-
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
|
2579
|
-
|
2580
|
-
if GPU_MEMORY_TYPE_WEIGHTS in tags:
|
2581
|
-
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
|
2582
|
-
torch.distributed.barrier(self.tp_cpu_group)
|
2583
|
-
_import_static_state(
|
2584
|
-
self.tp_worker.worker.model_runner.model,
|
2585
|
-
self.stashed_model_static_state,
|
2586
|
-
)
|
2587
|
-
del self.stashed_model_static_state
|
2588
|
-
|
2589
|
-
if GPU_MEMORY_TYPE_KV_CACHE in tags:
|
2590
|
-
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_KV_CACHE)
|
2591
|
-
|
2592
|
-
return ResumeMemoryOccupationReqOutput()
|
2593
|
-
|
2594
2273
|
def slow_down(self, recv_req: SlowDownReqInput):
|
2595
2274
|
t = recv_req.forward_sleep_time
|
2596
2275
|
if t is not None and t <= 0:
|
@@ -2598,254 +2277,6 @@ class Scheduler(
|
|
2598
2277
|
self.forward_sleep_time = t
|
2599
2278
|
return SlowDownReqOutput()
|
2600
2279
|
|
2601
|
-
def profile(self, recv_req: ProfileReq):
|
2602
|
-
if recv_req.type == ProfileReqType.START_PROFILE:
|
2603
|
-
if recv_req.profile_by_stage or recv_req.start_step:
|
2604
|
-
return self.init_profile(
|
2605
|
-
recv_req.output_dir,
|
2606
|
-
recv_req.start_step,
|
2607
|
-
recv_req.num_steps,
|
2608
|
-
recv_req.activities,
|
2609
|
-
recv_req.with_stack,
|
2610
|
-
recv_req.record_shapes,
|
2611
|
-
recv_req.profile_by_stage,
|
2612
|
-
recv_req.profile_id,
|
2613
|
-
)
|
2614
|
-
else:
|
2615
|
-
self.init_profile(
|
2616
|
-
recv_req.output_dir,
|
2617
|
-
recv_req.start_step,
|
2618
|
-
recv_req.num_steps,
|
2619
|
-
recv_req.activities,
|
2620
|
-
recv_req.with_stack,
|
2621
|
-
recv_req.record_shapes,
|
2622
|
-
recv_req.profile_by_stage,
|
2623
|
-
recv_req.profile_id,
|
2624
|
-
)
|
2625
|
-
return self.start_profile(True)
|
2626
|
-
else:
|
2627
|
-
return self.stop_profile()
|
2628
|
-
|
2629
|
-
def init_profile(
|
2630
|
-
self,
|
2631
|
-
output_dir: Optional[str],
|
2632
|
-
start_step: Optional[int],
|
2633
|
-
num_steps: Optional[int],
|
2634
|
-
activities: Optional[List[str]],
|
2635
|
-
with_stack: Optional[bool],
|
2636
|
-
record_shapes: Optional[bool],
|
2637
|
-
profile_by_stage: bool,
|
2638
|
-
profile_id: str,
|
2639
|
-
) -> ProfileReqOutput:
|
2640
|
-
if self.profile_in_progress:
|
2641
|
-
return ProfileReqOutput(
|
2642
|
-
success=False,
|
2643
|
-
message="Profiling is already in progress. Call /stop_profile first.",
|
2644
|
-
)
|
2645
|
-
|
2646
|
-
self.profile_by_stage = profile_by_stage
|
2647
|
-
|
2648
|
-
if output_dir is None:
|
2649
|
-
output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp")
|
2650
|
-
if activities is None:
|
2651
|
-
activities = ["CPU", "GPU"]
|
2652
|
-
|
2653
|
-
self.torch_profiler_output_dir = output_dir
|
2654
|
-
self.torch_profiler_with_stack = with_stack
|
2655
|
-
self.torch_profiler_record_shapes = record_shapes
|
2656
|
-
self.profiler_activities = activities
|
2657
|
-
self.profile_id = profile_id
|
2658
|
-
|
2659
|
-
if start_step:
|
2660
|
-
self.profiler_start_forward_ct = max(start_step, self.forward_ct + 1)
|
2661
|
-
|
2662
|
-
if num_steps:
|
2663
|
-
self.profile_steps = num_steps
|
2664
|
-
if self.profile_by_stage:
|
2665
|
-
self.profiler_target_prefill_ct = num_steps
|
2666
|
-
self.profiler_target_decode_ct = num_steps
|
2667
|
-
self.profiler_prefill_ct = 0
|
2668
|
-
self.profiler_decode_ct = 0
|
2669
|
-
elif start_step:
|
2670
|
-
self.profiler_target_forward_ct = (
|
2671
|
-
self.profiler_start_forward_ct + num_steps
|
2672
|
-
)
|
2673
|
-
else:
|
2674
|
-
self.profiler_target_forward_ct = self.forward_ct + num_steps
|
2675
|
-
# The caller will be notified when reaching profiler_target_forward_ct
|
2676
|
-
else:
|
2677
|
-
self.profiler_target_forward_ct = None
|
2678
|
-
|
2679
|
-
return ProfileReqOutput(success=True, message="Succeeded")
|
2680
|
-
|
2681
|
-
def start_profile(
|
2682
|
-
self, stage: Optional[ForwardMode] = None
|
2683
|
-
) -> ProfileReqOutput | None:
|
2684
|
-
stage_str = f" for {stage.__str__()}" if stage else ""
|
2685
|
-
logger.info(
|
2686
|
-
f"Profiling starts{stage_str}. Traces will be saved to: {self.torch_profiler_output_dir} (with profile id: {self.profile_id})",
|
2687
|
-
)
|
2688
|
-
|
2689
|
-
activities = self.profiler_activities
|
2690
|
-
with_stack = self.torch_profiler_with_stack
|
2691
|
-
record_shapes = self.torch_profiler_record_shapes
|
2692
|
-
|
2693
|
-
activity_map = {
|
2694
|
-
"CPU": torch.profiler.ProfilerActivity.CPU,
|
2695
|
-
"GPU": torch.profiler.ProfilerActivity.CUDA,
|
2696
|
-
}
|
2697
|
-
torchprof_activities = [
|
2698
|
-
activity_map[a] for a in activities if a in activity_map
|
2699
|
-
]
|
2700
|
-
|
2701
|
-
if "RPD" in activities:
|
2702
|
-
from rpdTracerControl import rpdTracerControl
|
2703
|
-
|
2704
|
-
rpdTracerControl.skipCreate()
|
2705
|
-
|
2706
|
-
self.rpd_profile_path = os.path.join(
|
2707
|
-
self.torch_profiler_output_dir,
|
2708
|
-
"rpd-" + str(time.time()) + f"-TP-{self.tp_rank}" + ".trace.json.gz",
|
2709
|
-
)
|
2710
|
-
|
2711
|
-
if self.tp_rank == 0:
|
2712
|
-
import sqlite3
|
2713
|
-
|
2714
|
-
from rocpd.schema import RocpdSchema
|
2715
|
-
|
2716
|
-
if os.path.exists("trace.rpd"):
|
2717
|
-
os.unlink("trace.rpd")
|
2718
|
-
schema = RocpdSchema()
|
2719
|
-
connection = sqlite3.connect("trace.rpd")
|
2720
|
-
schema.writeSchema(connection)
|
2721
|
-
connection.commit()
|
2722
|
-
del connection
|
2723
|
-
torch.distributed.barrier(self.tp_cpu_group)
|
2724
|
-
|
2725
|
-
self.rpd_profiler = rpdTracerControl()
|
2726
|
-
self.rpd_profiler.setPythonTrace(True)
|
2727
|
-
self.rpd_profiler.start()
|
2728
|
-
self.rpd_profiler.rangePush("", "rpd profile range", "")
|
2729
|
-
self.profile_in_progress = True
|
2730
|
-
elif torchprof_activities:
|
2731
|
-
self.torch_profiler = torch.profiler.profile(
|
2732
|
-
activities=torchprof_activities,
|
2733
|
-
with_stack=with_stack if with_stack is not None else True,
|
2734
|
-
record_shapes=record_shapes if record_shapes is not None else False,
|
2735
|
-
)
|
2736
|
-
self.torch_profiler.start()
|
2737
|
-
self.profile_in_progress = True
|
2738
|
-
|
2739
|
-
if "MEM" in activities:
|
2740
|
-
torch.cuda.memory._record_memory_history(max_entries=100000)
|
2741
|
-
self.profile_in_progress = True
|
2742
|
-
|
2743
|
-
if "CUDA_PROFILER" in activities:
|
2744
|
-
torch.cuda.cudart().cudaProfilerStart()
|
2745
|
-
self.profile_in_progress = True
|
2746
|
-
|
2747
|
-
return ProfileReqOutput(success=True, message="Succeeded")
|
2748
|
-
|
2749
|
-
def stop_profile(
|
2750
|
-
self, stage: Optional[ForwardMode] = None
|
2751
|
-
) -> ProfileReqOutput | None:
|
2752
|
-
if not self.profile_in_progress:
|
2753
|
-
return ProfileReqOutput(
|
2754
|
-
success=False,
|
2755
|
-
message="Profiling is not in progress. Call /start_profile first.",
|
2756
|
-
)
|
2757
|
-
|
2758
|
-
if not Path(self.torch_profiler_output_dir).exists():
|
2759
|
-
Path(self.torch_profiler_output_dir).mkdir(parents=True, exist_ok=True)
|
2760
|
-
|
2761
|
-
stage_suffix = f"-{stage.__str__()}" if stage else ""
|
2762
|
-
logger.info("Stop profiling" + stage_suffix + "...")
|
2763
|
-
if self.torch_profiler is not None:
|
2764
|
-
self.torch_profiler.stop()
|
2765
|
-
self.torch_profiler.export_chrome_trace(
|
2766
|
-
os.path.join(
|
2767
|
-
self.torch_profiler_output_dir,
|
2768
|
-
self.profile_id
|
2769
|
-
+ f"-TP-{self.tp_rank}"
|
2770
|
-
+ stage_suffix
|
2771
|
-
+ ".trace.json.gz",
|
2772
|
-
)
|
2773
|
-
)
|
2774
|
-
torch.distributed.barrier(self.tp_cpu_group)
|
2775
|
-
|
2776
|
-
if self.rpd_profiler is not None:
|
2777
|
-
self.rpd_profiler.rangePop()
|
2778
|
-
self.rpd_profiler.stop()
|
2779
|
-
self.rpd_profiler.flush()
|
2780
|
-
|
2781
|
-
torch.distributed.barrier(self.tp_cpu_group)
|
2782
|
-
if self.tp_rank == 0:
|
2783
|
-
from sglang.srt.utils import rpd_to_chrome_trace
|
2784
|
-
|
2785
|
-
rpd_to_chrome_trace("trace.rpd", self.rpd_profile_path)
|
2786
|
-
self.rpd_profiler = None
|
2787
|
-
self.rpd_profiler_path = None
|
2788
|
-
|
2789
|
-
if self.profiler_activities is not None and "MEM" in self.profiler_activities:
|
2790
|
-
memory_profile_path = os.path.join(
|
2791
|
-
self.torch_profiler_output_dir,
|
2792
|
-
str(time.time())
|
2793
|
-
+ f"-TP-{self.tp_rank}-memory"
|
2794
|
-
+ stage_suffix
|
2795
|
-
+ ".pickle",
|
2796
|
-
)
|
2797
|
-
torch.cuda.memory._dump_snapshot(memory_profile_path)
|
2798
|
-
torch.cuda.memory._record_memory_history(enabled=None)
|
2799
|
-
|
2800
|
-
if "CUDA_PROFILER" in self.profiler_activities:
|
2801
|
-
torch.cuda.cudart().cudaProfilerStop()
|
2802
|
-
|
2803
|
-
logger.info(
|
2804
|
-
"Profiling done. Traces are saved to: %s",
|
2805
|
-
self.torch_profiler_output_dir,
|
2806
|
-
)
|
2807
|
-
self.torch_profiler = None
|
2808
|
-
self.profile_in_progress = False
|
2809
|
-
self.profiler_start_forward_ct = None
|
2810
|
-
|
2811
|
-
return ProfileReqOutput(success=True, message="Succeeded.")
|
2812
|
-
|
2813
|
-
def _profile_batch_predicate(self, batch):
|
2814
|
-
if self.profile_by_stage:
|
2815
|
-
if batch.forward_mode.is_prefill():
|
2816
|
-
if self.profiler_prefill_ct == 0:
|
2817
|
-
self.start_profile(batch.forward_mode)
|
2818
|
-
self.profiler_prefill_ct += 1
|
2819
|
-
if self.profiler_prefill_ct > self.profiler_target_prefill_ct:
|
2820
|
-
if self.profile_in_progress:
|
2821
|
-
self.stop_profile(stage=ForwardMode.EXTEND)
|
2822
|
-
elif batch.forward_mode.is_decode():
|
2823
|
-
if self.profiler_decode_ct == 0:
|
2824
|
-
if self.profile_in_progress:
|
2825
|
-
# force trace flush
|
2826
|
-
self.stop_profile(ForwardMode.EXTEND)
|
2827
|
-
self.start_profile(batch.forward_mode)
|
2828
|
-
self.profiler_decode_ct += 1
|
2829
|
-
if self.profiler_decode_ct > self.profiler_target_decode_ct:
|
2830
|
-
if self.profile_in_progress:
|
2831
|
-
self.stop_profile(stage=ForwardMode.DECODE)
|
2832
|
-
elif batch.forward_mode.is_idle():
|
2833
|
-
pass
|
2834
|
-
else:
|
2835
|
-
raise RuntimeError(f"unsupported profile stage: {batch.forward_mode}")
|
2836
|
-
else:
|
2837
|
-
# Check profiler
|
2838
|
-
if (
|
2839
|
-
self.profiler_target_forward_ct
|
2840
|
-
and self.profiler_target_forward_ct <= self.forward_ct
|
2841
|
-
):
|
2842
|
-
self.stop_profile()
|
2843
|
-
if (
|
2844
|
-
self.profiler_start_forward_ct
|
2845
|
-
and self.profiler_start_forward_ct == self.forward_ct
|
2846
|
-
):
|
2847
|
-
self.start_profile()
|
2848
|
-
|
2849
2280
|
def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
|
2850
2281
|
if recv_req == ExpertDistributionReq.START_RECORD:
|
2851
2282
|
get_global_expert_distribution_recorder().start_record()
|
@@ -2854,7 +2285,7 @@ class Scheduler(
|
|
2854
2285
|
elif recv_req == ExpertDistributionReq.DUMP_RECORD:
|
2855
2286
|
get_global_expert_distribution_recorder().dump_record()
|
2856
2287
|
else:
|
2857
|
-
raise ValueError("Unrecognized ExpertDistributionReq value")
|
2288
|
+
raise ValueError(f"Unrecognized ExpertDistributionReq value: {recv_req=}")
|
2858
2289
|
return ExpertDistributionReqOutput()
|
2859
2290
|
|
2860
2291
|
def open_session(self, recv_req: OpenSessionReqInput):
|
@@ -2890,30 +2321,41 @@ class Scheduler(
|
|
2890
2321
|
prefix += f" PP{self.pp_rank}"
|
2891
2322
|
return prefix
|
2892
2323
|
|
2893
|
-
def
|
2894
|
-
|
2895
|
-
events = self.tree_cache.take_events()
|
2896
|
-
if events:
|
2897
|
-
batch = KVEventBatch(ts=time.time(), events=events)
|
2898
|
-
self.kv_event_publisher.publish(batch)
|
2324
|
+
def current_scheduler_metrics_enabled(self):
|
2325
|
+
return self.attn_tp_rank == 0 or self.enable_metrics_for_all_schedulers
|
2899
2326
|
|
2327
|
+
def maybe_sleep_on_idle(self):
|
2328
|
+
if self.idle_sleeper is not None:
|
2329
|
+
self.idle_sleeper.maybe_sleep()
|
2900
2330
|
|
2901
|
-
def is_health_check_generate_req(recv_req):
|
2902
|
-
return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
|
2903
2331
|
|
2332
|
+
class IdleSleeper:
|
2333
|
+
"""
|
2334
|
+
In setups which have long inactivity periods it is desirable to reduce
|
2335
|
+
system power consumption when sglang does nothing. This would lead not only
|
2336
|
+
to power savings, but also to more CPU thermal headroom when a request
|
2337
|
+
eventually comes. This is important in cases when multiple GPUs are connected
|
2338
|
+
as each GPU would otherwise pin one thread at 100% CPU usage.
|
2904
2339
|
|
2905
|
-
|
2906
|
-
|
2907
|
-
|
2908
|
-
(name, buffer.detach().clone()) for name, buffer in model.named_buffers()
|
2909
|
-
]
|
2910
|
-
)
|
2340
|
+
The simplest solution is to use zmq.Poller on all sockets that may receive
|
2341
|
+
data that needs handling immediately.
|
2342
|
+
"""
|
2911
2343
|
|
2344
|
+
def __init__(self, sockets):
|
2345
|
+
self.poller = zmq.Poller()
|
2346
|
+
for s in sockets:
|
2347
|
+
self.poller.register(s, zmq.POLLIN)
|
2912
2348
|
|
2913
|
-
def
|
2914
|
-
|
2915
|
-
|
2916
|
-
|
2349
|
+
def maybe_sleep(self):
|
2350
|
+
self.poller.poll(1000)
|
2351
|
+
|
2352
|
+
|
2353
|
+
def is_health_check_generate_req(recv_req):
|
2354
|
+
return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
|
2355
|
+
|
2356
|
+
|
2357
|
+
def is_work_request(recv_req):
|
2358
|
+
return isinstance(recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput))
|
2917
2359
|
|
2918
2360
|
|
2919
2361
|
def run_scheduler_process(
|
@@ -2921,6 +2363,7 @@ def run_scheduler_process(
|
|
2921
2363
|
port_args: PortArgs,
|
2922
2364
|
gpu_id: int,
|
2923
2365
|
tp_rank: int,
|
2366
|
+
moe_ep_rank: int,
|
2924
2367
|
pp_rank: int,
|
2925
2368
|
dp_rank: Optional[int],
|
2926
2369
|
pipe_writer,
|
@@ -2931,6 +2374,8 @@ def run_scheduler_process(
|
|
2931
2374
|
prefix += f" DP{dp_rank}"
|
2932
2375
|
if server_args.tp_size > 1:
|
2933
2376
|
prefix += f" TP{tp_rank}"
|
2377
|
+
if server_args.ep_size > 1:
|
2378
|
+
prefix += f" EP{moe_ep_rank}"
|
2934
2379
|
if server_args.pp_size > 1:
|
2935
2380
|
prefix += f" PP{pp_rank}"
|
2936
2381
|
|
@@ -2954,7 +2399,9 @@ def run_scheduler_process(
|
|
2954
2399
|
|
2955
2400
|
# Create a scheduler and run the event loop
|
2956
2401
|
try:
|
2957
|
-
scheduler = Scheduler(
|
2402
|
+
scheduler = Scheduler(
|
2403
|
+
server_args, port_args, gpu_id, tp_rank, moe_ep_rank, pp_rank, dp_rank
|
2404
|
+
)
|
2958
2405
|
pipe_writer.send(
|
2959
2406
|
{
|
2960
2407
|
"status": "ready",
|