sglang 0.4.9.post6__py3-none-any.whl → 0.4.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/srt/configs/__init__.py +8 -0
  3. sglang/srt/configs/model_config.py +3 -0
  4. sglang/srt/configs/step3_vl.py +172 -0
  5. sglang/srt/conversation.py +23 -0
  6. sglang/srt/disaggregation/decode.py +2 -8
  7. sglang/srt/disaggregation/prefill.py +2 -6
  8. sglang/srt/distributed/parallel_state.py +86 -1
  9. sglang/srt/entrypoints/engine.py +14 -18
  10. sglang/srt/entrypoints/http_server.py +10 -2
  11. sglang/srt/entrypoints/openai/serving_chat.py +2 -21
  12. sglang/srt/eplb/expert_distribution.py +5 -0
  13. sglang/srt/eplb/expert_location.py +17 -6
  14. sglang/srt/eplb/expert_location_dispatch.py +1 -0
  15. sglang/srt/eplb/expert_location_updater.py +2 -0
  16. sglang/srt/function_call/function_call_parser.py +2 -0
  17. sglang/srt/function_call/step3_detector.py +436 -0
  18. sglang/srt/hf_transformers_utils.py +2 -0
  19. sglang/srt/jinja_template_utils.py +4 -1
  20. sglang/srt/layers/moe/cutlass_moe.py +2 -1
  21. sglang/srt/layers/moe/ep_moe/layer.py +20 -640
  22. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +26 -13
  23. sglang/srt/layers/moe/fused_moe_triton/layer.py +97 -38
  24. sglang/srt/layers/quantization/fp8.py +0 -18
  25. sglang/srt/layers/quantization/unquant.py +0 -8
  26. sglang/srt/layers/quantization/w4afp8.py +1 -0
  27. sglang/srt/managers/cache_controller.py +143 -45
  28. sglang/srt/managers/data_parallel_controller.py +2 -0
  29. sglang/srt/managers/io_struct.py +0 -2
  30. sglang/srt/managers/scheduler.py +89 -671
  31. sglang/srt/managers/scheduler_metrics_mixin.py +229 -0
  32. sglang/srt/managers/scheduler_profiler_mixin.py +279 -0
  33. sglang/srt/managers/scheduler_update_weights_mixin.py +142 -0
  34. sglang/srt/managers/template_manager.py +62 -19
  35. sglang/srt/managers/tokenizer_manager.py +123 -74
  36. sglang/srt/managers/tp_worker.py +4 -0
  37. sglang/srt/managers/tp_worker_overlap_thread.py +2 -1
  38. sglang/srt/mem_cache/hicache_storage.py +45 -11
  39. sglang/srt/mem_cache/hiradix_cache.py +15 -4
  40. sglang/srt/mem_cache/memory_pool_host.py +73 -1
  41. sglang/srt/mem_cache/mooncake_store/mooncake_store.py +264 -0
  42. sglang/srt/mem_cache/mooncake_store/unit_test.py +40 -0
  43. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +177 -0
  44. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +278 -0
  45. sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py +43 -0
  46. sglang/srt/model_executor/model_runner.py +5 -0
  47. sglang/srt/models/arcee.py +532 -0
  48. sglang/srt/models/deepseek_v2.py +2 -0
  49. sglang/srt/models/glm4_moe.py +3 -1
  50. sglang/srt/models/granitemoe.py +3 -0
  51. sglang/srt/models/grok.py +3 -0
  52. sglang/srt/models/hunyuan.py +1 -0
  53. sglang/srt/models/llama4.py +3 -0
  54. sglang/srt/models/mixtral.py +3 -0
  55. sglang/srt/models/olmoe.py +3 -0
  56. sglang/srt/models/phimoe.py +1 -0
  57. sglang/srt/models/step3_vl.py +994 -0
  58. sglang/srt/multimodal/processors/base_processor.py +15 -16
  59. sglang/srt/multimodal/processors/step3_vl.py +515 -0
  60. sglang/srt/reasoning_parser.py +2 -1
  61. sglang/srt/server_args.py +10 -13
  62. sglang/srt/speculative/eagle_worker.py +2 -0
  63. sglang/utils.py +0 -11
  64. sglang/version.py +1 -1
  65. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.dist-info}/METADATA +3 -4
  66. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.dist-info}/RECORD +69 -56
  67. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.dist-info}/WHEEL +0 -0
  68. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.dist-info}/licenses/LICENSE +0 -0
  69. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.dist-info}/top_level.txt +0 -0
@@ -13,7 +13,6 @@
13
13
  # ==============================================================================
14
14
  """A scheduler that manages a tensor parallel GPU worker."""
15
15
 
16
- import datetime
17
16
  import faulthandler
18
17
  import logging
19
18
  import os
@@ -21,11 +20,10 @@ import signal
21
20
  import sys
22
21
  import threading
23
22
  import time
24
- from collections import defaultdict, deque
23
+ from collections import deque
25
24
  from concurrent import futures
26
25
  from dataclasses import dataclass
27
26
  from http import HTTPStatus
28
- from pathlib import Path
29
27
  from types import SimpleNamespace
30
28
  from typing import Dict, List, Optional, Tuple, Union
31
29
 
@@ -37,7 +35,6 @@ from torch.distributed import barrier
37
35
 
38
36
  from sglang.global_config import global_config
39
37
  from sglang.srt.configs.model_config import ModelConfig
