sglang 0.4.9.post6__py3-none-any.whl → 0.4.10.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. sglang/bench_offline_throughput.py +20 -0
  2. sglang/bench_one_batch.py +3 -0
  3. sglang/srt/configs/__init__.py +8 -0
  4. sglang/srt/configs/model_config.py +4 -0
  5. sglang/srt/configs/step3_vl.py +172 -0
  6. sglang/srt/conversation.py +23 -0
  7. sglang/srt/disaggregation/decode.py +2 -8
  8. sglang/srt/disaggregation/launch_lb.py +5 -20
  9. sglang/srt/disaggregation/mooncake/conn.py +33 -15
  10. sglang/srt/disaggregation/prefill.py +2 -6
  11. sglang/srt/distributed/parallel_state.py +86 -1
  12. sglang/srt/entrypoints/engine.py +14 -18
  13. sglang/srt/entrypoints/http_server.py +10 -2
  14. sglang/srt/entrypoints/openai/serving_chat.py +2 -21
  15. sglang/srt/eplb/expert_distribution.py +5 -0
  16. sglang/srt/eplb/expert_location.py +17 -6
  17. sglang/srt/eplb/expert_location_dispatch.py +1 -0
  18. sglang/srt/eplb/expert_location_updater.py +2 -0
  19. sglang/srt/function_call/function_call_parser.py +2 -0
  20. sglang/srt/function_call/step3_detector.py +436 -0
  21. sglang/srt/hf_transformers_utils.py +2 -0
  22. sglang/srt/jinja_template_utils.py +4 -1
  23. sglang/srt/layers/attention/trtllm_mla_backend.py +372 -0
  24. sglang/srt/layers/attention/utils.py +6 -1
  25. sglang/srt/layers/moe/cutlass_moe.py +2 -1
  26. sglang/srt/layers/moe/ep_moe/layer.py +39 -674
  27. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +26 -13
  28. sglang/srt/layers/moe/fused_moe_triton/layer.py +152 -39
  29. sglang/srt/layers/quantization/fp8.py +52 -18
  30. sglang/srt/layers/quantization/unquant.py +0 -8
  31. sglang/srt/layers/quantization/w4afp8.py +1 -0
  32. sglang/srt/layers/quantization/w8a8_int8.py +4 -1
  33. sglang/srt/managers/cache_controller.py +165 -67
  34. sglang/srt/managers/data_parallel_controller.py +2 -0
  35. sglang/srt/managers/io_struct.py +0 -2
  36. sglang/srt/managers/scheduler.py +90 -671
  37. sglang/srt/managers/scheduler_metrics_mixin.py +229 -0
  38. sglang/srt/managers/scheduler_profiler_mixin.py +279 -0
  39. sglang/srt/managers/scheduler_update_weights_mixin.py +142 -0
  40. sglang/srt/managers/template_manager.py +62 -19
  41. sglang/srt/managers/tokenizer_manager.py +123 -74
  42. sglang/srt/managers/tp_worker.py +4 -0
  43. sglang/srt/managers/tp_worker_overlap_thread.py +2 -1
  44. sglang/srt/mem_cache/hicache_storage.py +60 -17
  45. sglang/srt/mem_cache/hiradix_cache.py +36 -8
  46. sglang/srt/mem_cache/memory_pool.py +15 -118
  47. sglang/srt/mem_cache/memory_pool_host.py +418 -29
  48. sglang/srt/mem_cache/mooncake_store/mooncake_store.py +264 -0
  49. sglang/srt/mem_cache/mooncake_store/unit_test.py +40 -0
  50. sglang/srt/mem_cache/nixl/hicache_nixl.py +163 -0
  51. sglang/srt/mem_cache/nixl/nixl_utils.py +238 -0
  52. sglang/srt/mem_cache/nixl/test_hicache_nixl_storage.py +216 -0
  53. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +183 -0
  54. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +278 -0
  55. sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py +43 -0
  56. sglang/srt/model_executor/cuda_graph_runner.py +25 -1
  57. sglang/srt/model_executor/model_runner.py +13 -1
  58. sglang/srt/model_loader/weight_utils.py +2 -0
  59. sglang/srt/models/arcee.py +532 -0
  60. sglang/srt/models/deepseek_v2.py +7 -6
  61. sglang/srt/models/glm4_moe.py +6 -4
  62. sglang/srt/models/granitemoe.py +3 -0
  63. sglang/srt/models/grok.py +3 -0
  64. sglang/srt/models/hunyuan.py +1 -0
  65. sglang/srt/models/llama4.py +3 -0
  66. sglang/srt/models/mixtral.py +3 -0
  67. sglang/srt/models/olmoe.py +3 -0
  68. sglang/srt/models/phimoe.py +1 -0
  69. sglang/srt/models/step3_vl.py +991 -0
  70. sglang/srt/multimodal/processors/base_processor.py +15 -16
  71. sglang/srt/multimodal/processors/step3_vl.py +515 -0
  72. sglang/srt/reasoning_parser.py +2 -1
  73. sglang/srt/server_args.py +49 -18
  74. sglang/srt/speculative/eagle_worker.py +2 -0
  75. sglang/srt/utils.py +1 -0
  76. sglang/test/attention/test_trtllm_mla_backend.py +945 -0
  77. sglang/utils.py +0 -11
  78. sglang/version.py +1 -1
  79. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/METADATA +3 -4
  80. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/RECORD +83 -65
  81. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/WHEEL +0 -0
  82. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/licenses/LICENSE +0 -0
  83. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/top_level.txt +0 -0
@@ -13,7 +13,6 @@
13
13
  # ==============================================================================
14
14
  """A scheduler that manages a tensor parallel GPU worker."""
15
15
 
16
- import datetime
17
16
  import faulthandler
18
17
  import logging
19
18
  import os
@@ -21,11 +20,10 @@ import signal
21
20
  import sys
22
21
  import threading
23
22
  import time
24
- from collections import defaultdict, deque
23
+ from collections import deque
25
24
  from concurrent import futures
26
25
  from dataclasses import dataclass
27
26
  from http import HTTPStatus
28
- from pathlib import Path
29
27
  from types import SimpleNamespace
30
28
  from typing import Dict, List, Optional, Tuple, Union
31
29
 
@@ -37,7 +35,6 @@ from torch.distributed import barrier
37
35
 
38
36
  from sglang.global_config import global_config
39
37
  from sglang.srt.configs.model_config import ModelConfig
