sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/configs/deepseekvl2.py +11 -2
  4. sglang/srt/configs/internvl.py +3 -0
  5. sglang/srt/configs/janus_pro.py +3 -0
  6. sglang/srt/configs/model_config.py +9 -7
  7. sglang/srt/configs/update_config.py +3 -1
  8. sglang/srt/conversation.py +1 -0
  9. sglang/srt/custom_op.py +5 -2
  10. sglang/srt/disaggregation/decode.py +9 -1
  11. sglang/srt/disaggregation/mooncake/conn.py +44 -56
  12. sglang/srt/distributed/parallel_state.py +33 -0
  13. sglang/srt/entrypoints/engine.py +30 -26
  14. sglang/srt/entrypoints/openai/serving_chat.py +21 -2
  15. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  16. sglang/srt/function_call/function_call_parser.py +2 -0
  17. sglang/srt/function_call/qwen3_detector.py +150 -0
  18. sglang/srt/hf_transformers_utils.py +0 -1
  19. sglang/srt/layers/activation.py +13 -0
  20. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  21. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  22. sglang/srt/layers/linear.py +13 -102
  23. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  24. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  25. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  26. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  27. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +35 -45
  33. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  34. sglang/srt/layers/moe/topk.py +187 -12
  35. sglang/srt/layers/quantization/__init__.py +20 -134
  36. sglang/srt/layers/quantization/awq.py +578 -11
  37. sglang/srt/layers/quantization/awq_triton.py +339 -0
  38. sglang/srt/layers/quantization/base_config.py +85 -10
  39. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  40. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  41. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +24 -73
  42. sglang/srt/layers/quantization/fp8.py +273 -62
  43. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  44. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  45. sglang/srt/layers/quantization/gptq.py +501 -143
  46. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  47. sglang/srt/layers/quantization/modelopt_quant.py +26 -108
  48. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  49. sglang/srt/layers/quantization/petit.py +252 -0
  50. sglang/srt/layers/quantization/petit_utils.py +104 -0
  51. sglang/srt/layers/quantization/qoq.py +7 -6
  52. sglang/srt/layers/quantization/scalar_type.py +352 -0
  53. sglang/srt/layers/quantization/unquant.py +422 -0
  54. sglang/srt/layers/quantization/utils.py +343 -3
  55. sglang/srt/layers/quantization/w4afp8.py +8 -4
  56. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  57. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  58. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  59. sglang/srt/lora/lora.py +0 -4
  60. sglang/srt/lora/lora_manager.py +87 -53
  61. sglang/srt/lora/mem_pool.py +81 -33
  62. sglang/srt/lora/utils.py +12 -5
  63. sglang/srt/managers/cache_controller.py +241 -0
  64. sglang/srt/managers/io_struct.py +41 -29
  65. sglang/srt/managers/mm_utils.py +7 -8
  66. sglang/srt/managers/schedule_batch.py +150 -110
  67. sglang/srt/managers/schedule_policy.py +68 -27
  68. sglang/srt/managers/scheduler.py +243 -61
  69. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  70. sglang/srt/managers/tokenizer_manager.py +11 -3
  71. sglang/srt/managers/tp_worker.py +14 -0
  72. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  73. sglang/srt/mem_cache/allocator.py +7 -16
  74. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  75. sglang/srt/mem_cache/chunk_cache.py +5 -2
  76. sglang/srt/mem_cache/hicache_storage.py +152 -0
  77. sglang/srt/mem_cache/hiradix_cache.py +179 -4
  78. sglang/srt/mem_cache/memory_pool.py +16 -1
  79. sglang/srt/mem_cache/memory_pool_host.py +41 -2
  80. sglang/srt/mem_cache/radix_cache.py +26 -0
  81. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  82. sglang/srt/metrics/collector.py +9 -0
  83. sglang/srt/model_executor/cuda_graph_runner.py +5 -6
  84. sglang/srt/model_executor/forward_batch_info.py +14 -1
  85. sglang/srt/model_executor/model_runner.py +109 -22
  86. sglang/srt/model_loader/loader.py +7 -1
  87. sglang/srt/model_loader/utils.py +4 -4
  88. sglang/srt/models/clip.py +1 -1
  89. sglang/srt/models/deepseek.py +9 -6
  90. sglang/srt/models/deepseek_janus_pro.py +1 -1
  91. sglang/srt/models/deepseek_v2.py +191 -171
  92. sglang/srt/models/deepseek_vl2.py +5 -5
  93. sglang/srt/models/gemma.py +48 -0
  94. sglang/srt/models/gemma2.py +52 -0
  95. sglang/srt/models/gemma3_causal.py +63 -0
  96. sglang/srt/models/gemma3_mm.py +1 -1
  97. sglang/srt/models/gemma3n_mm.py +2 -4
  98. sglang/srt/models/granitemoe.py +385 -0
  99. sglang/srt/models/grok.py +9 -3
  100. sglang/srt/models/hunyuan.py +63 -16
  101. sglang/srt/models/internvl.py +1 -1
  102. sglang/srt/models/kimi_vl.py +1 -1
  103. sglang/srt/models/llama.py +41 -0
  104. sglang/srt/models/llama4.py +11 -11
  105. sglang/srt/models/llava.py +2 -2
  106. sglang/srt/models/llavavid.py +1 -1
  107. sglang/srt/models/minicpm.py +0 -2
  108. sglang/srt/models/minicpmo.py +3 -7
  109. sglang/srt/models/minicpmv.py +1 -1
  110. sglang/srt/models/mistral.py +1 -1
  111. sglang/srt/models/mixtral.py +9 -2
  112. sglang/srt/models/mllama.py +3 -5
  113. sglang/srt/models/mllama4.py +3 -3
  114. sglang/srt/models/olmoe.py +8 -5
  115. sglang/srt/models/persimmon.py +330 -0
  116. sglang/srt/models/phi.py +321 -0
  117. sglang/srt/models/phi4mm.py +44 -4
  118. sglang/srt/models/phi4mm_audio.py +1260 -0
  119. sglang/srt/models/phi4mm_utils.py +1917 -0
  120. sglang/srt/models/phimoe.py +9 -3
  121. sglang/srt/models/qwen.py +37 -0
  122. sglang/srt/models/qwen2.py +41 -0
  123. sglang/srt/models/qwen2_5_vl.py +4 -4
  124. sglang/srt/models/qwen2_audio.py +1 -1
  125. sglang/srt/models/qwen2_moe.py +53 -5
  126. sglang/srt/models/qwen2_vl.py +4 -4
  127. sglang/srt/models/qwen3.py +65 -1
  128. sglang/srt/models/qwen3_moe.py +56 -18
  129. sglang/srt/models/vila.py +1 -1
  130. sglang/srt/multimodal/processors/base_processor.py +91 -97
  131. sglang/srt/multimodal/processors/clip.py +21 -19
  132. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  133. sglang/srt/multimodal/processors/gemma3.py +13 -17
  134. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  135. sglang/srt/multimodal/processors/internvl.py +9 -10
  136. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  137. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  138. sglang/srt/multimodal/processors/llava.py +4 -2
  139. sglang/srt/multimodal/processors/minicpm.py +35 -44
  140. sglang/srt/multimodal/processors/mlama.py +21 -18
  141. sglang/srt/multimodal/processors/mllama4.py +4 -5
  142. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  143. sglang/srt/multimodal/processors/pixtral.py +14 -35
  144. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  145. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  146. sglang/srt/multimodal/processors/vila.py +14 -14
  147. sglang/srt/sampling/sampling_params.py +8 -1
  148. sglang/srt/server_args.py +393 -230
  149. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +9 -1
  150. sglang/srt/two_batch_overlap.py +1 -0
  151. sglang/srt/utils.py +27 -1
  152. sglang/test/runners.py +14 -3
  153. sglang/test/test_block_fp8.py +8 -3
  154. sglang/test/test_block_fp8_ep.py +1 -1
  155. sglang/test/test_custom_ops.py +12 -7
  156. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  157. sglang/test/test_fp4_moe.py +1 -3
  158. sglang/test/test_marlin_moe.py +286 -0
  159. sglang/test/test_marlin_utils.py +171 -0
  160. sglang/test/test_utils.py +35 -0
  161. sglang/version.py +1 -1
  162. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/METADATA +8 -8
  163. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/RECORD +166 -146
  164. sglang/srt/layers/quantization/quant_utils.py +0 -166
  165. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  166. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/WHEEL +0 -0
  167. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/licenses/LICENSE +0 -0
  168. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/top_level.txt +0 -0