40
- from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
41
38
  from sglang.srt.constrained.base_grammar_backend import (
42
39
  INVALID_GRAMMAR_OBJ,
43
40
  create_grammar_backend,
@@ -47,7 +44,6 @@ from sglang.srt.disaggregation.decode import (
47
44
  DecodeTransferQueue,
48
45
  SchedulerDisaggregationDecodeMixin,
49
46
  )
50
- from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
51
47
  from sglang.srt.disaggregation.prefill import (
52
48
  PrefillBootstrapQueue,
53
49
  SchedulerDisaggregationPrefillMixin,
@@ -78,21 +74,15 @@ from sglang.srt.managers.io_struct import (
78
74
  GetInternalStateReq,
79
75
  GetInternalStateReqOutput,
80
76
  GetWeightsByNameReqInput,
81
- GetWeightsByNameReqOutput,
82
77
  HealthCheckOutput,
83
78
  InitWeightsUpdateGroupReqInput,
84
- InitWeightsUpdateGroupReqOutput,
85
79
  LoadLoRAAdapterReqInput,
86
80
  LoadLoRAAdapterReqOutput,
87
81
  OpenSessionReqInput,
88
82
  OpenSessionReqOutput,
89
83
  ProfileReq,
90
- ProfileReqOutput,
91
- ProfileReqType,
92
84
  ReleaseMemoryOccupationReqInput,
93
- ReleaseMemoryOccupationReqOutput,
94
85
  ResumeMemoryOccupationReqInput,
95
- ResumeMemoryOccupationReqOutput,
96
86
  RpcReqInput,
97
87
  RpcReqOutput,
98
88
  SetInternalStateReq,
@@ -104,11 +94,8 @@ from sglang.srt.managers.io_struct import (
104
94
  UnloadLoRAAdapterReqInput,
105
95
  UnloadLoRAAdapterReqOutput,
106
96
  UpdateWeightFromDiskReqInput,
107
- UpdateWeightFromDiskReqOutput,
108
97
  UpdateWeightsFromDistributedReqInput,
109
- UpdateWeightsFromDistributedReqOutput,
110
98
  UpdateWeightsFromTensorReqInput,
111
- UpdateWeightsFromTensorReqOutput,
112
99
  )
113
100
  from sglang.srt.managers.mm_utils import init_embedding_cache
114
101
  from sglang.srt.managers.schedule_batch import (
@@ -124,9 +111,17 @@ from sglang.srt.managers.schedule_policy import (
124
111
  SchedulePolicy,
125
112
  )
126
113
  from sglang.srt.managers.scheduler_input_blocker import SchedulerInputBlocker
114
+ from sglang.srt.managers.scheduler_metrics_mixin import (
115
+ RECORD_STEP_TIME,
116
+ SchedulerMetricsMixin,
117
+ )
127
118
  from sglang.srt.managers.scheduler_output_processor_mixin import (
128
119
  SchedulerOutputProcessorMixin,
129
120
  )
121
+ from sglang.srt.managers.scheduler_profiler_mixin import SchedulerProfilerMixin
122
+ from sglang.srt.managers.scheduler_update_weights_mixin import (
123
+ SchedulerUpdateWeightsMixin,
124
+ )
130
125
  from sglang.srt.managers.session_controller import Session
131
126
  from sglang.srt.managers.tp_worker import TpModelWorker
132
127
  from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
@@ -135,7 +130,6 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
135
130
  from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
136
131
  from sglang.srt.mem_cache.radix_cache import RadixCache
137
132
  from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
138
- from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
139
133
  from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
140
134
  from sglang.srt.reasoning_parser import ReasoningParser
141
135
  from sglang.srt.server_args import PortArgs, ServerArgs
@@ -168,7 +162,6 @@ logger = logging.getLogger(__name__)
168
162
 
169
163
  # Test retract decode for debugging purposes
170
164
  TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
171
- RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
172
165
  GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
173
166
 
174
167
  _is_cpu = is_cpu()
@@ -191,41 +184,11 @@ class EmbeddingBatchResult:
191
184
  bid: int
192
185
 
193
186
 
194
- class KvMetrics:
195
- def __init__(self):
196
- self.request_active_slots = None
197
- self.request_total_slots = None
198
- self.kv_active_blocks = None
199
- self.kv_total_blocks = None
200
- self.num_requests_waiting = None
201
- self.gpu_cache_usage_perc = None
202
- self.gpu_prefix_cache_hit_rate = None
203
- self.data_parallel_rank = None
204
-
205
-
206
- class IdleSleeper:
207
- """
208
- In setups which have long inactivity periods it is desirable to reduce
209
- system power consumption when sglang does nothing. This would lead not only
210
- to power savings, but also to more CPU thermal headroom when a request
211
- eventually comes. This is important in cases when multiple GPUs are connected
212
- as each GPU would otherwise pin one thread at 100% CPU usage.
213
-
214
- The simplest solution is to use zmq.Poller on all sockets that may receive
215
- data that needs handling immediately.
216
- """
217
-
218
- def __init__(self, sockets):
219
- self.poller = zmq.Poller()
220
- for s in sockets:
221
- self.poller.register(s, zmq.POLLIN)
222
-
223
- def maybe_sleep(self):
224
- self.poller.poll(1000)
225
-
226
-
227
187
  class Scheduler(
228
188
  SchedulerOutputProcessorMixin,
189
+ SchedulerUpdateWeightsMixin,
190
+ SchedulerProfilerMixin,
191
+ SchedulerMetricsMixin,
229
192
  SchedulerDisaggregationDecodeMixin,
230
193
  SchedulerDisaggregationPrefillMixin,
231
194
  ):
@@ -237,15 +200,18 @@ class Scheduler(
237
200
  port_args: PortArgs,
238
201
  gpu_id: int,
239
202
  tp_rank: int,
203
+ moe_ep_rank: int,
240
204
  pp_rank: int,
241
205
  dp_rank: Optional[int],
242
206
  ):
243
207
  # Parse args
244
208
  self.server_args = server_args
245
209
  self.tp_rank = tp_rank
210
+ self.moe_ep_rank = moe_ep_rank
246
211
  self.pp_rank = pp_rank
247
212
  self.dp_rank = dp_rank
248
213
  self.tp_size = server_args.tp_size
214
+ self.moe_ep_size = server_args.ep_size
249
215
  self.pp_size = server_args.pp_size
250
216
  self.dp_size = server_args.dp_size
251
217
  self.schedule_policy = server_args.schedule_policy
@@ -266,7 +232,7 @@ class Scheduler(
266
232
  self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
267
233
  self.enable_hicache_storage = server_args.hicache_storage_backend is not None
268
234
  self.page_size = server_args.page_size
269
- self.dp_size = server_args.dp_size
235
+
270
236
  self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
271
237
  compute_dp_attention_world_info(
272
238
  server_args.enable_dp_attention,
@@ -284,10 +250,13 @@ class Scheduler(
284
250
  self.recv_from_tokenizer = get_zmq_socket(
285
251
  context, zmq.PULL, port_args.scheduler_input_ipc_name, False
286
252
  )
253
+ self.recv_from_rpc = get_zmq_socket(
254
+ context, zmq.DEALER, port_args.rpc_ipc_name, False
255
+ )
256
+
287
257
  self.send_to_tokenizer = get_zmq_socket(
288
258
  context, zmq.PUSH, port_args.tokenizer_ipc_name, False
289
259
  )
290
-
291
260
  if server_args.skip_tokenizer_init:
292
261
  # Directly send to the TokenizerManager
293
262
  self.send_to_detokenizer = get_zmq_socket(
@@ -299,9 +268,6 @@ class Scheduler(
299
268
  context, zmq.PUSH, port_args.detokenizer_ipc_name, False
300
269
  )
301
270
 
302
- self.recv_from_rpc = get_zmq_socket(
303
- context, zmq.DEALER, port_args.rpc_ipc_name, False
304
- )
305
271
  if self.server_args.sleep_on_idle:
306
272
  self.idle_sleeper = IdleSleeper(
307
273
  [
@@ -347,6 +313,7 @@ class Scheduler(
347
313
  server_args=server_args,
348
314
  gpu_id=gpu_id,
349
315
  tp_rank=tp_rank,
316
+ moe_ep_rank=moe_ep_rank,
350
317
  pp_rank=pp_rank,
351
318
  dp_rank=dp_rank,
352
319
  nccl_port=port_args.nccl_port,
@@ -359,6 +326,7 @@ class Scheduler(
359
326
  self.draft_worker = EAGLEWorker(
360
327
  gpu_id=gpu_id,
361
328
  tp_rank=tp_rank,
329
+ moe_ep_rank=moe_ep_rank,
362
330
  server_args=server_args,
363
331
  nccl_port=port_args.nccl_port,
364
332
  target_worker=self.tp_worker,
@@ -398,7 +366,7 @@ class Scheduler(
398
366
  global_server_args_dict.update(worker_global_server_args_dict)
399
367
  set_random_seed(self.random_seed)
400
368
 
401
- # Hybrid
369
+ # Hybrid memory pool
402
370
  self.is_hybrid = self.tp_worker.is_hybrid
403
371
  if self.is_hybrid:
404
372
  self.sliding_window_size = self.tp_worker.sliding_window_size
@@ -515,6 +483,15 @@ class Scheduler(
515
483
  self.init_metrics(tp_rank, pp_rank, dp_rank)
516
484
  self.init_kv_events(server_args.kv_events_config)
517
485
 
486
+ # Init disaggregation
487
+ self.disaggregation_mode = DisaggregationMode(
488
+ self.server_args.disaggregation_mode
489
+ )
490
+ self.init_disaggregation()
491
+
492
+ if get_bool_env_var("SGLANG_GC_LOG"):
493
+ configure_gc_logger()
494
+
518
495
  # Init request dispatcher
519
496
  self._request_dispatcher = TypeBasedDispatcher(
520
497
  [
@@ -545,22 +522,6 @@ class Scheduler(
545
522
  ]
546
523
  )
547
524
 
548
- # Init disaggregation
549
- self.disaggregation_mode = DisaggregationMode(
550
- self.server_args.disaggregation_mode
551
- )
552
- self.init_disaggregation()
553
-
554
- if get_bool_env_var("SGLANG_GC_LOG"):
555
- configure_gc_logger()
556
-
557
- def current_scheduler_metrics_enabled(self):
558
- return self.attn_tp_rank == 0 or self.enable_metrics_for_all_schedulers
559
-
560
- def maybe_sleep_on_idle(self):
561
- if self.idle_sleeper is not None:
562
- self.idle_sleeper.maybe_sleep()
563
-
564
525
  def init_tokenizer(self):
565
526
  server_args = self.server_args
566
527
 
@@ -668,50 +629,6 @@ class Scheduler(
668
629
  embedding_cache_size = int(os.environ.get("SGLANG_VLM_CACHE_SIZE_MB", "100"))
669
630
  init_embedding_cache(embedding_cache_size * 1024 * 1024)
670
631
 
671
- def init_profier(self):
672
- self.torch_profiler = None
673
- self.torch_profiler_output_dir: Optional[str] = None
674
- self.profiler_activities: Optional[List[str]] = None
675
- self.profile_id: Optional[str] = None
676
- self.profiler_start_forward_ct: Optional[int] = None
677
- self.profiler_target_forward_ct: Optional[int] = None
678
- self.profiler_target_prefill_ct: Optional[int] = None
679
- self.profiler_target_decode_ct: Optional[int] = None
680
- self.profiler_prefill_ct: Optional[int] = None
681
- self.profiler_decode_ct: Optional[int] = None
682
- self.profile_by_stage: bool = False
683
- self.profile_steps: Optional[int] = None
684
- self.profile_in_progress: bool = False
685
- self.rpd_profiler = None
686
-
687
- def init_metrics(self, tp_rank: int, pp_rank: int, dp_rank: Optional[int]):
688
- self.last_gen_throughput: float = 0.0
689
- self.last_input_throughput: float = 0.0
690
- self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
691
- self.spec_num_total_accepted_tokens = 0
692
- self.spec_num_total_forward_ct = 0
693
- self.cum_spec_accept_length = 0
694
- self.cum_spec_accept_count = 0
695
- self.total_retracted_reqs = 0
696
- self.stats = SchedulerStats()
697
- if self.enable_metrics:
698
- engine_type = "unified"
699
- labels = {
700
- "model_name": self.server_args.served_model_name,
701
- "engine_type": engine_type,
702
- "tp_rank": tp_rank,
703
- "pp_rank": pp_rank,
704
- }
705
- if dp_rank is not None:
706
- labels["dp_rank"] = dp_rank
707
- self.metrics_collector = SchedulerMetricsCollector(labels=labels)
708
-
709
- def init_kv_events(self, kv_events_config: Optional[str]):
710
- if self.enable_kv_cache_events:
711
- self.kv_event_publisher = EventPublisherFactory.create(
712
- kv_events_config, self.attn_dp_rank
713
- )
714
-
715
632
  def init_disaggregation(self):
716
633
  self.transfer_backend = TransferBackend(
717
634
  self.server_args.disaggregation_transfer_backend
@@ -820,10 +737,7 @@ class Scheduler(
820
737
  self.process_batch_result(batch, result)
821
738
  else:
822
739
  # When the server is idle, do self-check and re-init some states
823
- self.check_memory()
824
- self.check_tree_cache()
825
- self.new_token_ratio = self.init_new_token_ratio
826
- self.maybe_sleep_on_idle()
740
+ self.self_check_during_idle()
827
741
 
828
742
  self.last_batch = batch
829
743
 
@@ -866,10 +780,7 @@ class Scheduler(
866
780
  )
867
781
  elif batch is None:
868
782
  # When the server is idle, do self-check and re-init some states
869
- self.check_memory()
870
- self.check_tree_cache()
871
- self.new_token_ratio = self.init_new_token_ratio
872
- self.maybe_sleep_on_idle()
783
+ self.self_check_during_idle()
873
784
 
874
785
  self.last_batch = batch
875
786
 
@@ -1003,10 +914,8 @@ class Scheduler(
1003
914
 
1004
915
  # When the server is idle, self-check and re-init some states
1005
916
  if server_is_idle:
1006
- self.check_memory()
1007
- self.check_tree_cache()
1008
- self.new_token_ratio = self.init_new_token_ratio
1009
- self.maybe_sleep_on_idle()
917
+ # When the server is idle, do self-check and re-init some states
918
+ self.self_check_during_idle()
1010
919
 
1011
920
  def recv_requests(self) -> List[Req]:
1012
921
  """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
@@ -1281,23 +1190,28 @@ class Scheduler(
1281
1190
  def _add_request_to_queue(self, req: Req):
1282
1191
  req.queue_time_start = time.perf_counter()
1283
1192
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
1193
+ self._prefetch_kvcache(req)
1284
1194
  self.disagg_prefill_bootstrap_queue.add(
1285
1195
  req, self.model_config.num_key_value_heads
1286
1196
  )
1287
1197
  elif self.disaggregation_mode == DisaggregationMode.DECODE:
1288
1198
  self.disagg_decode_prealloc_queue.add(req)
1289
1199
  else:
1290
- if self.enable_hicache_storage:
1291
- req.init_next_round_input(self.tree_cache)
1292
- last_hash = req.last_host_node.get_last_hash_value()
1293
- matched_len = len(req.prefix_indices) + req.host_hit_length
1294
- if (matched_len > 0 and last_hash is not None) or matched_len == 0:
1295
- new_input_tokens = req.fill_ids[matched_len:]
1296
- self.tree_cache.prefetch_from_storage(
1297
- req.rid, req.last_host_node, new_input_tokens, last_hash
1298
- )
1200
+ self._prefetch_kvcache(req)
1299
1201
  self.waiting_queue.append(req)
1300
1202
 
1203
+ def _prefetch_kvcache(self, req: Req):
1204
+ if self.enable_hicache_storage:
1205
+ req.init_next_round_input(self.tree_cache)
1206
+ last_hash = req.last_host_node.get_last_hash_value()
1207
+ matched_len = len(req.prefix_indices) + req.host_hit_length
1208
+ # todo, free-form fetching, calculating hash keys on the fly
1209
+ if (matched_len > 0 and last_hash is not None) or matched_len == 0:
1210
+ new_input_tokens = req.fill_ids[matched_len:]
1211
+ self.tree_cache.prefetch_from_storage(
1212
+ req.rid, req.last_host_node, new_input_tokens, last_hash
1213
+ )
1214
+
1301
1215
  def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
1302
1216
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
1303
1217
  self.disagg_prefill_bootstrap_queue.extend(
@@ -1355,170 +1269,11 @@ class Scheduler(
1355
1269
  req.logprob_start_len = len(req.origin_input_ids) - 1
1356
1270
  self._add_request_to_queue(req)
1357
1271
 
1358
- def _emit_kv_metrics(self):
1359
- kv_metrics = KvMetrics()
1360
- kv_metrics.request_active_slots = self.stats.num_running_reqs
1361
- kv_metrics.request_total_slots = self.max_running_requests
1362
- kv_metrics.kv_active_blocks = int(
1363
- self.stats.token_usage * self.max_total_num_tokens
1364
- )
1365
- kv_metrics.kv_total_blocks = self.max_total_num_tokens
1366
- kv_metrics.num_requests_waiting = self.stats.num_queue_reqs
1367
- kv_metrics.gpu_cache_usage_perc = self.stats.token_usage
1368
- kv_metrics.gpu_prefix_cache_hit_rate = self.stats.cache_hit_rate
1369
- kv_metrics.data_parallel_rank = self.dp_rank if self.dp_rank is not None else 0
1370
-
1371
- if not self.send_metrics_from_scheduler.closed:
1372
- self.send_metrics_from_scheduler.send_pyobj(kv_metrics)
1373
-
1374
- def log_prefill_stats(
1375
- self,
1376
- adder: PrefillAdder,
1377
- can_run_list: List[Req],
1378
- running_bs: int,
1379
- ):
1380
- gap_latency = time.perf_counter() - self.last_prefill_stats_tic
1381
- self.last_prefill_stats_tic = time.perf_counter()
1382
- self.last_input_throughput = self.last_prefill_tokens / gap_latency
1383
- self.last_prefill_tokens = adder.log_input_tokens
1384
-
1385
- if self.is_hybrid:
1386
- (
1387
- full_num_used,
1388
- swa_num_used,
1389
- full_token_usage,
1390
- swa_token_usage,
1391
- _,
1392
- _,
1393
- _,
1394
- _,
1395
- ) = self._get_swa_token_info()
1396
- num_used = max(full_num_used, swa_num_used)
1397
- token_usage = max(full_token_usage, swa_token_usage)
1398
- token_msg = (
1399
- f"full token usage: {full_token_usage:.2f}, "
1400
- f"swa token usage: {swa_token_usage:.2f}, "
1401
- )
1402
- else:
1403
- num_used, token_usage, _, _ = self._get_token_info()
1404
- token_msg = f"token usage: {token_usage:.2f}, "
1405
-
1406
- num_new_seq = len(can_run_list)
1407
- f = (
1408
- f"Prefill batch. "
1409
- f"#new-seq: {num_new_seq}, "
1410
- f"#new-token: {adder.log_input_tokens}, "
1411
- f"#cached-token: {adder.log_hit_tokens}, "
1412
- f"{token_msg}"
1413
- )
1414
-
1415
- if self.disaggregation_mode == DisaggregationMode.PREFILL:
1416
- f += f"#unbootstrapped-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
1417
- f += f"#queue-req: {len(self.waiting_queue)}, "
1418
- f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)}, "
1419
- f += f"input throughput (token/s): {self.last_input_throughput:.2f}, "
1420
- else:
1421
- f += f"#running-req: {running_bs}, "
1422
- f += f"#queue-req: {len(self.waiting_queue)}, "
1423
-
1424
- logger.info(f)
1425
-
1426
- if self.enable_metrics:
1427
- total_tokens = adder.log_input_tokens + adder.log_hit_tokens
1428
-
1429
- cache_hit_rate = (
1430
- adder.log_hit_tokens / total_tokens if total_tokens > 0 else 0.0
1431
- )
1432
- self.stats.num_running_reqs = running_bs
1433
- self.stats.num_used_tokens = num_used
1434
- self.stats.token_usage = round(token_usage, 2)
1435
- self.stats.num_queue_reqs = len(self.waiting_queue)
1436
- self.stats.cache_hit_rate = cache_hit_rate
1437
-
1438
- total_queue_latency = 0
1439
- for req in can_run_list:
1440
- total_queue_latency += req.queue_time_end - req.queue_time_start
1441
- self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq
1442
-
1443
- self.metrics_collector.log_stats(self.stats)
1444
- self._emit_kv_metrics()
1445
- self._publish_kv_events()
1446
-
1447
- def log_decode_stats(
1448
- self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
1449
- ):
1450
- batch = running_batch or self.running_batch
1451
-
1452
- gap_latency = time.perf_counter() - self.last_decode_stats_tic
1453
- self.last_decode_stats_tic = time.perf_counter()
1454
- self.last_gen_throughput = self.num_generated_tokens / gap_latency
1455
- self.num_generated_tokens = 0
1456
- num_running_reqs = len(batch.reqs)
1457
- if self.is_hybrid:
1458
- (
1459
- full_num_used,
1460
- swa_num_used,
1461
- full_token_usage,
1462
- swa_token_usage,
1463
- _,
1464
- _,
1465
- _,
1466
- _,
1467
- ) = self._get_swa_token_info()
1468
- num_used = max(full_num_used, swa_num_used)
1469
- token_usage = max(full_token_usage, swa_token_usage)
1470
- token_msg = (
1471
- f"#full token: {full_num_used}, "
1472
- f"full token usage: {full_token_usage:.2f}, "
1473
- f"#swa token: {swa_num_used}, "
1474
- f"swa token usage: {swa_token_usage:.2f}, "
1475
- )
1476
- else:
1477
- num_used, token_usage, _, _ = self._get_token_info()
1478
- token_msg = f"#token: {num_used}, " f"token usage: {token_usage:.2f}, "
1479
-
1480
- if RECORD_STEP_TIME:
1481
- self.step_time_dict[num_running_reqs].append(
1482
- gap_latency / self.server_args.decode_log_interval
1483
- )
1484
-
1485
- msg = f"Decode batch. #running-req: {num_running_reqs}, {token_msg}"
1486
-
1487
- if self.spec_algorithm.is_none():
1488
- spec_accept_length = 0
1489
- else:
1490
- spec_accept_length = (
1491
- self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
1492
- )
1493
- self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
1494
- self.cum_spec_accept_count += self.spec_num_total_forward_ct
1495
- self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
1496
- msg += f"accept len: {spec_accept_length:.2f}, "
1497
-
1498
- if self.disaggregation_mode == DisaggregationMode.DECODE:
1499
- msg += f"pre-allocated usage: {self.disagg_decode_prealloc_queue.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
1500
- msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, "
1501
-
1502
- msg += (
1503
- f"cuda graph: {can_run_cuda_graph}, "
1504
- f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
1505
- f"#queue-req: {len(self.waiting_queue)}, "
1506
- )
1507
-
1508
- logger.info(msg)
1509
- if self.enable_metrics:
1510
- self.stats.num_running_reqs = num_running_reqs
1511
- self.stats.num_used_tokens = num_used
1512
- self.stats.token_usage = round(token_usage, 2)
1513
- self.stats.cache_hit_rate = 0.0
1514
- self.stats.gen_throughput = self.last_gen_throughput
1515
- self.stats.num_queue_reqs = len(self.waiting_queue)
1516
- self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
1517
- self.stats.spec_accept_length = spec_accept_length
1518
- self.stats.total_retracted_reqs = self.total_retracted_reqs
1519
- self.metrics_collector.log_stats(self.stats)
1520
- self._emit_kv_metrics()
1521
- self._publish_kv_events()
1272
+ def self_check_during_idle(self):
1273
+ self.check_memory()
1274
+ self.check_tree_cache()
1275
+ self.new_token_ratio = self.init_new_token_ratio
1276
+ self.maybe_sleep_on_idle()
1522
1277
 
1523
1278
  def check_memory(self):
1524
1279
  if self.is_hybrid:
@@ -2422,22 +2177,6 @@ class Scheduler(
2422
2177
  barrier()
2423
2178
  return RpcReqOutput(success, "" if not exec else str(exec))
2424
2179
 
2425
- def save_remote_model(self, params):
2426
- url = params["url"]
2427
-
2428
- worker = self.tp_worker.worker
2429
-
2430
- worker.model_runner.save_remote_model(url)
2431
-
2432
- def save_sharded_model(self, params):
2433
- worker = self.tp_worker.worker
2434
-
2435
- worker.model_runner.save_sharded_model(
2436
- path=params["path"],
2437
- pattern=params["pattern"],
2438
- max_size=params["max_size"],
2439
- )
2440
-
2441
2180
  def abort_request(self, recv_req: AbortReq):
2442
2181
  # Delete requests in the waiting queue
2443
2182
  to_del = []
@@ -2515,16 +2254,6 @@ class Scheduler(
2515
2254
  def _pause_engine(self) -> Tuple[List[Req], int]:
2516
2255
  raise NotImplementedError()
2517
2256
 
2518
- def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
2519
- """In-place update of the weights from disk."""
2520
- success, message = self.tp_worker.update_weights_from_disk(recv_req)
2521
- if success:
2522
- flush_cache_success = self.flush_cache()
2523
- assert flush_cache_success, "Cache flush failed after updating weights"
2524
- else:
2525
- logger.error(message)
2526
- return UpdateWeightFromDiskReqOutput(success, message, 0)
2527
-
2528
2257
  def load_lora_adapter(
2529
2258
  self, recv_req: LoadLoRAAdapterReqInput
2530
2259
  ) -> LoadLoRAAdapterReqOutput:
@@ -2541,81 +2270,6 @@ class Scheduler(
2541
2270
  result = self.tp_worker.unload_lora_adapter(recv_req)
2542
2271
  return result
2543
2272
 
2544
- def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
2545
- """Initialize the online model parameter update group."""
2546
- success, message = self.tp_worker.init_weights_update_group(recv_req)
2547
- return InitWeightsUpdateGroupReqOutput(success, message)
2548
-
2549
- def update_weights_from_distributed(
2550
- self,
2551
- recv_req: UpdateWeightsFromDistributedReqInput,
2552
- ) -> Tuple[bool, str]:
2553
- """Update the online model parameter."""
2554
- success, message = self.tp_worker.update_weights_from_distributed(recv_req)
2555
- if success:
2556
- if recv_req.flush_cache:
2557
- flush_cache_success = self.flush_cache()
2558
- assert flush_cache_success, "Cache flush failed after updating weights"
2559
- else:
2560
- logger.error(message)
2561
- return UpdateWeightsFromDistributedReqOutput(success, message)
2562
-
2563
- def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
2564
- """Update the online model parameter from tensors."""
2565
- success, message = self.tp_worker.update_weights_from_tensor(recv_req)
2566
- # TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
2567
- if success:
2568
- if recv_req.flush_cache:
2569
- flush_cache_success = self.flush_cache()
2570
- assert flush_cache_success, "Cache flush failed after updating weights"
2571
- else:
2572
- logger.error(message)
2573
- barrier(group=self.tp_cpu_group)
2574
- return UpdateWeightsFromTensorReqOutput(success, message)
2575
-
2576
- def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
2577
- parameter = self.tp_worker.get_weights_by_name(recv_req)
2578
- return GetWeightsByNameReqOutput(parameter)
2579
-
2580
- def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
2581
- tags = recv_req.tags
2582
-
2583
- if tags is None or len(tags) == 0:
2584
- tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
2585
-
2586
- if GPU_MEMORY_TYPE_KV_CACHE in tags:
2587
- self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE)
2588
- self.flush_cache()
2589
-
2590
- if GPU_MEMORY_TYPE_WEIGHTS in tags:
2591
- self.stashed_model_static_state = _export_static_state(
2592
- self.tp_worker.worker.model_runner.model
2593
- )
2594
- torch.distributed.barrier(self.tp_cpu_group)
2595
- self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_WEIGHTS)
2596
-
2597
- return ReleaseMemoryOccupationReqOutput()
2598
-
2599
- def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
2600
- tags = recv_req.tags
2601
-
2602
- if tags is None or len(tags) == 0:
2603
- tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
2604
-
2605
- if GPU_MEMORY_TYPE_WEIGHTS in tags:
2606
- self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
2607
- torch.distributed.barrier(self.tp_cpu_group)
2608
- _import_static_state(
2609
- self.tp_worker.worker.model_runner.model,
2610
- self.stashed_model_static_state,
2611
- )
2612
- del self.stashed_model_static_state
2613
-
2614
- if GPU_MEMORY_TYPE_KV_CACHE in tags:
2615
- self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_KV_CACHE)
2616
-
2617
- return ResumeMemoryOccupationReqOutput()
2618
-
2619
2273
  def slow_down(self, recv_req: SlowDownReqInput):
2620
2274
  t = recv_req.forward_sleep_time
2621
2275
  if t is not None and t <= 0:
@@ -2623,254 +2277,6 @@ class Scheduler(
2623
2277
  self.forward_sleep_time = t
2624
2278
  return SlowDownReqOutput()
2625
2279
 
2626
- def profile(self, recv_req: ProfileReq):
2627
- if recv_req.type == ProfileReqType.START_PROFILE:
2628
- if recv_req.profile_by_stage or recv_req.start_step:
2629
- return self.init_profile(
2630
- recv_req.output_dir,
2631
- recv_req.start_step,
2632
- recv_req.num_steps,
2633
- recv_req.activities,
2634
- recv_req.with_stack,
2635
- recv_req.record_shapes,
2636
- recv_req.profile_by_stage,
2637
- recv_req.profile_id,
2638
- )
2639
- else:
2640
- self.init_profile(
2641
- recv_req.output_dir,
2642
- recv_req.start_step,
2643
- recv_req.num_steps,
2644
- recv_req.activities,
2645
- recv_req.with_stack,
2646
- recv_req.record_shapes,
2647
- recv_req.profile_by_stage,
2648
- recv_req.profile_id,
2649
- )
2650
- return self.start_profile(True)
2651
- else:
2652
- return self.stop_profile()
2653
-
2654
- def init_profile(
2655
- self,
2656
- output_dir: Optional[str],
2657
- start_step: Optional[int],
2658
- num_steps: Optional[int],
2659
- activities: Optional[List[str]],
2660
- with_stack: Optional[bool],
2661
- record_shapes: Optional[bool],
2662
- profile_by_stage: bool,
2663
- profile_id: str,
2664
- ) -> ProfileReqOutput:
2665
- if self.profile_in_progress:
2666
- return ProfileReqOutput(
2667
- success=False,
2668
- message="Profiling is already in progress. Call /stop_profile first.",
2669
- )
2670
-
2671
- self.profile_by_stage = profile_by_stage
2672
-
2673
- if output_dir is None:
2674
- output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp")
2675
- if activities is None:
2676
- activities = ["CPU", "GPU"]
2677
-
2678
- self.torch_profiler_output_dir = output_dir
2679
- self.torch_profiler_with_stack = with_stack
2680
- self.torch_profiler_record_shapes = record_shapes
2681
- self.profiler_activities = activities
2682
- self.profile_id = profile_id
2683
-
2684
- if start_step:
2685
- self.profiler_start_forward_ct = max(start_step, self.forward_ct + 1)
2686
-
2687
- if num_steps:
2688
- self.profile_steps = num_steps
2689
- if self.profile_by_stage:
2690
- self.profiler_target_prefill_ct = num_steps
2691
- self.profiler_target_decode_ct = num_steps
2692
- self.profiler_prefill_ct = 0
2693
- self.profiler_decode_ct = 0
2694
- elif start_step:
2695
- self.profiler_target_forward_ct = (
2696
- self.profiler_start_forward_ct + num_steps
2697
- )
2698
- else:
2699
- self.profiler_target_forward_ct = self.forward_ct + num_steps
2700
- # The caller will be notified when reaching profiler_target_forward_ct
2701
- else:
2702
- self.profiler_target_forward_ct = None
2703
-
2704
- return ProfileReqOutput(success=True, message="Succeeded")
2705
-
2706
- def start_profile(
2707
- self, stage: Optional[ForwardMode] = None
2708
- ) -> ProfileReqOutput | None:
2709
- stage_str = f" for {stage.__str__()}" if stage else ""
2710
- logger.info(
2711
- f"Profiling starts{stage_str}. Traces will be saved to: {self.torch_profiler_output_dir} (with profile id: {self.profile_id})",
2712
- )
2713
-
2714
- activities = self.profiler_activities
2715
- with_stack = self.torch_profiler_with_stack
2716
- record_shapes = self.torch_profiler_record_shapes
2717
-
2718
- activity_map = {
2719
- "CPU": torch.profiler.ProfilerActivity.CPU,
2720
- "GPU": torch.profiler.ProfilerActivity.CUDA,
2721
- }
2722
- torchprof_activities = [
2723
- activity_map[a] for a in activities if a in activity_map
2724
- ]
2725
-
2726
- if "RPD" in activities:
2727
- from rpdTracerControl import rpdTracerControl
2728
-
2729
- rpdTracerControl.skipCreate()
2730
-
2731
- self.rpd_profile_path = os.path.join(
2732
- self.torch_profiler_output_dir,
2733
- "rpd-" + str(time.time()) + f"-TP-{self.tp_rank}" + ".trace.json.gz",
2734
- )
2735
-
2736
- if self.tp_rank == 0:
2737
- import sqlite3
2738
-
2739
- from rocpd.schema import RocpdSchema
2740
-
2741
- if os.path.exists("trace.rpd"):
2742
- os.unlink("trace.rpd")
2743
- schema = RocpdSchema()
2744
- connection = sqlite3.connect("trace.rpd")
2745
- schema.writeSchema(connection)
2746
- connection.commit()
2747
- del connection
2748
- torch.distributed.barrier(self.tp_cpu_group)
2749
-
2750
- self.rpd_profiler = rpdTracerControl()
2751
- self.rpd_profiler.setPythonTrace(True)
2752
- self.rpd_profiler.start()
2753
- self.rpd_profiler.rangePush("", "rpd profile range", "")
2754
- self.profile_in_progress = True
2755
- elif torchprof_activities:
2756
- self.torch_profiler = torch.profiler.profile(
2757
- activities=torchprof_activities,
2758
- with_stack=with_stack if with_stack is not None else True,
2759
- record_shapes=record_shapes if record_shapes is not None else False,
2760
- )
2761
- self.torch_profiler.start()
2762
- self.profile_in_progress = True
2763
-
2764
- if "MEM" in activities:
2765
- torch.cuda.memory._record_memory_history(max_entries=100000)
2766
- self.profile_in_progress = True
2767
-
2768
- if "CUDA_PROFILER" in activities:
2769
- torch.cuda.cudart().cudaProfilerStart()
2770
- self.profile_in_progress = True
2771
-
2772
- return ProfileReqOutput(success=True, message="Succeeded")
2773
-
2774
- def stop_profile(
2775
- self, stage: Optional[ForwardMode] = None
2776
- ) -> ProfileReqOutput | None:
2777
- if not self.profile_in_progress:
2778
- return ProfileReqOutput(
2779
- success=False,
2780
- message="Profiling is not in progress. Call /start_profile first.",
2781
- )
2782
-
2783
- if not Path(self.torch_profiler_output_dir).exists():
2784
- Path(self.torch_profiler_output_dir).mkdir(parents=True, exist_ok=True)
2785
-
2786
- stage_suffix = f"-{stage.__str__()}" if stage else ""
2787
- logger.info("Stop profiling" + stage_suffix + "...")
2788
- if self.torch_profiler is not None:
2789
- self.torch_profiler.stop()
2790
- self.torch_profiler.export_chrome_trace(
2791
- os.path.join(
2792
- self.torch_profiler_output_dir,
2793
- self.profile_id
2794
- + f"-TP-{self.tp_rank}"
2795
- + stage_suffix
2796
- + ".trace.json.gz",
2797
- )
2798
- )
2799
- torch.distributed.barrier(self.tp_cpu_group)
2800
-
2801
- if self.rpd_profiler is not None:
2802
- self.rpd_profiler.rangePop()
2803
- self.rpd_profiler.stop()
2804
- self.rpd_profiler.flush()
2805
-
2806
- torch.distributed.barrier(self.tp_cpu_group)
2807
- if self.tp_rank == 0:
2808
- from sglang.srt.utils import rpd_to_chrome_trace
2809
-
2810
- rpd_to_chrome_trace("trace.rpd", self.rpd_profile_path)
2811
- self.rpd_profiler = None
2812
- self.rpd_profiler_path = None
2813
-
2814
- if self.profiler_activities is not None and "MEM" in self.profiler_activities:
2815
- memory_profile_path = os.path.join(
2816
- self.torch_profiler_output_dir,
2817
- str(time.time())
2818
- + f"-TP-{self.tp_rank}-memory"
2819
- + stage_suffix
2820
- + ".pickle",
2821
- )
2822
- torch.cuda.memory._dump_snapshot(memory_profile_path)
2823
- torch.cuda.memory._record_memory_history(enabled=None)
2824
-
2825
- if "CUDA_PROFILER" in self.profiler_activities:
2826
- torch.cuda.cudart().cudaProfilerStop()
2827
-
2828
- logger.info(
2829
- "Profiling done. Traces are saved to: %s",
2830
- self.torch_profiler_output_dir,
2831
- )
2832
- self.torch_profiler = None
2833
- self.profile_in_progress = False
2834
- self.profiler_start_forward_ct = None
2835
-
2836
- return ProfileReqOutput(success=True, message="Succeeded.")
2837
-
2838
- def _profile_batch_predicate(self, batch):
2839
- if self.profile_by_stage:
2840
- if batch.forward_mode.is_prefill():
2841
- if self.profiler_prefill_ct == 0:
2842
- self.start_profile(batch.forward_mode)
2843
- self.profiler_prefill_ct += 1
2844
- if self.profiler_prefill_ct > self.profiler_target_prefill_ct:
2845
- if self.profile_in_progress:
2846
- self.stop_profile(stage=ForwardMode.EXTEND)
2847
- elif batch.forward_mode.is_decode():
2848
- if self.profiler_decode_ct == 0:
2849
- if self.profile_in_progress:
2850
- # force trace flush
2851
- self.stop_profile(ForwardMode.EXTEND)
2852
- self.start_profile(batch.forward_mode)
2853
- self.profiler_decode_ct += 1
2854
- if self.profiler_decode_ct > self.profiler_target_decode_ct:
2855
- if self.profile_in_progress:
2856
- self.stop_profile(stage=ForwardMode.DECODE)
2857
- elif batch.forward_mode.is_idle():
2858
- pass
2859
- else:
2860
- raise RuntimeError(f"unsupported profile stage: {batch.forward_mode}")
2861
- else:
2862
- # Check profiler
2863
- if (
2864
- self.profiler_target_forward_ct
2865
- and self.profiler_target_forward_ct <= self.forward_ct
2866
- ):
2867
- self.stop_profile()
2868
- if (
2869
- self.profiler_start_forward_ct
2870
- and self.profiler_start_forward_ct == self.forward_ct
2871
- ):
2872
- self.start_profile()
2873
-
2874
2280
  def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
2875
2281
  if recv_req == ExpertDistributionReq.START_RECORD:
2876
2282
  get_global_expert_distribution_recorder().start_record()
@@ -2879,7 +2285,7 @@ class Scheduler(
2879
2285
  elif recv_req == ExpertDistributionReq.DUMP_RECORD:
2880
2286
  get_global_expert_distribution_recorder().dump_record()
2881
2287
  else:
2882
- raise ValueError("Unrecognized ExpertDistributionReq value")
2288
+ raise ValueError(f"Unrecognized ExpertDistributionReq value: {recv_req=}")
2883
2289
  return ExpertDistributionReqOutput()
2884
2290
 
2885
2291
  def open_session(self, recv_req: OpenSessionReqInput):
@@ -2915,34 +2321,41 @@ class Scheduler(
2915
2321
  prefix += f" PP{self.pp_rank}"
2916
2322
  return prefix
2917
2323
 
2918
- def _publish_kv_events(self):
2919
- if self.enable_kv_cache_events:
2920
- events = self.tree_cache.take_events()
2921
- if events:
2922
- batch = KVEventBatch(ts=time.time(), events=events)
2923
- self.kv_event_publisher.publish(batch)
2324
+ def current_scheduler_metrics_enabled(self):
2325
+ return self.attn_tp_rank == 0 or self.enable_metrics_for_all_schedulers
2924
2326
 
2327
+ def maybe_sleep_on_idle(self):
2328
+ if self.idle_sleeper is not None:
2329
+ self.idle_sleeper.maybe_sleep()
2925
2330
 
2926
- def is_health_check_generate_req(recv_req):
2927
- return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
2928
2331
 
2332
+ class IdleSleeper:
2333
+ """
2334
+ In setups which have long inactivity periods it is desirable to reduce
2335
+ system power consumption when sglang does nothing. This would lead not only
2336
+ to power savings, but also to more CPU thermal headroom when a request
2337
+ eventually comes. This is important in cases when multiple GPUs are connected
2338
+ as each GPU would otherwise pin one thread at 100% CPU usage.
2929
2339
 
2930
- def is_work_request(recv_req):
2931
- return isinstance(recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput))
2340
+ The simplest solution is to use zmq.Poller on all sockets that may receive
2341
+ data that needs handling immediately.
2342
+ """
2932
2343
 
2344
+ def __init__(self, sockets):
2345
+ self.poller = zmq.Poller()
2346
+ for s in sockets:
2347
+ self.poller.register(s, zmq.POLLIN)
2933
2348
 
2934
- def _export_static_state(model):
2935
- return dict(
2936
- buffers=[
2937
- (name, buffer.detach().clone()) for name, buffer in model.named_buffers()
2938
- ]
2939
- )
2349
+ def maybe_sleep(self):
2350
+ self.poller.poll(1000)
2940
2351
 
2941
2352
 
2942
- def _import_static_state(model, static_params):
2943
- self_named_buffers = dict(model.named_buffers())
2944
- for name, tensor in static_params["buffers"]:
2945
- self_named_buffers[name][...] = tensor
2353
+ def is_health_check_generate_req(recv_req):
2354
+ return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
2355
+
2356
+
2357
+ def is_work_request(recv_req):
2358
+ return isinstance(recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput))
2946
2359
 
2947
2360
 
2948
2361
  def run_scheduler_process(
@@ -2950,6 +2363,7 @@ def run_scheduler_process(
2950
2363
  port_args: PortArgs,
2951
2364
  gpu_id: int,
2952
2365
  tp_rank: int,
2366
+ moe_ep_rank: int,
2953
2367
  pp_rank: int,
2954
2368
  dp_rank: Optional[int],
2955
2369
  pipe_writer,
@@ -2960,6 +2374,8 @@ def run_scheduler_process(
2960
2374
  prefix += f" DP{dp_rank}"
2961
2375
  if server_args.tp_size > 1:
2962
2376
  prefix += f" TP{tp_rank}"
2377
+ if server_args.ep_size > 1:
2378
+ prefix += f" EP{moe_ep_rank}"
2963
2379
  if server_args.pp_size > 1:
2964
2380
  prefix += f" PP{pp_rank}"
2965
2381
 
@@ -2983,7 +2399,9 @@ def run_scheduler_process(
2983
2399
 
2984
2400
  # Create a scheduler and run the event loop
2985
2401
  try:
2986
- scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
2402
+ scheduler = Scheduler(
2403
+ server_args, port_args, gpu_id, tp_rank, moe_ep_rank, pp_rank, dp_rank
2404
+ )
2987
2405
  pipe_writer.send(
2988
2406
  {
2989
2407
  "status": "ready",