sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_offline_throughput.py +10 -8
  2. sglang/bench_one_batch.py +7 -6
  3. sglang/bench_one_batch_server.py +157 -21
  4. sglang/bench_serving.py +137 -59
  5. sglang/compile_deep_gemm.py +5 -5
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +78 -78
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +2 -2
  11. sglang/srt/configs/model_config.py +40 -28
  12. sglang/srt/constrained/base_grammar_backend.py +55 -72
  13. sglang/srt/constrained/llguidance_backend.py +25 -21
  14. sglang/srt/constrained/outlines_backend.py +27 -26
  15. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  16. sglang/srt/constrained/xgrammar_backend.py +69 -43
  17. sglang/srt/conversation.py +49 -44
  18. sglang/srt/disaggregation/base/conn.py +1 -0
  19. sglang/srt/disaggregation/decode.py +129 -135
  20. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
  21. sglang/srt/disaggregation/fake/conn.py +3 -13
  22. sglang/srt/disaggregation/kv_events.py +357 -0
  23. sglang/srt/disaggregation/mini_lb.py +57 -24
  24. sglang/srt/disaggregation/mooncake/conn.py +238 -122
  25. sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
  26. sglang/srt/disaggregation/nixl/conn.py +10 -19
  27. sglang/srt/disaggregation/prefill.py +132 -47
  28. sglang/srt/disaggregation/utils.py +123 -6
  29. sglang/srt/distributed/utils.py +3 -3
  30. sglang/srt/entrypoints/EngineBase.py +5 -0
  31. sglang/srt/entrypoints/engine.py +44 -9
  32. sglang/srt/entrypoints/http_server.py +23 -6
  33. sglang/srt/entrypoints/http_server_engine.py +5 -2
  34. sglang/srt/function_call/base_format_detector.py +250 -0
  35. sglang/srt/function_call/core_types.py +34 -0
  36. sglang/srt/function_call/deepseekv3_detector.py +157 -0
  37. sglang/srt/function_call/ebnf_composer.py +234 -0
  38. sglang/srt/function_call/function_call_parser.py +175 -0
  39. sglang/srt/function_call/llama32_detector.py +74 -0
  40. sglang/srt/function_call/mistral_detector.py +84 -0
  41. sglang/srt/function_call/pythonic_detector.py +163 -0
  42. sglang/srt/function_call/qwen25_detector.py +67 -0
  43. sglang/srt/function_call/utils.py +35 -0
  44. sglang/srt/hf_transformers_utils.py +46 -7
  45. sglang/srt/layers/attention/aiter_backend.py +513 -0
  46. sglang/srt/layers/attention/flashattention_backend.py +64 -18
  47. sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
  48. sglang/srt/layers/attention/flashmla_backend.py +340 -78
  49. sglang/srt/layers/attention/triton_backend.py +3 -0
  50. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  51. sglang/srt/layers/attention/utils.py +6 -4
  52. sglang/srt/layers/attention/vision.py +1 -1
  53. sglang/srt/layers/communicator.py +451 -0
  54. sglang/srt/layers/dp_attention.py +61 -21
  55. sglang/srt/layers/layernorm.py +1 -1
  56. sglang/srt/layers/logits_processor.py +46 -11
  57. sglang/srt/layers/moe/cutlass_moe.py +207 -0
  58. sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
  59. sglang/srt/layers/moe/ep_moe/layer.py +105 -51
  60. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
  63. sglang/srt/layers/moe/topk.py +67 -10
  64. sglang/srt/layers/multimodal.py +70 -0
  65. sglang/srt/layers/quantization/__init__.py +8 -3
  66. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  67. sglang/srt/layers/quantization/deep_gemm.py +77 -74
  68. sglang/srt/layers/quantization/fp8.py +92 -2
  69. sglang/srt/layers/quantization/fp8_kernel.py +3 -3
  70. sglang/srt/layers/quantization/fp8_utils.py +6 -0
  71. sglang/srt/layers/quantization/gptq.py +298 -6
  72. sglang/srt/layers/quantization/int8_kernel.py +20 -7
  73. sglang/srt/layers/quantization/qoq.py +244 -0
  74. sglang/srt/layers/sampler.py +0 -4
  75. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  76. sglang/srt/lora/lora_manager.py +2 -4
  77. sglang/srt/lora/mem_pool.py +4 -4
  78. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  79. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  80. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  81. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  82. sglang/srt/lora/utils.py +1 -1
  83. sglang/srt/managers/data_parallel_controller.py +3 -3
  84. sglang/srt/managers/deepseek_eplb.py +278 -0
  85. sglang/srt/managers/detokenizer_manager.py +21 -8
  86. sglang/srt/managers/eplb_manager.py +55 -0
  87. sglang/srt/managers/expert_distribution.py +704 -56
  88. sglang/srt/managers/expert_location.py +394 -0
  89. sglang/srt/managers/expert_location_dispatch.py +91 -0
  90. sglang/srt/managers/io_struct.py +19 -4
  91. sglang/srt/managers/mm_utils.py +294 -140
  92. sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
  93. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
  94. sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
  95. sglang/srt/managers/multimodal_processors/internvl.py +14 -5
  96. sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
  97. sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
  98. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  99. sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
  100. sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
  101. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  102. sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
  103. sglang/srt/managers/schedule_batch.py +122 -42
  104. sglang/srt/managers/schedule_policy.py +1 -5
  105. sglang/srt/managers/scheduler.py +205 -138
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/tokenizer_manager.py +232 -58
  109. sglang/srt/managers/tp_worker.py +12 -9
  110. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  111. sglang/srt/mem_cache/base_prefix_cache.py +3 -0
  112. sglang/srt/mem_cache/chunk_cache.py +3 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +4 -4
  114. sglang/srt/mem_cache/memory_pool.py +76 -52
  115. sglang/srt/mem_cache/multimodal_cache.py +45 -0
  116. sglang/srt/mem_cache/radix_cache.py +58 -5
  117. sglang/srt/metrics/collector.py +314 -39
  118. sglang/srt/mm_utils.py +10 -0
  119. sglang/srt/model_executor/cuda_graph_runner.py +29 -19
  120. sglang/srt/model_executor/expert_location_updater.py +422 -0
  121. sglang/srt/model_executor/forward_batch_info.py +5 -1
  122. sglang/srt/model_executor/model_runner.py +163 -68
  123. sglang/srt/model_loader/loader.py +10 -6
  124. sglang/srt/models/clip.py +5 -1
  125. sglang/srt/models/deepseek_janus_pro.py +2 -2
  126. sglang/srt/models/deepseek_v2.py +308 -351
  127. sglang/srt/models/exaone.py +8 -3
  128. sglang/srt/models/gemma3_mm.py +70 -33
  129. sglang/srt/models/llama.py +2 -0
  130. sglang/srt/models/llama4.py +15 -8
  131. sglang/srt/models/llava.py +258 -7
  132. sglang/srt/models/mimo_mtp.py +220 -0
  133. sglang/srt/models/minicpmo.py +5 -12
  134. sglang/srt/models/mistral.py +71 -1
  135. sglang/srt/models/mixtral.py +98 -34
  136. sglang/srt/models/mllama.py +3 -3
  137. sglang/srt/models/pixtral.py +467 -0
  138. sglang/srt/models/qwen2.py +95 -26
  139. sglang/srt/models/qwen2_5_vl.py +8 -0
  140. sglang/srt/models/qwen2_moe.py +330 -60
  141. sglang/srt/models/qwen2_vl.py +6 -0
  142. sglang/srt/models/qwen3.py +52 -10
  143. sglang/srt/models/qwen3_moe.py +411 -48
  144. sglang/srt/models/roberta.py +1 -1
  145. sglang/srt/models/siglip.py +294 -0
  146. sglang/srt/models/torch_native_llama.py +1 -1
  147. sglang/srt/openai_api/adapter.py +58 -20
  148. sglang/srt/openai_api/protocol.py +6 -8
  149. sglang/srt/operations.py +154 -0
  150. sglang/srt/operations_strategy.py +31 -0
  151. sglang/srt/reasoning_parser.py +3 -3
  152. sglang/srt/sampling/custom_logit_processor.py +18 -3
  153. sglang/srt/sampling/sampling_batch_info.py +4 -56
  154. sglang/srt/sampling/sampling_params.py +2 -2
  155. sglang/srt/server_args.py +162 -22
  156. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  157. sglang/srt/speculative/eagle_utils.py +138 -7
  158. sglang/srt/speculative/eagle_worker.py +69 -21
  159. sglang/srt/utils.py +74 -17
  160. sglang/test/few_shot_gsm8k.py +2 -2
  161. sglang/test/few_shot_gsm8k_engine.py +2 -2
  162. sglang/test/run_eval.py +2 -2
  163. sglang/test/runners.py +8 -1
  164. sglang/test/send_one.py +13 -3
  165. sglang/test/simple_eval_common.py +1 -1
  166. sglang/test/simple_eval_humaneval.py +1 -1
  167. sglang/test/test_cutlass_moe.py +278 -0
  168. sglang/test/test_programs.py +5 -5
  169. sglang/test/test_utils.py +55 -14
  170. sglang/utils.py +3 -3
  171. sglang/version.py +1 -1
  172. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
  173. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
  174. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
  175. sglang/srt/function_call_parser.py +0 -858
  176. sglang/srt/platforms/interface.py +0 -371
  177. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  178. /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
  179. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -20,7 +20,6 @@ import signal
