sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. sglang/bench_one_batch_server.py +10 -1
  2. sglang/bench_serving.py +257 -29
  3. sglang/srt/configs/__init__.py +4 -0
  4. sglang/srt/configs/device_config.py +3 -1
  5. sglang/srt/configs/dots_vlm.py +139 -0
  6. sglang/srt/configs/load_config.py +1 -0
  7. sglang/srt/configs/model_config.py +50 -6
  8. sglang/srt/configs/qwen3_next.py +326 -0
  9. sglang/srt/connector/__init__.py +8 -1
  10. sglang/srt/connector/remote_instance.py +82 -0
  11. sglang/srt/constrained/base_grammar_backend.py +48 -12
  12. sglang/srt/constrained/llguidance_backend.py +0 -1
  13. sglang/srt/constrained/outlines_backend.py +0 -1
  14. sglang/srt/constrained/xgrammar_backend.py +28 -9
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/base/conn.py +1 -1
  21. sglang/srt/disaggregation/common/conn.py +15 -12
  22. sglang/srt/disaggregation/decode.py +21 -10
  23. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -445
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +5 -3
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +24 -3
  31. sglang/srt/entrypoints/engine.py +38 -17
  32. sglang/srt/entrypoints/grpc_request_manager.py +580 -0
  33. sglang/srt/entrypoints/grpc_server.py +680 -0
  34. sglang/srt/entrypoints/http_server.py +85 -54
  35. sglang/srt/entrypoints/openai/protocol.py +4 -1
  36. sglang/srt/entrypoints/openai/serving_base.py +46 -3
  37. sglang/srt/entrypoints/openai/serving_chat.py +36 -16
  38. sglang/srt/entrypoints/openai/serving_completions.py +12 -3
  39. sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
  40. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  41. sglang/srt/entrypoints/openai/serving_responses.py +6 -3
  42. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  43. sglang/srt/eplb/eplb_manager.py +2 -2
  44. sglang/srt/eplb/expert_distribution.py +26 -13
  45. sglang/srt/eplb/expert_location.py +8 -3
  46. sglang/srt/eplb/expert_location_updater.py +1 -1
  47. sglang/srt/function_call/base_format_detector.py +3 -6
  48. sglang/srt/function_call/ebnf_composer.py +11 -9
  49. sglang/srt/function_call/function_call_parser.py +6 -0
  50. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  51. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  52. sglang/srt/grpc/__init__.py +1 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
  56. sglang/srt/hf_transformers_utils.py +4 -0
  57. sglang/srt/layers/activation.py +142 -9
  58. sglang/srt/layers/attention/ascend_backend.py +11 -4
  59. sglang/srt/layers/attention/fla/chunk.py +242 -0
  60. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  61. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  62. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  63. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  66. sglang/srt/layers/attention/fla/index.py +37 -0
  67. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  68. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  69. sglang/srt/layers/attention/fla/op.py +66 -0
  70. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  71. sglang/srt/layers/attention/fla/utils.py +331 -0
  72. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  73. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  74. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  75. sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
  76. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  77. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  78. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  79. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  80. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  81. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  82. sglang/srt/layers/attention/triton_backend.py +18 -1
  83. sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
  84. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  85. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  86. sglang/srt/layers/dp_attention.py +30 -1
  87. sglang/srt/layers/layernorm.py +32 -15
  88. sglang/srt/layers/linear.py +34 -3
  89. sglang/srt/layers/logits_processor.py +29 -10
  90. sglang/srt/layers/moe/__init__.py +2 -1
  91. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  92. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  93. sglang/srt/layers/moe/ep_moe/layer.py +182 -62
  94. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
  95. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  96. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  100. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  101. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  102. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  103. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  104. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  105. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  106. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  107. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
  108. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  109. sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
  110. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  111. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  112. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  113. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  114. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  115. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  116. sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
  117. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  118. sglang/srt/layers/moe/topk.py +30 -9
  119. sglang/srt/layers/moe/utils.py +12 -6
  120. sglang/srt/layers/quantization/awq.py +19 -7
  121. sglang/srt/layers/quantization/base_config.py +11 -6
  122. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  123. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  124. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  125. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  126. sglang/srt/layers/quantization/fp8.py +76 -47
  127. sglang/srt/layers/quantization/fp8_utils.py +50 -31
  128. sglang/srt/layers/quantization/gptq.py +25 -17
  129. sglang/srt/layers/quantization/modelopt_quant.py +147 -47
  130. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  131. sglang/srt/layers/quantization/mxfp4.py +64 -40
  132. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  133. sglang/srt/layers/quantization/unquant.py +135 -47
  134. sglang/srt/layers/quantization/w4afp8.py +30 -17
  135. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  136. sglang/srt/layers/quantization/w8a8_int8.py +76 -38
  137. sglang/srt/layers/sampler.py +162 -18
  138. sglang/srt/lora/backend/base_backend.py +50 -8
  139. sglang/srt/lora/backend/triton_backend.py +90 -2
  140. sglang/srt/lora/layers.py +32 -0
  141. sglang/srt/lora/lora.py +4 -1
  142. sglang/srt/lora/lora_manager.py +35 -112
  143. sglang/srt/lora/mem_pool.py +24 -10
  144. sglang/srt/lora/utils.py +18 -9
  145. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  146. sglang/srt/managers/cache_controller.py +158 -160
  147. sglang/srt/managers/data_parallel_controller.py +105 -35
  148. sglang/srt/managers/detokenizer_manager.py +8 -4
  149. sglang/srt/managers/disagg_service.py +46 -0
  150. sglang/srt/managers/io_struct.py +199 -12
  151. sglang/srt/managers/mm_utils.py +1 -0
  152. sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
  153. sglang/srt/managers/schedule_batch.py +77 -56
  154. sglang/srt/managers/schedule_policy.py +1 -1
  155. sglang/srt/managers/scheduler.py +187 -39
  156. sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
  157. sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
  158. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  159. sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
  160. sglang/srt/managers/tokenizer_manager.py +259 -519
  161. sglang/srt/managers/tp_worker.py +53 -4
  162. sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
  163. sglang/srt/mem_cache/hicache_storage.py +3 -23
  164. sglang/srt/mem_cache/hiradix_cache.py +103 -43
  165. sglang/srt/mem_cache/memory_pool.py +347 -48
  166. sglang/srt/mem_cache/memory_pool_host.py +105 -46
  167. sglang/srt/mem_cache/radix_cache.py +0 -2
  168. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  169. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  170. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
  171. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  172. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  173. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
  174. sglang/srt/mem_cache/swa_radix_cache.py +0 -2
  175. sglang/srt/metrics/collector.py +493 -76
  176. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  177. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  178. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  179. sglang/srt/model_executor/forward_batch_info.py +59 -2
  180. sglang/srt/model_executor/model_runner.py +356 -29
  181. sglang/srt/model_loader/__init__.py +9 -3
  182. sglang/srt/model_loader/loader.py +128 -4
  183. sglang/srt/model_loader/weight_utils.py +2 -1
  184. sglang/srt/models/apertus.py +686 -0
  185. sglang/srt/models/bailing_moe.py +798 -218
  186. sglang/srt/models/bailing_moe_nextn.py +168 -0
  187. sglang/srt/models/deepseek_v2.py +109 -15
  188. sglang/srt/models/dots_vlm.py +174 -0
  189. sglang/srt/models/dots_vlm_vit.py +337 -0
  190. sglang/srt/models/ernie4.py +1 -1
  191. sglang/srt/models/gemma3n_mm.py +1 -1
  192. sglang/srt/models/glm4_moe.py +1 -1
  193. sglang/srt/models/glm4v.py +4 -2
  194. sglang/srt/models/glm4v_moe.py +3 -0
  195. sglang/srt/models/gpt_oss.py +1 -1
  196. sglang/srt/models/llama4.py +9 -0
  197. sglang/srt/models/llama_eagle3.py +13 -0
  198. sglang/srt/models/longcat_flash.py +2 -2
  199. sglang/srt/models/mllama4.py +25 -0
  200. sglang/srt/models/opt.py +637 -0
  201. sglang/srt/models/qwen2.py +7 -0
  202. sglang/srt/models/qwen2_5_vl.py +27 -3
  203. sglang/srt/models/qwen2_moe.py +56 -12
  204. sglang/srt/models/qwen3_moe.py +1 -1
  205. sglang/srt/models/qwen3_next.py +1042 -0
  206. sglang/srt/models/qwen3_next_mtp.py +112 -0
  207. sglang/srt/models/step3_vl.py +1 -1
  208. sglang/srt/multimodal/processors/dots_vlm.py +99 -0
  209. sglang/srt/multimodal/processors/glm4v.py +9 -9
  210. sglang/srt/multimodal/processors/internvl.py +141 -129
  211. sglang/srt/multimodal/processors/qwen_vl.py +15 -5
  212. sglang/srt/offloader.py +27 -3
  213. sglang/srt/remote_instance_weight_loader_utils.py +69 -0
  214. sglang/srt/sampling/sampling_batch_info.py +18 -15
  215. sglang/srt/server_args.py +276 -35
  216. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  217. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  218. sglang/srt/speculative/eagle_utils.py +0 -2
  219. sglang/srt/speculative/eagle_worker.py +43 -4
  220. sglang/srt/speculative/spec_info.py +5 -0
  221. sglang/srt/speculative/standalone_worker.py +109 -0
  222. sglang/srt/tracing/trace.py +552 -0
  223. sglang/srt/utils.py +34 -3
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  226. sglang/test/runners.py +4 -0
  227. sglang/test/test_cutlass_moe.py +24 -6
  228. sglang/test/test_disaggregation_utils.py +66 -0
  229. sglang/test/test_fp4_moe.py +370 -1
  230. sglang/test/test_utils.py +28 -1
  231. sglang/utils.py +11 -0
  232. sglang/version.py +1 -1
  233. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
  234. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
  235. sglang/srt/disaggregation/launch_lb.py +0 -118
  236. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
  237. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
  238. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -38,7 +38,7 @@ import threading
