sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +7 -0
  3. sglang/bench_one_batch.py +8 -6
  4. sglang/bench_serving.py +1 -1
  5. sglang/lang/interpreter.py +40 -1
  6. sglang/lang/ir.py +27 -0
  7. sglang/math_utils.py +8 -0
  8. sglang/srt/_custom_ops.py +2 -2
  9. sglang/srt/code_completion_parser.py +2 -44
  10. sglang/srt/configs/model_config.py +6 -0
  11. sglang/srt/constants.py +3 -0
  12. sglang/srt/conversation.py +19 -3
  13. sglang/srt/custom_op.py +5 -1
  14. sglang/srt/disaggregation/base/__init__.py +1 -1
  15. sglang/srt/disaggregation/base/conn.py +25 -11
  16. sglang/srt/disaggregation/common/__init__.py +5 -1
  17. sglang/srt/disaggregation/common/utils.py +42 -0
  18. sglang/srt/disaggregation/decode.py +211 -72
  19. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  20. sglang/srt/disaggregation/fake/__init__.py +1 -1
  21. sglang/srt/disaggregation/fake/conn.py +15 -9
  22. sglang/srt/disaggregation/mini_lb.py +34 -4
  23. sglang/srt/disaggregation/mooncake/__init__.py +1 -1
  24. sglang/srt/disaggregation/mooncake/conn.py +30 -29
  25. sglang/srt/disaggregation/nixl/__init__.py +6 -1
  26. sglang/srt/disaggregation/nixl/conn.py +17 -12
  27. sglang/srt/disaggregation/prefill.py +144 -55
  28. sglang/srt/disaggregation/utils.py +155 -123
  29. sglang/srt/distributed/parallel_state.py +12 -4
  30. sglang/srt/entrypoints/engine.py +37 -29
  31. sglang/srt/entrypoints/http_server.py +153 -72
  32. sglang/srt/entrypoints/http_server_engine.py +0 -3
  33. sglang/srt/entrypoints/openai/__init__.py +0 -0
  34. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
  35. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  36. sglang/srt/entrypoints/openai/serving_chat.py +921 -0
  37. sglang/srt/entrypoints/openai/serving_completions.py +424 -0
  38. sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
  39. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  40. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  41. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  42. sglang/srt/entrypoints/openai/utils.py +72 -0
  43. sglang/srt/eplb_simulator/__init__.py +1 -0
  44. sglang/srt/eplb_simulator/reader.py +51 -0
  45. sglang/srt/function_call/base_format_detector.py +7 -4
  46. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  47. sglang/srt/function_call/ebnf_composer.py +64 -10
  48. sglang/srt/function_call/function_call_parser.py +6 -6
  49. sglang/srt/function_call/llama32_detector.py +1 -1
  50. sglang/srt/function_call/mistral_detector.py +1 -1
  51. sglang/srt/function_call/pythonic_detector.py +1 -1
  52. sglang/srt/function_call/qwen25_detector.py +1 -1
  53. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  54. sglang/srt/layers/activation.py +40 -3
  55. sglang/srt/layers/attention/aiter_backend.py +20 -4
  56. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  57. sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
  58. sglang/srt/layers/attention/flashattention_backend.py +71 -72
  59. sglang/srt/layers/attention/flashinfer_backend.py +10 -8
  60. sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
  61. sglang/srt/layers/attention/flashmla_backend.py +7 -12
  62. sglang/srt/layers/attention/tbo_backend.py +3 -3
  63. sglang/srt/layers/attention/triton_backend.py +138 -130
  64. sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
  65. sglang/srt/layers/attention/vision.py +51 -24
  66. sglang/srt/layers/communicator.py +28 -10
  67. sglang/srt/layers/dp_attention.py +11 -2
  68. sglang/srt/layers/layernorm.py +29 -2
  69. sglang/srt/layers/linear.py +0 -4
  70. sglang/srt/layers/logits_processor.py +2 -14
  71. sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
  72. sglang/srt/layers/moe/ep_moe/layer.py +249 -33
  73. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
  74. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
  76. sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
  77. sglang/srt/layers/moe/topk.py +107 -12
  78. sglang/srt/layers/pooler.py +56 -0
  79. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  80. sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
  81. sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
  82. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
  83. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
  84. sglang/srt/layers/quantization/fp8.py +25 -17
  85. sglang/srt/layers/quantization/fp8_kernel.py +44 -15
  86. sglang/srt/layers/quantization/fp8_utils.py +87 -22
  87. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  88. sglang/srt/layers/quantization/utils.py +5 -2
  89. sglang/srt/layers/radix_attention.py +2 -3
  90. sglang/srt/layers/rotary_embedding.py +42 -2
  91. sglang/srt/layers/sampler.py +1 -1
  92. sglang/srt/lora/lora_manager.py +249 -105
  93. sglang/srt/lora/mem_pool.py +53 -50
  94. sglang/srt/lora/utils.py +1 -1
  95. sglang/srt/managers/cache_controller.py +33 -14
  96. sglang/srt/managers/io_struct.py +31 -10
  97. sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
  98. sglang/srt/managers/multimodal_processors/vila.py +85 -0
  99. sglang/srt/managers/schedule_batch.py +79 -37
  100. sglang/srt/managers/schedule_policy.py +70 -56
  101. sglang/srt/managers/scheduler.py +220 -79
  102. sglang/srt/managers/template_manager.py +226 -0
  103. sglang/srt/managers/tokenizer_manager.py +40 -10
  104. sglang/srt/managers/tp_worker.py +12 -2
  105. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  106. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  107. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  108. sglang/srt/mem_cache/chunk_cache.py +11 -15
  109. sglang/srt/mem_cache/hiradix_cache.py +38 -25
  110. sglang/srt/mem_cache/memory_pool.py +213 -505
  111. sglang/srt/mem_cache/memory_pool_host.py +380 -0
  112. sglang/srt/mem_cache/radix_cache.py +56 -28
  113. sglang/srt/model_executor/cuda_graph_runner.py +198 -100
  114. sglang/srt/model_executor/forward_batch_info.py +32 -10
  115. sglang/srt/model_executor/model_runner.py +28 -12
  116. sglang/srt/model_loader/loader.py +16 -2
  117. sglang/srt/model_loader/weight_utils.py +11 -2
  118. sglang/srt/models/bert.py +113 -13
  119. sglang/srt/models/deepseek_nextn.py +29 -27
  120. sglang/srt/models/deepseek_v2.py +213 -173
  121. sglang/srt/models/glm4.py +312 -0
  122. sglang/srt/models/internvl.py +46 -102
  123. sglang/srt/models/mimo_mtp.py +2 -18
  124. sglang/srt/models/roberta.py +117 -9
  125. sglang/srt/models/vila.py +305 -0
  126. sglang/srt/reasoning_parser.py +21 -11
  127. sglang/srt/sampling/sampling_batch_info.py +24 -0
  128. sglang/srt/sampling/sampling_params.py +2 -0
  129. sglang/srt/server_args.py +351 -238
  130. sglang/srt/speculative/build_eagle_tree.py +1 -1
  131. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
  132. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
  133. sglang/srt/speculative/eagle_utils.py +468 -116
  134. sglang/srt/speculative/eagle_worker.py +258 -84
  135. sglang/srt/torch_memory_saver_adapter.py +19 -15
  136. sglang/srt/two_batch_overlap.py +4 -2
  137. sglang/srt/utils.py +235 -11
  138. sglang/test/attention/test_prefix_chunk_info.py +2 -0
  139. sglang/test/runners.py +38 -3
  140. sglang/test/test_block_fp8.py +1 -0
  141. sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
  142. sglang/test/test_block_fp8_ep.py +2 -0
  143. sglang/test/test_utils.py +4 -1
  144. sglang/utils.py +9 -0
  145. sglang/version.py +1 -1
  146. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
  147. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
  148. sglang/srt/entrypoints/verl_engine.py +0 -179
  149. sglang/srt/openai_api/adapter.py +0 -1990
  150. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
  151. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
  152. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -23,7 +23,6 @@ import time
