sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (205) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_offline_throughput.py +19 -0
  3. sglang/bench_one_batch.py +2 -2
  4. sglang/bench_serving.py +123 -79
  5. sglang/global_config.py +8 -3
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/lang/ir.py +1 -1
  8. sglang/srt/_custom_ops.py +83 -91
  9. sglang/srt/configs/load_config.py +4 -1
  10. sglang/srt/configs/model_config.py +48 -2
  11. sglang/srt/configs/qwen2_5_vl_config.py +5 -2
  12. sglang/srt/constrained/base_grammar_backend.py +117 -15
  13. sglang/srt/constrained/llguidance_backend.py +151 -0
  14. sglang/srt/constrained/outlines_backend.py +24 -33
  15. sglang/srt/constrained/xgrammar_backend.py +69 -38
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
  17. sglang/srt/distributed/parallel_state.py +48 -3
  18. sglang/srt/entrypoints/engine.py +67 -9
  19. sglang/srt/entrypoints/http_server.py +190 -41
  20. sglang/srt/entrypoints/verl_engine.py +147 -0
  21. sglang/srt/function_call_parser.py +0 -1
  22. sglang/srt/layers/activation.py +11 -0
  23. sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
  24. sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +302 -414
  26. sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
  27. sglang/srt/layers/attention/torch_native_backend.py +1 -1
  28. sglang/srt/layers/attention/triton_backend.py +13 -8
  29. sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
  30. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
  31. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
  32. sglang/srt/layers/attention/utils.py +39 -0
  33. sglang/srt/layers/attention/vision.py +60 -63
  34. sglang/srt/layers/dp_attention.py +142 -1
  35. sglang/srt/layers/layernorm.py +1 -1
  36. sglang/srt/layers/linear.py +3 -1
  37. sglang/srt/layers/logits_processor.py +281 -45
  38. sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +140 -28
  40. sglang/srt/layers/moe/fused_moe_native.py +2 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
  63. sglang/srt/layers/moe/topk.py +13 -4
  64. sglang/srt/layers/quantization/__init__.py +111 -7
  65. sglang/srt/layers/quantization/blockwise_int8.py +409 -0
  66. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  68. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  69. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  70. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  71. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  72. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  73. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  74. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  76. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  77. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  78. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  80. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  81. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  82. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  83. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  84. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  85. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  86. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  88. sglang/srt/layers/quantization/fp8.py +69 -28
  89. sglang/srt/layers/quantization/fp8_utils.py +17 -1
  90. sglang/srt/layers/quantization/gptq.py +416 -0
  91. sglang/srt/layers/quantization/int8_kernel.py +327 -0
  92. sglang/srt/layers/quantization/int8_utils.py +73 -0
  93. sglang/srt/layers/quantization/modelopt_quant.py +18 -1
  94. sglang/srt/layers/radix_attention.py +1 -0
  95. sglang/srt/layers/rotary_embedding.py +0 -1
  96. sglang/srt/layers/sampler.py +76 -31
  97. sglang/srt/layers/vocab_parallel_embedding.py +14 -13
  98. sglang/srt/lora/lora.py +17 -1
  99. sglang/srt/lora/lora_config.py +5 -0
  100. sglang/srt/lora/lora_manager.py +1 -3
  101. sglang/srt/managers/cache_controller.py +193 -62
  102. sglang/srt/managers/configure_logging.py +2 -1
  103. sglang/srt/managers/data_parallel_controller.py +6 -2
  104. sglang/srt/managers/detokenizer_manager.py +124 -102
  105. sglang/srt/managers/image_processor.py +2 -1
  106. sglang/srt/managers/io_struct.py +144 -6
  107. sglang/srt/managers/schedule_batch.py +237 -197
  108. sglang/srt/managers/schedule_policy.py +29 -29
  109. sglang/srt/managers/scheduler.py +773 -334
  110. sglang/srt/managers/session_controller.py +6 -2
  111. sglang/srt/managers/tokenizer_manager.py +225 -68
  112. sglang/srt/managers/tp_worker.py +15 -4
  113. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  114. sglang/srt/mem_cache/chunk_cache.py +18 -11
  115. sglang/srt/mem_cache/hiradix_cache.py +394 -0
  116. sglang/srt/mem_cache/memory_pool.py +68 -37
  117. sglang/srt/mem_cache/radix_cache.py +58 -47
  118. sglang/srt/metrics/collector.py +102 -36
  119. sglang/srt/model_executor/cuda_graph_runner.py +56 -31
  120. sglang/srt/model_executor/forward_batch_info.py +49 -16
  121. sglang/srt/model_executor/model_runner.py +280 -81
  122. sglang/srt/model_loader/loader.py +3 -3
  123. sglang/srt/model_loader/weight_utils.py +36 -14
  124. sglang/srt/models/baichuan.py +31 -6
  125. sglang/srt/models/chatglm.py +39 -7
  126. sglang/srt/models/commandr.py +29 -5
  127. sglang/srt/models/dbrx.py +31 -5
  128. sglang/srt/models/deepseek.py +43 -6
  129. sglang/srt/models/deepseek_nextn.py +32 -19
  130. sglang/srt/models/deepseek_v2.py +265 -32
  131. sglang/srt/models/exaone.py +19 -9
  132. sglang/srt/models/gemma.py +22 -8
  133. sglang/srt/models/gemma2.py +25 -12
  134. sglang/srt/models/gemma2_reward.py +5 -1
  135. sglang/srt/models/gpt2.py +28 -13
  136. sglang/srt/models/gpt_bigcode.py +27 -5
  137. sglang/srt/models/granite.py +21 -9
  138. sglang/srt/models/grok.py +21 -4
  139. sglang/srt/models/internlm2.py +36 -6
  140. sglang/srt/models/internlm2_reward.py +5 -1
  141. sglang/srt/models/llama.py +26 -9
  142. sglang/srt/models/llama_classification.py +5 -1
  143. sglang/srt/models/llama_eagle.py +17 -4
  144. sglang/srt/models/llama_embedding.py +5 -1
  145. sglang/srt/models/llama_reward.py +7 -2
  146. sglang/srt/models/llava.py +19 -3
  147. sglang/srt/models/llavavid.py +10 -1
  148. sglang/srt/models/minicpm.py +26 -2
  149. sglang/srt/models/minicpm3.py +39 -3
  150. sglang/srt/models/minicpmv.py +45 -14
  151. sglang/srt/models/mixtral.py +20 -9
  152. sglang/srt/models/mixtral_quant.py +50 -8
  153. sglang/srt/models/mllama.py +57 -11
  154. sglang/srt/models/olmo.py +34 -6
  155. sglang/srt/models/olmo2.py +34 -13
  156. sglang/srt/models/olmoe.py +26 -4
  157. sglang/srt/models/phi3_small.py +29 -10
  158. sglang/srt/models/qwen.py +26 -3
  159. sglang/srt/models/qwen2.py +26 -4
  160. sglang/srt/models/qwen2_5_vl.py +46 -8
  161. sglang/srt/models/qwen2_eagle.py +17 -5
  162. sglang/srt/models/qwen2_moe.py +44 -6
  163. sglang/srt/models/qwen2_rm.py +78 -0
  164. sglang/srt/models/qwen2_vl.py +39 -8
  165. sglang/srt/models/stablelm.py +32 -5
  166. sglang/srt/models/torch_native_llama.py +5 -2
  167. sglang/srt/models/xverse.py +21 -9
  168. sglang/srt/models/xverse_moe.py +45 -7
  169. sglang/srt/models/yivl.py +2 -1
  170. sglang/srt/openai_api/adapter.py +109 -24
  171. sglang/srt/openai_api/protocol.py +17 -1
  172. sglang/srt/reasoning_parser.py +154 -0
  173. sglang/srt/sampling/penaltylib/__init__.py +4 -6
  174. sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
  175. sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
  176. sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
  177. sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
  178. sglang/srt/sampling/sampling_batch_info.py +79 -157
  179. sglang/srt/sampling/sampling_params.py +16 -13
  180. sglang/srt/server_args.py +135 -60
  181. sglang/srt/speculative/build_eagle_tree.py +8 -9
  182. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -12
  183. sglang/srt/speculative/eagle_utils.py +92 -57
  184. sglang/srt/speculative/eagle_worker.py +238 -111
  185. sglang/srt/speculative/spec_info.py +1 -13
  186. sglang/srt/utils.py +43 -17
  187. sglang/srt/warmup.py +47 -0
  188. sglang/test/few_shot_gsm8k.py +4 -1
  189. sglang/test/runners.py +389 -126
  190. sglang/test/send_one.py +88 -0
  191. sglang/test/test_block_fp8_ep.py +361 -0
  192. sglang/test/test_programs.py +1 -1
  193. sglang/test/test_utils.py +138 -84
  194. sglang/utils.py +50 -60
  195. sglang/version.py +1 -1
  196. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/METADATA +22 -15
  197. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/RECORD +200 -166
  198. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/WHEEL +1 -1
  199. sglang/bench_latency.py +0 -1
  200. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
  201. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
  202. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
  203. sglang/test/srt/sampling/penaltylib/utils.py +0 -344
  204. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/LICENSE +0 -0
  205. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/top_level.txt +0 -0