40
- from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
41
38
  from sglang.srt.constrained.base_grammar_backend import (
42
39
  INVALID_GRAMMAR_OBJ,
43
40
  create_grammar_backend,
@@ -47,7 +44,6 @@ from sglang.srt.disaggregation.decode import (
47
44
  DecodeTransferQueue,
48
45
  SchedulerDisaggregationDecodeMixin,
49
46
  )
50
- from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
51
47
  from sglang.srt.disaggregation.prefill import (
52
48
  PrefillBootstrapQueue,
53
49
  SchedulerDisaggregationPrefillMixin,
@@ -78,21 +74,15 @@ from sglang.srt.managers.io_struct import (
78
74
  GetInternalStateReq,
79
75
  GetInternalStateReqOutput,
80
76
  GetWeightsByNameReqInput,
81
- GetWeightsByNameReqOutput,
82
77
  HealthCheckOutput,
83
78
  InitWeightsUpdateGroupReqInput,
84
- InitWeightsUpdateGroupReqOutput,
85
79
  LoadLoRAAdapterReqInput,
86
80
  LoadLoRAAdapterReqOutput,
87
81
  OpenSessionReqInput,
88
82
  OpenSessionReqOutput,
89
83
  ProfileReq,
90
- ProfileReqOutput,
91
- ProfileReqType,
92
84
  ReleaseMemoryOccupationReqInput,
93
- ReleaseMemoryOccupationReqOutput,
94
85
  ResumeMemoryOccupationReqInput,
95
- ResumeMemoryOccupationReqOutput,
96
86
  RpcReqInput,
97
87
  RpcReqOutput,
98
88
  SetInternalStateReq,
@@ -104,11 +94,8 @@ from sglang.srt.managers.io_struct import (
104
94
  UnloadLoRAAdapterReqInput,
105
95
  UnloadLoRAAdapterReqOutput,
106
96
  UpdateWeightFromDiskReqInput,
107
- UpdateWeightFromDiskReqOutput,
108
97
  UpdateWeightsFromDistributedReqInput,
109
- UpdateWeightsFromDistributedReqOutput,
110
98
  UpdateWeightsFromTensorReqInput,
111
- UpdateWeightsFromTensorReqOutput,
112
99
  )
113
100
  from sglang.srt.managers.mm_utils import init_embedding_cache
114
101
  from sglang.srt.managers.schedule_batch import (
@@ -124,9 +111,17 @@ from sglang.srt.managers.schedule_policy import (
124
111
  SchedulePolicy,
125
112
  )
126
113
  from sglang.srt.managers.scheduler_input_blocker import SchedulerInputBlocker
114
+ from sglang.srt.managers.scheduler_metrics_mixin import (
115
+ RECORD_STEP_TIME,
116
+ SchedulerMetricsMixin,
117
+ )
127
118
  from sglang.srt.managers.scheduler_output_processor_mixin import (
128
119
  SchedulerOutputProcessorMixin,
129
120
  )
121
+ from sglang.srt.managers.scheduler_profiler_mixin import SchedulerProfilerMixin
122
+ from sglang.srt.managers.scheduler_update_weights_mixin import (
123
+ SchedulerUpdateWeightsMixin,
124
+ )
130
125
  from sglang.srt.managers.session_controller import Session
131
126
  from sglang.srt.managers.tp_worker import TpModelWorker
132
127
  from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
@@ -135,7 +130,6 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
135
130
  from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
136
131
  from sglang.srt.mem_cache.radix_cache import RadixCache
137
132
  from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
138
- from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
139
133
  from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
140
134
  from sglang.srt.reasoning_parser import ReasoningParser
141
135
  from sglang.srt.server_args import PortArgs, ServerArgs
@@ -168,7 +162,6 @@ logger = logging.getLogger(__name__)
168
162
 
169
163
  # Test retract decode for debugging purposes
170
164
  TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
171
- RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
172
165
  GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
173
166
 
174
167
  _is_cpu = is_cpu()
@@ -191,41 +184,11 @@ class EmbeddingBatchResult:
191
184
  bid: int
192
185
 
193
186
 
194
- class KvMetrics:
195
- def __init__(self):
196
- self.request_active_slots = None
197
- self.request_total_slots = None
198
- self.kv_active_blocks = None
199
- self.kv_total_blocks = None
200
- self.num_requests_waiting = None
201
- self.gpu_cache_usage_perc = None
202
- self.gpu_prefix_cache_hit_rate = None
203
- self.data_parallel_rank = None
204
-
205
-
206
- class IdleSleeper:
207
- """
208
- In setups which have long inactivity periods it is desirable to reduce
209
- system power consumption when sglang does nothing. This would lead not only
210
- to power savings, but also to more CPU thermal headroom when a request
211
- eventually comes. This is important in cases when multiple GPUs are connected
212
- as each GPU would otherwise pin one thread at 100% CPU usage.
213
-
214
- The simplest solution is to use zmq.Poller on all sockets that may receive
215
- data that needs handling immediately.
216
- """
217
-
218
- def __init__(self, sockets):
219
- self.poller = zmq.Poller()
220
- for s in sockets:
221
- self.poller.register(s, zmq.POLLIN)
222
-
223
- def maybe_sleep(self):
224
- self.poller.poll(1000)
225
-
226
-
227
187
  class Scheduler(
228
188
  SchedulerOutputProcessorMixin,
189
+ SchedulerUpdateWeightsMixin,
190
+ SchedulerProfilerMixin,
191
+ SchedulerMetricsMixin,
229
192
  SchedulerDisaggregationDecodeMixin,
230
193
  SchedulerDisaggregationPrefillMixin,
231
194
  ):
@@ -237,15 +200,18 @@ class Scheduler(
237
200
  port_args: PortArgs,
238
201
  gpu_id: int,
239
202
  tp_rank: int,
203
+ moe_ep_rank: int,
240
204
  pp_rank: int,
241
205
  dp_rank: Optional[int],
242
206
  ):
243
207
  # Parse args
244
208
  self.server_args = server_args
245
209
  self.tp_rank = tp_rank
210
+ self.moe_ep_rank = moe_ep_rank
246
211
  self.pp_rank = pp_rank
247
212
  self.dp_rank = dp_rank
248
213
  self.tp_size = server_args.tp_size
214
+ self.moe_ep_size = server_args.ep_size
249
215
  self.pp_size = server_args.pp_size
250
216
  self.dp_size = server_args.dp_size
251
217
  self.schedule_policy = server_args.schedule_policy
@@ -266,7 +232,7 @@ class Scheduler(
266
232
  self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
267
233
  self.enable_hicache_storage = server_args.hicache_storage_backend is not None
268
234
  self.page_size = server_args.page_size
269
- self.dp_size = server_args.dp_size
235
+
270
236
  self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
271
237
  compute_dp_attention_world_info(
272
238
  server_args.enable_dp_attention,
@@ -284,10 +250,13 @@ class Scheduler(
284
250
  self.recv_from_tokenizer = get_zmq_socket(
285
251
  context, zmq.PULL, port_args.scheduler_input_ipc_name, False
286
252
  )
253
+ self.recv_from_rpc = get_zmq_socket(
254
+ context, zmq.DEALER, port_args.rpc_ipc_name, False
255
+ )
256
+
287
257
  self.send_to_tokenizer = get_zmq_socket(
288
258
  context, zmq.PUSH, port_args.tokenizer_ipc_name, False
289
259
  )
290
-
291
260
  if server_args.skip_tokenizer_init:
292
261
  # Directly send to the TokenizerManager
293
262
  self.send_to_detokenizer = get_zmq_socket(
@@ -299,9 +268,6 @@ class Scheduler(
299
268
  context, zmq.PUSH, port_args.detokenizer_ipc_name, False
300
269
  )
301
270
 
302
- self.recv_from_rpc = get_zmq_socket(
303
- context, zmq.DEALER, port_args.rpc_ipc_name, False
304
- )
305
271
  if self.server_args.sleep_on_idle:
306
272
  self.idle_sleeper = IdleSleeper(
307
273
  [
@@ -347,6 +313,7 @@ class Scheduler(
347
313
  server_args=server_args,
348
314
  gpu_id=gpu_id,
349
315
  tp_rank=tp_rank,
316
+ moe_ep_rank=moe_ep_rank,
350
317
  pp_rank=pp_rank,
351
318
  dp_rank=dp_rank,
352
319
  nccl_port=port_args.nccl_port,
@@ -359,6 +326,7 @@ class Scheduler(
359
326
  self.draft_worker = EAGLEWorker(
360
327
  gpu_id=gpu_id,
361
328
  tp_rank=tp_rank,
329
+ moe_ep_rank=moe_ep_rank,
362
330
  server_args=server_args,
363
331
  nccl_port=port_args.nccl_port,
364
332
  target_worker=self.tp_worker,
@@ -398,7 +366,7 @@ class Scheduler(
398
366
  global_server_args_dict.update(worker_global_server_args_dict)
399
367
  set_random_seed(self.random_seed)
400
368
 
401
- # Hybrid
369
+ # Hybrid memory pool
402
370
  self.is_hybrid = self.tp_worker.is_hybrid
403
371
  if self.is_hybrid:
404
372
  self.sliding_window_size = self.tp_worker.sliding_window_size
@@ -515,6 +483,15 @@ class Scheduler(
515
483
  self.init_metrics(tp_rank, pp_rank, dp_rank)
516
484
  self.init_kv_events(server_args.kv_events_config)
517
485
 
486
+ # Init disaggregation
487
+ self.disaggregation_mode = DisaggregationMode(
488
+ self.server_args.disaggregation_mode
489
+ )
490
+ self.init_disaggregation()
491
+
492
+ if get_bool_env_var("SGLANG_GC_LOG"):
493
+ configure_gc_logger()
494
+
518
495
  # Init request dispatcher
519
496
  self._request_dispatcher = TypeBasedDispatcher(
520
497
  [
@@ -545,22 +522,6 @@ class Scheduler(
545
522
  ]
546
523
  )
547
524
 
548
- # Init disaggregation
549
- self.disaggregation_mode = DisaggregationMode(
550
- self.server_args.disaggregation_mode
551
- )
552
- self.init_disaggregation()
553
-
554
- if get_bool_env_var("SGLANG_GC_LOG"):
555
- configure_gc_logger()
556
-
557
- def current_scheduler_metrics_enabled(self):
558
- return self.attn_tp_rank == 0 or self.enable_metrics_for_all_schedulers
559
-
560
- def maybe_sleep_on_idle(self):
561
- if self.idle_sleeper is not None:
562
- self.idle_sleeper.maybe_sleep()
563
-
564
525
  def init_tokenizer(self):
565
526
  server_args = self.server_args
566
527
 
@@ -627,6 +588,7 @@ class Scheduler(
627
588
  == "fa3" # hot fix for incompatibility
628
589
  else server_args.hicache_io_backend
629
590
  ),
591
+ hicache_mem_layout=server_args.hicache_mem_layout,
630
592
  hicache_storage_backend=server_args.hicache_storage_backend,
631
593
  )
632
594
  self.tp_worker.register_hicache_layer_transfer_counter(
@@ -668,50 +630,6 @@ class Scheduler(
668
630
  embedding_cache_size = int(os.environ.get("SGLANG_VLM_CACHE_SIZE_MB", "100"))
669
631
  init_embedding_cache(embedding_cache_size * 1024 * 1024)
670
632
 
671
- def init_profier(self):
672
- self.torch_profiler = None
673
- self.torch_profiler_output_dir: Optional[str] = None
674
- self.profiler_activities: Optional[List[str]] = None
675
- self.profile_id: Optional[str] = None
676
- self.profiler_start_forward_ct: Optional[int] = None
677
- self.profiler_target_forward_ct: Optional[int] = None
678
- self.profiler_target_prefill_ct: Optional[int] = None
679
- self.profiler_target_decode_ct: Optional[int] = None
680
- self.profiler_prefill_ct: Optional[int] = None
681
- self.profiler_decode_ct: Optional[int] = None
682
- self.profile_by_stage: bool = False
683
- self.profile_steps: Optional[int] = None
684
- self.profile_in_progress: bool = False
685
- self.rpd_profiler = None
686
-
687
- def init_metrics(self, tp_rank: int, pp_rank: int, dp_rank: Optional[int]):
688
- self.last_gen_throughput: float = 0.0
689
- self.last_input_throughput: float = 0.0
690
- self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
691
- self.spec_num_total_accepted_tokens = 0
692
- self.spec_num_total_forward_ct = 0
693
- self.cum_spec_accept_length = 0
694
- self.cum_spec_accept_count = 0
695
- self.total_retracted_reqs = 0
696
- self.stats = SchedulerStats()
697
- if self.enable_metrics:
698
- engine_type = "unified"
699
- labels = {
700
- "model_name": self.server_args.served_model_name,
701
- "engine_type": engine_type,
702
- "tp_rank": tp_rank,
703
- "pp_rank": pp_rank,
704
- }
705
- if dp_rank is not None:
706
- labels["dp_rank"] = dp_rank
707
- self.metrics_collector = SchedulerMetricsCollector(labels=labels)
708
-
709
- def init_kv_events(self, kv_events_config: Optional[str]):
710
- if self.enable_kv_cache_events:
711
- self.kv_event_publisher = EventPublisherFactory.create(
712
- kv_events_config, self.attn_dp_rank
713
- )
714
-
715
633
  def init_disaggregation(self):
716
634
  self.transfer_backend = TransferBackend(
717
635
  self.server_args.disaggregation_transfer_backend
@@ -820,10 +738,7 @@ class Scheduler(
820
738
  self.process_batch_result(batch, result)
821
739
  else:
822
740
  # When the server is idle, do self-check and re-init some states
823
- self.check_memory()
824
- self.check_tree_cache()
825
- self.new_token_ratio = self.init_new_token_ratio
826
- self.maybe_sleep_on_idle()
741
+ self.self_check_during_idle()
827
742
 
828
743
  self.last_batch = batch
829
744
 
@@ -866,10 +781,7 @@ class Scheduler(
866
781
  )
867
782
  elif batch is None:
868
783
  # When the server is idle, do self-check and re-init some states
869
- self.check_memory()
870
- self.check_tree_cache()
871
- self.new_token_ratio = self.init_new_token_ratio
872
- self.maybe_sleep_on_idle()
784
+ self.self_check_during_idle()
873
785
 
874
786
  self.last_batch = batch
875
787
 
@@ -1003,10 +915,8 @@ class Scheduler(
1003
915
 
1004
916
  # When the server is idle, self-check and re-init some states
1005
917
  if server_is_idle:
1006
- self.check_memory()
1007
- self.check_tree_cache()
1008
- self.new_token_ratio = self.init_new_token_ratio
1009
- self.maybe_sleep_on_idle()
918
+ # When the server is idle, do self-check and re-init some states
919
+ self.self_check_during_idle()
1010
920
 
1011
921
  def recv_requests(self) -> List[Req]:
1012
922
  """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
@@ -1281,23 +1191,28 @@ class Scheduler(
1281
1191
  def _add_request_to_queue(self, req: Req):
1282
1192
  req.queue_time_start = time.perf_counter()
1283
1193
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
1194
+ self._prefetch_kvcache(req)
1284
1195
  self.disagg_prefill_bootstrap_queue.add(
1285
1196
  req, self.model_config.num_key_value_heads
1286
1197
  )
1287
1198
  elif self.disaggregation_mode == DisaggregationMode.DECODE:
1288
1199
  self.disagg_decode_prealloc_queue.add(req)
1289
1200
  else:
1290
- if self.enable_hicache_storage:
1291
- req.init_next_round_input(self.tree_cache)
1292
- last_hash = req.last_host_node.get_last_hash_value()
1293
- matched_len = len(req.prefix_indices) + req.host_hit_length
1294
- if (matched_len > 0 and last_hash is not None) or matched_len == 0:
1295
- new_input_tokens = req.fill_ids[matched_len:]
1296
- self.tree_cache.prefetch_from_storage(
1297
- req.rid, req.last_host_node, new_input_tokens, last_hash
1298
- )
1201
+ self._prefetch_kvcache(req)
1299
1202
  self.waiting_queue.append(req)
1300
1203
 
1204
+ def _prefetch_kvcache(self, req: Req):
1205
+ if self.enable_hicache_storage:
1206
+ req.init_next_round_input(self.tree_cache)
1207
+ last_hash = req.last_host_node.get_last_hash_value()
1208
+ matched_len = len(req.prefix_indices) + req.host_hit_length
1209
+ # todo, free-form fetching, calculating hash keys on the fly
1210
+ if (matched_len > 0 and last_hash is not None) or matched_len == 0:
1211
+ new_input_tokens = req.fill_ids[matched_len:]
1212
+ self.tree_cache.prefetch_from_storage(
1213
+ req.rid, req.last_host_node, new_input_tokens, last_hash
1214
+ )
1215
+
1301
1216
  def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
1302
1217
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
1303
1218
  self.disagg_prefill_bootstrap_queue.extend(
@@ -1355,170 +1270,11 @@ class Scheduler(
1355
1270
  req.logprob_start_len = len(req.origin_input_ids) - 1
1356
1271
  self._add_request_to_queue(req)
1357
1272
 
1358
- def _emit_kv_metrics(self):
1359
- kv_metrics = KvMetrics()
1360
- kv_metrics.request_active_slots = self.stats.num_running_reqs
1361
- kv_metrics.request_total_slots = self.max_running_requests
1362
- kv_metrics.kv_active_blocks = int(
1363
- self.stats.token_usage * self.max_total_num_tokens
1364
- )
1365
- kv_metrics.kv_total_blocks = self.max_total_num_tokens
1366
- kv_metrics.num_requests_waiting = self.stats.num_queue_reqs
1367
- kv_metrics.gpu_cache_usage_perc = self.stats.token_usage
1368
- kv_metrics.gpu_prefix_cache_hit_rate = self.stats.cache_hit_rate
1369
- kv_metrics.data_parallel_rank = self.dp_rank if self.dp_rank is not None else 0
1370
-
1371
- if not self.send_metrics_from_scheduler.closed:
1372
- self.send_metrics_from_scheduler.send_pyobj(kv_metrics)
1373
-
1374
- def log_prefill_stats(
1375
- self,
1376
- adder: PrefillAdder,
1377
- can_run_list: List[Req],
1378
- running_bs: int,
1379
- ):
1380
- gap_latency = time.perf_counter() - self.last_prefill_stats_tic
1381
- self.last_prefill_stats_tic = time.perf_counter()
1382
- self.last_input_throughput = self.last_prefill_tokens / gap_latency
1383
- self.last_prefill_tokens = adder.log_input_tokens
1384
-
1385
- if self.is_hybrid:
1386
- (
1387
- full_num_used,
1388
- swa_num_used,
1389
- full_token_usage,
1390
- swa_token_usage,
1391
- _,
1392
- _,
1393
- _,
1394
- _,
1395
- ) = self._get_swa_token_info()
1396
- num_used = max(full_num_used, swa_num_used)
1397
- token_usage = max(full_token_usage, swa_token_usage)
1398
- token_msg = (
1399
- f"full token usage: {full_token_usage:.2f}, "
1400
- f"swa token usage: {swa_token_usage:.2f}, "
1401
- )
1402
- else:
1403
- num_used, token_usage, _, _ = self._get_token_info()
1404
- token_msg = f"token usage: {token_usage:.2f}, "
1405
-
1406
- num_new_seq = len(can_run_list)
1407
- f = (
1408
- f"Prefill batch. "
1409
- f"#new-seq: {num_new_seq}, "
1410
- f"#new-token: {adder.log_input_tokens}, "
1411
- f"#cached-token: {adder.log_hit_tokens}, "
1412
- f"{token_msg}"
1413
- )
1414
-
1415
- if self.disaggregation_mode == DisaggregationMode.PREFILL:
1416
- f += f"#unbootstrapped-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
1417
- f += f"#queue-req: {len(self.waiting_queue)}, "
1418
- f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)}, "
1419
- f += f"input throughput (token/s): {self.last_input_throughput:.2f}, "
1420
- else:
1421
- f += f"#running-req: {running_bs}, "
1422
- f += f"#queue-req: {len(self.waiting_queue)}, "
1423
-
1424
- logger.info(f)
1425
-
1426
- if self.enable_metrics:
1427
- total_tokens = adder.log_input_tokens + adder.log_hit_tokens
1428
-
1429
- cache_hit_rate = (
1430
- adder.log_hit_tokens / total_tokens if total_tokens > 0 else 0.0
1431
- )
1432
- self.stats.num_running_reqs = running_bs
1433
- self.stats.num_used_tokens = num_used
1434
- self.stats.token_usage = round(token_usage, 2)
1435
- self.stats.num_queue_reqs = len(self.waiting_queue)
1436
- self.stats.cache_hit_rate = cache_hit_rate
1437
-
1438
- total_queue_latency = 0
1439
- for req in can_run_list:
1440
- total_queue_latency += req.queue_time_end - req.queue_time_start
1441
- self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq
1442
-
1443
- self.metrics_collector.log_stats(self.stats)
1444
- self._emit_kv_metrics()
1445
- self._publish_kv_events()
1446
-
1447
- def log_decode_stats(
1448
- self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
1449
- ):
1450
- batch = running_batch or self.running_batch
1451
-
1452
- gap_latency = time.perf_counter() - self.last_decode_stats_tic
1453
- self.last_decode_stats_tic = time.perf_counter()
1454
- self.last_gen_throughput = self.num_generated_tokens / gap_latency
1455
- self.num_generated_tokens = 0
1456
- num_running_reqs = len(batch.reqs)
1457
- if self.is_hybrid:
1458
- (
1459
- full_num_used,
1460
- swa_num_used,
1461
- full_token_usage,
1462
- swa_token_usage,
1463
- _,
1464
- _,
1465
- _,
1466
- _,
1467
- ) = self._get_swa_token_info()
1468
- num_used = max(full_num_used, swa_num_used)
1469
- token_usage = max(full_token_usage, swa_token_usage)
1470
- token_msg = (
1471
- f"#full token: {full_num_used}, "
1472
- f"full token usage: {full_token_usage:.2f}, "
1473
- f"#swa token: {swa_num_used}, "
1474
- f"swa token usage: {swa_token_usage:.2f}, "
1475
- )
1476
- else:
1477
- num_used, token_usage, _, _ = self._get_token_info()
1478
- token_msg = f"#token: {num_used}, " f"token usage: {token_usage:.2f}, "
1479
-
1480
- if RECORD_STEP_TIME:
1481
- self.step_time_dict[num_running_reqs].append(
1482
- gap_latency / self.server_args.decode_log_interval
1483
- )
1484
-
1485
- msg = f"Decode batch. #running-req: {num_running_reqs}, {token_msg}"
1486
-
1487
- if self.spec_algorithm.is_none():
1488
- spec_accept_length = 0
1489
- else:
1490
- spec_accept_length = (
1491
- self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
1492
- )
1493
- self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
1494
- self.cum_spec_accept_count += self.spec_num_total_forward_ct
1495
- self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
1496
- msg += f"accept len: {spec_accept_length:.2f}, "
1497
-
1498
- if self.disaggregation_mode == DisaggregationMode.DECODE:
1499
- msg += f"pre-allocated usage: {self.disagg_decode_prealloc_queue.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
1500
- msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, "
1501
-
1502
- msg += (
1503
- f"cuda graph: {can_run_cuda_graph}, "
1504
- f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
1505
- f"#queue-req: {len(self.waiting_queue)}, "
1506
- )
1507
-
1508
- logger.info(msg)
1509
- if self.enable_metrics:
1510
- self.stats.num_running_reqs = num_running_reqs
1511
- self.stats.num_used_tokens = num_used
1512
- self.stats.token_usage = round(token_usage, 2)
1513
- self.stats.cache_hit_rate = 0.0
1514
- self.stats.gen_throughput = self.last_gen_throughput
1515
- self.stats.num_queue_reqs = len(self.waiting_queue)
1516
- self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
1517
- self.stats.spec_accept_length = spec_accept_length
1518
- self.stats.total_retracted_reqs = self.total_retracted_reqs
1519
- self.metrics_collector.log_stats(self.stats)
1520
- self._emit_kv_metrics()
1521
- self._publish_kv_events()
1273
+ def self_check_during_idle(self):
1274
+ self.check_memory()
1275
+ self.check_tree_cache()
1276
+ self.new_token_ratio = self.init_new_token_ratio
1277
+ self.maybe_sleep_on_idle()
1522
1278
 
1523
1279
  def check_memory(self):
1524
1280
  if self.is_hybrid:
@@ -2422,22 +2178,6 @@ class Scheduler(
2422
2178
  barrier()
2423
2179
  return RpcReqOutput(success, "" if not exec else str(exec))
2424
2180
 
2425
- def save_remote_model(self, params):
2426
- url = params["url"]
2427
-
2428
- worker = self.tp_worker.worker
2429
-
2430
- worker.model_runner.save_remote_model(url)
2431
-
2432
- def save_sharded_model(self, params):
2433
- worker = self.tp_worker.worker
2434
-
2435
- worker.model_runner.save_sharded_model(
2436
- path=params["path"],
2437
- pattern=params["pattern"],
2438
- max_size=params["max_size"],
2439
- )
2440
-
2441
2181
  def abort_request(self, recv_req: AbortReq):
2442
2182
  # Delete requests in the waiting queue
2443
2183
  to_del = []
@@ -2515,16 +2255,6 @@ class Scheduler(
2515
2255
  def _pause_engine(self) -> Tuple[List[Req], int]:
2516
2256
  raise NotImplementedError()
2517
2257
 
2518
- def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
2519
- """In-place update of the weights from disk."""
2520
- success, message = self.tp_worker.update_weights_from_disk(recv_req)
2521
- if success:
2522
- flush_cache_success = self.flush_cache()
2523
- assert flush_cache_success, "Cache flush failed after updating weights"
2524
- else:
2525
- logger.error(message)
2526
- return UpdateWeightFromDiskReqOutput(success, message, 0)
2527
-
2528
2258
  def load_lora_adapter(
2529
2259
  self, recv_req: LoadLoRAAdapterReqInput
2530
2260
  ) -> LoadLoRAAdapterReqOutput:
@@ -2541,81 +2271,6 @@ class Scheduler(
2541
2271
  result = self.tp_worker.unload_lora_adapter(recv_req)
2542
2272
  return result
2543
2273
 
2544
- def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
2545
- """Initialize the online model parameter update group."""
2546
- success, message = self.tp_worker.init_weights_update_group(recv_req)
2547
- return InitWeightsUpdateGroupReqOutput(success, message)
2548
-
2549
- def update_weights_from_distributed(
2550
- self,
2551
- recv_req: UpdateWeightsFromDistributedReqInput,
2552
- ) -> Tuple[bool, str]:
2553
- """Update the online model parameter."""
2554
- success, message = self.tp_worker.update_weights_from_distributed(recv_req)
2555
- if success:
2556
- if recv_req.flush_cache:
2557
- flush_cache_success = self.flush_cache()
2558
- assert flush_cache_success, "Cache flush failed after updating weights"
2559
- else:
2560
- logger.error(message)
2561
- return UpdateWeightsFromDistributedReqOutput(success, message)
2562
-
2563
- def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
2564
- """Update the online model parameter from tensors."""
2565
- success, message = self.tp_worker.update_weights_from_tensor(recv_req)
2566
- # TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
2567
- if success:
2568
- if recv_req.flush_cache:
2569
- flush_cache_success = self.flush_cache()
2570
- assert flush_cache_success, "Cache flush failed after updating weights"
2571
- else:
2572
- logger.error(message)
2573
- barrier(group=self.tp_cpu_group)
2574
- return UpdateWeightsFromTensorReqOutput(success, message)
2575
-
2576
- def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
2577
- parameter = self.tp_worker.get_weights_by_name(recv_req)
2578
- return GetWeightsByNameReqOutput(parameter)
2579
-
2580
- def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
2581
- tags = recv_req.tags
2582
-
2583
- if tags is None or len(tags) == 0:
2584
- tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
2585
-
2586
- if GPU_MEMORY_TYPE_KV_CACHE in tags:
2587
- self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE)
2588
- self.flush_cache()
2589
-
2590
- if GPU_MEMORY_TYPE_WEIGHTS in tags:
2591
- self.stashed_model_static_state = _export_static_state(
2592
- self.tp_worker.worker.model_runner.model
2593
- )
2594
- torch.distributed.barrier(self.tp_cpu_group)
2595
- self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_WEIGHTS)
2596
-
2597
- return ReleaseMemoryOccupationReqOutput()
2598
-
2599
- def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
2600
- tags = recv_req.tags
2601
-
2602
- if tags is None or len(tags) == 0:
2603
- tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
2604
-
2605
- if GPU_MEMORY_TYPE_WEIGHTS in tags:
2606
- self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
2607
- torch.distributed.barrier(self.tp_cpu_group)
2608
- _import_static_state(
2609
- self.tp_worker.worker.model_runner.model,
2610
- self.stashed_model_static_state,
2611
- )
2612
- del self.stashed_model_static_state
2613
-
2614
- if GPU_MEMORY_TYPE_KV_CACHE in tags:
2615
- self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_KV_CACHE)
2616
-
2617
- return ResumeMemoryOccupationReqOutput()
2618
-
2619
2274
  def slow_down(self, recv_req: SlowDownReqInput):
2620
2275
  t = recv_req.forward_sleep_time
2621
2276
  if t is not None and t <= 0:
@@ -2623,254 +2278,6 @@ class Scheduler(
2623
2278
  self.forward_sleep_time = t
2624
2279
  return SlowDownReqOutput()
2625
2280
 
2626
- def profile(self, recv_req: ProfileReq):
2627
- if recv_req.type == ProfileReqType.START_PROFILE:
2628
- if recv_req.profile_by_stage or recv_req.start_step:
2629
- return self.init_profile(
2630
- recv_req.output_dir,
2631
- recv_req.start_step,
2632
- recv_req.num_steps,
2633
- recv_req.activities,
2634
- recv_req.with_stack,
2635
- recv_req.record_shapes,
2636
- recv_req.profile_by_stage,
2637
- recv_req.profile_id,
2638
- )
2639
- else:
2640
- self.init_profile(
2641
- recv_req.output_dir,
2642
- recv_req.start_step,
2643
- recv_req.num_steps,
2644
- recv_req.activities,
2645
- recv_req.with_stack,
2646
- recv_req.record_shapes,
2647
- recv_req.profile_by_stage,
2648
- recv_req.profile_id,
2649
- )
2650
- return self.start_profile(True)
2651
- else:
2652
- return self.stop_profile()
2653
-
2654
- def init_profile(
2655
- self,
2656
- output_dir: Optional[str],
2657
- start_step: Optional[int],
2658
- num_steps: Optional[int],
2659
- activities: Optional[List[str]],
2660
- with_stack: Optional[bool],
2661
- record_shapes: Optional[bool],
2662
- profile_by_stage: bool,
2663
- profile_id: str,
2664
- ) -> ProfileReqOutput:
2665
- if self.profile_in_progress:
2666
- return ProfileReqOutput(
2667
- success=False,
2668
- message="Profiling is already in progress. Call /stop_profile first.",
2669
- )
2670
-
2671
- self.profile_by_stage = profile_by_stage
2672
-
2673
- if output_dir is None:
2674
- output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp")
2675
- if activities is None:
2676
- activities = ["CPU", "GPU"]
2677
-
2678
- self.torch_profiler_output_dir = output_dir
2679
- self.torch_profiler_with_stack = with_stack
2680
- self.torch_profiler_record_shapes = record_shapes
2681
- self.profiler_activities = activities
2682
- self.profile_id = profile_id
2683
-
2684
- if start_step:
2685
- self.profiler_start_forward_ct = max(start_step, self.forward_ct + 1)
2686
-
2687
- if num_steps:
2688
- self.profile_steps = num_steps
2689
- if self.profile_by_stage:
2690
- self.profiler_target_prefill_ct = num_steps
2691
- self.profiler_target_decode_ct = num_steps
2692
- self.profiler_prefill_ct = 0
2693
- self.profiler_decode_ct = 0
2694
- elif start_step:
2695
- self.profiler_target_forward_ct = (
2696
- self.profiler_start_forward_ct + num_steps
2697
- )
2698
- else:
2699
- self.profiler_target_forward_ct = self.forward_ct + num_steps
2700
- # The caller will be notified when reaching profiler_target_forward_ct
2701
- else:
2702
- self.profiler_target_forward_ct = None
2703
-
2704
- return ProfileReqOutput(success=True, message="Succeeded")
2705
-
2706
- def start_profile(
2707
- self, stage: Optional[ForwardMode] = None
2708
- ) -> ProfileReqOutput | None:
2709
- stage_str = f" for {stage.__str__()}" if stage else ""
2710
- logger.info(
2711
- f"Profiling starts{stage_str}. Traces will be saved to: {self.torch_profiler_output_dir} (with profile id: {self.profile_id})",
2712
- )
2713
-
2714
- activities = self.profiler_activities
2715
- with_stack = self.torch_profiler_with_stack
2716
- record_shapes = self.torch_profiler_record_shapes
2717
-
2718
- activity_map = {
2719
- "CPU": torch.profiler.ProfilerActivity.CPU,
2720
- "GPU": torch.profiler.ProfilerActivity.CUDA,
2721
- }
2722
- torchprof_activities = [
2723
- activity_map[a] for a in activities if a in activity_map
2724
- ]
2725
-
2726
- if "RPD" in activities:
2727
- from rpdTracerControl import rpdTracerControl
2728
-
2729
- rpdTracerControl.skipCreate()
2730
-
2731
- self.rpd_profile_path = os.path.join(
2732
- self.torch_profiler_output_dir,
2733
- "rpd-" + str(time.time()) + f"-TP-{self.tp_rank}" + ".trace.json.gz",
2734
- )
2735
-
2736
- if self.tp_rank == 0:
2737
- import sqlite3
2738
-
2739
- from rocpd.schema import RocpdSchema
2740
-
2741
- if os.path.exists("trace.rpd"):
2742
- os.unlink("trace.rpd")
2743
- schema = RocpdSchema()
2744
- connection = sqlite3.connect("trace.rpd")
2745
- schema.writeSchema(connection)
2746
- connection.commit()
2747
- del connection
2748
- torch.distributed.barrier(self.tp_cpu_group)
2749
-
2750
- self.rpd_profiler = rpdTracerControl()
2751
- self.rpd_profiler.setPythonTrace(True)
2752
- self.rpd_profiler.start()
2753
- self.rpd_profiler.rangePush("", "rpd profile range", "")
2754
- self.profile_in_progress = True
2755
- elif torchprof_activities:
2756
- self.torch_profiler = torch.profiler.profile(
2757
- activities=torchprof_activities,
2758
- with_stack=with_stack if with_stack is not None else True,
2759
- record_shapes=record_shapes if record_shapes is not None else False,
2760
- )
2761
- self.torch_profiler.start()
2762
- self.profile_in_progress = True
2763
-
2764
- if "MEM" in activities:
2765
- torch.cuda.memory._record_memory_history(max_entries=100000)
2766
- self.profile_in_progress = True
2767
-
2768
- if "CUDA_PROFILER" in activities:
2769
- torch.cuda.cudart().cudaProfilerStart()
2770
- self.profile_in_progress = True
2771
-
2772
- return ProfileReqOutput(success=True, message="Succeeded")
2773
-
2774
- def stop_profile(
2775
- self, stage: Optional[ForwardMode] = None
2776
- ) -> ProfileReqOutput | None:
2777
- if not self.profile_in_progress:
2778
- return ProfileReqOutput(
2779
- success=False,
2780
- message="Profiling is not in progress. Call /start_profile first.",
2781
- )
2782
-
2783
- if not Path(self.torch_profiler_output_dir).exists():
2784
- Path(self.torch_profiler_output_dir).mkdir(parents=True, exist_ok=True)
2785
-
2786
- stage_suffix = f"-{stage.__str__()}" if stage else ""
2787
- logger.info("Stop profiling" + stage_suffix + "...")
2788
- if self.torch_profiler is not None:
2789
- self.torch_profiler.stop()
2790
- self.torch_profiler.export_chrome_trace(
2791
- os.path.join(
2792
- self.torch_profiler_output_dir,
2793
- self.profile_id
2794
- + f"-TP-{self.tp_rank}"
2795
- + stage_suffix
2796
- + ".trace.json.gz",
2797
- )
2798
- )
2799
- torch.distributed.barrier(self.tp_cpu_group)
2800
-
2801
- if self.rpd_profiler is not None:
2802
- self.rpd_profiler.rangePop()
2803
- self.rpd_profiler.stop()
2804
- self.rpd_profiler.flush()
2805
-
2806
- torch.distributed.barrier(self.tp_cpu_group)
2807
- if self.tp_rank == 0:
2808
- from sglang.srt.utils import rpd_to_chrome_trace
2809
-
2810
- rpd_to_chrome_trace("trace.rpd", self.rpd_profile_path)
2811
- self.rpd_profiler = None
2812
- self.rpd_profiler_path = None
2813
-
2814
- if self.profiler_activities is not None and "MEM" in self.profiler_activities:
2815
- memory_profile_path = os.path.join(
2816
- self.torch_profiler_output_dir,
2817
- str(time.time())
2818
- + f"-TP-{self.tp_rank}-memory"
2819
- + stage_suffix
2820
- + ".pickle",
2821
- )
2822
- torch.cuda.memory._dump_snapshot(memory_profile_path)
2823
- torch.cuda.memory._record_memory_history(enabled=None)
2824
-
2825
- if "CUDA_PROFILER" in self.profiler_activities:
2826
- torch.cuda.cudart().cudaProfilerStop()
2827
-
2828
- logger.info(
2829
- "Profiling done. Traces are saved to: %s",
2830
- self.torch_profiler_output_dir,
2831
- )
2832
- self.torch_profiler = None
2833
- self.profile_in_progress = False
2834
- self.profiler_start_forward_ct = None
2835
-
2836
- return ProfileReqOutput(success=True, message="Succeeded.")
2837
-
2838
- def _profile_batch_predicate(self, batch):
2839
- if self.profile_by_stage:
2840
- if batch.forward_mode.is_prefill():
2841
- if self.profiler_prefill_ct == 0:
2842
- self.start_profile(batch.forward_mode)
2843
- self.profiler_prefill_ct += 1
2844
- if self.profiler_prefill_ct > self.profiler_target_prefill_ct:
2845
- if self.profile_in_progress:
2846
- self.stop_profile(stage=ForwardMode.EXTEND)
2847
- elif batch.forward_mode.is_decode():
2848
- if self.profiler_decode_ct == 0:
2849
- if self.profile_in_progress:
2850
- # force trace flush
2851
- self.stop_profile(ForwardMode.EXTEND)
2852
- self.start_profile(batch.forward_mode)
2853
- self.profiler_decode_ct += 1
2854
- if self.profiler_decode_ct > self.profiler_target_decode_ct:
2855
- if self.profile_in_progress:
2856
- self.stop_profile(stage=ForwardMode.DECODE)
2857
- elif batch.forward_mode.is_idle():
2858
- pass
2859
- else:
2860
- raise RuntimeError(f"unsupported profile stage: {batch.forward_mode}")
2861
- else:
2862
- # Check profiler
2863
- if (
2864
- self.profiler_target_forward_ct
2865
- and self.profiler_target_forward_ct <= self.forward_ct
2866
- ):
2867
- self.stop_profile()
2868
- if (
2869
- self.profiler_start_forward_ct
2870
- and self.profiler_start_forward_ct == self.forward_ct
2871
- ):
2872
- self.start_profile()
2873
-
2874
2281
  def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
2875
2282
  if recv_req == ExpertDistributionReq.START_RECORD:
2876
2283
  get_global_expert_distribution_recorder().start_record()
@@ -2879,7 +2286,7 @@ class Scheduler(
2879
2286
  elif recv_req == ExpertDistributionReq.DUMP_RECORD:
2880
2287
  get_global_expert_distribution_recorder().dump_record()
2881
2288
  else:
2882
- raise ValueError("Unrecognized ExpertDistributionReq value")
2289
+ raise ValueError(f"Unrecognized ExpertDistributionReq value: {recv_req=}")
2883
2290
  return ExpertDistributionReqOutput()
2884
2291
 
2885
2292
  def open_session(self, recv_req: OpenSessionReqInput):
@@ -2915,34 +2322,41 @@ class Scheduler(
2915
2322
  prefix += f" PP{self.pp_rank}"
2916
2323
  return prefix
2917
2324
 
2918
- def _publish_kv_events(self):
2919
- if self.enable_kv_cache_events:
2920
- events = self.tree_cache.take_events()
2921
- if events:
2922
- batch = KVEventBatch(ts=time.time(), events=events)
2923
- self.kv_event_publisher.publish(batch)
2325
+ def current_scheduler_metrics_enabled(self):
2326
+ return self.attn_tp_rank == 0 or self.enable_metrics_for_all_schedulers
2924
2327
 
2328
+ def maybe_sleep_on_idle(self):
2329
+ if self.idle_sleeper is not None:
2330
+ self.idle_sleeper.maybe_sleep()
2925
2331
 
2926
- def is_health_check_generate_req(recv_req):
2927
- return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
2928
2332
 
2333
+ class IdleSleeper:
2334
+ """
2335
+ In setups which have long inactivity periods it is desirable to reduce
2336
+ system power consumption when sglang does nothing. This would lead not only
2337
+ to power savings, but also to more CPU thermal headroom when a request
2338
+ eventually comes. This is important in cases when multiple GPUs are connected
2339
+ as each GPU would otherwise pin one thread at 100% CPU usage.
2929
2340
 
2930
- def is_work_request(recv_req):
2931
- return isinstance(recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput))
2341
+ The simplest solution is to use zmq.Poller on all sockets that may receive
2342
+ data that needs handling immediately.
2343
+ """
2932
2344
 
2345
+ def __init__(self, sockets):
2346
+ self.poller = zmq.Poller()
2347
+ for s in sockets:
2348
+ self.poller.register(s, zmq.POLLIN)
2933
2349
 
2934
- def _export_static_state(model):
2935
- return dict(
2936
- buffers=[
2937
- (name, buffer.detach().clone()) for name, buffer in model.named_buffers()
2938
- ]
2939
- )
2350
+ def maybe_sleep(self):
2351
+ self.poller.poll(1000)
2940
2352
 
2941
2353
 
2942
- def _import_static_state(model, static_params):
2943
- self_named_buffers = dict(model.named_buffers())
2944
- for name, tensor in static_params["buffers"]:
2945
- self_named_buffers[name][...] = tensor
2354
+ def is_health_check_generate_req(recv_req):
2355
+ return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
2356
+
2357
+
2358
+ def is_work_request(recv_req):
2359
+ return isinstance(recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput))
2946
2360
 
2947
2361
 
2948
2362
  def run_scheduler_process(
@@ -2950,6 +2364,7 @@ def run_scheduler_process(
2950
2364
  port_args: PortArgs,
2951
2365
  gpu_id: int,
2952
2366
  tp_rank: int,
2367
+ moe_ep_rank: int,
2953
2368
  pp_rank: int,
2954
2369
  dp_rank: Optional[int],
2955
2370
  pipe_writer,
@@ -2960,6 +2375,8 @@ def run_scheduler_process(
2960
2375
  prefix += f" DP{dp_rank}"
2961
2376
  if server_args.tp_size > 1:
2962
2377
  prefix += f" TP{tp_rank}"
2378
+ if server_args.ep_size > 1:
2379
+ prefix += f" EP{moe_ep_rank}"
2963
2380
  if server_args.pp_size > 1:
2964
2381
  prefix += f" PP{pp_rank}"
2965
2382
 
@@ -2983,7 +2400,9 @@ def run_scheduler_process(
2983
2400
 
2984
2401
  # Create a scheduler and run the event loop
2985
2402
  try:
2986
- scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
2403
+ scheduler = Scheduler(
2404
+ server_args, port_args, gpu_id, tp_rank, moe_ep_rank, pp_rank, dp_rank
2405
+ )
2987
2406
  pipe_writer.send(
2988
2407
  {
2989
2408
  "status": "ready",