23
23
  from collections import defaultdict, deque
24
24
  from concurrent import futures
25
25
  from dataclasses import dataclass
26
- from http import HTTPStatus
27
26
  from pathlib import Path
28
27
  from types import SimpleNamespace
29
28
  from typing import Dict, List, Optional, Tuple, Union
@@ -36,6 +35,7 @@ from torch.distributed import barrier
36
35
 
37
36
  from sglang.global_config import global_config
38
37
  from sglang.srt.configs.model_config import ModelConfig
38
+ from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
39
39
  from sglang.srt.constrained.base_grammar_backend import (
40
40
  INVALID_GRAMMAR_OBJ,
41
41
  create_grammar_backend,
@@ -140,6 +140,7 @@ from sglang.srt.utils import (
140
140
  DeepEPMode,
141
141
  DynamicGradMode,
142
142
  broadcast_pyobj,
143
+ configure_gc_logger,
143
144
  configure_logger,
144
145
  disable_request_logging,
145
146
  get_available_gpu_memory,
@@ -148,6 +149,8 @@ from sglang.srt.utils import (
148
149
  kill_itself_when_parent_died,
149
150
  point_to_point_pyobj,
150
151
  pyspy_dump_schedulers,
152
+ require_mlp_sync,
153
+ require_mlp_tp_gather,
151
154
  set_gpu_proc_affinity,
152
155
  set_random_seed,
153
156
  suppress_other_loggers,
@@ -179,6 +182,27 @@ class EmbeddingBatchResult:
179
182
  bid: int
180
183
 
181
184
 
185
+ class IdleSleeper:
186
+ """
187
+ In setups which have long inactivity periods it is desirable to reduce
188
+ system power consumption when sglang does nothing. This would lead not only
189
+ to power savings, but also to more CPU thermal headroom when a request
190
+ eventually comes. This is important in cases when multiple GPUs are connected
191
+ as each GPU would otherwise pin one thread at 100% CPU usage.
192
+
193
+ The simplest solution is to use zmq.Poller on all sockets that may receive
194
+ data that needs handling immediately.
195
+ """
196
+
197
+ def __init__(self, sockets):
198
+ self.poller = zmq.Poller()
199
+ for s in sockets:
200
+ self.poller.register(s, zmq.POLLIN)
201
+
202
+ def maybe_sleep(self):
203
+ self.poller.poll(1000)
204
+
205
+
182
206
  class Scheduler(
183
207
  SchedulerOutputProcessorMixin,
184
208
  SchedulerDisaggregationDecodeMixin,
@@ -228,6 +252,8 @@ class Scheduler(
228
252
 
229
253
  # Init inter-process communication
230
254
  context = zmq.Context(2)
255
+ self.idle_sleeper = None
256
+
231
257
  if self.pp_rank == 0 and self.attn_tp_rank == 0:
232
258
  self.recv_from_tokenizer = get_zmq_socket(
233
259
  context, zmq.PULL, port_args.scheduler_input_ipc_name, False
@@ -250,6 +276,13 @@ class Scheduler(
250
276
  self.recv_from_rpc = get_zmq_socket(
251
277
  context, zmq.DEALER, port_args.rpc_ipc_name, False
252
278
  )
279
+ if self.server_args.sleep_on_idle:
280
+ self.idle_sleeper = IdleSleeper(
281
+ [
282
+ self.recv_from_tokenizer,
283
+ self.recv_from_rpc,
284
+ ]
285
+ )
253
286
  else:
254
287
  self.recv_from_tokenizer = None
255
288
  self.recv_from_rpc = None
@@ -361,7 +394,7 @@ class Scheduler(
361
394
  self.forward_ct = 0
362
395
  self.forward_ct_decode = 0
363
396
  self.num_generated_tokens = 0
364
- self.num_prefill_tokens = 0
397
+ self.last_prefill_tokens = 0
365
398
  self.last_decode_stats_tic = time.perf_counter()
366
399
  self.last_prefill_stats_tic = time.perf_counter()
367
400
  self.return_health_check_ct = 0
@@ -420,8 +453,6 @@ class Scheduler(
420
453
  t = threading.Thread(target=self.watchdog_thread, daemon=True)
421
454
  t.start()
422
455
  self.parent_process = psutil.Process().parent()
423
-
424
- # Init memory saver
425
456
  self.memory_saver_adapter = TorchMemorySaverAdapter.create(
426
457
  enable=server_args.enable_memory_saver
427
458
  )
@@ -478,6 +509,13 @@ class Scheduler(
478
509
  )
479
510
  self.init_disaggregation()
480
511
 
512
+ if get_bool_env_var("SGLANG_GC_LOG"):
513
+ configure_gc_logger()
514
+
515
+ def maybe_sleep_on_idle(self):
516
+ if self.idle_sleeper is not None:
517
+ self.idle_sleeper.maybe_sleep()
518
+
481
519
  def init_tokenizer(self):
482
520
  server_args = self.server_args
483
521
 
@@ -525,12 +563,20 @@ class Scheduler(
525
563
  self.tree_cache = HiRadixCache(
526
564
  req_to_token_pool=self.req_to_token_pool,
527
565
  token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
528
- tp_cache_group=self.tp_cpu_group,
566
+ tp_cache_group=(
567
+ self.attn_tp_cpu_group
568
+ if self.server_args.enable_dp_attention
569
+ else self.tp_cpu_group
570
+ ),
529
571
  page_size=self.page_size,
530
572
  hicache_ratio=server_args.hicache_ratio,
531
573
  hicache_size=server_args.hicache_size,
532
574
  hicache_write_policy=server_args.hicache_write_policy,
533
575
  )
576
+ self.tp_worker.register_hicache_layer_transfer_counter(
577
+ self.tree_cache.cache_controller.layer_done_counter
578
+ )
579
+
534
580
  else:
535
581
  self.tree_cache = RadixCache(
536
582
  req_to_token_pool=self.req_to_token_pool,
@@ -585,15 +631,21 @@ class Scheduler(
585
631
  self.disaggregation_mode == DisaggregationMode.DECODE
586
632
  ): # *2 for the headroom.
587
633
  buffer_size = (self.req_to_token_pool.size) * 2
588
- req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
634
+ self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
589
635
  buffer_size
590
636
  )
591
- self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
637
+ self.disagg_metadata_buffers = MetadataBuffers(
638
+ buffer_size,
639
+ hidden_size=self.model_config.hf_text_config.hidden_size,
640
+ dtype=self.model_config.dtype,
641
+ custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
642
+ )
592
643
 
593
644
  # The decode requests polling kv cache
594
645
  self.disagg_decode_transfer_queue = DecodeTransferQueue(
595
646
  gloo_group=self.attn_tp_cpu_group,
596
- req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
647
+ req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
648
+ tp_rank=self.tp_rank,
597
649
  metadata_buffers=self.disagg_metadata_buffers,
598
650
  scheduler=self,
599
651
  tree_cache=self.tree_cache,
@@ -608,7 +660,7 @@ class Scheduler(
608
660
  if self.draft_worker is None
609
661
  else self.draft_worker.model_runner.token_to_kv_pool
610
662
  ),
611
- req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
663
+ req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
612
664
  metadata_buffers=self.disagg_metadata_buffers,
613
665
  scheduler=self,
614
666
  transfer_queue=self.disagg_decode_transfer_queue,
@@ -616,7 +668,12 @@ class Scheduler(
616
668
  gloo_group=self.attn_tp_cpu_group,
617
669
  tp_rank=self.tp_rank,
618
670
  tp_size=self.tp_size,
671
+ dp_size=self.server_args.dp_size,
672
+ gpu_id=self.gpu_id,
619
673
  bootstrap_port=self.server_args.disaggregation_bootstrap_port,
674
+ max_total_num_tokens=self.max_total_num_tokens,
675
+ prefill_pp_size=self.server_args.disaggregation_prefill_pp,
676
+ num_reserved_decode_tokens=self.server_args.num_reserved_decode_tokens,
620
677
  transfer_backend=self.transfer_backend,
621
678
  )
622
679
 
@@ -626,10 +683,15 @@ class Scheduler(
626
683
  elif self.disaggregation_mode == DisaggregationMode.PREFILL:
627
684
  # *2 for the headroom.
628
685
  buffer_size = self.max_running_requests * 2
629
- req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
686
+ self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
630
687
  buffer_size
631
688
  )
632
- self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
689
+ self.disagg_metadata_buffers = MetadataBuffers(
690
+ buffer_size,
691
+ hidden_size=self.model_config.hf_text_config.hidden_size,
692
+ dtype=self.model_config.dtype,
693
+ custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
694
+ )
633
695
 
634
696
  self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
635
697
  token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
@@ -638,14 +700,20 @@ class Scheduler(
638
700
  if self.draft_worker is None
639
701
  else self.draft_worker.model_runner.token_to_kv_pool
640
702
  ),
641
- req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
703
+ req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
642
704
  metadata_buffers=self.disagg_metadata_buffers,
643
705
  tp_rank=self.tp_rank,
644
706
  tp_size=self.tp_size,
707
+ gpu_id=self.gpu_id,
645
708
  bootstrap_port=self.server_args.disaggregation_bootstrap_port,
646
709
  gloo_group=self.attn_tp_cpu_group,
647
- transfer_backend=self.transfer_backend,
710
+ max_total_num_tokens=self.max_total_num_tokens,
711
+ decode_tp_size=self.server_args.disaggregation_decode_tp,
712
+ decode_dp_size=self.server_args.disaggregation_decode_dp,
648
713
  scheduler=self,
714
+ pp_rank=self.pp_rank,
715
+ pp_size=self.pp_size,
716
+ transfer_backend=self.transfer_backend,
649
717
  )
650
718
  # The prefill requests that are in the middle of kv sending
651
719
  self.disagg_prefill_inflight_queue: List[Req] = []
@@ -667,6 +735,7 @@ class Scheduler(
667
735
  # When the server is idle, do self-check and re-init some states
668
736
  self.check_memory()
669
737
  self.new_token_ratio = self.init_new_token_ratio
738
+ self.maybe_sleep_on_idle()
670
739
 
671
740
  self.last_batch = batch
672
741
 
@@ -711,6 +780,7 @@ class Scheduler(
711
780
  # When the server is idle, do self-check and re-init some states
712
781
  self.check_memory()
713
782
  self.new_token_ratio = self.init_new_token_ratio
783
+ self.maybe_sleep_on_idle()
714
784
 
715
785
  self.last_batch = batch
716
786
 
@@ -747,11 +817,28 @@ class Scheduler(
747
817
  result.next_token_ids,
748
818
  result.bid,
749
819
  )
750
- pp_outputs = PPProxyTensors(
751
- {
752
- "next_token_ids": next_token_ids,
753
- }
754
- )
820
+ if self.cur_batch.return_logprob:
821
+ pp_outputs = PPProxyTensors(
822
+ {
823
+ "next_token_ids": next_token_ids,
824
+ "extend_input_len_per_req": result.extend_input_len_per_req,
825
+ "extend_logprob_start_len_per_req": result.extend_logprob_start_len_per_req,
826
+ }
827
+ | (
828
+ {
829
+ f"logits_output.{k}": v
830
+ for k, v in result.logits_output.__dict__.items()
831
+ }
832
+ if result.logits_output is not None
833
+ else {}
834
+ )
835
+ )
836
+ else:
837
+ pp_outputs = PPProxyTensors(
838
+ {
839
+ "next_token_ids": next_token_ids,
840
+ }
841
+ )
755
842
  # send the output from the last round to let the next stage worker run post processing
756
843
  self.pp_group.send_tensor_dict(
757
844
  pp_outputs.tensors,
@@ -768,12 +855,25 @@ class Scheduler(
768
855
  )
769
856
  )
770
857
  mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
858
+ logits_output_args = {
859
+ k[len("logits_output.") :]: v
860
+ for k, v in next_pp_outputs.tensors.items()
861
+ if k.startswith("logits_output.")
862
+ }
863
+ if len(logits_output_args) > 0:
864
+ logits_output = LogitsProcessorOutput(**logits_output_args)
865
+ else:
866
+ logits_output = None
771
867
  output_result = GenerationBatchResult(
772
- logits_output=None,
868
+ logits_output=logits_output,
773
869
  pp_hidden_states_proxy_tensors=None,
774
870
  next_token_ids=next_pp_outputs["next_token_ids"],
775
- extend_input_len_per_req=None,
776
- extend_logprob_start_len_per_req=None,
871
+ extend_input_len_per_req=next_pp_outputs.tensors.get(
872
+ "extend_input_len_per_req", None
873
+ ),
874
+ extend_logprob_start_len_per_req=next_pp_outputs.tensors.get(
875
+ "extend_logprob_start_len_per_req", None
876
+ ),
777
877
  bid=bids[next_mb_id],
778
878
  can_run_cuda_graph=result.can_run_cuda_graph,
779
879
  )
@@ -816,6 +916,7 @@ class Scheduler(
816
916
  if server_is_idle:
817
917
  self.check_memory()
818
918
  self.new_token_ratio = self.init_new_token_ratio
919
+ self.maybe_sleep_on_idle()
819
920
 
820
921
  def recv_requests(self) -> List[Req]:
821
922
  """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
@@ -1073,18 +1174,22 @@ class Scheduler(
1073
1174
  def _add_request_to_queue(self, req: Req):
1074
1175
  req.queue_time_start = time.perf_counter()
1075
1176
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
1076
- self.disagg_prefill_bootstrap_queue.add(req)
1177
+ self.disagg_prefill_bootstrap_queue.add(
1178
+ req, self.model_config.num_key_value_heads
1179
+ )
1077
1180
  elif self.disaggregation_mode == DisaggregationMode.DECODE:
1078
1181
  self.disagg_decode_prealloc_queue.add(req)
1079
1182
  else:
1080
1183
  self.waiting_queue.append(req)
1081
1184
 
1082
- def _extend_requests_to_queue(self, reqs: List[Req]):
1185
+ def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
1083
1186
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
1084
- self.disagg_prefill_bootstrap_queue.extend(reqs)
1187
+ self.disagg_prefill_bootstrap_queue.extend(
1188
+ reqs, self.model_config.num_key_value_heads
1189
+ )
1085
1190
  elif self.disaggregation_mode == DisaggregationMode.DECODE:
1086
1191
  # If this is a decode server, we put the request to the decode pending prealloc queue
1087
- self.disagg_decode_prealloc_queue.extend(reqs)
1192
+ self.disagg_decode_prealloc_queue.extend(reqs, is_retracted)
1088
1193
  else:
1089
1194
  self.waiting_queue.extend(reqs)
1090
1195
 
@@ -1097,6 +1202,7 @@ class Scheduler(
1097
1202
  recv_req.input_text,
1098
1203
  recv_req.input_ids,
1099
1204
  recv_req.sampling_params,
1205
+ token_type_ids=recv_req.token_type_ids,
1100
1206
  )
1101
1207
  req.tokenizer = self.tokenizer
1102
1208
 
@@ -1141,8 +1247,8 @@ class Scheduler(
1141
1247
  ):
1142
1248
  gap_latency = time.perf_counter() - self.last_prefill_stats_tic
1143
1249
  self.last_prefill_stats_tic = time.perf_counter()
1144
- self.last_input_throughput = self.num_prefill_tokens / gap_latency
1145
- self.num_prefill_tokens = 0
1250
+ self.last_input_throughput = self.last_prefill_tokens / gap_latency
1251
+ self.last_prefill_tokens = adder.log_input_tokens
1146
1252
 
1147
1253
  num_used = self.max_total_num_tokens - (
1148
1254
  self.token_to_kv_pool_allocator.available_size()
@@ -1156,15 +1262,15 @@ class Scheduler(
1156
1262
  f"#new-token: {adder.log_input_tokens}, "
1157
1263
  f"#cached-token: {adder.log_hit_tokens}, "
1158
1264
  f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
1159
- f"#running-req: {running_bs}, "
1160
1265
  )
1161
1266
 
1162
1267
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
1163
1268
  f += f"#unbootstrapped-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
1164
1269
  f += f"#queue-req: {len(self.waiting_queue)}, "
1165
1270
  f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)}, "
1166
- f += f"time: {gap_latency:.2f} "
1271
+ f += f"input throughput (token/s): {self.last_input_throughput:.2f} "
1167
1272
  else:
1273
+ f += f"#running-req: {running_bs}, "
1168
1274
  f += f"#queue-req: {len(self.waiting_queue)}"
1169
1275
 
1170
1276
  logger.info(f)
@@ -1227,6 +1333,7 @@ class Scheduler(
1227
1333
 
1228
1334
  if self.disaggregation_mode == DisaggregationMode.DECODE:
1229
1335
  msg += f"pre-allocated usage: {self.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
1336
+ msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, "
1230
1337
 
1231
1338
  msg += (
1232
1339
  f"cuda graph: {can_run_cuda_graph}, "
@@ -1267,7 +1374,14 @@ class Scheduler(
1267
1374
  )
1268
1375
  raise ValueError(msg)
1269
1376
 
1270
- if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
1377
+ if self.disaggregation_mode == DisaggregationMode.DECODE:
1378
+ req_total_size = (
1379
+ self.req_to_token_pool.size + self.req_to_token_pool.pre_alloc_size
1380
+ )
1381
+ else:
1382
+ req_total_size = self.req_to_token_pool.size
1383
+
1384
+ if len(self.req_to_token_pool.free_slots) != req_total_size:
1271
1385
  msg = (
1272
1386
  "req_to_token_pool memory leak detected!"
1273
1387
  f"available_size={len(self.req_to_token_pool.free_slots)}, "
@@ -1328,6 +1442,15 @@ class Scheduler(
1328
1442
  self.running_batch.merge_batch(self.last_batch)
1329
1443
 
1330
1444
  new_batch = self.get_new_batch_prefill()
1445
+
1446
+ need_dp_attn_preparation = require_mlp_sync(self.server_args)
1447
+
1448
+ if need_dp_attn_preparation and not self.spec_algorithm.is_none():
1449
+ # In speculative decoding, prefill batches and decode batches cannot be processed in the same DP attention group.
1450
+ # We prepare idle batches in advance to skip preparing decode batches when there are prefill batches in the group.
1451
+ new_batch, _ = self.prepare_mlp_sync_batch(new_batch)
1452
+ need_dp_attn_preparation = new_batch is None
1453
+
1331
1454
  if new_batch is not None:
1332
1455
  # Run prefill first if possible
1333
1456
  ret = new_batch
@@ -1340,8 +1463,8 @@ class Scheduler(
1340
1463
  ret = None
1341
1464
 
1342
1465
  # Handle DP attention
1343
- if self.server_args.enable_dp_attention or self.server_args.enable_sp_layernorm:
1344
- ret, _ = self.prepare_dp_attn_batch(ret)
1466
+ if need_dp_attn_preparation:
1467
+ ret, _ = self.prepare_mlp_sync_batch(ret)
1345
1468
 
1346
1469
  return ret
1347
1470
 
@@ -1373,15 +1496,14 @@ class Scheduler(
1373
1496
  return None
1374
1497
 
1375
1498
  if self.enable_hierarchical_cache:
1376
- # check for completion of hierarchical cache activities to release memory
1377
- self.tree_cache.writing_check()
1378
- self.tree_cache.loading_check()
1499
+ self.tree_cache.check_hicache_events()
1379
1500
 
1380
1501
  # Get priority queue
1381
- prefix_computed = self.policy.calc_priority(self.waiting_queue)
1502
+ self.policy.calc_priority(self.waiting_queue)
1382
1503
 
1383
1504
  # Prefill policy
1384
1505
  adder = PrefillAdder(
1506
+ self.page_size,
1385
1507
  self.tree_cache,
1386
1508
  self.token_to_kv_pool_allocator,
1387
1509
  self.running_batch,
@@ -1423,14 +1545,8 @@ class Scheduler(
1423
1545
  self.running_batch.batch_is_full = True
1424
1546
  break
1425
1547
 
1426
- req.init_next_round_input(
1427
- None if prefix_computed else self.tree_cache,
1428
- self.enable_hierarchical_cache,
1429
- )
1430
-
1431
- res = adder.add_one_req(
1432
- req, self.chunked_req, self.enable_hierarchical_cache
1433
- )
1548
+ req.init_next_round_input(self.tree_cache)
1549
+ res = adder.add_one_req(req, has_chunked_req=(self.chunked_req is not None))
1434
1550
 
1435
1551
  if res != AddReqResult.CONTINUE:
1436
1552
  if res == AddReqResult.NO_TOKEN:
@@ -1457,9 +1573,6 @@ class Scheduler(
1457
1573
  x for x in self.waiting_queue if x not in set(can_run_list)
1458
1574
  ]
1459
1575
 
1460
- if self.enable_hierarchical_cache:
1461
- self.tree_cache.ready_to_load_cache()
1462
-
1463
1576
  if adder.new_chunked_req is not None:
1464
1577
  assert self.chunked_req is None
1465
1578
  self.chunked_req = adder.new_chunked_req
@@ -1483,6 +1596,12 @@ class Scheduler(
1483
1596
  self.server_args.enable_custom_logit_processor,
1484
1597
  chunked_req=self.chunked_req,
1485
1598
  )
1599
+ if self.enable_hierarchical_cache:
1600
+ # todo (zhiqiang): disable cuda graph execution if hicache loading triggered
1601
+ new_batch.hicache_consumer_index = (
1602
+ self.tree_cache.ready_to_load_host_cache()
1603
+ )
1604
+
1486
1605
  new_batch.prepare_for_extend()
1487
1606
 
1488
1607
  # Mixed-style chunked prefill
@@ -1528,7 +1647,7 @@ class Scheduler(
1528
1647
  f"#retracted_reqs: {len(retracted_reqs)}, "
1529
1648
  f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
1530
1649
  )
1531
- self._extend_requests_to_queue(retracted_reqs)
1650
+ self._extend_requests_to_queue(retracted_reqs, is_retracted=True)
1532
1651
  else:
1533
1652
  self.new_token_ratio = max(
1534
1653
  self.new_token_ratio - self.new_token_ratio_decay,
@@ -1558,6 +1677,11 @@ class Scheduler(
1558
1677
  if self.is_generation:
1559
1678
  if self.spec_algorithm.is_none():
1560
1679
  model_worker_batch = batch.get_model_worker_batch()
1680
+
1681
+ # update the consumer index of hicache to the running batch
1682
+ self.tp_worker.set_hicache_consumer(
1683
+ model_worker_batch.hicache_consumer_index
1684
+ )
1561
1685
  if self.pp_group.is_last_rank:
1562
1686
  logits_output, next_token_ids, can_run_cuda_graph = (
1563
1687
  self.tp_worker.forward_batch_generation(model_worker_batch)
@@ -1586,13 +1710,15 @@ class Scheduler(
1586
1710
  # These 2 values are needed for processing the output, but the values can be
1587
1711
  # modified by overlap schedule. So we have to copy them here so that
1588
1712
  # we can use the correct values in output processing.
1589
- if batch.return_logprob:
1713
+ if batch.return_logprob or self.spec_algorithm.is_eagle():
1590
1714
  extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
1715
+ else:
1716
+ extend_input_len_per_req = None
1717
+ if batch.return_logprob:
1591
1718
  extend_logprob_start_len_per_req = [
1592
1719
  req.extend_logprob_start_len for req in batch.reqs
1593
1720
  ]
1594
1721
  else:
1595
- extend_input_len_per_req = None
1596
1722
  extend_logprob_start_len_per_req = None
1597
1723
 
1598
1724
  ret = GenerationBatchResult(
@@ -1640,12 +1766,11 @@ class Scheduler(
1640
1766
  self.return_health_check_ct -= 1
1641
1767
  self.send_to_tokenizer.send_pyobj(HealthCheckOutput())
1642
1768
 
1643
- def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
1644
- return self.prepare_dp_attn_batch_raw(
1769
+ def prepare_mlp_sync_batch(self, local_batch: ScheduleBatch):
1770
+ return self.prepare_mlp_sync_batch_raw(
1645
1771
  local_batch,
1646
1772
  dp_size=self.server_args.dp_size,
1647
1773
  attn_tp_size=self.attn_tp_size,
1648
- moe_dense_tp_size=self.server_args.moe_dense_tp_size,
1649
1774
  tp_cpu_group=self.tp_cpu_group,
1650
1775
  get_idle_batch=self.get_idle_batch,
1651
1776
  disable_cuda_graph=self.server_args.disable_cuda_graph,
@@ -1654,14 +1779,14 @@ class Scheduler(
1654
1779
  enable_two_batch_overlap=self.server_args.enable_two_batch_overlap,
1655
1780
  enable_deepep_moe=self.server_args.enable_deepep_moe,
1656
1781
  deepep_mode=DeepEPMode[self.server_args.deepep_mode],
1782
+ require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
1657
1783
  )
1658
1784
 
1659
1785
  @staticmethod
1660
- def prepare_dp_attn_batch_raw(
1786
+ def prepare_mlp_sync_batch_raw(
1661
1787
  local_batch: ScheduleBatch,
1662
1788
  dp_size,
1663
1789
  attn_tp_size: int,
1664
- moe_dense_tp_size: Optional[int],
1665
1790
  tp_cpu_group,
1666
1791
  get_idle_batch,
1667
1792
  disable_cuda_graph: bool,
@@ -1670,6 +1795,7 @@ class Scheduler(
1670
1795
  enable_two_batch_overlap: bool,
1671
1796
  enable_deepep_moe: bool,
1672
1797
  deepep_mode: DeepEPMode,
1798
+ require_mlp_tp_gather: bool,
1673
1799
  ):
1674
1800
  # Check if other DP workers have running batches
1675
1801
  if local_batch is None:
@@ -1677,8 +1803,6 @@ class Scheduler(
1677
1803
  num_tokens_for_logprob = 0
1678
1804
  elif local_batch.forward_mode.is_decode():
1679
1805
  num_tokens = local_batch.batch_size()
1680
- if not spec_algorithm.is_none() and spec_algorithm.is_eagle():
1681
- num_tokens = num_tokens * speculative_num_draft_tokens
1682
1806
  num_tokens_for_logprob = num_tokens
1683
1807
  else:
1684
1808
  num_tokens = local_batch.extend_num_tokens
@@ -1697,11 +1821,6 @@ class Scheduler(
1697
1821
  else:
1698
1822
  can_cuda_graph = 0
1699
1823
 
1700
- if not spec_algorithm.is_none():
1701
- # TODO(sang): Support cuda graph when idle batch is there.
1702
- if local_batch is None or local_batch.forward_mode.is_idle():
1703
- can_cuda_graph = 0
1704
-
1705
1824
  is_extend_in_batch = (
1706
1825
  local_batch.forward_mode.is_extend() if local_batch else False
1707
1826
  )
@@ -1746,7 +1865,7 @@ class Scheduler(
1746
1865
 
1747
1866
  if local_batch is not None:
1748
1867
  # TODO: handle the case when moe_dense_tp_size != 1
1749
- if moe_dense_tp_size == 1 and global_server_args_dict["enable_dp_lm_head"]:
1868
+ if not require_mlp_tp_gather:
1750
1869
  local_batch.global_num_tokens = [num_tokens]
1751
1870
  local_batch.global_num_tokens_for_logprob = [num_tokens_for_logprob]
1752
1871
  else:
@@ -1754,6 +1873,7 @@ class Scheduler(
1754
1873
  local_batch.global_num_tokens_for_logprob = (
1755
1874
  global_num_tokens_for_logprob
1756
1875
  )
1876
+ local_batch.is_extend_in_batch = any(is_extend_in_batch)
1757
1877
  local_batch.tbo_split_seq_index = tbo_split_seq_index
1758
1878
  local_batch.global_forward_mode = global_forward_mode
1759
1879
 
@@ -1761,6 +1881,7 @@ class Scheduler(
1761
1881
  if not disable_cuda_graph:
1762
1882
  local_batch.can_run_dp_cuda_graph = can_cuda_graph
1763
1883
 
1884
+ # TODO(ch-wan): refactor: any(is_extend_in_batch) now is a part of local_batch. Remove it from here.
1764
1885
  return local_batch, any(is_extend_in_batch)
1765
1886
 
1766
1887
  def get_idle_batch(self):
@@ -2055,7 +2176,8 @@ class Scheduler(
2055
2176
  # In this case, we change the input_ids to be only one token to make this prefill cheap.
2056
2177
  if req.rid.startswith(recv_req.rid):
2057
2178
  logger.debug(f"Abort grammar queue request. {req.rid=}")
2058
- req.grammar.cancel()
2179
+ if req.grammar:
2180
+ req.grammar.cancel()
2059
2181
  req.set_finish_with_abort("Aborted by AbortReq.")
2060
2182
 
2061
2183
  # Delete requests in the running batch
@@ -2120,23 +2242,40 @@ class Scheduler(
2120
2242
  return GetWeightsByNameReqOutput(parameter)
2121
2243
 
2122
2244
  def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
2123
- self.memory_saver_adapter.check_validity(
2124
- caller_name="release_memory_occupation"
2125
- )
2126
- self.stashed_model_static_state = _export_static_state(
2127
- self.tp_worker.worker.model_runner.model
2128
- )
2129
- self.memory_saver_adapter.pause()
2130
- self.flush_cache()
2245
+ tags = recv_req.tags
2246
+ import subprocess
2247
+
2248
+ if tags is None:
2249
+ tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
2250
+
2251
+ if GPU_MEMORY_TYPE_KV_CACHE in tags:
2252
+ self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE)
2253
+ self.flush_cache()
2254
+
2255
+ if GPU_MEMORY_TYPE_WEIGHTS in tags:
2256
+ self.stashed_model_static_state = _export_static_state(
2257
+ self.tp_worker.worker.model_runner.model
2258
+ )
2259
+ self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_WEIGHTS)
2260
+
2131
2261
  return ReleaseMemoryOccupationReqOutput()
2132
2262
 
2133
2263
  def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
2134
- self.memory_saver_adapter.check_validity(caller_name="resume_memory_occupation")
2135
- self.memory_saver_adapter.resume()
2136
- _import_static_state(
2137
- self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
2138
- )
2139
- del self.stashed_model_static_state
2264
+ tags = recv_req.tags
2265
+ if tags is None or len(tags) == 0:
2266
+ tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
2267
+
2268
+ if GPU_MEMORY_TYPE_WEIGHTS in tags:
2269
+ self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
2270
+ _import_static_state(
2271
+ self.tp_worker.worker.model_runner.model,
2272
+ self.stashed_model_static_state,
2273
+ )
2274
+ del self.stashed_model_static_state
2275
+
2276
+ if GPU_MEMORY_TYPE_KV_CACHE in tags:
2277
+ self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_KV_CACHE)
2278
+
2140
2279
  return ResumeMemoryOccupationReqOutput()
2141
2280
 
2142
2281
  def slow_down(self, recv_req: SlowDownReqInput):
@@ -2365,8 +2504,10 @@ class Scheduler(
2365
2504
  if self.profiler_decode_ct > self.profiler_target_decode_ct:
2366
2505
  if self.profile_in_progress:
2367
2506
  self.stop_profile(stage=ForwardMode.DECODE)
2507
+ elif batch.forward_mode.is_idle():
2508
+ pass
2368
2509
  else:
2369
- raise RuntimeError("unsupported profile stage")
2510
+ raise RuntimeError(f"unsupported profile stage: {batch.forward_mode}")
2370
2511
  else:
2371
2512
  # Check profiler
2372
2513
  if (