@@ -129,10 +129,10 @@ from sglang.srt.managers.session_controller import Session
129
129
  from sglang.srt.managers.tp_worker import TpModelWorker
130
130
  from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
131
131
  from sglang.srt.managers.utils import validate_input_length
132
- from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
133
132
  from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
134
133
  from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
135
134
  from sglang.srt.mem_cache.radix_cache import RadixCache
135
+ from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
136
136
  from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
137
137
  from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
138
138
  from sglang.srt.reasoning_parser import ReasoningParser
@@ -252,6 +252,9 @@ class Scheduler(
252
252
  self.enable_overlap = not server_args.disable_overlap_schedule
253
253
  self.skip_tokenizer_init = server_args.skip_tokenizer_init
254
254
  self.enable_metrics = server_args.enable_metrics
255
+ self.enable_metrics_for_all_schedulers = (
256
+ server_args.enable_metrics_for_all_schedulers
257
+ )
255
258
  self.enable_kv_cache_events = server_args.kv_events_config is not None
256
259
  self.stream_interval = server_args.stream_interval
257
260
  self.spec_algorithm = SpeculativeAlgorithm.from_string(
@@ -259,6 +262,7 @@ class Scheduler(
259
262
  )
260
263
  self.gpu_id = gpu_id
261
264
  self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
265
+ self.enable_hicache_storage = server_args.hicache_storage_backend is not None
262
266
  self.page_size = server_args.page_size
263
267
  self.dp_size = server_args.dp_size
264
268
  self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
@@ -281,9 +285,6 @@ class Scheduler(
281
285
  self.send_to_tokenizer = get_zmq_socket(
282
286
  context, zmq.PUSH, port_args.tokenizer_ipc_name, False
283
287
  )
284
- self.send_metrics_from_scheduler = get_zmq_socket(
285
- context, zmq.PUSH, port_args.metrics_ipc_name, False
286
- )
287
288
 
288
289
  if server_args.skip_tokenizer_init:
289
290
  # Directly send to the TokenizerManager
@@ -309,10 +310,14 @@ class Scheduler(
309
310
  else:
310
311
  self.recv_from_tokenizer = None
311
312
  self.recv_from_rpc = None
312
- self.send_metrics_from_scheduler = None
313
313
  self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
314
314
  self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
315
315
 
316
+ if self.current_scheduler_metrics_enabled():
317
+ self.send_metrics_from_scheduler = get_zmq_socket(
318
+ context, zmq.PUSH, port_args.metrics_ipc_name, False
319
+ )
320
+
316
321
  # Init tokenizer
317
322
  self.init_tokenizer()
318
323
 
@@ -390,6 +395,14 @@ class Scheduler(
390
395
  global_server_args_dict.update(worker_global_server_args_dict)
391
396
  set_random_seed(self.random_seed)
392
397
 
398
+ # Hybrid
399
+ self.is_hybrid = self.tp_worker.is_hybrid
400
+ if self.is_hybrid:
401
+ self.sliding_window_size = self.tp_worker.sliding_window_size
402
+ self.full_tokens_per_layer, self.swa_tokens_per_layer = (
403
+ self.tp_worker.get_tokens_per_layer_info()
404
+ )
405
+
393
406
  # Print debug info
394
407
  if tp_rank == 0:
395
408
  avail_mem = get_available_gpu_memory(
@@ -487,7 +500,7 @@ class Scheduler(
487
500
  self.init_profier()
488
501
 
489
502
  # Init metrics stats
490
- self.init_metrics()
503
+ self.init_metrics(tp_rank, pp_rank, dp_rank)
491
504
  self.init_kv_events(server_args.kv_events_config)
492
505
 
493
506
  # Init request dispatcher
@@ -529,6 +542,9 @@ class Scheduler(
529
542
  if get_bool_env_var("SGLANG_GC_LOG"):
530
543
  configure_gc_logger()
531
544
 
545
+ def current_scheduler_metrics_enabled(self):
546
+ return self.attn_tp_rank == 0 or self.enable_metrics_for_all_schedulers
547
+
532
548
  def maybe_sleep_on_idle(self):
533
549
  if self.idle_sleeper is not None:
534
550
  self.idle_sleeper.maybe_sleep()
@@ -570,7 +586,7 @@ class Scheduler(
570
586
  server_args.chunked_prefill_size is not None
571
587
  and server_args.disable_radix_cache
572
588
  ):
573
- if self.model_config.is_hybrid:
589
+ if self.is_hybrid:
574
590
  ChunkCacheClass = SWAChunkCache
575
591
  else:
576
592
  ChunkCacheClass = ChunkCache
@@ -599,10 +615,22 @@ class Scheduler(
599
615
  == "fa3" # hot fix for incompatibility
600
616
  else server_args.hicache_io_backend
601
617
  ),
618
+ hicache_storage_backend=server_args.hicache_storage_backend,
602
619
  )
603
620
  self.tp_worker.register_hicache_layer_transfer_counter(
604
621
  self.tree_cache.cache_controller.layer_done_counter
605
622
  )
623
+ elif self.is_hybrid:
624
+ assert (
625
+ self.server_args.disaggregation_mode == "null"
626
+ ), "Hybrid mode does not support disaggregation yet"
627
+ self.tree_cache = SWARadixCache(
628
+ req_to_token_pool=self.req_to_token_pool,
629
+ token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
630
+ sliding_window_size=self.sliding_window_size,
631
+ page_size=self.page_size,
632
+ disable=server_args.disable_radix_cache,
633
+ )
606
634
 
607
635
  else:
608
636
  self.tree_cache = RadixCache(
@@ -641,7 +669,7 @@ class Scheduler(
641
669
  self.profile_in_progress: bool = False
642
670
  self.rpd_profiler = None
643
671
 
644
- def init_metrics(self):
672
+ def init_metrics(self, tp_rank: int, pp_rank: int, dp_rank: Optional[int]):
645
673
  self.last_gen_throughput: float = 0.0
646
674
  self.last_input_throughput: float = 0.0
647
675
  self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
@@ -649,15 +677,19 @@ class Scheduler(
649
677
  self.spec_num_total_forward_ct = 0
650
678
  self.cum_spec_accept_length = 0
651
679
  self.cum_spec_accept_count = 0
680
+ self.total_retracted_reqs = 0
652
681
  self.stats = SchedulerStats()
653
682
  if self.enable_metrics:
654
683
  engine_type = "unified"
655
- self.metrics_collector = SchedulerMetricsCollector(
656
- labels={
657
- "model_name": self.server_args.served_model_name,
658
- "engine_type": engine_type,
659
- },
660
- )
684
+ labels = {
685
+ "model_name": self.server_args.served_model_name,
686
+ "engine_type": engine_type,
687
+ "tp_rank": tp_rank,
688
+ "pp_rank": pp_rank,
689
+ }
690
+ if dp_rank is not None:
691
+ labels["dp_rank"] = dp_rank
692
+ self.metrics_collector = SchedulerMetricsCollector(labels=labels)
661
693
 
662
694
  def init_kv_events(self, kv_events_config: Optional[str]):
663
695
  if self.enable_kv_cache_events:
@@ -774,6 +806,7 @@ class Scheduler(
774
806
  else:
775
807
  # When the server is idle, do self-check and re-init some states
776
808
  self.check_memory()
809
+ self.check_tree_cache()
777
810
  self.new_token_ratio = self.init_new_token_ratio
778
811
  self.maybe_sleep_on_idle()
779
812
 
@@ -819,6 +852,7 @@ class Scheduler(
819
852
  elif batch is None:
820
853
  # When the server is idle, do self-check and re-init some states
821
854
  self.check_memory()
855
+ self.check_tree_cache()
822
856
  self.new_token_ratio = self.init_new_token_ratio
823
857
  self.maybe_sleep_on_idle()
824
858
 
@@ -955,6 +989,7 @@ class Scheduler(
955
989
  # When the server is idle, self-check and re-init some states
956
990
  if server_is_idle:
957
991
  self.check_memory()
992
+ self.check_tree_cache()
958
993
  self.new_token_ratio = self.init_new_token_ratio
959
994
  self.maybe_sleep_on_idle()
960
995
 
@@ -1220,6 +1255,15 @@ class Scheduler(
1220
1255
  elif self.disaggregation_mode == DisaggregationMode.DECODE:
1221
1256
  self.disagg_decode_prealloc_queue.add(req)
1222
1257
  else:
1258
+ if self.enable_hicache_storage:
1259
+ req.init_next_round_input(self.tree_cache)
1260
+ last_hash = req.last_host_node.get_last_hash_value()
1261
+ matched_len = len(req.prefix_indices) + req.host_hit_length
1262
+ if (matched_len > 0 and last_hash is not None) or matched_len == 0:
1263
+ new_input_tokens = req.fill_ids[matched_len:]
1264
+ self.tree_cache.prefetch_from_storage(
1265
+ req.rid, req.last_host_node, new_input_tokens, last_hash
1266
+ )
1223
1267
  self.waiting_queue.append(req)
1224
1268
 
1225
1269
  def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
@@ -1306,9 +1350,26 @@ class Scheduler(
1306
1350
  self.last_input_throughput = self.last_prefill_tokens / gap_latency
1307
1351
  self.last_prefill_tokens = adder.log_input_tokens
1308
1352
 
1309
- usage_msg, num_used = self.token_to_kv_pool_allocator.log_usage(
1310
- self.tree_cache.evictable_size()
1311
- )
1353
+ if self.is_hybrid:
1354
+ (
1355
+ full_num_used,
1356
+ swa_num_used,
1357
+ full_token_usage,
1358
+ swa_token_usage,
1359
+ _,
1360
+ _,
1361
+ _,
1362
+ _,
1363
+ ) = self._get_swa_token_info()
1364
+ num_used = max(full_num_used, swa_num_used)
1365
+ token_usage = max(full_token_usage, swa_token_usage)
1366
+ token_msg = (
1367
+ f"full token usage: {full_token_usage:.2f}, "
1368
+ f"swa token usage: {swa_token_usage:.2f}, "
1369
+ )
1370
+ else:
1371
+ num_used, token_usage, _, _ = self._get_token_info()
1372
+ token_msg = f"token usage: {token_usage:.2f}, "
1312
1373
 
1313
1374
  num_new_seq = len(can_run_list)
1314
1375
  f = (
@@ -1316,7 +1377,7 @@ class Scheduler(
1316
1377
  f"#new-seq: {num_new_seq}, "
1317
1378
  f"#new-token: {adder.log_input_tokens}, "
1318
1379
  f"#cached-token: {adder.log_hit_tokens}, "
1319
- f"{usage_msg}"
1380
+ f"{token_msg}"
1320
1381
  )
1321
1382
 
1322
1383
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
@@ -1328,8 +1389,6 @@ class Scheduler(
1328
1389
  f += f"#running-req: {running_bs}, "
1329
1390
  f += f"#queue-req: {len(self.waiting_queue)}, "
1330
1391
 
1331
- f += f"timestamp: {datetime.datetime.now().isoformat()}"
1332
-
1333
1392
  logger.info(f)
1334
1393
 
1335
1394
  if self.enable_metrics:
@@ -1338,7 +1397,7 @@ class Scheduler(
1338
1397
  )
1339
1398
  self.stats.num_running_reqs = running_bs
1340
1399
  self.stats.num_used_tokens = num_used
1341
- self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
1400
+ self.stats.token_usage = round(token_usage, 2)
1342
1401
  self.stats.num_queue_reqs = len(self.waiting_queue)
1343
1402
  self.stats.cache_hit_rate = cache_hit_rate
1344
1403
 
@@ -1361,16 +1420,35 @@ class Scheduler(
1361
1420
  self.last_gen_throughput = self.num_generated_tokens / gap_latency
1362
1421
  self.num_generated_tokens = 0
1363
1422
  num_running_reqs = len(batch.reqs)
1364
- usage_msg, num_used = self.token_to_kv_pool_allocator.log_usage(
1365
- self.tree_cache.evictable_size()
1366
- )
1423
+ if self.is_hybrid:
1424
+ (
1425
+ full_num_used,
1426
+ swa_num_used,
1427
+ full_token_usage,
1428
+ swa_token_usage,
1429
+ _,
1430
+ _,
1431
+ _,
1432
+ _,
1433
+ ) = self._get_swa_token_info()
1434
+ num_used = max(full_num_used, swa_num_used)
1435
+ token_usage = max(full_token_usage, swa_token_usage)
1436
+ token_msg = (
1437
+ f"#full token: {full_num_used}, "
1438
+ f"full token usage: {full_token_usage:.2f}, "
1439
+ f"#swa token: {swa_num_used}, "
1440
+ f"swa token usage: {swa_token_usage:.2f}, "
1441
+ )
1442
+ else:
1443
+ num_used, token_usage, _, _ = self._get_token_info()
1444
+ token_msg = f"#token: {num_used}, " f"token usage: {token_usage:.2f}, "
1367
1445
 
1368
1446
  if RECORD_STEP_TIME:
1369
1447
  self.step_time_dict[num_running_reqs].append(
1370
1448
  gap_latency / self.server_args.decode_log_interval
1371
1449
  )
1372
1450
 
1373
- msg = f"Decode batch. " f"#running-req: {num_running_reqs}, " f"{usage_msg}"
1451
+ msg = f"Decode batch. #running-req: {num_running_reqs}, {token_msg}"
1374
1452
 
1375
1453
  if self.spec_algorithm.is_none():
1376
1454
  spec_accept_length = 0
@@ -1391,42 +1469,52 @@ class Scheduler(
1391
1469
  f"cuda graph: {can_run_cuda_graph}, "
1392
1470
  f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
1393
1471
  f"#queue-req: {len(self.waiting_queue)}, "
1394
- f"timestamp: {datetime.datetime.now().isoformat()}"
1395
1472
  )
1396
1473
 
1397
1474
  logger.info(msg)
1398
1475
  if self.enable_metrics:
1399
1476
  self.stats.num_running_reqs = num_running_reqs
1400
1477
  self.stats.num_used_tokens = num_used
1401
- self.stats.token_usage = num_used / self.max_total_num_tokens
1478
+ self.stats.token_usage = round(token_usage, 2)
1402
1479
  self.stats.cache_hit_rate = 0.0
1403
1480
  self.stats.gen_throughput = self.last_gen_throughput
1404
1481
  self.stats.num_queue_reqs = len(self.waiting_queue)
1405
1482
  self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
1406
1483
  self.stats.spec_accept_length = spec_accept_length
1484
+ self.stats.total_retracted_reqs = self.total_retracted_reqs
1407
1485
  self.metrics_collector.log_stats(self.stats)
1408
1486
  self._emit_kv_metrics()
1409
1487
  self._publish_kv_events()
1410
1488
 
1411
1489
  def check_memory(self):
1412
- if isinstance(self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator):
1413
- available_token_size = self.token_to_kv_pool_allocator.full_available_size()
1490
+ if self.is_hybrid:
1491
+ (
1492
+ full_num_used,
1493
+ swa_num_used,
1494
+ _,
1495
+ _,
1496
+ full_available_size,
1497
+ full_evictable_size,
1498
+ swa_available_size,
1499
+ swa_evictable_size,
1500
+ ) = self._get_swa_token_info()
1501
+ memory_leak = full_num_used != 0 or swa_num_used != 0
1502
+ token_msg = (
1503
+ f"{self.full_tokens_per_layer=}, {full_available_size=}, {full_evictable_size=}, {self.tree_cache.full_protected_size()=}\n"
1504
+ f"{self.swa_tokens_per_layer=}, {swa_available_size=}, {swa_evictable_size=}, {self.tree_cache.swa_protected_size()=}\n"
1505
+ )
1414
1506
  else:
1415
- available_token_size = self.token_to_kv_pool_allocator.available_size()
1416
- available_size = available_token_size + self.tree_cache.evictable_size()
1417
- protected_size = self.tree_cache.protected_size()
1418
- memory_leak = available_size != (
1419
- self.max_total_num_tokens
1420
- if not self.enable_hierarchical_cache
1421
- else self.max_total_num_tokens - protected_size
1422
- )
1423
- if memory_leak:
1424
- msg = (
1425
- "token_to_kv_pool_allocator memory leak detected! "
1426
- f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
1427
- f"{available_token_size=}\n"
1428
- f"{self.tree_cache.evictable_size()=}\n"
1507
+ _, _, available_size, evictable_size = self._get_token_info()
1508
+ protected_size = self.tree_cache.protected_size()
1509
+ memory_leak = (available_size + evictable_size) != (
1510
+ self.max_total_num_tokens
1511
+ if not self.enable_hierarchical_cache
1512
+ else self.max_total_num_tokens - protected_size
1429
1513
  )
1514
+ token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n"
1515
+
1516
+ if memory_leak:
1517
+ msg = "token_to_kv_pool_allocator memory leak detected! " f"{token_msg}"
1430
1518
  raise ValueError(msg)
1431
1519
 
1432
1520
  if self.disaggregation_mode == DisaggregationMode.DECODE:
@@ -1446,24 +1534,70 @@ class Scheduler(
1446
1534
 
1447
1535
  if (
1448
1536
  self.enable_metrics
1449
- and self.attn_tp_rank == 0
1537
+ and self.current_scheduler_metrics_enabled()
1450
1538
  and time.perf_counter() > self.metrics_collector.last_log_time + 30
1451
1539
  ):
1452
1540
  # During idle time, also collect metrics every 30 seconds.
1453
- num_used = self.max_total_num_tokens - (
1454
- self.token_to_kv_pool_allocator.available_size()
1455
- + self.tree_cache.evictable_size()
1456
- )
1541
+ if self.is_hybrid:
1542
+ (
1543
+ full_num_used,
1544
+ swa_num_used,
1545
+ full_token_usage,
1546
+ swa_token_usage,
1547
+ _,
1548
+ _,
1549
+ _,
1550
+ _,
1551
+ ) = self._get_swa_token_info()
1552
+ num_used = max(full_num_used, swa_num_used)
1553
+ token_usage = max(full_token_usage, swa_token_usage)
1554
+ else:
1555
+ num_used, token_usage, _, _ = self._get_token_info()
1457
1556
  num_running_reqs = len(self.running_batch.reqs)
1458
1557
  self.stats.num_running_reqs = num_running_reqs
1459
1558
  self.stats.num_used_tokens = num_used
1460
- self.stats.token_usage = num_used / self.max_total_num_tokens
1559
+ self.stats.token_usage = round(token_usage, 2)
1461
1560
  self.stats.gen_throughput = 0
1462
1561
  self.stats.num_queue_reqs = len(self.waiting_queue)
1463
1562
  self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
1464
1563
  self.metrics_collector.log_stats(self.stats)
1465
1564
  self._publish_kv_events()
1466
1565
 
1566
+ def check_tree_cache(self):
1567
+ if self.is_hybrid and isinstance(self.tree_cache, SWARadixCache):
1568
+ self.tree_cache.sanity_check()
1569
+
1570
+ def _get_token_info(self):
1571
+ available_size = self.token_to_kv_pool_allocator.available_size()
1572
+ evictable_size = self.tree_cache.evictable_size()
1573
+ num_used = self.max_total_num_tokens - (available_size + evictable_size)
1574
+ token_usage = num_used / self.max_total_num_tokens
1575
+ return num_used, token_usage, available_size, evictable_size
1576
+
1577
+ def _get_swa_token_info(self):
1578
+ full_available_size = self.token_to_kv_pool_allocator.full_available_size()
1579
+ full_evictable_size = self.tree_cache.full_evictable_size()
1580
+ swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()
1581
+ swa_evictable_size = self.tree_cache.swa_evictable_size()
1582
+ full_num_used = self.full_tokens_per_layer - (
1583
+ full_available_size + full_evictable_size
1584
+ )
1585
+ swa_num_used = self.swa_tokens_per_layer - (
1586
+ swa_available_size + swa_evictable_size
1587
+ )
1588
+ full_token_usage = full_num_used / self.full_tokens_per_layer
1589
+ swa_token_usage = swa_num_used / self.swa_tokens_per_layer
1590
+ return (
1591
+ full_num_used,
1592
+ swa_num_used,
1593
+ full_token_usage,
1594
+ swa_token_usage,
1595
+ full_available_size,
1596
+ full_evictable_size,
1597
+ swa_available_size,
1598
+ swa_evictable_size,
1599
+ )
1600
+
1467
1601
  def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
1468
1602
  # Merge the prefill batch into the running batch
1469
1603
  chunked_req_to_exclude = set()
@@ -1600,6 +1734,9 @@ class Scheduler(
1600
1734
  self.running_batch.batch_is_full = True
1601
1735
  break
1602
1736
 
1737
+ if self.enable_hicache_storage:
1738
+ self.tree_cache.check_prefetch_progress(req.rid)
1739
+
1603
1740
  req.init_next_round_input(self.tree_cache)
1604
1741
  res = adder.add_one_req(req, has_chunked_req=(self.chunked_req is not None))
1605
1742
 
@@ -1636,7 +1773,7 @@ class Scheduler(
1636
1773
  self.chunked_req.is_chunked += 1
1637
1774
 
1638
1775
  # Print stats
1639
- if self.attn_tp_rank == 0:
1776
+ if self.current_scheduler_metrics_enabled():
1640
1777
  self.log_prefill_stats(adder, can_run_list, running_bs)
1641
1778
 
1642
1779
  # Create a new batch
@@ -1695,14 +1832,17 @@ class Scheduler(
1695
1832
  old_ratio = self.new_token_ratio
1696
1833
 
1697
1834
  retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
1835
+ num_retracted_reqs = len(retracted_reqs)
1698
1836
  self.new_token_ratio = new_token_ratio
1699
1837
 
1700
1838
  logger.info(
1701
1839
  "KV cache pool is full. Retract requests. "
1702
- f"#retracted_reqs: {len(retracted_reqs)}, "
1840
+ f"#retracted_reqs: {num_retracted_reqs}, "
1703
1841
  f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
1704
1842
  )
1843
+
1705
1844
  self._extend_requests_to_queue(retracted_reqs, is_retracted=True)
1845
+ self.total_retracted_reqs += num_retracted_reqs
1706
1846
  else:
1707
1847
  self.new_token_ratio = max(
1708
1848
  self.new_token_ratio - self.new_token_ratio_decay,
@@ -1826,7 +1966,7 @@ class Scheduler(
1826
1966
  local_batch,
1827
1967
  dp_size=self.server_args.dp_size,
1828
1968
  attn_tp_size=self.attn_tp_size,
1829
- tp_cpu_group=self.tp_cpu_group,
1969
+ tp_group=self.tp_group,
1830
1970
  get_idle_batch=self.get_idle_batch,
1831
1971
  disable_cuda_graph=self.server_args.disable_cuda_graph,
1832
1972
  spec_algorithm=self.spec_algorithm,
@@ -1835,6 +1975,7 @@ class Scheduler(
1835
1975
  enable_deepep_moe=self.server_args.enable_deepep_moe,
1836
1976
  deepep_mode=DeepEPMode[self.server_args.deepep_mode],
1837
1977
  require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
1978
+ disable_overlap_schedule=self.server_args.disable_overlap_schedule,
1838
1979
  )
1839
1980
 
1840
1981
  @staticmethod
@@ -1842,7 +1983,7 @@ class Scheduler(
1842
1983
  local_batch: ScheduleBatch,
1843
1984
  dp_size,
1844
1985
  attn_tp_size: int,
1845
- tp_cpu_group,
1986
+ tp_group,
1846
1987
  get_idle_batch,
1847
1988
  disable_cuda_graph: bool,
1848
1989
  spec_algorithm,
@@ -1851,6 +1992,7 @@ class Scheduler(
1851
1992
  enable_deepep_moe: bool,
1852
1993
  deepep_mode: DeepEPMode,
1853
1994
  require_mlp_tp_gather: bool,
1995
+ disable_overlap_schedule: bool,
1854
1996
  ):
1855
1997
  # Check if other DP workers have running batches
1856
1998
  if local_batch is None:
@@ -1881,6 +2023,12 @@ class Scheduler(
1881
2023
  )
1882
2024
 
1883
2025
  tbo_preparer = TboDPAttentionPreparer()
2026
+ if disable_overlap_schedule:
2027
+ group = tp_group.device_group
2028
+ device = tp_group.device
2029
+ else:
2030
+ group = tp_group.cpu_group
2031
+ device = "cpu"
1884
2032
 
1885
2033
  local_info = torch.tensor(
1886
2034
  [
@@ -1896,15 +2044,17 @@ class Scheduler(
1896
2044
  ),
1897
2045
  ],
1898
2046
  dtype=torch.int64,
2047
+ device=device,
1899
2048
  )
1900
2049
  global_info = torch.empty(
1901
2050
  (dp_size, attn_tp_size, 6),
1902
2051
  dtype=torch.int64,
2052
+ device=device,
1903
2053
  )
1904
2054
  torch.distributed.all_gather_into_tensor(
1905
2055
  global_info.flatten(),
1906
2056
  local_info,
1907
- group=tp_cpu_group,
2057
+ group=group,
1908
2058
  )
1909
2059
  global_num_tokens = global_info[:, 0, 0].tolist()
1910
2060
  can_cuda_graph = min(global_info[:, 0, 1].tolist())
@@ -2042,11 +2192,30 @@ class Scheduler(
2042
2192
 
2043
2193
  if not disable_request_logging():
2044
2194
  # Print batch size and memory pool info to check whether there are de-sync issues.
2195
+ if self.is_hybrid:
2196
+ (
2197
+ _,
2198
+ _,
2199
+ _,
2200
+ _,
2201
+ full_available_size,
2202
+ full_evictable_size,
2203
+ swa_available_size,
2204
+ swa_evictable_size,
2205
+ ) = self._get_swa_token_info()
2206
+ info_msg = (
2207
+ f"{full_available_size=}, "
2208
+ f"{full_evictable_size=}, "
2209
+ f"{swa_available_size=}, "
2210
+ f"{swa_evictable_size=}, "
2211
+ )
2212
+ else:
2213
+ _, _, available_size, evictable_size = self._get_token_info()
2214
+ info_msg = f"{available_size=}, " f"{evictable_size=}, "
2045
2215
  logger.error(
2046
2216
  f"{self.cur_batch.batch_size()=}, "
2047
2217
  f"{self.cur_batch.reqs=}, "
2048
- f"{self.token_to_kv_pool_allocator.available_size()=}, "
2049
- f"{self.tree_cache.evictable_size()=}, "
2218
+ f"{info_msg}"
2050
2219
  )
2051
2220
 
2052
2221
  pyspy_dump_schedulers()
@@ -2101,11 +2270,24 @@ class Scheduler(
2101
2270
 
2102
2271
  def get_load(self):
2103
2272
  # TODO(lsyin): use dynamically maintained num_waiting_tokens
2104
- load = (
2105
- self.max_total_num_tokens
2106
- - self.token_to_kv_pool_allocator.available_size()
2107
- - self.tree_cache.evictable_size()
2108
- )
2273
+ if self.is_hybrid:
2274
+ load_full = (
2275
+ self.full_tokens_per_layer
2276
+ - self.token_to_kv_pool_allocator.full_available_size()
2277
+ - self.tree_cache.full_evictable_size()
2278
+ )
2279
+ load_swa = (
2280
+ self.swa_tokens_per_layer
2281
+ - self.token_to_kv_pool_allocator.swa_available_size()
2282
+ - self.tree_cache.swa_evictable_size()
2283
+ )
2284
+ load = max(load_full, load_swa)
2285
+ else:
2286
+ load = (
2287
+ self.max_total_num_tokens
2288
+ - self.token_to_kv_pool_allocator.available_size()
2289
+ - self.tree_cache.evictable_size()
2290
+ )
2109
2291
  load += sum(len(req.origin_input_ids) for req in self.waiting_queue)
2110
2292
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
2111
2293
  load += sum(
@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
7
7
 
8
8
  from sglang.srt.disaggregation.utils import DisaggregationMode
9
9
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
10
- from sglang.srt.managers.io_struct import BatchEmbeddingOut, BatchTokenIDOut
10
+ from sglang.srt.managers.io_struct import AbortReq, BatchEmbeddingOut, BatchTokenIDOut
11
11
  from sglang.srt.managers.schedule_batch import BaseFinishReason, Req, ScheduleBatch
12
12
 
13
13
  if TYPE_CHECKING:
@@ -126,7 +126,16 @@ class SchedulerOutputProcessorMixin:
126
126
  )
127
127
 
128
128
  if req.grammar is not None:
129
- req.grammar.accept_token(next_token_id)
129
+ # FIXME: this try-except block is for handling unexpected xgrammar issue.
130
+ try:
131
+ req.grammar.accept_token(next_token_id)
132
+ except ValueError as e:
133
+ # Grammar accept_token can raise ValueError if the token is not in the grammar.
134
+ # This can happen if the grammar is not set correctly or the token is invalid.
135
+ logger.error(
136
+ f"Grammar accept_token failed for req {req.rid} with token {next_token_id}: {e}"
137
+ )
138
+ self.abort_request(AbortReq(req.rid))
130
139
  req.grammar.finished = req.finished()
131
140
  else:
132
141
  # being chunked reqs' prefill is not finished
@@ -263,7 +272,16 @@ class SchedulerOutputProcessorMixin:
263
272
  )
264
273
 
265
274
  if req.grammar is not None and batch.spec_algorithm.is_none():
266
- req.grammar.accept_token(next_token_id)
275
+ # FIXME: this try-except block is for handling unexpected xgrammar issue.
276
+ try:
277
+ req.grammar.accept_token(next_token_id)
278
+ except ValueError as e:
279
+ # Grammar accept_token can raise ValueError if the token is not in the grammar.
280
+ # This can happen if the grammar is not set correctly or the token is invalid.
281
+ logger.error(
282
+ f"Grammar accept_token failed for req {req.rid} with token {next_token_id}: {e}"
283
+ )
284
+ self.abort_request(AbortReq(req.rid))
267
285
  req.grammar.finished = req.finished()
268
286
 
269
287
  self.set_next_batch_sampling_info_done(batch)
@@ -272,7 +290,7 @@ class SchedulerOutputProcessorMixin:
272
290
 
273
291
  self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
274
292
  if (
275
- self.attn_tp_rank == 0
293
+ self.current_scheduler_metrics_enabled()
276
294
  and self.forward_ct_decode % self.server_args.decode_log_interval == 0
277
295
  ):
278
296
  self.log_decode_stats(can_run_cuda_graph, running_batch=batch)