@@ -29,6 +29,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
29
29
  It contains low-level tensor data. Most of the data consists of GPU tensors.
30
30
  """
31
31
 
32
+ import copy
32
33
  import dataclasses
33
34
  import logging
34
35
  from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
@@ -43,14 +44,15 @@ from sglang.srt.configs.model_config import ModelConfig
43
44
  from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
44
45
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
45
46
  from sglang.srt.mem_cache.chunk_cache import ChunkCache
46
- from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
47
+ from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
47
48
  from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
48
49
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
49
50
  from sglang.srt.sampling.sampling_params import SamplingParams
50
51
  from sglang.srt.server_args import ServerArgs
51
52
 
52
53
  if TYPE_CHECKING:
53
- from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm
54
+ from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
55
+ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
54
56
 
55
57
  INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
56
58
 
@@ -65,8 +67,11 @@ global_server_args_dict = {
65
67
  "enable_dp_attention": ServerArgs.enable_dp_attention,
66
68
  "enable_ep_moe": ServerArgs.enable_ep_moe,
67
69
  "device": ServerArgs.device,
70
+ "speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
71
+ "speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
68
72
  "enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
69
73
  "disable_radix_cache": ServerArgs.disable_radix_cache,
74
+ "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
70
75
  }
71
76
 
72
77
  logger = logging.getLogger(__name__)
@@ -229,12 +234,14 @@ class Req:
229
234
  sampling_params: SamplingParams,
230
235
  return_logprob: bool = False,
231
236
  top_logprobs_num: int = 0,
237
+ token_ids_logprob: List[int] = None,
232
238
  stream: bool = False,
233
239
  origin_input_ids_unpadded: Optional[Tuple[int]] = None,
234
240
  lora_path: Optional[str] = None,
235
241
  input_embeds: Optional[List[List[float]]] = None,
236
242
  session_id: Optional[str] = None,
237
243
  custom_logit_processor: Optional[str] = None,
244
+ return_hidden_states: bool = False,
238
245
  eos_token_ids: Optional[Set[int]] = None,
239
246
  ):
240
247
  # Input and output info
@@ -254,16 +261,27 @@ class Req:
254
261
  self.input_embeds = input_embeds
255
262
 
256
263
  # Sampling info
264
+ if isinstance(sampling_params.custom_params, dict):
265
+ sampling_params = copy.copy(sampling_params)
266
+ sampling_params.custom_params = sampling_params.custom_params | {
267
+ "__req__": self
268
+ }
257
269
  self.sampling_params = sampling_params
270
+
258
271
  self.custom_logit_processor = custom_logit_processor
272
+ self.return_hidden_states = return_hidden_states
259
273
 
260
274
  # Memory pool info
261
- self.req_pool_idx = None
275
+ self.req_pool_idx: Optional[int] = None
262
276
 
263
277
  # Check finish
264
278
  self.tokenizer = None
265
279
  self.finished_reason = None
280
+ # If we want to abort the request in the middle of the event loop, set this to true
281
+ # Note: We should never set finished_reason in the middle, the req will get filtered and never respond
266
282
  self.to_abort = False
283
+ # This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
284
+ self.to_abort_message: str = "Unknown error"
267
285
  self.stream = stream
268
286
  self.eos_token_ids = eos_token_ids
269
287
 
@@ -276,7 +294,6 @@ class Req:
276
294
  # 1: surr_offset
277
295
  # 2: read_offset
278
296
  # 3: last token
279
- self.vid = 0 # version id to sync decode status with in detokenizer_manager
280
297
  self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
281
298
  self.read_offset = None
282
299
  self.decoded_text = ""
@@ -285,47 +302,58 @@ class Req:
285
302
  self.image_inputs: Optional[ImageInputs] = None
286
303
 
287
304
  # Prefix info
305
+ # The indices to kv cache for the shared prefix.
288
306
  self.prefix_indices = []
289
- # Tokens to run prefill. input_tokens - shared_prefix_tokens.
290
- # Updated if chunked.
307
+ # Number of tokens to run prefill.
291
308
  self.extend_input_len = 0
309
+ # The relative logprob_start_len in an extend batch
310
+ self.extend_logprob_start_len = 0
292
311
  self.last_node = None
293
312
 
294
- # Chunked prefill
295
- self.is_being_chunked = 0
313
+ # Whether or not if it is chunked. It increments whenever
314
+ # it is chunked, and decrement whenever chunked request is
315
+ # processed.
316
+ self.is_chunked = 0
296
317
 
297
318
  # For retraction
298
319
  self.is_retracted = False
299
320
 
300
321
  # Logprobs (arguments)
301
322
  self.return_logprob = return_logprob
323
+ # Start index to compute logprob from.
302
324
  self.logprob_start_len = 0
303
325
  self.top_logprobs_num = top_logprobs_num
326
+ self.token_ids_logprob = token_ids_logprob
304
327
 
305
328
  # Logprobs (return values)
306
329
  self.input_token_logprobs_val: Optional[List[float]] = None
307
330
  self.input_token_logprobs_idx: Optional[List[int]] = None
308
331
  self.input_top_logprobs_val: Optional[List[float]] = None
309
332
  self.input_top_logprobs_idx: Optional[List[int]] = None
333
+ self.input_token_ids_logprobs_val: Optional[List[float]] = None
334
+ self.input_token_ids_logprobs_idx: Optional[List[int]] = None
335
+ # Temporary holder to store input_token_logprobs.
336
+ self.input_token_logprobs: Optional[List[Tuple[int]]] = None
337
+ self.temp_input_top_logprobs_val: Optional[List[torch.Tensor]] = None
338
+ self.temp_input_top_logprobs_idx: Optional[List[int]] = None
339
+ self.temp_input_token_ids_logprobs_val: Optional[List[float]] = None
340
+ self.temp_input_token_ids_logprobs_idx: Optional[List[int]] = None
310
341
 
311
342
  if return_logprob:
312
343
  self.output_token_logprobs_val = []
313
344
  self.output_token_logprobs_idx = []
314
345
  self.output_top_logprobs_val = []
315
346
  self.output_top_logprobs_idx = []
347
+ self.output_token_ids_logprobs_val = []
348
+ self.output_token_ids_logprobs_idx = []
316
349
  else:
317
350
  self.output_token_logprobs_val = self.output_token_logprobs_idx = (
318
351
  self.output_top_logprobs_val
319
- ) = self.output_top_logprobs_idx = None
352
+ ) = self.output_top_logprobs_idx = self.output_token_ids_logprobs_val = (
353
+ self.output_token_ids_logprobs_idx
354
+ ) = None
320
355
  self.hidden_states = []
321
356
 
322
- # Logprobs (internal values)
323
- # The tokens is prefilled but need to be considered as decode tokens
324
- # and should be updated for the decode logprobs
325
- self.last_update_decode_tokens = 0
326
- # The relative logprob_start_len in an extend batch
327
- self.extend_logprob_start_len = 0
328
-
329
357
  # Embedding (return values)
330
358
  self.embedding = None
331
359
 
@@ -341,6 +369,10 @@ class Req:
341
369
  self.spec_verify_ct = 0
342
370
  self.lora_path = lora_path
343
371
 
372
+ @property
373
+ def seqlen(self):
374
+ return len(self.origin_input_ids) + len(self.output_ids)
375
+
344
376
  def extend_image_inputs(self, image_inputs):
345
377
  if self.image_inputs is None:
346
378
  self.image_inputs = image_inputs
@@ -418,7 +450,9 @@ class Req:
418
450
  return
419
451
 
420
452
  if self.to_abort:
421
- self.finished_reason = FINISH_ABORT()
453
+ self.finished_reason = FINISH_ABORT(
454
+ message=self.to_abort_message,
455
+ )
422
456
  return
423
457
 
424
458
  if len(self.output_ids) >= self.sampling_params.max_new_tokens:
@@ -458,81 +492,22 @@ class Req:
458
492
  self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
459
493
  return
460
494
 
461
- def jump_forward_and_retokenize(self, jump_forward_str, next_state):
462
- if self.origin_input_text is None:
463
- # Recovering text can only use unpadded ids
464
- self.origin_input_text = self.tokenizer.decode(
465
- self.origin_input_ids_unpadded
466
- )
467
-
468
- all_text = self.origin_input_text + self.decoded_text + jump_forward_str
469
- all_ids = self.tokenizer.encode(all_text)
470
- if not all_ids:
471
- logger.warning("Encoded all_text resulted in empty all_ids")
472
- return False
473
-
474
- prompt_tokens = len(self.origin_input_ids_unpadded)
475
- if prompt_tokens > len(all_ids):
476
- logger.warning("prompt_tokens is larger than encoded all_ids")
477
- return False
478
-
479
- if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
480
- # TODO(lsyin): fix token fusion
481
- logger.warning(
482
- "Token fusion between input and output, try to avoid this by removing the space at the end of the input."
483
- )
484
- return False
485
-
486
- old_output_ids = self.output_ids
487
- self.output_ids = all_ids[prompt_tokens:]
488
- self.decoded_text = self.decoded_text + jump_forward_str
489
- self.surr_offset = prompt_tokens
490
- self.read_offset = len(all_ids)
491
-
492
- # NOTE: A trick to reduce the surrouding tokens decoding overhead
493
- for i in range(0, INIT_INCREMENTAL_DETOKENIZATION_OFFSET):
494
- surr_text_ = self.tokenizer.decode(
495
- all_ids[self.read_offset - i : self.read_offset]
496
- )
497
- if not surr_text_.endswith("�"):
498
- self.surr_offset = self.read_offset - i
499
- break
500
-
501
- # update the inner state of the grammar
502
- self.grammar.jump_and_retokenize(old_output_ids, self.output_ids, next_state)
503
-
504
- if self.return_logprob:
505
- # For fast-forward part's logprobs
506
- k = 0
507
- for i, old_id in enumerate(old_output_ids):
508
- if old_id == self.output_ids[i]:
509
- k = k + 1
510
- else:
511
- break
512
- self.output_token_logprobs_val = self.output_token_logprobs_val[:k]
513
- self.output_token_logprobs_idx = self.output_token_logprobs_idx[:k]
514
- self.output_top_logprobs_val = self.output_top_logprobs_val[:k]
515
- self.output_top_logprobs_idx = self.output_top_logprobs_idx[:k]
516
- self.logprob_start_len = prompt_tokens + k
517
- self.last_update_decode_tokens = len(self.output_ids) - k
518
-
519
- return True
520
-
521
495
  def reset_for_retract(self):
522
496
  self.prefix_indices = []
523
497
  self.last_node = None
524
498
  self.extend_input_len = 0
525
499
  self.is_retracted = True
526
-
527
- # For incremental logprobs
528
- # TODO: Fix the `logprob_start_len`
529
- self.last_update_decode_tokens = 0
530
- self.logprob_start_len = 10**9
500
+ self.input_token_logprobs = None
501
+ self.temp_input_top_logprobs_val = None
502
+ self.temp_input_top_logprobs_idx = None
503
+ self.extend_logprob_start_len = 0
504
+ self.is_chunked = 0
505
+ self.req_pool_idx = None
531
506
 
532
507
  def __repr__(self):
533
508
  return (
534
- f"rid(n={self.rid}, "
535
- f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}"
509
+ f"Req(rid={self.rid}, "
510
+ f"input_ids={self.origin_input_ids}, output_ids={self.output_ids})"
536
511
  )
537
512
 
538
513
 
@@ -546,7 +521,7 @@ class ScheduleBatch:
546
521
  # Request, memory pool, and cache
547
522
  reqs: List[Req]
548
523
  req_to_token_pool: ReqToTokenPool = None
549
- token_to_kv_pool: BaseTokenToKVPool = None
524
+ token_to_kv_pool_allocator: TokenToKVPoolAllocator = None
550
525
  tree_cache: BasePrefixCache = None
551
526
 
552
527
  # Batch configs
@@ -572,11 +547,13 @@ class ScheduleBatch:
572
547
 
573
548
  # For DP attention
574
549
  global_num_tokens: Optional[List[int]] = None
550
+ global_num_tokens_for_logprob: Optional[List[int]] = None
575
551
  can_run_dp_cuda_graph: bool = False
576
552
 
577
553
  # For processing logprobs
578
554
  return_logprob: bool = False
579
555
  top_logprobs_nums: Optional[List[int]] = None
556
+ token_ids_logprobs: Optional[List[List[int]]] = None
580
557
 
581
558
  # For extend and mixed chunekd prefill
582
559
  prefix_lens: List[int] = None
@@ -584,6 +561,8 @@ class ScheduleBatch:
584
561
  extend_num_tokens: int = None
585
562
  decoding_reqs: List[Req] = None
586
563
  extend_logprob_start_lens: List[int] = None
564
+ # It comes empty list if logprob is not required.
565
+ extend_input_logprob_token_ids: Optional[torch.Tensor] = None
587
566
 
588
567
  # For encoder-decoder
589
568
  encoder_cached: Optional[List[bool]] = None
@@ -602,12 +581,12 @@ class ScheduleBatch:
602
581
 
603
582
  # Speculative decoding
604
583
  spec_algorithm: SpeculativeAlgorithm = None
605
- spec_info: Optional[SpecInfo] = None
584
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None
606
585
 
607
586
  # Enable custom logit processor
608
587
  enable_custom_logit_processor: bool = False
609
588
 
610
- # Return hidden states
589
+ # Whether to return hidden states
611
590
  return_hidden_states: bool = False
612
591
 
613
592
  @classmethod
@@ -615,18 +594,17 @@ class ScheduleBatch:
615
594
  cls,
616
595
  reqs: List[Req],
617
596
  req_to_token_pool: ReqToTokenPool,
618
- token_to_kv_pool: ReqToTokenPool,
597
+ token_to_kv_pool_allocator: TokenToKVPoolAllocator,
619
598
  tree_cache: BasePrefixCache,
620
599
  model_config: ModelConfig,
621
600
  enable_overlap: bool,
622
601
  spec_algorithm: SpeculativeAlgorithm,
623
602
  enable_custom_logit_processor: bool,
624
- return_hidden_states: bool = False,
625
603
  ):
626
604
  return cls(
627
605
  reqs=reqs,
628
606
  req_to_token_pool=req_to_token_pool,
629
- token_to_kv_pool=token_to_kv_pool,
607
+ token_to_kv_pool_allocator=token_to_kv_pool_allocator,
630
608
  tree_cache=tree_cache,
631
609
  model_config=model_config,
632
610
  enable_overlap=enable_overlap,
@@ -636,7 +614,7 @@ class ScheduleBatch:
636
614
  device=req_to_token_pool.device,
637
615
  spec_algorithm=spec_algorithm,
638
616
  enable_custom_logit_processor=enable_custom_logit_processor,
639
- return_hidden_states=return_hidden_states,
617
+ return_hidden_states=any(req.return_hidden_states for req in reqs),
640
618
  )
641
619
 
642
620
  def batch_size(self):
@@ -649,25 +627,27 @@ class ScheduleBatch:
649
627
  req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
650
628
  if req_pool_indices is None:
651
629
  raise RuntimeError(
652
- "Out of memory. "
653
- "Please set a smaller number for `--max-running-requests`."
630
+ "alloc_req_slots runs out of memory. "
631
+ "Please set a smaller number for `--max-running-requests`. "
632
+ f"{self.req_to_token_pool.available_size()=}, "
633
+ f"{num_reqs=}, "
654
634
  )
655
635
  return req_pool_indices
656
636
 
657
637
  def alloc_token_slots(self, num_tokens: int):
658
- out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)
638
+ out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
659
639
 
660
640
  if out_cache_loc is None:
661
641
  if self.tree_cache is not None:
662
- self.tree_cache.evict(num_tokens, self.token_to_kv_pool.free)
663
- out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)
642
+ self.tree_cache.evict(num_tokens, self.token_to_kv_pool_allocator.free)
643
+ out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
664
644
 
665
645
  if out_cache_loc is None:
666
646
  phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
667
647
  logger.error(
668
648
  f"{phase_str} out of memory. Try to lower your batch size.\n"
669
649
  f"Try to allocate {num_tokens} tokens.\n"
670
- f"Avaliable tokens: {self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()}\n"
650
+ f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
671
651
  )
672
652
  if self.tree_cache is not None:
673
653
  self.tree_cache.pretty_print()
@@ -761,6 +741,7 @@ class ScheduleBatch:
761
741
  out_cache_loc = self.alloc_token_slots(extend_num_tokens)
762
742
 
763
743
  input_embeds = []
744
+ extend_input_logprob_token_ids = []
764
745
 
765
746
  pt = 0
766
747
  for i, req in enumerate(reqs):
@@ -779,22 +760,64 @@ class ScheduleBatch:
779
760
  # If req.input_embeds is already a list, append its content directly
780
761
  input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
781
762
 
782
- if req.return_logprob:
783
- # Compute the relative logprob_start_len in an extend batch
784
- if req.logprob_start_len >= pre_len:
785
- extend_logprob_start_len = min(
786
- req.logprob_start_len - pre_len, req.extend_input_len - 1
787
- )
788
- else:
789
- raise RuntimeError(
790
- f"This should never happen. {req.logprob_start_len=}, {pre_len=}"
791
- )
792
- req.extend_logprob_start_len = extend_logprob_start_len
793
-
794
763
  req.cached_tokens += pre_len - req.already_computed
795
764
  req.already_computed = seq_len
796
765
  req.is_retracted = False
797
766
  pre_lens.append(pre_len)
767
+ # Compute the relative logprob_start_len in an extend batch
768
+ if req.logprob_start_len >= pre_len:
769
+ req.extend_logprob_start_len = min(
770
+ req.logprob_start_len - pre_len,
771
+ req.extend_input_len,
772
+ req.seqlen - 1,
773
+ )
774
+ else:
775
+ req.extend_logprob_start_len = 0
776
+
777
+ if self.return_logprob:
778
+ # Find input logprob token ids.
779
+ # First, find a global index within origin_input_ids and slide it by 1
780
+ # to compute input logprobs. It is because you need the next token
781
+ # to compute input logprobs. E.g., (chunk size 2)
782
+ #
783
+ # input_logprobs = [1, 2, 3, 4]
784
+ # fill_ids = [1, 2]
785
+ # extend_input_logprob_token_id = [2, 3]
786
+ #
787
+ # Note that it can also overflow. In this case, we pad it with 0.
788
+ # input_logprobs = [1, 2, 3, 4]
789
+ # fill_ids = [3, 4]
790
+ # extend_input_logprob_token_id = [4, 0]
791
+ global_start_idx, global_end_idx = (
792
+ len(req.prefix_indices),
793
+ len(req.fill_ids),
794
+ )
795
+ # Apply logprob_start_len
796
+ if global_start_idx < req.logprob_start_len:
797
+ global_start_idx = req.logprob_start_len
798
+
799
+ logprob_token_ids = req.origin_input_ids[
800
+ global_start_idx + 1 : global_end_idx + 1
801
+ ]
802
+ extend_input_logprob_token_ids.extend(logprob_token_ids)
803
+
804
+ # We will need req.extend_input_len - req.extend_logprob_start_len number of
805
+ # tokens, and logprob_token_ids is for input logprob, so pad the rest of them by 0.
806
+ extend_input_logprob_token_ids.extend(
807
+ [0]
808
+ * (
809
+ req.extend_input_len
810
+ - req.extend_logprob_start_len
811
+ - len(logprob_token_ids)
812
+ )
813
+ )
814
+
815
+ if self.return_logprob:
816
+ extend_input_logprob_token_ids = torch.tensor(
817
+ extend_input_logprob_token_ids
818
+ )
819
+ else:
820
+ extend_input_logprob_token_ids = None
798
821
 
799
822
  # Set fields
800
823
  self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
@@ -817,10 +840,12 @@ class ScheduleBatch:
817
840
  self.seq_lens_sum = sum(seq_lens)
818
841
  if self.return_logprob:
819
842
  self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
843
+ self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
820
844
  self.extend_num_tokens = extend_num_tokens
821
845
  self.prefix_lens = [len(r.prefix_indices) for r in reqs]
822
846
  self.extend_lens = [r.extend_input_len for r in reqs]
823
847
  self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
848
+ self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
824
849
 
825
850
  # Write to req_to_token_pool
826
851
  pre_lens = torch.tensor(pre_lens, dtype=torch.int32).to(
@@ -856,7 +881,6 @@ class ScheduleBatch:
856
881
  self.sampling_info = SamplingBatchInfo.from_schedule_batch(
857
882
  self,
858
883
  self.model_config.vocab_size,
859
- enable_overlap_schedule=self.enable_overlap,
860
884
  )
861
885
 
862
886
  def mix_with_running(self, running_batch: "ScheduleBatch"):
@@ -891,41 +915,60 @@ class ScheduleBatch:
891
915
 
892
916
  def check_decode_mem(self, buf_multiplier=1):
893
917
  bs = len(self.reqs) * buf_multiplier
894
- if self.token_to_kv_pool.available_size() >= bs:
918
+ if self.token_to_kv_pool_allocator.available_size() >= bs:
895
919
  return True
896
920
 
897
- self.tree_cache.evict(bs, self.token_to_kv_pool.free)
921
+ self.tree_cache.evict(bs, self.token_to_kv_pool_allocator.free)
898
922
 
899
- if self.token_to_kv_pool.available_size() >= bs:
923
+ if self.token_to_kv_pool_allocator.available_size() >= bs:
900
924
  return True
901
925
 
902
926
  return False
903
927
 
904
- def retract_decode(self):
928
+ def retract_decode(self, server_args: ServerArgs):
905
929
  """Retract the decoding requests when there is not enough memory."""
906
930
  sorted_indices = [i for i in range(len(self.reqs))]
907
931
 
908
932
  # TODO(lsyin): improve retraction policy for radix cache
909
- sorted_indices.sort(
910
- key=lambda i: (
911
- len(self.reqs[i].output_ids),
912
- -len(self.reqs[i].origin_input_ids),
913
- ),
914
- reverse=True,
915
- )
933
+ # For spec decoding, filter_batch API can only filter
934
+ # requests from the back, so we can only retract from the back.
935
+ # TODO(sang): Clean up finish path and support better retract
936
+ # policy.
937
+ if not server_args.speculative_algorithm:
938
+ sorted_indices.sort(
939
+ key=lambda i: (
940
+ len(self.reqs[i].output_ids),
941
+ -len(self.reqs[i].origin_input_ids),
942
+ ),
943
+ reverse=True,
944
+ )
916
945
 
917
946
  retracted_reqs = []
918
947
  seq_lens_cpu = self.seq_lens.cpu().numpy()
919
948
  first_iter = True
949
+
950
+ def get_required_tokens(num_reqs: int):
951
+ headroom_for_spec_decode = 0
952
+ if server_args.speculative_algorithm:
953
+ headroom_for_spec_decode += (
954
+ num_reqs
955
+ * server_args.speculative_eagle_topk
956
+ * server_args.speculative_num_steps
957
+ + num_reqs * server_args.speculative_num_draft_tokens
958
+ )
959
+ return (
960
+ num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
961
+ )
962
+
920
963
  while (
921
- self.token_to_kv_pool.available_size()
922
- < len(sorted_indices) * global_config.retract_decode_steps
964
+ self.token_to_kv_pool_allocator.available_size()
965
+ < get_required_tokens(len(sorted_indices))
923
966
  or first_iter
924
967
  ):
925
968
  if len(sorted_indices) == 1:
926
969
  # Corner case: only one request left
927
970
  assert (
928
- self.token_to_kv_pool.available_size() > 0
971
+ self.token_to_kv_pool_allocator.available_size() > 0
929
972
  ), "No space left for only one request"
930
973
  break
931
974
 
@@ -939,7 +982,7 @@ class ScheduleBatch:
939
982
  token_indices = self.req_to_token_pool.req_to_token[
940
983
  req.req_pool_idx, : seq_lens_cpu[idx]
941
984
  ]
942
- self.token_to_kv_pool.free(token_indices)
985
+ self.token_to_kv_pool_allocator.free(token_indices)
943
986
  self.req_to_token_pool.free(req.req_pool_idx)
944
987
  del self.tree_cache.entries[req.rid]
945
988
  else:
@@ -948,7 +991,7 @@ class ScheduleBatch:
948
991
  token_indices = self.req_to_token_pool.req_to_token[
949
992
  req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
950
993
  ]
951
- self.token_to_kv_pool.free(token_indices)
994
+ self.token_to_kv_pool_allocator.free(token_indices)
952
995
  self.req_to_token_pool.free(req.req_pool_idx)
953
996
 
954
997
  # release the last node
@@ -957,10 +1000,13 @@ class ScheduleBatch:
957
1000
  # NOTE(lsyin): we should use the newly evictable memory instantly.
958
1001
  residual_size = (
959
1002
  len(sorted_indices) * global_config.retract_decode_steps
960
- - self.token_to_kv_pool.available_size()
1003
+ - self.token_to_kv_pool_allocator.available_size()
961
1004
  )
962
1005
  residual_size = max(0, residual_size)
963
- self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
1006
+ self.tree_cache.evict(
1007
+ residual_size, self.token_to_kv_pool_allocator.free
1008
+ )
1009
+
964
1010
  req.reset_for_retract()
965
1011
 
966
1012
  self.filter_batch(keep_indices=sorted_indices)
@@ -976,59 +1022,6 @@ class ScheduleBatch:
976
1022
 
977
1023
  return retracted_reqs, new_estimate_ratio
978
1024
 
979
- def check_for_jump_forward(self, pad_input_ids_func):
980
- jump_forward_reqs = []
981
- keep_indices = set(i for i in range(len(self.reqs)))
982
-
983
- for i, req in enumerate(self.reqs):
984
- if req.grammar is not None:
985
- jump_helper = req.grammar.try_jump_forward(req.tokenizer)
986
- if jump_helper:
987
- suffix_ids, _ = jump_helper
988
-
989
- # Current ids, for cache and revert
990
- cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
991
- cur_output_ids = req.output_ids
992
-
993
- req.output_ids.extend(suffix_ids)
994
- decode_res, new_text = req.get_next_inc_detokenization()
995
- if not decode_res:
996
- req.output_ids = cur_output_ids
997
- continue
998
-
999
- (
1000
- jump_forward_str,
1001
- next_state,
1002
- ) = req.grammar.jump_forward_str_state(jump_helper)
1003
-
1004
- # Make the incrementally decoded text part of jump_forward_str
1005
- # so that the UTF-8 will not corrupt
1006
- jump_forward_str = new_text + jump_forward_str
1007
- if not req.jump_forward_and_retokenize(
1008
- jump_forward_str, next_state
1009
- ):
1010
- req.output_ids = cur_output_ids
1011
- continue
1012
-
1013
- # The decode status has diverged from detokenizer_manager
1014
- req.vid += 1
1015
-
1016
- # insert the old request into tree_cache
1017
- self.tree_cache.cache_finished_req(req, cur_all_ids)
1018
-
1019
- # re-applying image padding
1020
- if req.image_inputs is not None:
1021
- req.origin_input_ids = pad_input_ids_func(
1022
- req.origin_input_ids_unpadded, req.image_inputs
1023
- )
1024
-
1025
- jump_forward_reqs.append(req)
1026
- keep_indices.remove(i)
1027
-
1028
- self.filter_batch(keep_indices=list(keep_indices))
1029
-
1030
- return jump_forward_reqs
1031
-
1032
1025
  def prepare_encoder_info_decode(self):
1033
1026
  # Reset the encoder cached status
1034
1027
  self.encoder_cached = [True] * len(self.reqs)
@@ -1044,17 +1037,40 @@ class ScheduleBatch:
1044
1037
  self.sampling_info = SamplingBatchInfo.from_schedule_batch(
1045
1038
  self,
1046
1039
  self.model_config.vocab_size,
1047
- enable_overlap_schedule=self.enable_overlap,
1048
1040
  )
1049
1041
 
1050
1042
  def prepare_for_decode(self):
1051
1043
  self.forward_mode = ForwardMode.DECODE
1052
1044
  if self.spec_algorithm.is_eagle():
1045
+ # if spec decoding is used, the decode batch is prepared inside
1046
+ # `forward_batch_speculative_generation` after running draft models.
1053
1047
  return
1054
1048
 
1049
+ if self.sampling_info.penalizer_orchestrator.is_required:
1050
+ if self.enable_overlap:
1051
+ # TODO: this can be slow, optimize this.
1052
+ delayed_output_ids = torch.tensor(
1053
+ [
1054
+ (
1055
+ req.output_ids[-1]
1056
+ if len(req.output_ids)
1057
+ else req.origin_input_ids[-1]
1058
+ )
1059
+ for req in self.reqs
1060
+ ],
1061
+ dtype=torch.int64,
1062
+ device=self.device,
1063
+ )
1064
+ self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
1065
+ delayed_output_ids
1066
+ )
1067
+ else:
1068
+ self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
1069
+ self.output_ids.to(torch.int64)
1070
+ )
1071
+
1055
1072
  self.input_ids = self.output_ids
1056
1073
  self.output_ids = None
1057
- self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(self.input_ids)
1058
1074
 
1059
1075
  # Alloc mem
1060
1076
  bs = len(self.reqs)
@@ -1082,14 +1098,15 @@ class ScheduleBatch:
1082
1098
 
1083
1099
  def filter_batch(
1084
1100
  self,
1085
- being_chunked_req: Optional[Req] = None,
1101
+ chunked_req_to_exclude: Optional[Req] = None,
1086
1102
  keep_indices: Optional[List[int]] = None,
1087
1103
  ):
1088
1104
  if keep_indices is None:
1089
1105
  keep_indices = [
1090
1106
  i
1091
1107
  for i in range(len(self.reqs))
1092
- if not self.reqs[i].finished() and self.reqs[i] is not being_chunked_req
1108
+ if not self.reqs[i].finished()
1109
+ and self.reqs[i] is not chunked_req_to_exclude
1093
1110
  ]
1094
1111
 
1095
1112
  if keep_indices is None or len(keep_indices) == 0:
@@ -1101,31 +1118,34 @@ class ScheduleBatch:
1101
1118
  # No need to filter
1102
1119
  return
1103
1120
 
1121
+ keep_indices_device = torch.tensor(keep_indices, dtype=torch.int64).to(
1122
+ self.device, non_blocking=True
1123
+ )
1124
+
1104
1125
  if self.model_config.is_encoder_decoder:
1105
- self.encoder_lens = self.encoder_lens[keep_indices]
1126
+ self.encoder_lens = self.encoder_lens[keep_indices_device]
1106
1127
  self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]
1107
1128
 
1108
1129
  self.reqs = [self.reqs[i] for i in keep_indices]
1109
- new_indices = torch.tensor(keep_indices, dtype=torch.int64).to(
1110
- self.device, non_blocking=True
1111
- )
1112
- self.req_pool_indices = self.req_pool_indices[new_indices]
1113
- self.seq_lens = self.seq_lens[new_indices]
1130
+ self.req_pool_indices = self.req_pool_indices[keep_indices_device]
1131
+ self.seq_lens = self.seq_lens[keep_indices_device]
1114
1132
  self.out_cache_loc = None
1115
1133
  self.seq_lens_sum = self.seq_lens.sum().item()
1116
- self.output_ids = self.output_ids[new_indices]
1134
+ self.output_ids = self.output_ids[keep_indices_device]
1117
1135
  self.return_logprob = any(req.return_logprob for req in self.reqs)
1118
1136
  if self.return_logprob:
1119
1137
  self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
1138
+ self.token_ids_logprobs = [self.token_ids_logprobs[i] for i in keep_indices]
1120
1139
  else:
1121
1140
  self.top_logprobs_nums = None
1141
+ self.token_ids_logprobs = None
1122
1142
 
1123
1143
  self.has_stream = any(req.stream for req in self.reqs)
1124
1144
  self.has_grammar = any(req.grammar for req in self.reqs)
1125
1145
 
1126
- self.sampling_info.filter_batch(keep_indices, new_indices)
1146
+ self.sampling_info.filter_batch(keep_indices, keep_indices_device)
1127
1147
  if self.spec_info:
1128
- self.spec_info.filter_batch(new_indices)
1148
+ self.spec_info.filter_batch(keep_indices_device)
1129
1149
 
1130
1150
  def merge_batch(self, other: "ScheduleBatch"):
1131
1151
  # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
@@ -1148,23 +1168,32 @@ class ScheduleBatch:
1148
1168
  self.output_ids = torch.concat([self.output_ids, other.output_ids])
1149
1169
  if self.return_logprob and other.return_logprob:
1150
1170
  self.top_logprobs_nums.extend(other.top_logprobs_nums)
1171
+ self.token_ids_logprobs.extend(other.token_ids_logprobs)
1151
1172
  elif self.return_logprob:
1152
1173
  self.top_logprobs_nums.extend([0] * len(other.reqs))
1174
+ self.token_ids_logprobs.extend([None] * len(other.reqs))
1153
1175
  elif other.return_logprob:
1154
1176
  self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
1177
+ self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs
1155
1178
  self.reqs.extend(other.reqs)
1156
1179
 
1157
1180
  self.return_logprob |= other.return_logprob
1158
1181
  self.has_stream |= other.has_stream
1159
1182
  self.has_grammar |= other.has_grammar
1183
+ self.return_hidden_states |= other.return_hidden_states
1160
1184
 
1161
1185
  if self.spec_info:
1162
1186
  self.spec_info.merge_batch(other.spec_info)
1163
1187
 
1164
- def get_model_worker_batch(self):
1188
+ def get_model_worker_batch(self) -> ModelWorkerBatch:
1165
1189
  if self.forward_mode.is_decode_or_idle():
1190
+ if global_server_args_dict["enable_flashinfer_mla"]:
1191
+ decode_seq_lens = self.seq_lens.cpu()
1192
+ else:
1193
+ decode_seq_lens = None
1166
1194
  extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
1167
1195
  else:
1196
+ decode_seq_lens = None
1168
1197
  extend_seq_lens = self.extend_lens
1169
1198
  extend_prefix_lens = self.prefix_lens
1170
1199
  extend_logprob_start_lens = self.extend_logprob_start_lens
@@ -1187,8 +1216,11 @@ class ScheduleBatch:
1187
1216
  seq_lens_sum=self.seq_lens_sum,
1188
1217
  return_logprob=self.return_logprob,
1189
1218
  top_logprobs_nums=self.top_logprobs_nums,
1219
+ token_ids_logprobs=self.token_ids_logprobs,
1190
1220
  global_num_tokens=self.global_num_tokens,
1221
+ global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
1191
1222
  can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
1223
+ decode_seq_lens=decode_seq_lens,
1192
1224
  extend_num_tokens=self.extend_num_tokens,
1193
1225
  extend_seq_lens=extend_seq_lens,
1194
1226
  extend_prefix_lens=extend_prefix_lens,
@@ -1214,6 +1246,7 @@ class ScheduleBatch:
1214
1246
  else CaptureHiddenMode.NULL
1215
1247
  )
1216
1248
  ),
1249
+ extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
1217
1250
  )
1218
1251
 
1219
1252
  def copy(self):
@@ -1248,7 +1281,7 @@ class ModelWorkerBatch:
1248
1281
  req_pool_indices: torch.Tensor
1249
1282
  # The sequence length
1250
1283
  seq_lens: torch.Tensor
1251
- # The indices of output tokens in the token_to_kv_pool
1284
+ # The indices of output tokens in the token_to_kv_pool_allocator
1252
1285
  out_cache_loc: torch.Tensor
1253
1286
 
1254
1287
  # The sum of all sequence lengths
@@ -1257,16 +1290,22 @@ class ModelWorkerBatch:
1257
1290
  # For logprob
1258
1291
  return_logprob: bool
1259
1292
  top_logprobs_nums: Optional[List[int]]
1293
+ token_ids_logprobs: Optional[List[List[int]]]
1260
1294
 
1261
1295
  # For DP attention
1262
1296
  global_num_tokens: Optional[List[int]]
1297
+ global_num_tokens_for_logprob: Optional[List[int]]
1263
1298
  can_run_dp_cuda_graph: bool
1264
1299
 
1300
+ # For decode
1301
+ decode_seq_lens: Optional[torch.Tensor]
1302
+
1265
1303
  # For extend
1266
1304
  extend_num_tokens: Optional[int]
1267
1305
  extend_seq_lens: Optional[List[int]]
1268
1306
  extend_prefix_lens: Optional[List[int]]
1269
1307
  extend_logprob_start_lens: Optional[List[int]]
1308
+ extend_input_logprob_token_ids: Optional[torch.Tensor]
1270
1309
 
1271
1310
  # For multimodal
1272
1311
  image_inputs: Optional[List[ImageInputs]]
@@ -1288,7 +1327,8 @@ class ModelWorkerBatch:
1288
1327
 
1289
1328
  # Speculative decoding
1290
1329
  spec_algorithm: SpeculativeAlgorithm = None
1291
- spec_info: Optional[SpecInfo] = None
1330
+ spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
1331
+ # If set, the output of the batch contains the hidden states of the run.
1292
1332
  capture_hidden_mode: CaptureHiddenMode = None
1293
1333
 
1294
1334