38
38
  from enum import Enum, auto
39
39
  from http import HTTPStatus
40
40
  from itertools import chain
41
- from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
41
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
42
42
 
43
43
  import numpy as np
44
44
  import torch
@@ -52,7 +52,6 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
52
52
  ScheduleBatchDisaggregationDecodeMixin,
53
53
  )
54
54
  from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
55
- from sglang.srt.layers.moe import is_tbo_enabled
56
55
  from sglang.srt.mem_cache.allocator import (
57
56
  BaseTokenToKVPoolAllocator,
58
57
  SWATokenToKVPoolAllocator,
@@ -60,7 +59,7 @@ from sglang.srt.mem_cache.allocator import (
60
59
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
61
60
  from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
62
61
  from sglang.srt.mem_cache.lora_radix_cache import LoRAKey, LoRARadixCache
63
- from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
62
+ from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
64
63
  from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
65
64
  from sglang.srt.metrics.collector import TimeStats
66
65
  from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
@@ -99,13 +98,13 @@ GLOBAL_SERVER_ARGS_KEYS = [
99
98
  "sampling_backend",
100
99
  "speculative_accept_threshold_single",
101
100
  "speculative_accept_threshold_acc",
101
+ "speculative_attention_mode",
102
102
  "torchao_config",
103
103
  "triton_attention_reduce_in_fp32",
104
104
  "num_reserved_decode_tokens",
105
105
  "weight_loader_disable_mmap",
106
106
  "enable_multimodal",
107
107
  "enable_symm_mem",
108
- "quantization",
109
108
  "enable_custom_logit_processor",
110
109
  "disaggregation_mode",
111
110
  ]
@@ -561,7 +560,10 @@ class Req:
561
560
  # shape: (bs, k)
562
561
  self.output_top_logprobs_val = []
563
562
  self.output_top_logprobs_idx = []
564
- self.output_token_ids_logprobs_val = []
563
+ # Can contain either lists or GPU tensors (delayed copy optimization for prefill-only scoring)
564
+ self.output_token_ids_logprobs_val: List[
565
+ Union[List[float], torch.Tensor]
566
+ ] = []
565
567
  self.output_token_ids_logprobs_idx = []
566
568
  else:
567
569
  self.output_token_logprobs_val = self.output_token_logprobs_idx = (
@@ -619,6 +621,11 @@ class Req:
619
621
  def seqlen(self):
620
622
  return len(self.origin_input_ids) + len(self.output_ids)
621
623
 
624
+ @property
625
+ def is_prefill_only(self) -> bool:
626
+ """Check if this request is prefill-only (no token generation needed)."""
627
+ return self.sampling_params.max_new_tokens == 0
628
+
622
629
  def extend_image_inputs(self, image_inputs):
623
630
  if self.multimodal_inputs is None:
624
631
  self.multimodal_inputs = image_inputs
@@ -684,9 +691,15 @@ class Req:
684
691
  self.surr_offset = max(
685
692
  self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
686
693
  )
694
+ self.surr_and_decode_ids = (
695
+ self.origin_input_ids_unpadded[self.surr_offset :] + self.output_ids
696
+ )
697
+ self.cur_decode_ids_len = len(self.output_ids)
698
+ else:
699
+ self.surr_and_decode_ids.extend(self.output_ids[self.cur_decode_ids_len :])
700
+ self.cur_decode_ids_len = len(self.output_ids)
687
701
 
688
- all_ids = self.origin_input_ids_unpadded + self.output_ids
689
- return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
702
+ return self.surr_and_decode_ids, self.read_offset - self.surr_offset
690
703
 
691
704
  def check_finished(self):
692
705
  if self.finished():
@@ -911,7 +924,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
911
924
  is_prefill_only: bool = False
912
925
 
913
926
  # hicache pointer for synchronizing data loading from CPU to GPU
914
- hicache_consumer_index: int = 0
927
+ hicache_consumer_index: int = -1
915
928
 
916
929
  @classmethod
917
930
  def init_new(
@@ -950,9 +963,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
950
963
  device=req_to_token_pool.device,
951
964
  spec_algorithm=spec_algorithm,
952
965
  return_hidden_states=any(req.return_hidden_states for req in reqs),
953
- is_prefill_only=all(
954
- req.sampling_params.max_new_tokens == 0 for req in reqs
955
- ),
966
+ is_prefill_only=all(req.is_prefill_only for req in reqs),
956
967
  chunked_req=chunked_req,
957
968
  )
958
969
 
@@ -962,8 +973,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
962
973
  def is_empty(self):
963
974
  return len(self.reqs) == 0
964
975
 
965
- def alloc_req_slots(self, num_reqs: int):
966
- req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
976
+ def alloc_req_slots(self, num_reqs: int, reqs: Optional[List[Req]] = None):
977
+ if isinstance(self.req_to_token_pool, HybridReqToTokenPool):
978
+ req_pool_indices = self.req_to_token_pool.alloc(num_reqs, reqs)
979
+ else:
980
+ req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
967
981
  if req_pool_indices is None:
968
982
  raise RuntimeError(
969
983
  "alloc_req_slots runs out of memory. "
@@ -1138,7 +1152,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1138
1152
 
1139
1153
  # Allocate req slots
1140
1154
  bs = len(self.reqs)
1141
- req_pool_indices = self.alloc_req_slots(bs)
1155
+ req_pool_indices = self.alloc_req_slots(bs, self.reqs)
1142
1156
 
1143
1157
  # Init tensors
1144
1158
  reqs = self.reqs
@@ -1207,13 +1221,36 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1207
1221
  req.is_retracted = False
1208
1222
 
1209
1223
  # Compute the relative logprob_start_len in an extend batch
1224
+ #
1225
+ # Key variables:
1226
+ # - logprob_start_len: Absolute position in full sequence where logprob computation begins
1227
+ # - extend_logprob_start_len: Relative position within current extend batch where logprob computation begins
1228
+ # - extend_input_len: Number of tokens that need to be processed in this extend batch
1229
+ # (= len(fill_ids) - len(prefix_indices), where fill_ids = origin_input_ids + output_ids
1230
+ # and prefix_indices are the cached/shared prefix tokens)
1231
+ #
1210
1232
  if req.logprob_start_len >= pre_len:
1211
- req.extend_logprob_start_len = min(
1212
- req.logprob_start_len - pre_len,
1213
- req.extend_input_len,
1214
- req.seqlen - 1,
1215
- )
1233
+ # Optimization for prefill-only requests: When we only need logprobs at
1234
+ # positions beyond the input sequence (to score next-token likelihood), skip all
1235
+ # input logprob computation during prefill since no generation will occur.
1236
+ if self.is_prefill_only and req.logprob_start_len == len(
1237
+ req.origin_input_ids
1238
+ ):
1239
+ # Skip ALL input logprobs: set extend_logprob_start_len = extend_input_len
1240
+ req.extend_logprob_start_len = req.extend_input_len
1241
+ else:
1242
+ # Convert absolute logprob_start_len to relative extend_logprob_start_len
1243
+ #
1244
+ # Example: origin_input_ids=[1,2,3,4,5] (5 tokens, positions 0-4), logprob_start_len=3
1245
+ # Regular logic: min(3-0, 5, 5-1) = min(3,5,4) = 3
1246
+ # This means: "compute logprobs from position 3 onwards in extend batch"
1247
+ req.extend_logprob_start_len = min(
1248
+ req.logprob_start_len - pre_len,
1249
+ req.extend_input_len,
1250
+ req.seqlen - 1,
1251
+ )
1216
1252
  else:
1253
+ # logprob_start_len is before the current extend batch, so start from beginning
1217
1254
  req.extend_logprob_start_len = 0
1218
1255
 
1219
1256
  if self.return_logprob:
@@ -1372,21 +1409,28 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1372
1409
  # TODO (lianmin): Revisit this. It should be seq_len - 1
1373
1410
  self.extend_logprob_start_lens.extend([0] * running_bs)
1374
1411
 
1375
- def new_page_count_next_decode(self):
1412
+ def new_page_count_next_decode(self, selected_indices: Optional[List[int]] = None):
1376
1413
  page_size = self.token_to_kv_pool_allocator.page_size
1414
+ requests = (
1415
+ self.reqs
1416
+ if selected_indices is None
1417
+ else [self.reqs[i] for i in selected_indices]
1418
+ )
1377
1419
  if page_size == 1:
1378
- return len(self.reqs)
1420
+ return len(requests)
1379
1421
  # In the decoding phase, the length of a request's KV cache should be
1380
1422
  # the total length of the request minus 1
1381
1423
  return (
1382
- sum(1 for req in self.reqs if req.seqlen % page_size == 0)
1424
+ sum(1 for req in requests if req.seqlen % page_size == 0)
1383
1425
  if self.enable_overlap
1384
- else sum(1 for req in self.reqs if (req.seqlen - 1) % page_size == 0)
1426
+ else sum(1 for req in requests if (req.seqlen - 1) % page_size == 0)
1385
1427
  )
1386
1428
 
1387
- def check_decode_mem(self, buf_multiplier=1):
1429
+ def check_decode_mem(
1430
+ self, buf_multiplier=1, selected_indices: Optional[List[int]] = None
1431
+ ):
1388
1432
  num_tokens = (
1389
- self.new_page_count_next_decode()
1433
+ self.new_page_count_next_decode(selected_indices)
1390
1434
  * buf_multiplier
1391
1435
  * self.token_to_kv_pool_allocator.page_size
1392
1436
  )
@@ -1412,34 +1456,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1412
1456
  reverse=True,
1413
1457
  )
1414
1458
 
1415
- def get_required_tokens(num_reqs: int):
1416
- headroom_for_spec_decode = 0
1417
- if server_args.speculative_algorithm:
1418
- headroom_for_spec_decode += (
1419
- num_reqs
1420
- * server_args.speculative_eagle_topk
1421
- * server_args.speculative_num_steps
1422
- + num_reqs * server_args.speculative_num_draft_tokens
1423
- )
1424
- return (
1425
- num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
1426
- )
1427
-
1428
- def _get_available_size():
1429
- if self.is_hybrid:
1430
- return min(
1431
- self.token_to_kv_pool_allocator.full_available_size(),
1432
- self.token_to_kv_pool_allocator.swa_available_size(),
1433
- )
1434
- else:
1435
- return self.token_to_kv_pool_allocator.available_size()
1436
-
1437
1459
  retracted_reqs = []
1438
1460
  seq_lens_cpu = self.seq_lens.cpu().numpy()
1439
1461
  first_iter = True
1440
- while (
1441
- _get_available_size() < get_required_tokens(len(sorted_indices))
1442
- or first_iter
1462
+ while first_iter or (
1463
+ not self.check_decode_mem(selected_indices=sorted_indices)
1443
1464
  ):
1444
1465
  if len(sorted_indices) == 1:
1445
1466
  # Corner case: only one request left
@@ -1493,10 +1514,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1493
1514
  else:
1494
1515
  self.tree_cache.dec_lock_ref(req.last_node)
1495
1516
 
1496
- # NOTE(lsyin): we should use the newly evictable memory instantly.
1497
- num_tokens = len(sorted_indices) * global_config.retract_decode_steps
1498
- self._evict_tree_cache_if_needed(num_tokens)
1499
-
1500
1517
  req.reset_for_retract()
1501
1518
 
1502
1519
  if len(retracted_reqs) == 0:
@@ -1540,7 +1557,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1540
1557
  self.forward_mode = ForwardMode.DECODE
1541
1558
  bs = len(self.reqs)
1542
1559
 
1543
- if self.spec_algorithm.is_eagle():
1560
+ if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone():
1544
1561
  # if spec decoding is used, the decode batch is prepared inside
1545
1562
  # `forward_batch_speculative_generation` after running draft models.
1546
1563
  return
@@ -1780,6 +1797,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1780
1797
  ),
1781
1798
  extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
1782
1799
  launch_done=self.launch_done,
1800
+ is_prefill_only=self.is_prefill_only,
1783
1801
  )
1784
1802
 
1785
1803
  def copy(self):
@@ -1917,11 +1935,14 @@ class ModelWorkerBatch:
1917
1935
  spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
1918
1936
  # If set, the output of the batch contains the hidden states of the run.
1919
1937
  capture_hidden_mode: CaptureHiddenMode = None
1920
- hicache_consumer_index: int = 0
1938
+ hicache_consumer_index: int = -1
1921
1939
 
1922
1940
  # Overlap event
1923
1941
  launch_done: Optional[threading.Event] = None
1924
1942
 
1943
+ # Whether this batch is prefill-only (no token generation needed)
1944
+ is_prefill_only: bool = False
1945
+
1925
1946
 
1926
1947
  @triton.jit
1927
1948
  def write_req_to_token_pool_triton(
@@ -550,7 +550,7 @@ class PrefillAdder:
550
550
  )
551
551
  else:
552
552
  # Make sure at least one page is available
553
- trunc_len = self.rem_chunk_tokens - self.page_size + 1
553
+ trunc_len = self.rem_chunk_tokens // self.page_size * self.page_size
554
554
  if trunc_len <= 0:
555
555
  return AddReqResult.OTHER
556
556