sglang 0.4.9.post5__py3-none-any.whl → 0.4.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/srt/configs/__init__.py +8 -0
  3. sglang/srt/configs/model_config.py +6 -0
  4. sglang/srt/configs/step3_vl.py +172 -0
  5. sglang/srt/conversation.py +23 -0
  6. sglang/srt/disaggregation/decode.py +2 -8
  7. sglang/srt/disaggregation/prefill.py +2 -6
  8. sglang/srt/distributed/parallel_state.py +86 -1
  9. sglang/srt/entrypoints/engine.py +14 -18
  10. sglang/srt/entrypoints/http_server.py +23 -3
  11. sglang/srt/entrypoints/openai/protocol.py +3 -1
  12. sglang/srt/entrypoints/openai/serving_base.py +5 -2
  13. sglang/srt/entrypoints/openai/serving_chat.py +2 -21
  14. sglang/srt/eplb/expert_distribution.py +5 -0
  15. sglang/srt/eplb/expert_location.py +17 -6
  16. sglang/srt/eplb/expert_location_dispatch.py +1 -0
  17. sglang/srt/eplb/expert_location_updater.py +2 -0
  18. sglang/srt/function_call/function_call_parser.py +2 -0
  19. sglang/srt/function_call/step3_detector.py +436 -0
  20. sglang/srt/hf_transformers_utils.py +2 -0
  21. sglang/srt/jinja_template_utils.py +4 -1
  22. sglang/srt/layers/moe/cutlass_moe.py +2 -1
  23. sglang/srt/layers/moe/ep_moe/layer.py +98 -603
  24. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +83 -118
  25. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +26 -13
  28. sglang/srt/layers/moe/fused_moe_triton/layer.py +97 -38
  29. sglang/srt/layers/moe/token_dispatcher/__init__.py +0 -0
  30. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +48 -0
  31. sglang/srt/layers/moe/token_dispatcher/standard.py +19 -0
  32. sglang/srt/layers/moe/topk.py +6 -2
  33. sglang/srt/layers/quantization/fp8.py +0 -18
  34. sglang/srt/layers/quantization/modelopt_quant.py +2 -0
  35. sglang/srt/layers/quantization/unquant.py +0 -8
  36. sglang/srt/layers/quantization/w4afp8.py +1 -0
  37. sglang/srt/managers/cache_controller.py +143 -45
  38. sglang/srt/managers/data_parallel_controller.py +6 -0
  39. sglang/srt/managers/io_struct.py +12 -2
  40. sglang/srt/managers/scheduler.py +116 -669
  41. sglang/srt/managers/scheduler_input_blocker.py +106 -0
  42. sglang/srt/managers/scheduler_metrics_mixin.py +229 -0
  43. sglang/srt/managers/scheduler_profiler_mixin.py +279 -0
  44. sglang/srt/managers/scheduler_update_weights_mixin.py +142 -0
  45. sglang/srt/managers/template_manager.py +62 -19
  46. sglang/srt/managers/tokenizer_manager.py +166 -83
  47. sglang/srt/managers/tp_worker.py +9 -0
  48. sglang/srt/managers/tp_worker_overlap_thread.py +2 -1
  49. sglang/srt/mem_cache/hicache_storage.py +45 -11
  50. sglang/srt/mem_cache/hiradix_cache.py +15 -4
  51. sglang/srt/mem_cache/memory_pool_host.py +73 -1
  52. sglang/srt/mem_cache/mooncake_store/mooncake_store.py +264 -0
  53. sglang/srt/mem_cache/mooncake_store/unit_test.py +40 -0
  54. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +177 -0
  55. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +278 -0
  56. sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py +43 -0
  57. sglang/srt/model_executor/model_runner.py +20 -13
  58. sglang/srt/models/arcee.py +532 -0
  59. sglang/srt/models/deepseek_v2.py +15 -56
  60. sglang/srt/models/glm4_moe.py +3 -1
  61. sglang/srt/models/granitemoe.py +3 -0
  62. sglang/srt/models/grok.py +3 -0
  63. sglang/srt/models/hunyuan.py +1 -0
  64. sglang/srt/models/llama4.py +3 -0
  65. sglang/srt/models/mixtral.py +3 -0
  66. sglang/srt/models/olmoe.py +3 -0
  67. sglang/srt/models/phimoe.py +1 -0
  68. sglang/srt/models/qwen3_moe.py +12 -69
  69. sglang/srt/models/step3_vl.py +994 -0
  70. sglang/srt/multimodal/processors/base_processor.py +15 -16
  71. sglang/srt/multimodal/processors/step3_vl.py +515 -0
  72. sglang/srt/poll_based_barrier.py +31 -0
  73. sglang/srt/reasoning_parser.py +2 -1
  74. sglang/srt/server_args.py +18 -13
  75. sglang/srt/speculative/eagle_worker.py +2 -0
  76. sglang/srt/two_batch_overlap.py +8 -3
  77. sglang/test/test_utils.py +53 -0
  78. sglang/utils.py +0 -11
  79. sglang/version.py +1 -1
  80. {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/METADATA +4 -4
  81. {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/RECORD +84 -64
  82. {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/WHEEL +0 -0
  83. {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/licenses/LICENSE +0 -0
  84. {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/top_level.txt +0 -0
@@ -13,7 +13,6 @@
13
13
  # ==============================================================================
14
14
  """A scheduler that manages a tensor parallel GPU worker."""
15
15
 
16
- import datetime
17
16
  import faulthandler
18
17
  import logging
19
18
  import os
@@ -21,10 +20,10 @@ import signal
21
20
  import sys
22
21
  import threading
23
22
  import time
24
- from collections import defaultdict, deque
23
+ from collections import deque
25
24
  from concurrent import futures
26
25
  from dataclasses import dataclass
27
- from pathlib import Path
26
+ from http import HTTPStatus
28
27
  from types import SimpleNamespace
29
28
  from typing import Dict, List, Optional, Tuple, Union
30
29
 
@@ -36,7 +35,6 @@ from torch.distributed import barrier
36
35
 
37
36
  from sglang.global_config import global_config
38
37
  from sglang.srt.configs.model_config import ModelConfig
39
- from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
40
38
  from sglang.srt.constrained.base_grammar_backend import (
41
39
  INVALID_GRAMMAR_OBJ,
42
40
  create_grammar_backend,
@@ -46,7 +44,6 @@ from sglang.srt.disaggregation.decode import (
46
44
  DecodeTransferQueue,
47
45
  SchedulerDisaggregationDecodeMixin,
48
46
  )
49
- from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
50
47
  from sglang.srt.disaggregation.prefill import (
51
48
  PrefillBootstrapQueue,
52
49
  SchedulerDisaggregationPrefillMixin,
@@ -77,21 +74,15 @@ from sglang.srt.managers.io_struct import (
77
74
  GetInternalStateReq,
78
75
  GetInternalStateReqOutput,
79
76
  GetWeightsByNameReqInput,
80
- GetWeightsByNameReqOutput,
81
77
  HealthCheckOutput,
82
78
  InitWeightsUpdateGroupReqInput,
83
- InitWeightsUpdateGroupReqOutput,
84
79
  LoadLoRAAdapterReqInput,
85
80
  LoadLoRAAdapterReqOutput,
86
81
  OpenSessionReqInput,
87
82
  OpenSessionReqOutput,
88
83
  ProfileReq,
89
- ProfileReqOutput,
90
- ProfileReqType,
91
84
  ReleaseMemoryOccupationReqInput,
92
- ReleaseMemoryOccupationReqOutput,
93
85
  ResumeMemoryOccupationReqInput,
94
- ResumeMemoryOccupationReqOutput,
95
86
  RpcReqInput,
96
87
  RpcReqOutput,
97
88
  SetInternalStateReq,
@@ -103,11 +94,8 @@ from sglang.srt.managers.io_struct import (
103
94
  UnloadLoRAAdapterReqInput,
104
95
  UnloadLoRAAdapterReqOutput,
105
96
  UpdateWeightFromDiskReqInput,
106
- UpdateWeightFromDiskReqOutput,
107
97
  UpdateWeightsFromDistributedReqInput,
108
- UpdateWeightsFromDistributedReqOutput,
109
98
  UpdateWeightsFromTensorReqInput,
110
- UpdateWeightsFromTensorReqOutput,
111
99
  )
112
100
  from sglang.srt.managers.mm_utils import init_embedding_cache
113
101
  from sglang.srt.managers.schedule_batch import (
@@ -122,9 +110,18 @@ from sglang.srt.managers.schedule_policy import (
122
110
  PrefillAdder,
123
111
  SchedulePolicy,
124
112
  )
113
+ from sglang.srt.managers.scheduler_input_blocker import SchedulerInputBlocker
114
+ from sglang.srt.managers.scheduler_metrics_mixin import (
115
+ RECORD_STEP_TIME,
116
+ SchedulerMetricsMixin,
117
+ )
125
118
  from sglang.srt.managers.scheduler_output_processor_mixin import (
126
119
  SchedulerOutputProcessorMixin,
127
120
  )
121
+ from sglang.srt.managers.scheduler_profiler_mixin import SchedulerProfilerMixin
122
+ from sglang.srt.managers.scheduler_update_weights_mixin import (
123
+ SchedulerUpdateWeightsMixin,
124
+ )
128
125
  from sglang.srt.managers.session_controller import Session
129
126
  from sglang.srt.managers.tp_worker import TpModelWorker
130
127
  from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
@@ -133,7 +130,6 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
133
130
  from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
134
131
  from sglang.srt.mem_cache.radix_cache import RadixCache
135
132
  from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
136
- from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
137
133
  from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
138
134
  from sglang.srt.reasoning_parser import ReasoningParser
139
135
  from sglang.srt.server_args import PortArgs, ServerArgs
@@ -166,7 +162,6 @@ logger = logging.getLogger(__name__)
166
162
 
167
163
  # Test retract decode for debugging purposes
168
164
  TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
169
- RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
170
165
  GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
171
166
 
172
167
  _is_cpu = is_cpu()
@@ -189,41 +184,11 @@ class EmbeddingBatchResult:
189
184
  bid: int
190
185
 
191
186
 
192
- class KvMetrics:
193
- def __init__(self):
194
- self.request_active_slots = None
195
- self.request_total_slots = None
196
- self.kv_active_blocks = None
197
- self.kv_total_blocks = None
198
- self.num_requests_waiting = None
199
- self.gpu_cache_usage_perc = None
200
- self.gpu_prefix_cache_hit_rate = None
201
- self.data_parallel_rank = None
202
-
203
-
204
- class IdleSleeper:
205
- """
206
- In setups which have long inactivity periods it is desirable to reduce
207
- system power consumption when sglang does nothing. This would lead not only
208
- to power savings, but also to more CPU thermal headroom when a request
209
- eventually comes. This is important in cases when multiple GPUs are connected
210
- as each GPU would otherwise pin one thread at 100% CPU usage.
211
-
212
- The simplest solution is to use zmq.Poller on all sockets that may receive
213
- data that needs handling immediately.
214
- """
215
-
216
- def __init__(self, sockets):
217
- self.poller = zmq.Poller()
218
- for s in sockets:
219
- self.poller.register(s, zmq.POLLIN)
220
-
221
- def maybe_sleep(self):
222
- self.poller.poll(1000)
223
-
224
-
225
187
  class Scheduler(
226
188
  SchedulerOutputProcessorMixin,
189
+ SchedulerUpdateWeightsMixin,
190
+ SchedulerProfilerMixin,
191
+ SchedulerMetricsMixin,
227
192
  SchedulerDisaggregationDecodeMixin,
228
193
  SchedulerDisaggregationPrefillMixin,
229
194
  ):
@@ -235,15 +200,18 @@ class Scheduler(
235
200
  port_args: PortArgs,
236
201
  gpu_id: int,
237
202
  tp_rank: int,
203
+ moe_ep_rank: int,
238
204
  pp_rank: int,
239
205
  dp_rank: Optional[int],
240
206
  ):
241
207
  # Parse args
242
208
  self.server_args = server_args
243
209
  self.tp_rank = tp_rank
210
+ self.moe_ep_rank = moe_ep_rank
244
211
  self.pp_rank = pp_rank
245
212
  self.dp_rank = dp_rank
246
213
  self.tp_size = server_args.tp_size
214
+ self.moe_ep_size = server_args.ep_size
247
215
  self.pp_size = server_args.pp_size
248
216
  self.dp_size = server_args.dp_size
249
217
  self.schedule_policy = server_args.schedule_policy
@@ -264,7 +232,7 @@ class Scheduler(
264
232
  self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
265
233
  self.enable_hicache_storage = server_args.hicache_storage_backend is not None
266
234
  self.page_size = server_args.page_size
267
- self.dp_size = server_args.dp_size
235
+
268
236
  self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
269
237
  compute_dp_attention_world_info(
270
238
  server_args.enable_dp_attention,
@@ -282,10 +250,13 @@ class Scheduler(
282
250
  self.recv_from_tokenizer = get_zmq_socket(
283
251
  context, zmq.PULL, port_args.scheduler_input_ipc_name, False
284
252
  )
253
+ self.recv_from_rpc = get_zmq_socket(
254
+ context, zmq.DEALER, port_args.rpc_ipc_name, False
255
+ )
256
+
285
257
  self.send_to_tokenizer = get_zmq_socket(
286
258
  context, zmq.PUSH, port_args.tokenizer_ipc_name, False
287
259
  )
288
-
289
260
  if server_args.skip_tokenizer_init:
290
261
  # Directly send to the TokenizerManager
291
262
  self.send_to_detokenizer = get_zmq_socket(
@@ -297,9 +268,6 @@ class Scheduler(
297
268
  context, zmq.PUSH, port_args.detokenizer_ipc_name, False
298
269
  )
299
270
 
300
- self.recv_from_rpc = get_zmq_socket(
301
- context, zmq.DEALER, port_args.rpc_ipc_name, False
302
- )
303
271
  if self.server_args.sleep_on_idle:
304
272
  self.idle_sleeper = IdleSleeper(
305
273
  [
@@ -345,6 +313,7 @@ class Scheduler(
345
313
  server_args=server_args,
346
314
  gpu_id=gpu_id,
347
315
  tp_rank=tp_rank,
316
+ moe_ep_rank=moe_ep_rank,
348
317
  pp_rank=pp_rank,
349
318
  dp_rank=dp_rank,
350
319
  nccl_port=port_args.nccl_port,
@@ -357,6 +326,7 @@ class Scheduler(
357
326
  self.draft_worker = EAGLEWorker(
358
327
  gpu_id=gpu_id,
359
328
  tp_rank=tp_rank,
329
+ moe_ep_rank=moe_ep_rank,
360
330
  server_args=server_args,
361
331
  nccl_port=port_args.nccl_port,
362
332
  target_worker=self.tp_worker,
@@ -370,6 +340,7 @@ class Scheduler(
370
340
  self.max_total_num_tokens,
371
341
  self.max_prefill_tokens,
372
342
  self.max_running_requests,
343
+ self.max_queued_requests,
373
344
  self.max_req_len,
374
345
  self.max_req_input_len,
375
346
  self.random_seed,
@@ -395,7 +366,7 @@ class Scheduler(
395
366
  global_server_args_dict.update(worker_global_server_args_dict)
396
367
  set_random_seed(self.random_seed)
397
368
 
398
- # Hybrid
369
+ # Hybrid memory pool
399
370
  self.is_hybrid = self.tp_worker.is_hybrid
400
371
  if self.is_hybrid:
401
372
  self.sliding_window_size = self.tp_worker.sliding_window_size
@@ -502,10 +473,25 @@ class Scheduler(
502
473
  )
503
474
  self.init_profier()
504
475
 
476
+ self.input_blocker = (
477
+ SchedulerInputBlocker(noop=self.attn_tp_rank != 0)
478
+ if get_bool_env_var("SGLANG_ENABLE_COLOCATED_BATCH_GEN")
479
+ else None
480
+ )
481
+
505
482
  # Init metrics stats
506
483
  self.init_metrics(tp_rank, pp_rank, dp_rank)
507
484
  self.init_kv_events(server_args.kv_events_config)
508
485
 
486
+ # Init disaggregation
487
+ self.disaggregation_mode = DisaggregationMode(
488
+ self.server_args.disaggregation_mode
489
+ )
490
+ self.init_disaggregation()
491
+
492
+ if get_bool_env_var("SGLANG_GC_LOG"):
493
+ configure_gc_logger()
494
+
509
495
  # Init request dispatcher
510
496
  self._request_dispatcher = TypeBasedDispatcher(
511
497
  [
@@ -536,22 +522,6 @@ class Scheduler(
536
522
  ]
537
523
  )
538
524
 
539
- # Init disaggregation
540
- self.disaggregation_mode = DisaggregationMode(
541
- self.server_args.disaggregation_mode
542
- )
543
- self.init_disaggregation()
544
-
545
- if get_bool_env_var("SGLANG_GC_LOG"):
546
- configure_gc_logger()
547
-
548
- def current_scheduler_metrics_enabled(self):
549
- return self.attn_tp_rank == 0 or self.enable_metrics_for_all_schedulers
550
-
551
- def maybe_sleep_on_idle(self):
552
- if self.idle_sleeper is not None:
553
- self.idle_sleeper.maybe_sleep()
554
-
555
525
  def init_tokenizer(self):
556
526
  server_args = self.server_args
557
527
 
@@ -659,50 +629,6 @@ class Scheduler(
659
629
  embedding_cache_size = int(os.environ.get("SGLANG_VLM_CACHE_SIZE_MB", "100"))
660
630
  init_embedding_cache(embedding_cache_size * 1024 * 1024)
661
631
 
662
- def init_profier(self):
663
- self.torch_profiler = None
664
- self.torch_profiler_output_dir: Optional[str] = None
665
- self.profiler_activities: Optional[List[str]] = None
666
- self.profile_id: Optional[str] = None
667
- self.profiler_start_forward_ct: Optional[int] = None
668
- self.profiler_target_forward_ct: Optional[int] = None
669
- self.profiler_target_prefill_ct: Optional[int] = None
670
- self.profiler_target_decode_ct: Optional[int] = None
671
- self.profiler_prefill_ct: Optional[int] = None
672
- self.profiler_decode_ct: Optional[int] = None
673
- self.profile_by_stage: bool = False
674
- self.profile_steps: Optional[int] = None
675
- self.profile_in_progress: bool = False
676
- self.rpd_profiler = None
677
-
678
- def init_metrics(self, tp_rank: int, pp_rank: int, dp_rank: Optional[int]):
679
- self.last_gen_throughput: float = 0.0
680
- self.last_input_throughput: float = 0.0
681
- self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
682
- self.spec_num_total_accepted_tokens = 0
683
- self.spec_num_total_forward_ct = 0
684
- self.cum_spec_accept_length = 0
685
- self.cum_spec_accept_count = 0
686
- self.total_retracted_reqs = 0
687
- self.stats = SchedulerStats()
688
- if self.enable_metrics:
689
- engine_type = "unified"
690
- labels = {
691
- "model_name": self.server_args.served_model_name,
692
- "engine_type": engine_type,
693
- "tp_rank": tp_rank,
694
- "pp_rank": pp_rank,
695
- }
696
- if dp_rank is not None:
697
- labels["dp_rank"] = dp_rank
698
- self.metrics_collector = SchedulerMetricsCollector(labels=labels)
699
-
700
- def init_kv_events(self, kv_events_config: Optional[str]):
701
- if self.enable_kv_cache_events:
702
- self.kv_event_publisher = EventPublisherFactory.create(
703
- kv_events_config, self.attn_dp_rank
704
- )
705
-
706
632
  def init_disaggregation(self):
707
633
  self.transfer_backend = TransferBackend(
708
634
  self.server_args.disaggregation_transfer_backend
@@ -811,10 +737,7 @@ class Scheduler(
811
737
  self.process_batch_result(batch, result)
812
738
  else:
813
739
  # When the server is idle, do self-check and re-init some states
814
- self.check_memory()
815
- self.check_tree_cache()
816
- self.new_token_ratio = self.init_new_token_ratio
817
- self.maybe_sleep_on_idle()
740
+ self.self_check_during_idle()
818
741
 
819
742
  self.last_batch = batch
820
743
 
@@ -857,10 +780,7 @@ class Scheduler(
857
780
  )
858
781
  elif batch is None:
859
782
  # When the server is idle, do self-check and re-init some states
860
- self.check_memory()
861
- self.check_tree_cache()
862
- self.new_token_ratio = self.init_new_token_ratio
863
- self.maybe_sleep_on_idle()
783
+ self.self_check_during_idle()
864
784
 
865
785
  self.last_batch = batch
866
786
 
@@ -994,10 +914,8 @@ class Scheduler(
994
914
 
995
915
  # When the server is idle, self-check and re-init some states
996
916
  if server_is_idle:
997
- self.check_memory()
998
- self.check_tree_cache()
999
- self.new_token_ratio = self.init_new_token_ratio
1000
- self.maybe_sleep_on_idle()
917
+ # When the server is idle, do self-check and re-init some states
918
+ self.self_check_during_idle()
1001
919
 
1002
920
  def recv_requests(self) -> List[Req]:
1003
921
  """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
@@ -1033,6 +951,9 @@ class Scheduler(
1033
951
  else:
1034
952
  recv_reqs = None
1035
953
 
954
+ if self.input_blocker is not None:
955
+ recv_reqs = self.input_blocker.handle(recv_reqs)
956
+
1036
957
  if self.server_args.enable_dp_attention:
1037
958
  if self.attn_tp_rank == 0:
1038
959
  work_reqs = [
@@ -1086,6 +1007,19 @@ class Scheduler(
1086
1007
  self.return_health_check_ct += 1
1087
1008
  continue
1088
1009
 
1010
+ # If it is a work request, accept or reject the request based on the request queue size.
1011
+ if is_work_request(recv_req):
1012
+ if len(self.waiting_queue) + 1 > self.max_queued_requests:
1013
+ abort_req = AbortReq(
1014
+ recv_req.rid,
1015
+ finished_reason={
1016
+ "type": "abort",
1017
+ "status_code": HTTPStatus.SERVICE_UNAVAILABLE,
1018
+ "message": "The request queue is full.",
1019
+ },
1020
+ )
1021
+ self.send_to_tokenizer.send_pyobj(abort_req)
1022
+ continue
1089
1023
  output = self._request_dispatcher(recv_req)
1090
1024
  if output is not None:
1091
1025
  if isinstance(output, RpcReqOutput):
@@ -1256,23 +1190,28 @@ class Scheduler(
1256
1190
  def _add_request_to_queue(self, req: Req):
1257
1191
  req.queue_time_start = time.perf_counter()
1258
1192
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
1193
+ self._prefetch_kvcache(req)
1259
1194
  self.disagg_prefill_bootstrap_queue.add(
1260
1195
  req, self.model_config.num_key_value_heads
1261
1196
  )
1262
1197
  elif self.disaggregation_mode == DisaggregationMode.DECODE:
1263
1198
  self.disagg_decode_prealloc_queue.add(req)
1264
1199
  else:
1265
- if self.enable_hicache_storage:
1266
- req.init_next_round_input(self.tree_cache)
1267
- last_hash = req.last_host_node.get_last_hash_value()
1268
- matched_len = len(req.prefix_indices) + req.host_hit_length
1269
- if (matched_len > 0 and last_hash is not None) or matched_len == 0:
1270
- new_input_tokens = req.fill_ids[matched_len:]
1271
- self.tree_cache.prefetch_from_storage(
1272
- req.rid, req.last_host_node, new_input_tokens, last_hash
1273
- )
1200
+ self._prefetch_kvcache(req)
1274
1201
  self.waiting_queue.append(req)
1275
1202
 
1203
+ def _prefetch_kvcache(self, req: Req):
1204
+ if self.enable_hicache_storage:
1205
+ req.init_next_round_input(self.tree_cache)
1206
+ last_hash = req.last_host_node.get_last_hash_value()
1207
+ matched_len = len(req.prefix_indices) + req.host_hit_length
1208
+ # todo, free-form fetching, calculating hash keys on the fly
1209
+ if (matched_len > 0 and last_hash is not None) or matched_len == 0:
1210
+ new_input_tokens = req.fill_ids[matched_len:]
1211
+ self.tree_cache.prefetch_from_storage(
1212
+ req.rid, req.last_host_node, new_input_tokens, last_hash
1213
+ )
1214
+
1276
1215
  def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
1277
1216
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
1278
1217
  self.disagg_prefill_bootstrap_queue.extend(
@@ -1330,170 +1269,11 @@ class Scheduler(
1330
1269
  req.logprob_start_len = len(req.origin_input_ids) - 1
1331
1270
  self._add_request_to_queue(req)
1332
1271
 
1333
- def _emit_kv_metrics(self):
1334
- kv_metrics = KvMetrics()
1335
- kv_metrics.request_active_slots = self.stats.num_running_reqs
1336
- kv_metrics.request_total_slots = self.max_running_requests
1337
- kv_metrics.kv_active_blocks = int(
1338
- self.stats.token_usage * self.max_total_num_tokens
1339
- )
1340
- kv_metrics.kv_total_blocks = self.max_total_num_tokens
1341
- kv_metrics.num_requests_waiting = self.stats.num_queue_reqs
1342
- kv_metrics.gpu_cache_usage_perc = self.stats.token_usage
1343
- kv_metrics.gpu_prefix_cache_hit_rate = self.stats.cache_hit_rate
1344
- kv_metrics.data_parallel_rank = self.dp_rank if self.dp_rank is not None else 0
1345
-
1346
- if not self.send_metrics_from_scheduler.closed:
1347
- self.send_metrics_from_scheduler.send_pyobj(kv_metrics)
1348
-
1349
- def log_prefill_stats(
1350
- self,
1351
- adder: PrefillAdder,
1352
- can_run_list: List[Req],
1353
- running_bs: int,
1354
- ):
1355
- gap_latency = time.perf_counter() - self.last_prefill_stats_tic
1356
- self.last_prefill_stats_tic = time.perf_counter()
1357
- self.last_input_throughput = self.last_prefill_tokens / gap_latency
1358
- self.last_prefill_tokens = adder.log_input_tokens
1359
-
1360
- if self.is_hybrid:
1361
- (
1362
- full_num_used,
1363
- swa_num_used,
1364
- full_token_usage,
1365
- swa_token_usage,
1366
- _,
1367
- _,
1368
- _,
1369
- _,
1370
- ) = self._get_swa_token_info()
1371
- num_used = max(full_num_used, swa_num_used)
1372
- token_usage = max(full_token_usage, swa_token_usage)
1373
- token_msg = (
1374
- f"full token usage: {full_token_usage:.2f}, "
1375
- f"swa token usage: {swa_token_usage:.2f}, "
1376
- )
1377
- else:
1378
- num_used, token_usage, _, _ = self._get_token_info()
1379
- token_msg = f"token usage: {token_usage:.2f}, "
1380
-
1381
- num_new_seq = len(can_run_list)
1382
- f = (
1383
- f"Prefill batch. "
1384
- f"#new-seq: {num_new_seq}, "
1385
- f"#new-token: {adder.log_input_tokens}, "
1386
- f"#cached-token: {adder.log_hit_tokens}, "
1387
- f"{token_msg}"
1388
- )
1389
-
1390
- if self.disaggregation_mode == DisaggregationMode.PREFILL:
1391
- f += f"#unbootstrapped-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
1392
- f += f"#queue-req: {len(self.waiting_queue)}, "
1393
- f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)}, "
1394
- f += f"input throughput (token/s): {self.last_input_throughput:.2f}, "
1395
- else:
1396
- f += f"#running-req: {running_bs}, "
1397
- f += f"#queue-req: {len(self.waiting_queue)}, "
1398
-
1399
- logger.info(f)
1400
-
1401
- if self.enable_metrics:
1402
- total_tokens = adder.log_input_tokens + adder.log_hit_tokens
1403
-
1404
- cache_hit_rate = (
1405
- adder.log_hit_tokens / total_tokens if total_tokens > 0 else 0.0
1406
- )
1407
- self.stats.num_running_reqs = running_bs
1408
- self.stats.num_used_tokens = num_used
1409
- self.stats.token_usage = round(token_usage, 2)
1410
- self.stats.num_queue_reqs = len(self.waiting_queue)
1411
- self.stats.cache_hit_rate = cache_hit_rate
1412
-
1413
- total_queue_latency = 0
1414
- for req in can_run_list:
1415
- total_queue_latency += req.queue_time_end - req.queue_time_start
1416
- self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq
1417
-
1418
- self.metrics_collector.log_stats(self.stats)
1419
- self._emit_kv_metrics()
1420
- self._publish_kv_events()
1421
-
1422
- def log_decode_stats(
1423
- self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
1424
- ):
1425
- batch = running_batch or self.running_batch
1426
-
1427
- gap_latency = time.perf_counter() - self.last_decode_stats_tic
1428
- self.last_decode_stats_tic = time.perf_counter()
1429
- self.last_gen_throughput = self.num_generated_tokens / gap_latency
1430
- self.num_generated_tokens = 0
1431
- num_running_reqs = len(batch.reqs)
1432
- if self.is_hybrid:
1433
- (
1434
- full_num_used,
1435
- swa_num_used,
1436
- full_token_usage,
1437
- swa_token_usage,
1438
- _,
1439
- _,
1440
- _,
1441
- _,
1442
- ) = self._get_swa_token_info()
1443
- num_used = max(full_num_used, swa_num_used)
1444
- token_usage = max(full_token_usage, swa_token_usage)
1445
- token_msg = (
1446
- f"#full token: {full_num_used}, "
1447
- f"full token usage: {full_token_usage:.2f}, "
1448
- f"#swa token: {swa_num_used}, "
1449
- f"swa token usage: {swa_token_usage:.2f}, "
1450
- )
1451
- else:
1452
- num_used, token_usage, _, _ = self._get_token_info()
1453
- token_msg = f"#token: {num_used}, " f"token usage: {token_usage:.2f}, "
1454
-
1455
- if RECORD_STEP_TIME:
1456
- self.step_time_dict[num_running_reqs].append(
1457
- gap_latency / self.server_args.decode_log_interval
1458
- )
1459
-
1460
- msg = f"Decode batch. #running-req: {num_running_reqs}, {token_msg}"
1461
-
1462
- if self.spec_algorithm.is_none():
1463
- spec_accept_length = 0
1464
- else:
1465
- spec_accept_length = (
1466
- self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
1467
- )
1468
- self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
1469
- self.cum_spec_accept_count += self.spec_num_total_forward_ct
1470
- self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
1471
- msg += f"accept len: {spec_accept_length:.2f}, "
1472
-
1473
- if self.disaggregation_mode == DisaggregationMode.DECODE:
1474
- msg += f"pre-allocated usage: {self.disagg_decode_prealloc_queue.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
1475
- msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, "
1476
-
1477
- msg += (
1478
- f"cuda graph: {can_run_cuda_graph}, "
1479
- f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
1480
- f"#queue-req: {len(self.waiting_queue)}, "
1481
- )
1482
-
1483
- logger.info(msg)
1484
- if self.enable_metrics:
1485
- self.stats.num_running_reqs = num_running_reqs
1486
- self.stats.num_used_tokens = num_used
1487
- self.stats.token_usage = round(token_usage, 2)
1488
- self.stats.cache_hit_rate = 0.0
1489
- self.stats.gen_throughput = self.last_gen_throughput
1490
- self.stats.num_queue_reqs = len(self.waiting_queue)
1491
- self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
1492
- self.stats.spec_accept_length = spec_accept_length
1493
- self.stats.total_retracted_reqs = self.total_retracted_reqs
1494
- self.metrics_collector.log_stats(self.stats)
1495
- self._emit_kv_metrics()
1496
- self._publish_kv_events()
1272
+ def self_check_during_idle(self):
1273
+ self.check_memory()
1274
+ self.check_tree_cache()
1275
+ self.new_token_ratio = self.init_new_token_ratio
1276
+ self.maybe_sleep_on_idle()
1497
1277
 
1498
1278
  def check_memory(self):
1499
1279
  if self.is_hybrid:
@@ -2397,22 +2177,6 @@ class Scheduler(
2397
2177
  barrier()
2398
2178
  return RpcReqOutput(success, "" if not exec else str(exec))
2399
2179
 
2400
- def save_remote_model(self, params):
2401
- url = params["url"]
2402
-
2403
- worker = self.tp_worker.worker
2404
-
2405
- worker.model_runner.save_remote_model(url)
2406
-
2407
- def save_sharded_model(self, params):
2408
- worker = self.tp_worker.worker
2409
-
2410
- worker.model_runner.save_sharded_model(
2411
- path=params["path"],
2412
- pattern=params["pattern"],
2413
- max_size=params["max_size"],
2414
- )
2415
-
2416
2180
  def abort_request(self, recv_req: AbortReq):
2417
2181
  # Delete requests in the waiting queue
2418
2182
  to_del = []
@@ -2490,16 +2254,6 @@ class Scheduler(
2490
2254
  def _pause_engine(self) -> Tuple[List[Req], int]:
2491
2255
  raise NotImplementedError()
2492
2256
 
2493
- def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
2494
- """In-place update of the weights from disk."""
2495
- success, message = self.tp_worker.update_weights_from_disk(recv_req)
2496
- if success:
2497
- flush_cache_success = self.flush_cache()
2498
- assert flush_cache_success, "Cache flush failed after updating weights"
2499
- else:
2500
- logger.error(message)
2501
- return UpdateWeightFromDiskReqOutput(success, message, 0)
2502
-
2503
2257
  def load_lora_adapter(
2504
2258
  self, recv_req: LoadLoRAAdapterReqInput
2505
2259
  ) -> LoadLoRAAdapterReqOutput:
@@ -2516,81 +2270,6 @@ class Scheduler(
2516
2270
  result = self.tp_worker.unload_lora_adapter(recv_req)
2517
2271
  return result
2518
2272
 
2519
- def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
2520
- """Initialize the online model parameter update group."""
2521
- success, message = self.tp_worker.init_weights_update_group(recv_req)
2522
- return InitWeightsUpdateGroupReqOutput(success, message)
2523
-
2524
- def update_weights_from_distributed(
2525
- self,
2526
- recv_req: UpdateWeightsFromDistributedReqInput,
2527
- ) -> Tuple[bool, str]:
2528
- """Update the online model parameter."""
2529
- success, message = self.tp_worker.update_weights_from_distributed(recv_req)
2530
- if success:
2531
- if recv_req.flush_cache:
2532
- flush_cache_success = self.flush_cache()
2533
- assert flush_cache_success, "Cache flush failed after updating weights"
2534
- else:
2535
- logger.error(message)
2536
- return UpdateWeightsFromDistributedReqOutput(success, message)
2537
-
2538
- def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
2539
- """Update the online model parameter from tensors."""
2540
- success, message = self.tp_worker.update_weights_from_tensor(recv_req)
2541
- # TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
2542
- if success:
2543
- if recv_req.flush_cache:
2544
- flush_cache_success = self.flush_cache()
2545
- assert flush_cache_success, "Cache flush failed after updating weights"
2546
- else:
2547
- logger.error(message)
2548
- barrier(group=self.tp_cpu_group)
2549
- return UpdateWeightsFromTensorReqOutput(success, message)
2550
-
2551
- def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
2552
- parameter = self.tp_worker.get_weights_by_name(recv_req)
2553
- return GetWeightsByNameReqOutput(parameter)
2554
-
2555
- def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
2556
- tags = recv_req.tags
2557
-
2558
- if tags is None or len(tags) == 0:
2559
- tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
2560
-
2561
- if GPU_MEMORY_TYPE_KV_CACHE in tags:
2562
- self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE)
2563
- self.flush_cache()
2564
-
2565
- if GPU_MEMORY_TYPE_WEIGHTS in tags:
2566
- self.stashed_model_static_state = _export_static_state(
2567
- self.tp_worker.worker.model_runner.model
2568
- )
2569
- torch.distributed.barrier(self.tp_cpu_group)
2570
- self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_WEIGHTS)
2571
-
2572
- return ReleaseMemoryOccupationReqOutput()
2573
-
2574
- def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
2575
- tags = recv_req.tags
2576
-
2577
- if tags is None or len(tags) == 0:
2578
- tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
2579
-
2580
- if GPU_MEMORY_TYPE_WEIGHTS in tags:
2581
- self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
2582
- torch.distributed.barrier(self.tp_cpu_group)
2583
- _import_static_state(
2584
- self.tp_worker.worker.model_runner.model,
2585
- self.stashed_model_static_state,
2586
- )
2587
- del self.stashed_model_static_state
2588
-
2589
- if GPU_MEMORY_TYPE_KV_CACHE in tags:
2590
- self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_KV_CACHE)
2591
-
2592
- return ResumeMemoryOccupationReqOutput()
2593
-
2594
2273
  def slow_down(self, recv_req: SlowDownReqInput):
2595
2274
  t = recv_req.forward_sleep_time
2596
2275
  if t is not None and t <= 0:
@@ -2598,254 +2277,6 @@ class Scheduler(
2598
2277
  self.forward_sleep_time = t
2599
2278
  return SlowDownReqOutput()
2600
2279
 
2601
- def profile(self, recv_req: ProfileReq):
2602
- if recv_req.type == ProfileReqType.START_PROFILE:
2603
- if recv_req.profile_by_stage or recv_req.start_step:
2604
- return self.init_profile(
2605
- recv_req.output_dir,
2606
- recv_req.start_step,
2607
- recv_req.num_steps,
2608
- recv_req.activities,
2609
- recv_req.with_stack,
2610
- recv_req.record_shapes,
2611
- recv_req.profile_by_stage,
2612
- recv_req.profile_id,
2613
- )
2614
- else:
2615
- self.init_profile(
2616
- recv_req.output_dir,
2617
- recv_req.start_step,
2618
- recv_req.num_steps,
2619
- recv_req.activities,
2620
- recv_req.with_stack,
2621
- recv_req.record_shapes,
2622
- recv_req.profile_by_stage,
2623
- recv_req.profile_id,
2624
- )
2625
- return self.start_profile(True)
2626
- else:
2627
- return self.stop_profile()
2628
-
2629
- def init_profile(
2630
- self,
2631
- output_dir: Optional[str],
2632
- start_step: Optional[int],
2633
- num_steps: Optional[int],
2634
- activities: Optional[List[str]],
2635
- with_stack: Optional[bool],
2636
- record_shapes: Optional[bool],
2637
- profile_by_stage: bool,
2638
- profile_id: str,
2639
- ) -> ProfileReqOutput:
2640
- if self.profile_in_progress:
2641
- return ProfileReqOutput(
2642
- success=False,
2643
- message="Profiling is already in progress. Call /stop_profile first.",
2644
- )
2645
-
2646
- self.profile_by_stage = profile_by_stage
2647
-
2648
- if output_dir is None:
2649
- output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp")
2650
- if activities is None:
2651
- activities = ["CPU", "GPU"]
2652
-
2653
- self.torch_profiler_output_dir = output_dir
2654
- self.torch_profiler_with_stack = with_stack
2655
- self.torch_profiler_record_shapes = record_shapes
2656
- self.profiler_activities = activities
2657
- self.profile_id = profile_id
2658
-
2659
- if start_step:
2660
- self.profiler_start_forward_ct = max(start_step, self.forward_ct + 1)
2661
-
2662
- if num_steps:
2663
- self.profile_steps = num_steps
2664
- if self.profile_by_stage:
2665
- self.profiler_target_prefill_ct = num_steps
2666
- self.profiler_target_decode_ct = num_steps
2667
- self.profiler_prefill_ct = 0
2668
- self.profiler_decode_ct = 0
2669
- elif start_step:
2670
- self.profiler_target_forward_ct = (
2671
- self.profiler_start_forward_ct + num_steps
2672
- )
2673
- else:
2674
- self.profiler_target_forward_ct = self.forward_ct + num_steps
2675
- # The caller will be notified when reaching profiler_target_forward_ct
2676
- else:
2677
- self.profiler_target_forward_ct = None
2678
-
2679
- return ProfileReqOutput(success=True, message="Succeeded")
2680
-
2681
- def start_profile(
2682
- self, stage: Optional[ForwardMode] = None
2683
- ) -> ProfileReqOutput | None:
2684
- stage_str = f" for {stage.__str__()}" if stage else ""
2685
- logger.info(
2686
- f"Profiling starts{stage_str}. Traces will be saved to: {self.torch_profiler_output_dir} (with profile id: {self.profile_id})",
2687
- )
2688
-
2689
- activities = self.profiler_activities
2690
- with_stack = self.torch_profiler_with_stack
2691
- record_shapes = self.torch_profiler_record_shapes
2692
-
2693
- activity_map = {
2694
- "CPU": torch.profiler.ProfilerActivity.CPU,
2695
- "GPU": torch.profiler.ProfilerActivity.CUDA,
2696
- }
2697
- torchprof_activities = [
2698
- activity_map[a] for a in activities if a in activity_map
2699
- ]
2700
-
2701
- if "RPD" in activities:
2702
- from rpdTracerControl import rpdTracerControl
2703
-
2704
- rpdTracerControl.skipCreate()
2705
-
2706
- self.rpd_profile_path = os.path.join(
2707
- self.torch_profiler_output_dir,
2708
- "rpd-" + str(time.time()) + f"-TP-{self.tp_rank}" + ".trace.json.gz",
2709
- )
2710
-
2711
- if self.tp_rank == 0:
2712
- import sqlite3
2713
-
2714
- from rocpd.schema import RocpdSchema
2715
-
2716
- if os.path.exists("trace.rpd"):
2717
- os.unlink("trace.rpd")
2718
- schema = RocpdSchema()
2719
- connection = sqlite3.connect("trace.rpd")
2720
- schema.writeSchema(connection)
2721
- connection.commit()
2722
- del connection
2723
- torch.distributed.barrier(self.tp_cpu_group)
2724
-
2725
- self.rpd_profiler = rpdTracerControl()
2726
- self.rpd_profiler.setPythonTrace(True)
2727
- self.rpd_profiler.start()
2728
- self.rpd_profiler.rangePush("", "rpd profile range", "")
2729
- self.profile_in_progress = True
2730
- elif torchprof_activities:
2731
- self.torch_profiler = torch.profiler.profile(
2732
- activities=torchprof_activities,
2733
- with_stack=with_stack if with_stack is not None else True,
2734
- record_shapes=record_shapes if record_shapes is not None else False,
2735
- )
2736
- self.torch_profiler.start()
2737
- self.profile_in_progress = True
2738
-
2739
- if "MEM" in activities:
2740
- torch.cuda.memory._record_memory_history(max_entries=100000)
2741
- self.profile_in_progress = True
2742
-
2743
- if "CUDA_PROFILER" in activities:
2744
- torch.cuda.cudart().cudaProfilerStart()
2745
- self.profile_in_progress = True
2746
-
2747
- return ProfileReqOutput(success=True, message="Succeeded")
2748
-
2749
- def stop_profile(
2750
- self, stage: Optional[ForwardMode] = None
2751
- ) -> ProfileReqOutput | None:
2752
- if not self.profile_in_progress:
2753
- return ProfileReqOutput(
2754
- success=False,
2755
- message="Profiling is not in progress. Call /start_profile first.",
2756
- )
2757
-
2758
- if not Path(self.torch_profiler_output_dir).exists():
2759
- Path(self.torch_profiler_output_dir).mkdir(parents=True, exist_ok=True)
2760
-
2761
- stage_suffix = f"-{stage.__str__()}" if stage else ""
2762
- logger.info("Stop profiling" + stage_suffix + "...")
2763
- if self.torch_profiler is not None:
2764
- self.torch_profiler.stop()
2765
- self.torch_profiler.export_chrome_trace(
2766
- os.path.join(
2767
- self.torch_profiler_output_dir,
2768
- self.profile_id
2769
- + f"-TP-{self.tp_rank}"
2770
- + stage_suffix
2771
- + ".trace.json.gz",
2772
- )
2773
- )
2774
- torch.distributed.barrier(self.tp_cpu_group)
2775
-
2776
- if self.rpd_profiler is not None:
2777
- self.rpd_profiler.rangePop()
2778
- self.rpd_profiler.stop()
2779
- self.rpd_profiler.flush()
2780
-
2781
- torch.distributed.barrier(self.tp_cpu_group)
2782
- if self.tp_rank == 0:
2783
- from sglang.srt.utils import rpd_to_chrome_trace
2784
-
2785
- rpd_to_chrome_trace("trace.rpd", self.rpd_profile_path)
2786
- self.rpd_profiler = None
2787
- self.rpd_profiler_path = None
2788
-
2789
- if self.profiler_activities is not None and "MEM" in self.profiler_activities:
2790
- memory_profile_path = os.path.join(
2791
- self.torch_profiler_output_dir,
2792
- str(time.time())
2793
- + f"-TP-{self.tp_rank}-memory"
2794
- + stage_suffix
2795
- + ".pickle",
2796
- )
2797
- torch.cuda.memory._dump_snapshot(memory_profile_path)
2798
- torch.cuda.memory._record_memory_history(enabled=None)
2799
-
2800
- if "CUDA_PROFILER" in self.profiler_activities:
2801
- torch.cuda.cudart().cudaProfilerStop()
2802
-
2803
- logger.info(
2804
- "Profiling done. Traces are saved to: %s",
2805
- self.torch_profiler_output_dir,
2806
- )
2807
- self.torch_profiler = None
2808
- self.profile_in_progress = False
2809
- self.profiler_start_forward_ct = None
2810
-
2811
- return ProfileReqOutput(success=True, message="Succeeded.")
2812
-
2813
- def _profile_batch_predicate(self, batch):
2814
- if self.profile_by_stage:
2815
- if batch.forward_mode.is_prefill():
2816
- if self.profiler_prefill_ct == 0:
2817
- self.start_profile(batch.forward_mode)
2818
- self.profiler_prefill_ct += 1
2819
- if self.profiler_prefill_ct > self.profiler_target_prefill_ct:
2820
- if self.profile_in_progress:
2821
- self.stop_profile(stage=ForwardMode.EXTEND)
2822
- elif batch.forward_mode.is_decode():
2823
- if self.profiler_decode_ct == 0:
2824
- if self.profile_in_progress:
2825
- # force trace flush
2826
- self.stop_profile(ForwardMode.EXTEND)
2827
- self.start_profile(batch.forward_mode)
2828
- self.profiler_decode_ct += 1
2829
- if self.profiler_decode_ct > self.profiler_target_decode_ct:
2830
- if self.profile_in_progress:
2831
- self.stop_profile(stage=ForwardMode.DECODE)
2832
- elif batch.forward_mode.is_idle():
2833
- pass
2834
- else:
2835
- raise RuntimeError(f"unsupported profile stage: {batch.forward_mode}")
2836
- else:
2837
- # Check profiler
2838
- if (
2839
- self.profiler_target_forward_ct
2840
- and self.profiler_target_forward_ct <= self.forward_ct
2841
- ):
2842
- self.stop_profile()
2843
- if (
2844
- self.profiler_start_forward_ct
2845
- and self.profiler_start_forward_ct == self.forward_ct
2846
- ):
2847
- self.start_profile()
2848
-
2849
2280
  def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
2850
2281
  if recv_req == ExpertDistributionReq.START_RECORD:
2851
2282
  get_global_expert_distribution_recorder().start_record()
@@ -2854,7 +2285,7 @@ class Scheduler(
2854
2285
  elif recv_req == ExpertDistributionReq.DUMP_RECORD:
2855
2286
  get_global_expert_distribution_recorder().dump_record()
2856
2287
  else:
2857
- raise ValueError("Unrecognized ExpertDistributionReq value")
2288
+ raise ValueError(f"Unrecognized ExpertDistributionReq value: {recv_req=}")
2858
2289
  return ExpertDistributionReqOutput()
2859
2290
 
2860
2291
  def open_session(self, recv_req: OpenSessionReqInput):
@@ -2890,30 +2321,41 @@ class Scheduler(
2890
2321
  prefix += f" PP{self.pp_rank}"
2891
2322
  return prefix
2892
2323
 
2893
- def _publish_kv_events(self):
2894
- if self.enable_kv_cache_events:
2895
- events = self.tree_cache.take_events()
2896
- if events:
2897
- batch = KVEventBatch(ts=time.time(), events=events)
2898
- self.kv_event_publisher.publish(batch)
2324
+ def current_scheduler_metrics_enabled(self):
2325
+ return self.attn_tp_rank == 0 or self.enable_metrics_for_all_schedulers
2899
2326
 
2327
+ def maybe_sleep_on_idle(self):
2328
+ if self.idle_sleeper is not None:
2329
+ self.idle_sleeper.maybe_sleep()
2900
2330
 
2901
- def is_health_check_generate_req(recv_req):
2902
- return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
2903
2331
 
2332
+ class IdleSleeper:
2333
+ """
2334
+ In setups which have long inactivity periods it is desirable to reduce
2335
+ system power consumption when sglang does nothing. This would lead not only
2336
+ to power savings, but also to more CPU thermal headroom when a request
2337
+ eventually comes. This is important in cases when multiple GPUs are connected
2338
+ as each GPU would otherwise pin one thread at 100% CPU usage.
2904
2339
 
2905
- def _export_static_state(model):
2906
- return dict(
2907
- buffers=[
2908
- (name, buffer.detach().clone()) for name, buffer in model.named_buffers()
2909
- ]
2910
- )
2340
+ The simplest solution is to use zmq.Poller on all sockets that may receive
2341
+ data that needs handling immediately.
2342
+ """
2911
2343
 
2344
+ def __init__(self, sockets):
2345
+ self.poller = zmq.Poller()
2346
+ for s in sockets:
2347
+ self.poller.register(s, zmq.POLLIN)
2912
2348
 
2913
- def _import_static_state(model, static_params):
2914
- self_named_buffers = dict(model.named_buffers())
2915
- for name, tensor in static_params["buffers"]:
2916
- self_named_buffers[name][...] = tensor
2349
+ def maybe_sleep(self):
2350
+ self.poller.poll(1000)
2351
+
2352
+
2353
+ def is_health_check_generate_req(recv_req):
2354
+ return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
2355
+
2356
+
2357
+ def is_work_request(recv_req):
2358
+ return isinstance(recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput))
2917
2359
 
2918
2360
 
2919
2361
  def run_scheduler_process(
@@ -2921,6 +2363,7 @@ def run_scheduler_process(
2921
2363
  port_args: PortArgs,
2922
2364
  gpu_id: int,
2923
2365
  tp_rank: int,
2366
+ moe_ep_rank: int,
2924
2367
  pp_rank: int,
2925
2368
  dp_rank: Optional[int],
2926
2369
  pipe_writer,
@@ -2931,6 +2374,8 @@ def run_scheduler_process(
2931
2374
  prefix += f" DP{dp_rank}"
2932
2375
  if server_args.tp_size > 1:
2933
2376
  prefix += f" TP{tp_rank}"
2377
+ if server_args.ep_size > 1:
2378
+ prefix += f" EP{moe_ep_rank}"
2934
2379
  if server_args.pp_size > 1:
2935
2380
  prefix += f" PP{pp_rank}"
2936
2381
 
@@ -2954,7 +2399,9 @@ def run_scheduler_process(
2954
2399
 
2955
2400
  # Create a scheduler and run the event loop
2956
2401
  try:
2957
- scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
2402
+ scheduler = Scheduler(
2403
+ server_args, port_args, gpu_id, tp_rank, moe_ep_rank, pp_rank, dp_rank
2404
+ )
2958
2405
  pipe_writer.send(
2959
2406
  {
2960
2407
  "status": "ready",