20
20
  import sys
21
21
  import threading
22
22
  import time
23
- import warnings
24
23
  from collections import defaultdict, deque
25
24
  from concurrent import futures
26
25
  from dataclasses import dataclass
@@ -42,14 +41,17 @@ from sglang.srt.disaggregation.decode import (
42
41
  DecodeTransferQueue,
43
42
  SchedulerDisaggregationDecodeMixin,
44
43
  )
44
+ from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
45
45
  from sglang.srt.disaggregation.prefill import (
46
46
  PrefillBootstrapQueue,
47
47
  SchedulerDisaggregationPrefillMixin,
48
48
  )
49
49
  from sglang.srt.disaggregation.utils import (
50
50
  DisaggregationMode,
51
+ MetadataBuffers,
51
52
  ReqToMetadataIdxAllocator,
52
53
  TransferBackend,
54
+ prepare_abort,
53
55
  )
54
56
  from sglang.srt.distributed import get_pp_group, get_world_group
55
57
  from sglang.srt.hf_transformers_utils import (
@@ -59,7 +61,10 @@ from sglang.srt.hf_transformers_utils import (
59
61
  )
60
62
  from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
61
63
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
62
- from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
64
+ from sglang.srt.managers.expert_distribution import (
65
+ ExpertDistributionRecorder,
66
+ get_global_expert_distribution_recorder,
67
+ )
63
68
  from sglang.srt.managers.io_struct import (
64
69
  AbortReq,
65
70
  CloseSessionReqInput,
@@ -98,6 +103,7 @@ from sglang.srt.managers.io_struct import (
98
103
  UpdateWeightsFromTensorReqInput,
99
104
  UpdateWeightsFromTensorReqOutput,
100
105
  )
106
+ from sglang.srt.managers.mm_utils import init_embedding_cache
101
107
  from sglang.srt.managers.schedule_batch import (
102
108
  FINISH_ABORT,
103
109
  MultimodalInputs,
@@ -121,11 +127,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
121
127
  from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
122
128
  from sglang.srt.mem_cache.radix_cache import RadixCache
123
129
  from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
124
- from sglang.srt.model_executor.forward_batch_info import (
125
- ForwardBatch,
126
- ForwardMode,
127
- PPProxyTensors,
128
- )
130
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
129
131
  from sglang.srt.reasoning_parser import ReasoningParser
130
132
  from sglang.srt.server_args import PortArgs, ServerArgs
131
133
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
@@ -134,7 +136,7 @@ from sglang.srt.utils import (
134
136
  DynamicGradMode,
135
137
  broadcast_pyobj,
136
138
  configure_logger,
137
- crash_on_warnings,
139
+ disable_request_logging,
138
140
  get_bool_env_var,
139
141
  get_zmq_socket,
140
142
  kill_itself_when_parent_died,
@@ -146,13 +148,12 @@ from sglang.srt.utils import (
146
148
  )
147
149
  from sglang.utils import TypeBasedDispatcher, get_exception_traceback
148
150
 
149
- expert_distribution_recorder = ExpertDistributionRecorder()
150
-
151
151
  logger = logging.getLogger(__name__)
152
152
 
153
153
  # Test retract decode for debugging purposes
154
154
  TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
155
155
  RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
156
+ GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
156
157
 
157
158
 
158
159
  @dataclass
@@ -163,6 +164,7 @@ class GenerationBatchResult:
163
164
  extend_input_len_per_req: List[int]
164
165
  extend_logprob_start_len_per_req: List[int]
165
166
  bid: int
167
+ can_run_cuda_graph: bool
166
168
 
167
169
 
168
170
  @dataclass
@@ -200,6 +202,7 @@ class Scheduler(
200
202
  self.enable_overlap = not server_args.disable_overlap_schedule
201
203
  self.skip_tokenizer_init = server_args.skip_tokenizer_init
202
204
  self.enable_metrics = server_args.enable_metrics
205
+ self.enable_kv_cache_events = server_args.kv_events_config is not None
203
206
  self.stream_interval = server_args.stream_interval
204
207
  self.spec_algorithm = SpeculativeAlgorithm.from_string(
205
208
  server_args.speculative_algorithm
@@ -207,9 +210,9 @@ class Scheduler(
207
210
  self.gpu_id = gpu_id
208
211
  self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
209
212
  self.page_size = server_args.page_size
210
-
211
213
  # Distributed rank info
212
- self.attn_tp_rank, self.attn_tp_size, self.dp_rank = (
214
+ self.dp_size = server_args.dp_size
215
+ self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
213
216
  compute_dp_attention_world_info(
214
217
  server_args.enable_dp_attention,
215
218
  self.tp_rank,
@@ -326,13 +329,14 @@ class Scheduler(
326
329
  set_random_seed(self.random_seed)
327
330
 
328
331
  # Print debug info
329
- logger.info(
330
- f"max_total_num_tokens={self.max_total_num_tokens}, "
331
- f"chunked_prefill_size={server_args.chunked_prefill_size}, "
332
- f"max_prefill_tokens={self.max_prefill_tokens}, "
333
- f"max_running_requests={self.max_running_requests}, "
334
- f"context_len={self.model_config.context_len}"
335
- )
332
+ if tp_rank == 0:
333
+ logger.info(
334
+ f"max_total_num_tokens={self.max_total_num_tokens}, "
335
+ f"chunked_prefill_size={server_args.chunked_prefill_size}, "
336
+ f"max_prefill_tokens={self.max_prefill_tokens}, "
337
+ f"max_running_requests={self.max_running_requests}, "
338
+ f"context_len={self.model_config.context_len}"
339
+ )
336
340
 
337
341
  # Init memory pool and cache
338
342
  self.init_memory_pool_and_cache()
@@ -349,8 +353,8 @@ class Scheduler(
349
353
  self.forward_ct_decode = 0
350
354
  self.num_generated_tokens = 0
351
355
  self.num_prefill_tokens = 0
352
- self.last_decode_stats_tic = time.time()
353
- self.last_prefill_stats_tic = time.time()
356
+ self.last_decode_stats_tic = time.perf_counter()
357
+ self.last_prefill_stats_tic = time.perf_counter()
354
358
  self.return_health_check_ct = 0
355
359
  self.current_stream = torch.get_device_module(self.device).current_stream()
356
360
  if self.device == "cpu":
@@ -423,6 +427,7 @@ class Scheduler(
423
427
 
424
428
  # Init metrics stats
425
429
  self.init_metrics()
430
+ self.init_kv_events(server_args.kv_events_config)
426
431
 
427
432
  # Init request dispatcher
428
433
  self._request_dispatcher = TypeBasedDispatcher(
@@ -516,6 +521,7 @@ class Scheduler(
516
521
  token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
517
522
  page_size=self.page_size,
518
523
  disable=server_args.disable_radix_cache,
524
+ enable_kv_cache_events=self.enable_kv_cache_events,
519
525
  )
520
526
 
521
527
  self.decode_mem_cache_buf_multiplier = (
@@ -531,10 +537,6 @@ class Scheduler(
531
537
  )
532
538
 
533
539
  def init_metrics(self):
534
- # The largest prefill length of a single request
535
- self._largest_prefill_len: int = 0
536
- # The largest context length (prefill + generation) of a single request
537
- self._largest_prefill_decode_len: int = 0
538
540
  self.last_gen_throughput: float = 0.0
539
541
  self.last_input_throughput: float = 0.0
540
542
  self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
@@ -552,6 +554,10 @@ class Scheduler(
552
554
  },
553
555
  )
554
556
 
557
+ def init_kv_events(self, kv_events_config: Optional[str]):
558
+ if self.enable_kv_cache_events:
559
+ self.kv_event_publisher = EventPublisherFactory.create(kv_events_config)
560
+
555
561
  def init_disaggregation(self):
556
562
  self.transfer_backend = TransferBackend(
557
563
  self.server_args.disaggregation_transfer_backend
@@ -564,29 +570,28 @@ class Scheduler(
564
570
  req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
565
571
  buffer_size
566
572
  )
567
- aux_dtype = torch.int32
568
- # A list of metadata buffers. The shape is (b, metadata_size) where
569
- # b corresponds to a max running requests. The last shape * dtype.itemsize
570
- # should be larger than 64 bytes to work with RDMA, so we pad it.
571
- output_id_buffer = torch.zeros(
572
- (buffer_size, 16), dtype=aux_dtype, device="cpu"
573
- )
574
- metadata_buffers = [output_id_buffer]
573
+ self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
575
574
 
576
575
  # The decode requests polling kv cache
577
576
  self.disagg_decode_transfer_queue = DecodeTransferQueue(
578
577
  gloo_group=self.attn_tp_cpu_group,
579
578
  req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
580
- metadata_buffers=metadata_buffers,
579
+ metadata_buffers=self.disagg_metadata_buffers,
580
+ scheduler=self,
581
+ tree_cache=self.tree_cache,
581
582
  )
582
583
 
583
584
  # The decode requests pending for pre-allocation
584
585
  self.disagg_decode_prealloc_queue = DecodePreallocQueue(
585
586
  req_to_token_pool=self.req_to_token_pool,
586
587
  token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
588
+ draft_token_to_kv_pool=(
589
+ None
590
+ if self.draft_worker is None
591
+ else self.draft_worker.model_runner.token_to_kv_pool
592
+ ),
587
593
  req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
588
- metadata_buffers=metadata_buffers,
589
- aux_dtype=aux_dtype,
594
+ metadata_buffers=self.disagg_metadata_buffers,
590
595
  scheduler=self,
591
596
  transfer_queue=self.disagg_decode_transfer_queue,
592
597
  tree_cache=self.tree_cache,
@@ -606,20 +611,17 @@ class Scheduler(
606
611
  req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
607
612
  buffer_size
608
613
  )
609
- aux_dtype = torch.int32
610
- # A list of metadata buffers. The shape is (b, metadata_size) where
611
- # b corresponds to a max running requests. The last shape * dtype.itemsize
612
- # should be larger than 64 bytes to work with RDMA, so we pad it.
613
- output_id_buffer = torch.zeros(
614
- (buffer_size, 16), dtype=aux_dtype, device="cpu"
615
- )
616
- metadata_buffers = [output_id_buffer]
614
+ self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
617
615
 
618
616
  self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
619
617
  token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
618
+ draft_token_to_kv_pool=(
619
+ None
620
+ if self.draft_worker is None
621
+ else self.draft_worker.model_runner.token_to_kv_pool
622
+ ),
620
623
  req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
621
- metadata_buffers=metadata_buffers,
622
- aux_dtype=aux_dtype,
624
+ metadata_buffers=self.disagg_metadata_buffers,
623
625
  tp_rank=self.tp_rank,
624
626
  tp_size=self.tp_size,
625
627
  bootstrap_port=self.server_args.disaggregation_bootstrap_port,
@@ -720,7 +722,7 @@ class Scheduler(
720
722
  server_is_idle = False
721
723
  result = self.run_batch(self.cur_batch)
722
724
 
723
- # send the outputs to the next step
725
+ # (last rank) send the outputs to the next step
724
726
  if self.pp_group.is_last_rank:
725
727
  if self.cur_batch:
726
728
  next_token_ids, bids[mb_id] = (
@@ -755,24 +757,25 @@ class Scheduler(
755
757
  extend_input_len_per_req=None,
756
758
  extend_logprob_start_len_per_req=None,
757
759
  bid=bids[next_mb_id],
760
+ can_run_cuda_graph=result.can_run_cuda_graph,
758
761
  )
759
762
  self.process_batch_result(mbs[next_mb_id], output_result)
760
763
  last_mbs[next_mb_id] = mbs[next_mb_id]
761
764
 
762
- # carry the outputs to the next stage
765
+ # (not last rank)
763
766
  if not self.pp_group.is_last_rank:
764
767
  if self.cur_batch:
765
768
  bids[mb_id] = result.bid
769
+ # carry the outputs to the next stage
770
+ # send the outputs from the last round to let the next stage worker run post processing
766
771
  if pp_outputs:
767
- # send the outputs from the last round to let the next stage worker run post processing
768
772
  self.pp_group.send_tensor_dict(
769
773
  pp_outputs.tensors,
770
774
  all_gather_group=self.attn_tp_group,
771
775
  )
772
776
 
773
- if not self.pp_group.is_last_rank:
774
777
  # send out reqs to the next stage
775
- dp_offset = self.dp_rank * self.attn_tp_size
778
+ dp_offset = self.attn_dp_rank * self.attn_tp_size
776
779
  if self.attn_tp_rank == 0:
777
780
  point_to_point_pyobj(
778
781
  recv_reqs,
@@ -819,7 +822,7 @@ class Scheduler(
819
822
  recv_reqs = None
820
823
  else:
821
824
  if self.attn_tp_rank == 0:
822
- dp_offset = self.dp_rank * self.attn_tp_size
825
+ dp_offset = self.attn_dp_rank * self.attn_tp_size
823
826
  recv_reqs = point_to_point_pyobj(
824
827
  [],
825
828
  self.pp_rank * self.tp_size + dp_offset,
@@ -907,19 +910,6 @@ class Scheduler(
907
910
  fake_input_ids = [1] * seq_length
908
911
  recv_req.input_ids = fake_input_ids
909
912
 
910
- # Handle custom logit processor passed to the request
911
- custom_logit_processor = recv_req.custom_logit_processor
912
- if (
913
- not self.server_args.enable_custom_logit_processor
914
- and custom_logit_processor is not None
915
- ):
916
- logger.warning(
917
- "The SGLang server is not configured to enable custom logit processor."
918
- "The custom logit processor passed in will be ignored."
919
- "Please set --enable-custom-logits-processor to enable this feature."
920
- )
921
- custom_logit_processor = None
922
-
923
913
  if recv_req.bootstrap_port is None:
924
914
  # Use default bootstrap port
925
915
  recv_req.bootstrap_port = self.server_args.disaggregation_bootstrap_port
@@ -935,7 +925,7 @@ class Scheduler(
935
925
  stream=recv_req.stream,
936
926
  lora_path=recv_req.lora_path,
937
927
  input_embeds=recv_req.input_embeds,
938
- custom_logit_processor=custom_logit_processor,
928
+ custom_logit_processor=recv_req.custom_logit_processor,
939
929
  return_hidden_states=recv_req.return_hidden_states,
940
930
  eos_token_ids=self.model_config.hf_eos_token_id,
941
931
  bootstrap_host=recv_req.bootstrap_host,
@@ -944,6 +934,18 @@ class Scheduler(
944
934
  )
945
935
  req.tokenizer = self.tokenizer
946
936
 
937
+ if self.disaggregation_mode != DisaggregationMode.NULL:
938
+ # Invalid request for disaggregated mode
939
+ if recv_req.bootstrap_room is None:
940
+ error_message = (
941
+ f"Invalid request: Disaggregated request received without "
942
+ f"boostrap room id. {req.rid=}"
943
+ )
944
+ logger.error(error_message)
945
+ prepare_abort(req, error_message)
946
+ self.stream_output([req], req.return_logprob)
947
+ return
948
+
947
949
  if (
948
950
  recv_req.session_params is not None
949
951
  and recv_req.session_params.id is not None
@@ -1041,19 +1043,21 @@ class Scheduler(
1041
1043
  elif req.sampling_params.structural_tag:
1042
1044
  key = ("structural_tag", req.sampling_params.structural_tag)
1043
1045
 
1044
- req.grammar = self.grammar_backend.get_cached_value(key)
1045
- if not req.grammar:
1046
- req.grammar = self.grammar_backend.get_future_value(key)
1046
+ value, cache_hit = self.grammar_backend.get_cached_or_future_value(key)
1047
+ req.grammar = value
1048
+
1049
+ if not cache_hit:
1050
+ req.grammar_key = key
1047
1051
  add_to_grammar_queue = True
1048
1052
 
1049
1053
  if add_to_grammar_queue:
1050
- req.queue_time_start = time.time()
1054
+ req.queue_time_start = time.perf_counter()
1051
1055
  self.grammar_queue.append(req)
1052
1056
  else:
1053
1057
  self._add_request_to_queue(req)
1054
1058
 
1055
1059
  def _add_request_to_queue(self, req: Req):
1056
- req.queue_time_start = time.time()
1060
+ req.queue_time_start = time.perf_counter()
1057
1061
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
1058
1062
  self.disagg_prefill_bootstrap_queue.add(req)
1059
1063
  elif self.disaggregation_mode == DisaggregationMode.DECODE:
@@ -1061,8 +1065,11 @@ class Scheduler(
1061
1065
  else:
1062
1066
  self.waiting_queue.append(req)
1063
1067
 
1064
- def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
1065
- if self.disaggregation_mode == DisaggregationMode.DECODE:
1068
+ def _extend_requests_to_queue(self, reqs: List[Req]):
1069
+ if self.disaggregation_mode == DisaggregationMode.PREFILL:
1070
+ self.disagg_prefill_bootstrap_queue.extend(reqs)
1071
+ elif self.disaggregation_mode == DisaggregationMode.DECODE:
1072
+ # If this is a decode server, we put the request to the decode pending prealloc queue
1066
1073
  self.disagg_decode_prealloc_queue.extend(reqs)
1067
1074
  else:
1068
1075
  self.waiting_queue.extend(reqs)
@@ -1100,7 +1107,7 @@ class Scheduler(
1100
1107
  req.finished_reason = FINISH_ABORT(
1101
1108
  error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
1102
1109
  )
1103
- req.queue_time_start = time.time()
1110
+ req.queue_time_start = time.perf_counter()
1104
1111
  self.waiting_queue.append(req)
1105
1112
  return
1106
1113
 
@@ -1124,8 +1131,8 @@ class Scheduler(
1124
1131
  can_run_list: List[Req],
1125
1132
  running_bs: int,
1126
1133
  ):
1127
- gap_latency = time.time() - self.last_prefill_stats_tic
1128
- self.last_prefill_stats_tic = time.time()
1134
+ gap_latency = time.perf_counter() - self.last_prefill_stats_tic
1135
+ self.last_prefill_stats_tic = time.perf_counter()
1129
1136
  self.last_input_throughput = self.num_prefill_tokens / gap_latency
1130
1137
  self.num_prefill_tokens = 0
1131
1138
 
@@ -1133,9 +1140,6 @@ class Scheduler(
1133
1140
  self.token_to_kv_pool_allocator.available_size()
1134
1141
  + self.tree_cache.evictable_size()
1135
1142
  )
1136
- self._largest_prefill_len = max(
1137
- self._largest_prefill_len, adder.log_input_tokens
1138
- )
1139
1143
 
1140
1144
  num_new_seq = len(can_run_list)
1141
1145
  f = (
@@ -1172,12 +1176,15 @@ class Scheduler(
1172
1176
  self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq
1173
1177
 
1174
1178
  self.metrics_collector.log_stats(self.stats)
1179
+ self._publish_kv_events()
1175
1180
 
1176
- def log_decode_stats(self, running_batch=None):
1181
+ def log_decode_stats(
1182
+ self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
1183
+ ):
1177
1184
  batch = running_batch or self.running_batch
1178
1185
 
1179
- gap_latency = time.time() - self.last_decode_stats_tic
1180
- self.last_decode_stats_tic = time.time()
1186
+ gap_latency = time.perf_counter() - self.last_decode_stats_tic
1187
+ self.last_decode_stats_tic = time.perf_counter()
1181
1188
  self.last_gen_throughput = self.num_generated_tokens / gap_latency
1182
1189
  self.num_generated_tokens = 0
1183
1190
  num_running_reqs = len(batch.reqs)
@@ -1213,6 +1220,7 @@ class Scheduler(
1213
1220
  msg += f"pre-allocated usage: {self.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
1214
1221
 
1215
1222
  msg += (
1223
+ f"cuda graph: {can_run_cuda_graph}, "
1216
1224
  f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
1217
1225
  f"#queue-req: {len(self.waiting_queue)}"
1218
1226
  )
@@ -1225,8 +1233,10 @@ class Scheduler(
1225
1233
  self.stats.cache_hit_rate = 0.0
1226
1234
  self.stats.gen_throughput = self.last_gen_throughput
1227
1235
  self.stats.num_queue_reqs = len(self.waiting_queue)
1236
+ self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
1228
1237
  self.stats.spec_accept_length = spec_accept_length
1229
1238
  self.metrics_collector.log_stats(self.stats)
1239
+ self._publish_kv_events()
1230
1240
 
1231
1241
  def check_memory(self):
1232
1242
  available_size = (
@@ -1246,9 +1256,7 @@ class Scheduler(
1246
1256
  f"{self.token_to_kv_pool_allocator.available_size()=}\n"
1247
1257
  f"{self.tree_cache.evictable_size()=}\n"
1248
1258
  )
1249
- warnings.warn(msg)
1250
- if crash_on_warnings():
1251
- raise ValueError(msg)
1259
+ raise ValueError(msg)
1252
1260
 
1253
1261
  if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
1254
1262
  msg = (
@@ -1256,14 +1264,12 @@ class Scheduler(
1256
1264
  f"available_size={len(self.req_to_token_pool.free_slots)}, "
1257
1265
  f"total_size={self.req_to_token_pool.size}\n"
1258
1266
  )
1259
- warnings.warn(msg)
1260
- if crash_on_warnings():
1261
- raise ValueError(msg)
1267
+ raise ValueError(msg)
1262
1268
 
1263
1269
  if (
1264
1270
  self.enable_metrics
1265
1271
  and self.attn_tp_rank == 0
1266
- and time.time() > self.metrics_collector.last_log_time + 30
1272
+ and time.perf_counter() > self.metrics_collector.last_log_time + 30
1267
1273
  ):
1268
1274
  # During idle time, also collect metrics every 30 seconds.
1269
1275
  num_used = self.max_total_num_tokens - (
@@ -1276,7 +1282,9 @@ class Scheduler(
1276
1282
  self.stats.token_usage = num_used / self.max_total_num_tokens
1277
1283
  self.stats.gen_throughput = 0
1278
1284
  self.stats.num_queue_reqs = len(self.waiting_queue)
1285
+ self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
1279
1286
  self.metrics_collector.log_stats(self.stats)
1287
+ self._publish_kv_events()
1280
1288
 
1281
1289
  def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
1282
1290
  # Merge the prefill batch into the running batch
@@ -1346,7 +1354,7 @@ class Scheduler(
1346
1354
  return None
1347
1355
 
1348
1356
  running_bs = len(self.running_batch.reqs)
1349
- # Igore the check if self.chunked_req is not None.
1357
+ # Ignore the check if self.chunked_req is not None.
1350
1358
  # In the non-PP case, when self.chunked_req is not None, num_allocatable_reqs should always be greater than 0,
1351
1359
  # as the space for the chunked request has just been released.
1352
1360
  # In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
@@ -1399,6 +1407,13 @@ class Scheduler(
1399
1407
  self.running_batch.batch_is_full = True
1400
1408
  break
1401
1409
 
1410
+ if self.disaggregation_mode == DisaggregationMode.PREFILL:
1411
+ # In prefill mode, prealloc queue and transfer queue can also take memory,
1412
+ # so we need to check if the available size for the actual available size.
1413
+ if len(adder.can_run_list) >= self.req_to_token_pool.available_size():
1414
+ self.running_batch.batch_is_full = True
1415
+ break
1416
+
1402
1417
  req.init_next_round_input(
1403
1418
  None if prefix_computed else self.tree_cache,
1404
1419
  self.enable_hierarchical_cache,
@@ -1427,7 +1442,7 @@ class Scheduler(
1427
1442
  if self.enable_metrics:
1428
1443
  # only record queue time when enable_metrics is True to avoid overhead
1429
1444
  for req in can_run_list:
1430
- req.queue_time_end = time.time()
1445
+ req.queue_time_end = time.perf_counter()
1431
1446
 
1432
1447
  self.waiting_queue = [
1433
1448
  x for x in self.waiting_queue if x not in set(can_run_list)
@@ -1529,7 +1544,7 @@ class Scheduler(
1529
1544
  self.profiler_target_forward_ct
1530
1545
  and self.profiler_target_forward_ct <= self.forward_ct
1531
1546
  ):
1532
- self.stop_profile()
1547
+ self.send_to_tokenizer.send_pyobj(self.stop_profile())
1533
1548
 
1534
1549
  if self.forward_sleep_time is not None:
1535
1550
  logger.info(f"Scheduler.run_batch sleep {self.forward_sleep_time}s")
@@ -1540,11 +1555,11 @@ class Scheduler(
1540
1555
  if self.spec_algorithm.is_none():
1541
1556
  model_worker_batch = batch.get_model_worker_batch()
1542
1557
  if self.pp_group.is_last_rank:
1543
- logits_output, next_token_ids = (
1558
+ logits_output, next_token_ids, can_run_cuda_graph = (
1544
1559
  self.tp_worker.forward_batch_generation(model_worker_batch)
1545
1560
  )
1546
1561
  else:
1547
- pp_hidden_states_proxy_tensors, _ = (
1562
+ pp_hidden_states_proxy_tensors, _, can_run_cuda_graph = (
1548
1563
  self.tp_worker.forward_batch_generation(model_worker_batch)
1549
1564
  )
1550
1565
  bid = model_worker_batch.bid
@@ -1554,6 +1569,7 @@ class Scheduler(
1554
1569
  next_token_ids,
1555
1570
  bid,
1556
1571
  num_accepted_tokens,
1572
+ can_run_cuda_graph,
1557
1573
  ) = self.draft_worker.forward_batch_speculative_generation(batch)
1558
1574
  self.spec_num_total_accepted_tokens += (
1559
1575
  num_accepted_tokens + batch.batch_size()
@@ -1587,6 +1603,7 @@ class Scheduler(
1587
1603
  extend_input_len_per_req=extend_input_len_per_req,
1588
1604
  extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
1589
1605
  bid=bid,
1606
+ can_run_cuda_graph=can_run_cuda_graph,
1590
1607
  )
1591
1608
  else: # embedding or reward model
1592
1609
  model_worker_batch = batch.get_model_worker_batch()
@@ -1609,14 +1626,9 @@ class Scheduler(
1609
1626
  elif batch.forward_mode.is_idle():
1610
1627
  if self.enable_overlap:
1611
1628
  self.tp_worker.resolve_last_batch_result(launch_done)
1612
- if batch.next_batch_sampling_info:
1613
- batch.next_batch_sampling_info.update_regex_vocab_mask()
1614
- self.current_stream.synchronize()
1615
- batch.next_batch_sampling_info.sampling_info_done.set()
1629
+ self.set_next_batch_sampling_info_done(batch)
1616
1630
  elif batch.forward_mode.is_dummy_first():
1617
- batch.next_batch_sampling_info.update_regex_vocab_mask()
1618
- self.current_stream.synchronize()
1619
- batch.next_batch_sampling_info.sampling_info_done.set()
1631
+ self.set_next_batch_sampling_info_done(batch)
1620
1632
 
1621
1633
  if self.return_health_check_ct:
1622
1634
  # Return some signal for the health check.
@@ -1630,6 +1642,7 @@ class Scheduler(
1630
1642
  local_batch,
1631
1643
  dp_size=self.server_args.dp_size,
1632
1644
  attn_tp_size=self.attn_tp_size,
1645
+ moe_dense_tp_size=self.server_args.moe_dense_tp_size,
1633
1646
  tp_cpu_group=self.tp_cpu_group,
1634
1647
  get_idle_batch=self.get_idle_batch,
1635
1648
  disable_cuda_graph=self.server_args.disable_cuda_graph,
@@ -1642,6 +1655,7 @@ class Scheduler(
1642
1655
  local_batch: ScheduleBatch,
1643
1656
  dp_size,
1644
1657
  attn_tp_size: int,
1658
+ moe_dense_tp_size: Optional[int],
1645
1659
  tp_cpu_group,
1646
1660
  get_idle_batch,
1647
1661
  disable_cuda_graph: bool,
@@ -1651,15 +1665,15 @@ class Scheduler(
1651
1665
  # Check if other DP workers have running batches
1652
1666
  if local_batch is None:
1653
1667
  num_tokens = 0
1654
- global_num_tokens_for_logprob = 0
1668
+ num_tokens_for_logprob = 0
1655
1669
  elif local_batch.forward_mode.is_decode():
1656
1670
  num_tokens = local_batch.batch_size()
1657
1671
  if not spec_algorithm.is_none() and spec_algorithm.is_eagle():
1658
1672
  num_tokens = num_tokens * speculative_num_draft_tokens
1659
- global_num_tokens_for_logprob = num_tokens
1673
+ num_tokens_for_logprob = num_tokens
1660
1674
  else:
1661
1675
  num_tokens = local_batch.extend_num_tokens
1662
- global_num_tokens_for_logprob = sum(
1676
+ num_tokens_for_logprob = sum(
1663
1677
  [
1664
1678
  # We should have at least 1 token for sample in every case.
1665
1679
  max(extend_len - logprob_start_len, 1)
@@ -1686,7 +1700,7 @@ class Scheduler(
1686
1700
  [
1687
1701
  num_tokens,
1688
1702
  can_cuda_graph,
1689
- global_num_tokens_for_logprob,
1703
+ num_tokens_for_logprob,
1690
1704
  is_extend_in_batch,
1691
1705
  ],
1692
1706
  dtype=torch.int64,
@@ -1709,8 +1723,15 @@ class Scheduler(
1709
1723
  local_batch = get_idle_batch()
1710
1724
 
1711
1725
  if local_batch is not None:
1712
- local_batch.global_num_tokens = global_num_tokens
1713
- local_batch.global_num_tokens_for_logprob = global_num_tokens_for_logprob
1726
+ # TODO: handle the case when moe_dense_tp_size != 1
1727
+ if moe_dense_tp_size == 1 and global_server_args_dict["enable_dp_lm_head"]:
1728
+ local_batch.global_num_tokens = [num_tokens]
1729
+ local_batch.global_num_tokens_for_logprob = [num_tokens_for_logprob]
1730
+ else:
1731
+ local_batch.global_num_tokens = global_num_tokens
1732
+ local_batch.global_num_tokens_for_logprob = (
1733
+ global_num_tokens_for_logprob
1734
+ )
1714
1735
 
1715
1736
  # Check forward mode for cuda graph
1716
1737
  if not disable_cuda_graph:
@@ -1736,11 +1757,17 @@ class Scheduler(
1736
1757
  """Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
1737
1758
 
1738
1759
  num_ready_reqs = 0
1760
+ num_abort_reqs = 0
1739
1761
  for req in self.grammar_queue:
1740
1762
  try:
1741
- req.grammar = req.grammar.result(timeout=0.05)
1763
+ req.grammar = req.grammar.result(timeout=0.03)
1764
+ if req.grammar:
1765
+ self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
1742
1766
  num_ready_reqs += 1
1743
1767
  except futures._base.TimeoutError:
1768
+ req.grammar_wait_ct += 1
1769
+ if req.grammar_wait_ct > GRAMMAR_TIMEOUT / 0.03:
1770
+ num_abort_reqs = 1
1744
1771
  break
1745
1772
 
1746
1773
  if self.server_args.enable_dp_attention:
@@ -1752,46 +1779,70 @@ class Scheduler(
1752
1779
 
1753
1780
  if tp_size > 1:
1754
1781
  # Sync across TP ranks to make sure they have the same number of ready requests
1755
- tensor = torch.tensor(num_ready_reqs, dtype=torch.int32)
1782
+ tensor = torch.tensor([num_ready_reqs, num_abort_reqs], dtype=torch.int32)
1756
1783
  torch.distributed.all_reduce(
1757
1784
  tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group
1758
1785
  )
1759
- num_ready_reqs_max = tensor.item()
1786
+ num_ready_reqs_max, num_abort_reqs_max = tensor.tolist()
1787
+
1760
1788
  for i in range(num_ready_reqs, num_ready_reqs_max):
1761
- self.grammar_queue[i].grammar = self.grammar_queue[i].grammar.result()
1762
- num_ready_reqs = num_ready_reqs_max
1789
+ req = self.grammar_queue[i]
1790
+ req.grammar = req.grammar.result()
1791
+ if req.grammar:
1792
+ self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
1793
+
1794
+ for i in range(num_ready_reqs, num_ready_reqs + num_abort_reqs_max):
1795
+ req = self.grammar_queue[i]
1796
+ req.grammar.cancel()
1797
+ req.grammar = None
1798
+ error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}"
1799
+ logger.error(error_msg)
1800
+ req.finished_reason = FINISH_ABORT(
1801
+ error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
1802
+ )
1803
+ num_ready_reqs = num_ready_reqs_max + num_abort_reqs_max
1763
1804
 
1764
1805
  self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
1765
1806
  self.grammar_queue = self.grammar_queue[num_ready_reqs:]
1766
1807
 
1808
+ def set_next_batch_sampling_info_done(self, batch: ScheduleBatch):
1809
+ if batch.next_batch_sampling_info:
1810
+ if batch.next_batch_sampling_info.grammars is not None:
1811
+ batch.next_batch_sampling_info.update_regex_vocab_mask()
1812
+ self.current_stream.synchronize()
1813
+ batch.next_batch_sampling_info.sampling_info_done.set()
1814
+
1767
1815
  def watchdog_thread(self):
1768
1816
  """A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
1769
1817
  self.watchdog_last_forward_ct = 0
1770
- self.watchdog_last_time = time.time()
1818
+ self.watchdog_last_time = time.perf_counter()
1771
1819
 
1772
1820
  while True:
1773
- current = time.time()
1821
+ current = time.perf_counter()
1774
1822
  if self.cur_batch is not None:
1775
1823
  if self.watchdog_last_forward_ct == self.forward_ct:
1776
1824
  if current > self.watchdog_last_time + self.watchdog_timeout:
1777
- logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
1778
1825
  break
1779
1826
  else:
1780
1827
  self.watchdog_last_forward_ct = self.forward_ct
1781
1828
  self.watchdog_last_time = current
1782
1829
  time.sleep(self.watchdog_timeout // 2)
1783
1830
 
1784
- # Print batch size and memory pool info to check whether there are de-sync issues.
1785
- logger.error(
1786
- f"{self.cur_batch.batch_size()=}, "
1787
- f"{self.cur_batch.reqs=}, "
1788
- f"{self.token_to_kv_pool_allocator.available_size()=}, "
1789
- f"{self.tree_cache.evictable_size()=}, "
1790
- )
1791
- # Wait for some time so that the parent process can print the error.
1831
+ if not disable_request_logging():
1832
+ # Print batch size and memory pool info to check whether there are de-sync issues.
1833
+ logger.error(
1834
+ f"{self.cur_batch.batch_size()=}, "
1835
+ f"{self.cur_batch.reqs=}, "
1836
+ f"{self.token_to_kv_pool_allocator.available_size()=}, "
1837
+ f"{self.tree_cache.evictable_size()=}, "
1838
+ )
1839
+
1792
1840
  pyspy_dump_schedulers()
1841
+ logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
1793
1842
  print(file=sys.stderr, flush=True)
1794
1843
  print(file=sys.stdout, flush=True)
1844
+
1845
+ # Wait for some time so that the parent process can print the error.
1795
1846
  time.sleep(5)
1796
1847
  self.parent_process.send_signal(signal.SIGQUIT)
1797
1848
 
@@ -1923,25 +1974,30 @@ class Scheduler(
1923
1974
  )
1924
1975
 
1925
1976
  def abort_request(self, recv_req: AbortReq):
1977
+ # TODO(lmzheng): abort the requests in the grammar queue.
1978
+
1926
1979
  # Delete requests in the waiting queue
1927
1980
  to_del = []
1928
1981
  for i, req in enumerate(self.waiting_queue):
1929
1982
  if req.rid.startswith(recv_req.rid):
1930
1983
  to_del.append(i)
1931
- break
1932
1984
 
1933
1985
  # Sort in reverse order to avoid index issues when deleting
1934
- for i in sorted(to_del, reverse=True):
1986
+ for i in reversed(to_del):
1935
1987
  req = self.waiting_queue.pop(i)
1988
+ self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
1936
1989
  logger.debug(f"Abort queued request. {req.rid=}")
1937
- return
1938
1990
 
1939
1991
  # Delete requests in the running batch
1940
- for req in self.running_batch.reqs:
1992
+ if self.cur_batch is self.running_batch or self.cur_batch is None:
1993
+ reqs = self.running_batch.reqs
1994
+ else:
1995
+ reqs = self.running_batch.reqs + self.cur_batch.reqs
1996
+
1997
+ for req in reqs:
1941
1998
  if req.rid.startswith(recv_req.rid) and not req.finished():
1942
1999
  logger.debug(f"Abort running request. {req.rid=}")
1943
2000
  req.to_abort = True
1944
- return
1945
2001
 
1946
2002
  def _pause_engine(self) -> Tuple[List[Req], int]:
1947
2003
  raise NotImplementedError()
@@ -2090,7 +2146,10 @@ class Scheduler(
2090
2146
 
2091
2147
  def stop_profile(self) -> None:
2092
2148
  if self.profiler_activities is None:
2093
- return
2149
+ return ProfileReqOutput(
2150
+ success=False,
2151
+ message="Profiling is not in progress. Call /start_profile first.",
2152
+ )
2094
2153
 
2095
2154
  logger.info("Stop profiling...")
2096
2155
  if self.torch_profiler is not None:
@@ -2121,18 +2180,15 @@ class Scheduler(
2121
2180
  self.torch_profiler_output_dir = None
2122
2181
  self.profiler_activities = None
2123
2182
 
2124
- if self.profiler_target_forward_ct:
2125
- self.send_to_tokenizer.send_pyobj(
2126
- ProfileReqOutput(success=True, message="Succeeded.")
2127
- )
2183
+ return ProfileReqOutput(success=True, message="Succeeded")
2128
2184
 
2129
2185
  def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
2130
2186
  if recv_req == ExpertDistributionReq.START_RECORD:
2131
- expert_distribution_recorder.start_record()
2187
+ get_global_expert_distribution_recorder().start_record()
2132
2188
  elif recv_req == ExpertDistributionReq.STOP_RECORD:
2133
- expert_distribution_recorder.stop_record()
2189
+ get_global_expert_distribution_recorder().stop_record()
2134
2190
  elif recv_req == ExpertDistributionReq.DUMP_RECORD:
2135
- expert_distribution_recorder.dump_record()
2191
+ get_global_expert_distribution_recorder().dump_record()
2136
2192
  else:
2137
2193
  raise ValueError("Unrecognized ExpertDistributionReq value")
2138
2194
  return ExpertDistributionReqOutput()
@@ -2162,14 +2218,21 @@ class Scheduler(
2162
2218
 
2163
2219
  def get_print_prefix(self):
2164
2220
  prefix = ""
2165
- if self.dp_rank is not None:
2166
- prefix += f" DP{self.dp_rank}"
2221
+ if self.attn_dp_rank is not None:
2222
+ prefix += f" DP{self.attn_dp_rank}"
2167
2223
  if self.server_args.tp_size > 1:
2168
2224
  prefix += f" TP{self.tp_rank}"
2169
2225
  if self.pp_size > 1:
2170
2226
  prefix += f" PP{self.pp_rank}"
2171
2227
  return prefix
2172
2228
 
2229
+ def _publish_kv_events(self):
2230
+ if self.enable_kv_cache_events:
2231
+ events = self.tree_cache.take_events()
2232
+ if events:
2233
+ batch = KVEventBatch(ts=time.time(), events=events)
2234
+ self.kv_event_publisher.publish(batch)
2235
+
2173
2236
 
2174
2237
  def is_health_check_generate_req(recv_req):
2175
2238
  return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
@@ -2225,6 +2288,10 @@ def run_scheduler_process(
2225
2288
  if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
2226
2289
  set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
2227
2290
 
2291
+ embedding_cache_size = 100
2292
+ if "SGLANG_VLM_CACHE_SIZE_MB" in os.environ:
2293
+ embedding_cache_size = int(os.environ["SGLANG_VLM_CACHE_SIZE_MB"])
2294
+ init_embedding_cache(embedding_cache_size * 1024 * 1024)
2228
2295
  # Create a scheduler and run the event loop
2229
2296
  try:
2230
2297
  scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)