sglang 0.4.3.post1__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_offline_throughput.py +19 -0
  3. sglang/bench_one_batch.py +2 -2
  4. sglang/bench_serving.py +123 -79
  5. sglang/global_config.py +8 -3
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/lang/ir.py +1 -1
  8. sglang/srt/_custom_ops.py +83 -91
  9. sglang/srt/configs/load_config.py +4 -1
  10. sglang/srt/configs/model_config.py +48 -2
  11. sglang/srt/configs/qwen2_5_vl_config.py +5 -2
  12. sglang/srt/constrained/base_grammar_backend.py +117 -15
  13. sglang/srt/constrained/llguidance_backend.py +151 -0
  14. sglang/srt/constrained/outlines_backend.py +24 -33
  15. sglang/srt/constrained/xgrammar_backend.py +69 -38
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
  17. sglang/srt/distributed/parallel_state.py +48 -3
  18. sglang/srt/entrypoints/engine.py +67 -9
  19. sglang/srt/entrypoints/http_server.py +190 -41
  20. sglang/srt/entrypoints/verl_engine.py +147 -0
  21. sglang/srt/function_call_parser.py +0 -1
  22. sglang/srt/layers/activation.py +11 -0
  23. sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
  24. sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +208 -295
  26. sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
  27. sglang/srt/layers/attention/torch_native_backend.py +1 -1
  28. sglang/srt/layers/attention/triton_backend.py +9 -6
  29. sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
  30. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
  31. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
  32. sglang/srt/layers/attention/utils.py +39 -0
  33. sglang/srt/layers/attention/vision.py +60 -63
  34. sglang/srt/layers/dp_attention.py +142 -1
  35. sglang/srt/layers/layernorm.py +1 -1
  36. sglang/srt/layers/linear.py +3 -1
  37. sglang/srt/layers/logits_processor.py +281 -45
  38. sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +140 -28
  40. sglang/srt/layers/moe/fused_moe_native.py +2 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
  61. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
  62. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
  63. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
  64. sglang/srt/layers/moe/topk.py +13 -4
  65. sglang/srt/layers/quantization/__init__.py +111 -7
  66. sglang/srt/layers/quantization/blockwise_int8.py +409 -0
  67. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  68. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  69. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  70. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  71. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  72. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  73. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  74. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  75. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  76. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  77. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  78. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  80. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  81. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  82. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  83. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  84. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  85. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  86. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  87. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  88. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  89. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  90. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  91. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  92. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  93. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  94. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  95. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  96. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  97. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  98. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  99. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  100. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  101. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  102. sglang/srt/layers/quantization/fp8.py +69 -28
  103. sglang/srt/layers/quantization/fp8_utils.py +17 -1
  104. sglang/srt/layers/quantization/gptq.py +416 -0
  105. sglang/srt/layers/quantization/int8_kernel.py +327 -0
  106. sglang/srt/layers/quantization/int8_utils.py +73 -0
  107. sglang/srt/layers/quantization/modelopt_quant.py +18 -1
  108. sglang/srt/layers/radix_attention.py +1 -0
  109. sglang/srt/layers/rotary_embedding.py +0 -1
  110. sglang/srt/layers/sampler.py +76 -31
  111. sglang/srt/layers/vocab_parallel_embedding.py +14 -13
  112. sglang/srt/lora/lora.py +17 -1
  113. sglang/srt/lora/lora_config.py +5 -0
  114. sglang/srt/lora/lora_manager.py +1 -3
  115. sglang/srt/managers/cache_controller.py +193 -62
  116. sglang/srt/managers/configure_logging.py +2 -1
  117. sglang/srt/managers/data_parallel_controller.py +6 -2
  118. sglang/srt/managers/detokenizer_manager.py +124 -102
  119. sglang/srt/managers/image_processor.py +2 -1
  120. sglang/srt/managers/io_struct.py +143 -6
  121. sglang/srt/managers/schedule_batch.py +238 -197
  122. sglang/srt/managers/schedule_policy.py +29 -29
  123. sglang/srt/managers/scheduler.py +681 -259
  124. sglang/srt/managers/session_controller.py +6 -2
  125. sglang/srt/managers/tokenizer_manager.py +224 -68
  126. sglang/srt/managers/tp_worker.py +15 -4
  127. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  128. sglang/srt/mem_cache/chunk_cache.py +18 -11
  129. sglang/srt/mem_cache/hiradix_cache.py +394 -0
  130. sglang/srt/mem_cache/memory_pool.py +44 -18
  131. sglang/srt/mem_cache/radix_cache.py +58 -47
  132. sglang/srt/metrics/collector.py +94 -36
  133. sglang/srt/model_executor/cuda_graph_runner.py +55 -24
  134. sglang/srt/model_executor/forward_batch_info.py +49 -16
  135. sglang/srt/model_executor/model_runner.py +209 -28
  136. sglang/srt/model_loader/loader.py +3 -3
  137. sglang/srt/model_loader/weight_utils.py +36 -14
  138. sglang/srt/models/baichuan.py +31 -6
  139. sglang/srt/models/chatglm.py +39 -7
  140. sglang/srt/models/commandr.py +29 -5
  141. sglang/srt/models/dbrx.py +31 -5
  142. sglang/srt/models/deepseek.py +43 -6
  143. sglang/srt/models/deepseek_nextn.py +32 -19
  144. sglang/srt/models/deepseek_v2.py +265 -29
  145. sglang/srt/models/exaone.py +19 -9
  146. sglang/srt/models/gemma.py +22 -8
  147. sglang/srt/models/gemma2.py +25 -12
  148. sglang/srt/models/gemma2_reward.py +5 -1
  149. sglang/srt/models/gpt2.py +28 -13
  150. sglang/srt/models/gpt_bigcode.py +27 -5
  151. sglang/srt/models/granite.py +21 -9
  152. sglang/srt/models/grok.py +21 -4
  153. sglang/srt/models/internlm2.py +36 -6
  154. sglang/srt/models/internlm2_reward.py +5 -1
  155. sglang/srt/models/llama.py +26 -9
  156. sglang/srt/models/llama_classification.py +5 -1
  157. sglang/srt/models/llama_eagle.py +17 -4
  158. sglang/srt/models/llama_embedding.py +5 -1
  159. sglang/srt/models/llama_reward.py +7 -2
  160. sglang/srt/models/llava.py +19 -3
  161. sglang/srt/models/llavavid.py +10 -1
  162. sglang/srt/models/minicpm.py +26 -2
  163. sglang/srt/models/minicpm3.py +39 -3
  164. sglang/srt/models/minicpmv.py +45 -14
  165. sglang/srt/models/mixtral.py +20 -9
  166. sglang/srt/models/mixtral_quant.py +50 -8
  167. sglang/srt/models/mllama.py +57 -11
  168. sglang/srt/models/olmo.py +34 -6
  169. sglang/srt/models/olmo2.py +34 -13
  170. sglang/srt/models/olmoe.py +26 -4
  171. sglang/srt/models/phi3_small.py +29 -10
  172. sglang/srt/models/qwen.py +26 -3
  173. sglang/srt/models/qwen2.py +26 -4
  174. sglang/srt/models/qwen2_5_vl.py +46 -8
  175. sglang/srt/models/qwen2_eagle.py +17 -5
  176. sglang/srt/models/qwen2_moe.py +44 -6
  177. sglang/srt/models/qwen2_rm.py +78 -0
  178. sglang/srt/models/qwen2_vl.py +39 -8
  179. sglang/srt/models/stablelm.py +32 -5
  180. sglang/srt/models/torch_native_llama.py +5 -2
  181. sglang/srt/models/xverse.py +21 -9
  182. sglang/srt/models/xverse_moe.py +45 -7
  183. sglang/srt/models/yivl.py +2 -1
  184. sglang/srt/openai_api/adapter.py +109 -24
  185. sglang/srt/openai_api/protocol.py +17 -1
  186. sglang/srt/reasoning_parser.py +154 -0
  187. sglang/srt/sampling/penaltylib/__init__.py +4 -6
  188. sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
  189. sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
  190. sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
  191. sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
  192. sglang/srt/sampling/sampling_batch_info.py +79 -157
  193. sglang/srt/sampling/sampling_params.py +16 -13
  194. sglang/srt/server_args.py +136 -52
  195. sglang/srt/speculative/build_eagle_tree.py +2 -8
  196. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
  197. sglang/srt/speculative/eagle_utils.py +92 -58
  198. sglang/srt/speculative/eagle_worker.py +186 -94
  199. sglang/srt/speculative/spec_info.py +1 -13
  200. sglang/srt/utils.py +43 -17
  201. sglang/srt/warmup.py +47 -0
  202. sglang/test/few_shot_gsm8k.py +4 -1
  203. sglang/test/runners.py +389 -126
  204. sglang/test/send_one.py +88 -0
  205. sglang/test/test_block_fp8_ep.py +361 -0
  206. sglang/test/test_programs.py +1 -1
  207. sglang/test/test_utils.py +138 -84
  208. sglang/utils.py +50 -60
  209. sglang/version.py +1 -1
  210. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
  211. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +214 -166
  212. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
  213. sglang/bench_latency.py +0 -1
  214. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
  215. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
  216. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
  217. sglang/test/srt/sampling/penaltylib/utils.py +0 -344
  218. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
  219. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -29,6 +29,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
29
29
  It contains low-level tensor data. Most of the data consists of GPU tensors.
30
30
  """
31
31
 
32
+ import copy
32
33
  import dataclasses
33
34
  import logging
34
35
  from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
@@ -43,14 +44,15 @@ from sglang.srt.configs.model_config import ModelConfig
43
44
  from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
44
45
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
45
46
  from sglang.srt.mem_cache.chunk_cache import ChunkCache
46
- from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
47
+ from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
47
48
  from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
48
49
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
49
50
  from sglang.srt.sampling.sampling_params import SamplingParams
50
51
  from sglang.srt.server_args import ServerArgs
51
52
 
52
53
  if TYPE_CHECKING:
53
- from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm
54
+ from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
55
+ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
54
56
 
55
57
  INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
56
58
 
@@ -65,7 +67,11 @@ global_server_args_dict = {
65
67
  "enable_dp_attention": ServerArgs.enable_dp_attention,
66
68
  "enable_ep_moe": ServerArgs.enable_ep_moe,
67
69
  "device": ServerArgs.device,
70
+ "speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
71
+ "speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
68
72
  "enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
73
+ "disable_radix_cache": ServerArgs.disable_radix_cache,
74
+ "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
69
75
  }
70
76
 
71
77
  logger = logging.getLogger(__name__)
@@ -228,12 +234,14 @@ class Req:
228
234
  sampling_params: SamplingParams,
229
235
  return_logprob: bool = False,
230
236
  top_logprobs_num: int = 0,
237
+ token_ids_logprob: List[int] = None,
231
238
  stream: bool = False,
232
239
  origin_input_ids_unpadded: Optional[Tuple[int]] = None,
233
240
  lora_path: Optional[str] = None,
234
241
  input_embeds: Optional[List[List[float]]] = None,
235
242
  session_id: Optional[str] = None,
236
243
  custom_logit_processor: Optional[str] = None,
244
+ return_hidden_states: bool = False,
237
245
  eos_token_ids: Optional[Set[int]] = None,
238
246
  ):
239
247
  # Input and output info
@@ -253,16 +261,27 @@ class Req:
253
261
  self.input_embeds = input_embeds
254
262
 
255
263
  # Sampling info
264
+ if isinstance(sampling_params.custom_params, dict):
265
+ sampling_params = copy.copy(sampling_params)
266
+ sampling_params.custom_params = sampling_params.custom_params | {
267
+ "__req__": self
268
+ }
256
269
  self.sampling_params = sampling_params
270
+
257
271
  self.custom_logit_processor = custom_logit_processor
272
+ self.return_hidden_states = return_hidden_states
258
273
 
259
274
  # Memory pool info
260
- self.req_pool_idx = None
275
+ self.req_pool_idx: Optional[int] = None
261
276
 
262
277
  # Check finish
263
278
  self.tokenizer = None
264
279
  self.finished_reason = None
280
+ # If we want to abort the request in the middle of the event loop, set this to true
281
+ # Note: We should never set finished_reason in the middle, the req will get filtered and never respond
265
282
  self.to_abort = False
283
+ # This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
284
+ self.to_abort_message: str = "Unknown error"
266
285
  self.stream = stream
267
286
  self.eos_token_ids = eos_token_ids
268
287
 
@@ -275,7 +294,6 @@ class Req:
275
294
  # 1: surr_offset
276
295
  # 2: read_offset
277
296
  # 3: last token
278
- self.vid = 0 # version id to sync decode status with in detokenizer_manager
279
297
  self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
280
298
  self.read_offset = None
281
299
  self.decoded_text = ""
@@ -284,47 +302,58 @@ class Req:
284
302
  self.image_inputs: Optional[ImageInputs] = None
285
303
 
286
304
  # Prefix info
305
+ # The indices to kv cache for the shared prefix.
287
306
  self.prefix_indices = []
288
- # Tokens to run prefill. input_tokens - shared_prefix_tokens.
289
- # Updated if chunked.
307
+ # Number of tokens to run prefill.
290
308
  self.extend_input_len = 0
309
+ # The relative logprob_start_len in an extend batch
310
+ self.extend_logprob_start_len = 0
291
311
  self.last_node = None
292
312
 
293
- # Chunked prefill
294
- self.is_being_chunked = 0
313
+ # Whether or not if it is chunked. It increments whenever
314
+ # it is chunked, and decrement whenever chunked request is
315
+ # processed.
316
+ self.is_chunked = 0
295
317
 
296
318
  # For retraction
297
319
  self.is_retracted = False
298
320
 
299
321
  # Logprobs (arguments)
300
322
  self.return_logprob = return_logprob
323
+ # Start index to compute logprob from.
301
324
  self.logprob_start_len = 0
302
325
  self.top_logprobs_num = top_logprobs_num
326
+ self.token_ids_logprob = token_ids_logprob
303
327
 
304
328
  # Logprobs (return values)
305
329
  self.input_token_logprobs_val: Optional[List[float]] = None
306
330
  self.input_token_logprobs_idx: Optional[List[int]] = None
307
331
  self.input_top_logprobs_val: Optional[List[float]] = None
308
332
  self.input_top_logprobs_idx: Optional[List[int]] = None
333
+ self.input_token_ids_logprobs_val: Optional[List[float]] = None
334
+ self.input_token_ids_logprobs_idx: Optional[List[int]] = None
335
+ # Temporary holder to store input_token_logprobs.
336
+ self.input_token_logprobs: Optional[List[Tuple[int]]] = None
337
+ self.temp_input_top_logprobs_val: Optional[List[torch.Tensor]] = None
338
+ self.temp_input_top_logprobs_idx: Optional[List[int]] = None
339
+ self.temp_input_token_ids_logprobs_val: Optional[List[float]] = None
340
+ self.temp_input_token_ids_logprobs_idx: Optional[List[int]] = None
309
341
 
310
342
  if return_logprob:
311
343
  self.output_token_logprobs_val = []
312
344
  self.output_token_logprobs_idx = []
313
345
  self.output_top_logprobs_val = []
314
346
  self.output_top_logprobs_idx = []
347
+ self.output_token_ids_logprobs_val = []
348
+ self.output_token_ids_logprobs_idx = []
315
349
  else:
316
350
  self.output_token_logprobs_val = self.output_token_logprobs_idx = (
317
351
  self.output_top_logprobs_val
318
- ) = self.output_top_logprobs_idx = None
352
+ ) = self.output_top_logprobs_idx = self.output_token_ids_logprobs_val = (
353
+ self.output_token_ids_logprobs_idx
354
+ ) = None
319
355
  self.hidden_states = []
320
356
 
321
- # Logprobs (internal values)
322
- # The tokens is prefilled but need to be considered as decode tokens
323
- # and should be updated for the decode logprobs
324
- self.last_update_decode_tokens = 0
325
- # The relative logprob_start_len in an extend batch
326
- self.extend_logprob_start_len = 0
327
-
328
357
  # Embedding (return values)
329
358
  self.embedding = None
330
359
 
@@ -340,6 +369,10 @@ class Req:
340
369
  self.spec_verify_ct = 0
341
370
  self.lora_path = lora_path
342
371
 
372
+ @property
373
+ def seqlen(self):
374
+ return len(self.origin_input_ids) + len(self.output_ids)
375
+
343
376
  def extend_image_inputs(self, image_inputs):
344
377
  if self.image_inputs is None:
345
378
  self.image_inputs = image_inputs
@@ -417,7 +450,9 @@ class Req:
417
450
  return
418
451
 
419
452
  if self.to_abort:
420
- self.finished_reason = FINISH_ABORT()
453
+ self.finished_reason = FINISH_ABORT(
454
+ message=self.to_abort_message,
455
+ )
421
456
  return
422
457
 
423
458
  if len(self.output_ids) >= self.sampling_params.max_new_tokens:
@@ -457,81 +492,22 @@ class Req:
457
492
  self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
458
493
  return
459
494
 
460
- def jump_forward_and_retokenize(self, jump_forward_str, next_state):
461
- if self.origin_input_text is None:
462
- # Recovering text can only use unpadded ids
463
- self.origin_input_text = self.tokenizer.decode(
464
- self.origin_input_ids_unpadded
465
- )
466
-
467
- all_text = self.origin_input_text + self.decoded_text + jump_forward_str
468
- all_ids = self.tokenizer.encode(all_text)
469
- if not all_ids:
470
- logger.warning("Encoded all_text resulted in empty all_ids")
471
- return False
472
-
473
- prompt_tokens = len(self.origin_input_ids_unpadded)
474
- if prompt_tokens > len(all_ids):
475
- logger.warning("prompt_tokens is larger than encoded all_ids")
476
- return False
477
-
478
- if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
479
- # TODO(lsyin): fix token fusion
480
- logger.warning(
481
- "Token fusion between input and output, try to avoid this by removing the space at the end of the input."
482
- )
483
- return False
484
-
485
- old_output_ids = self.output_ids
486
- self.output_ids = all_ids[prompt_tokens:]
487
- self.decoded_text = self.decoded_text + jump_forward_str
488
- self.surr_offset = prompt_tokens
489
- self.read_offset = len(all_ids)
490
-
491
- # NOTE: A trick to reduce the surrouding tokens decoding overhead
492
- for i in range(0, INIT_INCREMENTAL_DETOKENIZATION_OFFSET):
493
- surr_text_ = self.tokenizer.decode(
494
- all_ids[self.read_offset - i : self.read_offset]
495
- )
496
- if not surr_text_.endswith("�"):
497
- self.surr_offset = self.read_offset - i
498
- break
499
-
500
- # update the inner state of the grammar
501
- self.grammar.jump_and_retokenize(old_output_ids, self.output_ids, next_state)
502
-
503
- if self.return_logprob:
504
- # For fast-forward part's logprobs
505
- k = 0
506
- for i, old_id in enumerate(old_output_ids):
507
- if old_id == self.output_ids[i]:
508
- k = k + 1
509
- else:
510
- break
511
- self.output_token_logprobs_val = self.output_token_logprobs_val[:k]
512
- self.output_token_logprobs_idx = self.output_token_logprobs_idx[:k]
513
- self.output_top_logprobs_val = self.output_top_logprobs_val[:k]
514
- self.output_top_logprobs_idx = self.output_top_logprobs_idx[:k]
515
- self.logprob_start_len = prompt_tokens + k
516
- self.last_update_decode_tokens = len(self.output_ids) - k
517
-
518
- return True
519
-
520
495
  def reset_for_retract(self):
521
496
  self.prefix_indices = []
522
497
  self.last_node = None
523
498
  self.extend_input_len = 0
524
499
  self.is_retracted = True
525
-
526
- # For incremental logprobs
527
- # TODO: Fix the `logprob_start_len`
528
- self.last_update_decode_tokens = 0
529
- self.logprob_start_len = 10**9
500
+ self.input_token_logprobs = None
501
+ self.temp_input_top_logprobs_val = None
502
+ self.temp_input_top_logprobs_idx = None
503
+ self.extend_logprob_start_len = 0
504
+ self.is_chunked = 0
505
+ self.req_pool_idx = None
530
506
 
531
507
  def __repr__(self):
532
508
  return (
533
- f"rid(n={self.rid}, "
534
- f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}"
509
+ f"Req(rid={self.rid}, "
510
+ f"input_ids={self.origin_input_ids}, output_ids={self.output_ids})"
535
511
  )
536
512
 
537
513
 
@@ -545,7 +521,7 @@ class ScheduleBatch:
545
521
  # Request, memory pool, and cache
546
522
  reqs: List[Req]
547
523
  req_to_token_pool: ReqToTokenPool = None
548
- token_to_kv_pool: BaseTokenToKVPool = None
524
+ token_to_kv_pool_allocator: TokenToKVPoolAllocator = None
549
525
  tree_cache: BasePrefixCache = None
550
526
 
551
527
  # Batch configs
@@ -571,11 +547,13 @@ class ScheduleBatch:
571
547
 
572
548
  # For DP attention
573
549
  global_num_tokens: Optional[List[int]] = None
550
+ global_num_tokens_for_logprob: Optional[List[int]] = None
574
551
  can_run_dp_cuda_graph: bool = False
575
552
 
576
553
  # For processing logprobs
577
554
  return_logprob: bool = False
578
555
  top_logprobs_nums: Optional[List[int]] = None
556
+ token_ids_logprobs: Optional[List[List[int]]] = None
579
557
 
580
558
  # For extend and mixed chunekd prefill
581
559
  prefix_lens: List[int] = None
@@ -583,6 +561,8 @@ class ScheduleBatch:
583
561
  extend_num_tokens: int = None
584
562
  decoding_reqs: List[Req] = None
585
563
  extend_logprob_start_lens: List[int] = None
564
+ # It comes empty list if logprob is not required.
565
+ extend_input_logprob_token_ids: Optional[torch.Tensor] = None
586
566
 
587
567
  # For encoder-decoder
588
568
  encoder_cached: Optional[List[bool]] = None
@@ -601,12 +581,12 @@ class ScheduleBatch:
601
581
 
602
582
  # Speculative decoding
603
583
  spec_algorithm: SpeculativeAlgorithm = None
604
- spec_info: Optional[SpecInfo] = None
584
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None
605
585
 
606
586
  # Enable custom logit processor
607
587
  enable_custom_logit_processor: bool = False
608
588
 
609
- # Return hidden states
589
+ # Whether to return hidden states
610
590
  return_hidden_states: bool = False
611
591
 
612
592
  @classmethod
@@ -614,18 +594,17 @@ class ScheduleBatch:
614
594
  cls,
615
595
  reqs: List[Req],
616
596
  req_to_token_pool: ReqToTokenPool,
617
- token_to_kv_pool: ReqToTokenPool,
597
+ token_to_kv_pool_allocator: TokenToKVPoolAllocator,
618
598
  tree_cache: BasePrefixCache,
619
599
  model_config: ModelConfig,
620
600
  enable_overlap: bool,
621
601
  spec_algorithm: SpeculativeAlgorithm,
622
602
  enable_custom_logit_processor: bool,
623
- return_hidden_states: bool = False,
624
603
  ):
625
604
  return cls(
626
605
  reqs=reqs,
627
606
  req_to_token_pool=req_to_token_pool,
628
- token_to_kv_pool=token_to_kv_pool,
607
+ token_to_kv_pool_allocator=token_to_kv_pool_allocator,
629
608
  tree_cache=tree_cache,
630
609
  model_config=model_config,
631
610
  enable_overlap=enable_overlap,
@@ -635,7 +614,7 @@ class ScheduleBatch:
635
614
  device=req_to_token_pool.device,
636
615
  spec_algorithm=spec_algorithm,
637
616
  enable_custom_logit_processor=enable_custom_logit_processor,
638
- return_hidden_states=return_hidden_states,
617
+ return_hidden_states=any(req.return_hidden_states for req in reqs),
639
618
  )
640
619
 
641
620
  def batch_size(self):
@@ -648,25 +627,27 @@ class ScheduleBatch:
648
627
  req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
649
628
  if req_pool_indices is None:
650
629
  raise RuntimeError(
651
- "Out of memory. "
652
- "Please set a smaller number for `--max-running-requests`."
630
+ "alloc_req_slots runs out of memory. "
631
+ "Please set a smaller number for `--max-running-requests`. "
632
+ f"{self.req_to_token_pool.available_size()=}, "
633
+ f"{num_reqs=}, "
653
634
  )
654
635
  return req_pool_indices
655
636
 
656
637
  def alloc_token_slots(self, num_tokens: int):
657
- out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)
638
+ out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
658
639
 
659
640
  if out_cache_loc is None:
660
641
  if self.tree_cache is not None:
661
- self.tree_cache.evict(num_tokens, self.token_to_kv_pool.free)
662
- out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)
642
+ self.tree_cache.evict(num_tokens, self.token_to_kv_pool_allocator.free)
643
+ out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
663
644
 
664
645
  if out_cache_loc is None:
665
646
  phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
666
647
  logger.error(
667
648
  f"{phase_str} out of memory. Try to lower your batch size.\n"
668
649
  f"Try to allocate {num_tokens} tokens.\n"
669
- f"Avaliable tokens: {self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()}\n"
650
+ f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
670
651
  )
671
652
  if self.tree_cache is not None:
672
653
  self.tree_cache.pretty_print()
@@ -760,6 +741,7 @@ class ScheduleBatch:
760
741
  out_cache_loc = self.alloc_token_slots(extend_num_tokens)
761
742
 
762
743
  input_embeds = []
744
+ extend_input_logprob_token_ids = []
763
745
 
764
746
  pt = 0
765
747
  for i, req in enumerate(reqs):
@@ -778,22 +760,64 @@ class ScheduleBatch:
778
760
  # If req.input_embeds is already a list, append its content directly
779
761
  input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
780
762
 
781
- if req.return_logprob:
782
- # Compute the relative logprob_start_len in an extend batch
783
- if req.logprob_start_len >= pre_len:
784
- extend_logprob_start_len = min(
785
- req.logprob_start_len - pre_len, req.extend_input_len - 1
786
- )
787
- else:
788
- raise RuntimeError(
789
- f"This should never happen. {req.logprob_start_len=}, {pre_len=}"
790
- )
791
- req.extend_logprob_start_len = extend_logprob_start_len
792
-
793
763
  req.cached_tokens += pre_len - req.already_computed
794
764
  req.already_computed = seq_len
795
765
  req.is_retracted = False
796
766
  pre_lens.append(pre_len)
767
+ # Compute the relative logprob_start_len in an extend batch
768
+ if req.logprob_start_len >= pre_len:
769
+ req.extend_logprob_start_len = min(
770
+ req.logprob_start_len - pre_len,
771
+ req.extend_input_len,
772
+ req.seqlen - 1,
773
+ )
774
+ else:
775
+ req.extend_logprob_start_len = 0
776
+
777
+ if self.return_logprob:
778
+ # Find input logprob token ids.
779
+ # First, find a global index within origin_input_ids and slide it by 1
780
+ # to compute input logprobs. It is because you need the next token
781
+ # to compute input logprobs. E.g., (chunk size 2)
782
+ #
783
+ # input_logprobs = [1, 2, 3, 4]
784
+ # fill_ids = [1, 2]
785
+ # extend_input_logprob_token_id = [2, 3]
786
+ #
787
+ # Note that it can also overflow. In this case, we pad it with 0.
788
+ # input_logprobs = [1, 2, 3, 4]
789
+ # fill_ids = [3, 4]
790
+ # extend_input_logprob_token_id = [4, 0]
791
+ global_start_idx, global_end_idx = (
792
+ len(req.prefix_indices),
793
+ len(req.fill_ids),
794
+ )
795
+ # Apply logprob_start_len
796
+ if global_start_idx < req.logprob_start_len:
797
+ global_start_idx = req.logprob_start_len
798
+
799
+ logprob_token_ids = req.origin_input_ids[
800
+ global_start_idx + 1 : global_end_idx + 1
801
+ ]
802
+ extend_input_logprob_token_ids.extend(logprob_token_ids)
803
+
804
+ # We will need req.extend_input_len - req.extend_logprob_start_len number of
805
+ # tokens, and logprob_token_ids is for input logprob, so pad the rest of them by 0.
806
+ extend_input_logprob_token_ids.extend(
807
+ [0]
808
+ * (
809
+ req.extend_input_len
810
+ - req.extend_logprob_start_len
811
+ - len(logprob_token_ids)
812
+ )
813
+ )
814
+
815
+ if self.return_logprob:
816
+ extend_input_logprob_token_ids = torch.tensor(
817
+ extend_input_logprob_token_ids
818
+ )
819
+ else:
820
+ extend_input_logprob_token_ids = None
797
821
 
798
822
  # Set fields
799
823
  self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
@@ -816,10 +840,12 @@ class ScheduleBatch:
816
840
  self.seq_lens_sum = sum(seq_lens)
817
841
  if self.return_logprob:
818
842
  self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
843
+ self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
819
844
  self.extend_num_tokens = extend_num_tokens
820
845
  self.prefix_lens = [len(r.prefix_indices) for r in reqs]
821
846
  self.extend_lens = [r.extend_input_len for r in reqs]
822
847
  self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
848
+ self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
823
849
 
824
850
  # Write to req_to_token_pool
825
851
  pre_lens = torch.tensor(pre_lens, dtype=torch.int32).to(
@@ -855,7 +881,6 @@ class ScheduleBatch:
855
881
  self.sampling_info = SamplingBatchInfo.from_schedule_batch(
856
882
  self,
857
883
  self.model_config.vocab_size,
858
- enable_overlap_schedule=self.enable_overlap,
859
884
  )
860
885
 
861
886
  def mix_with_running(self, running_batch: "ScheduleBatch"):
@@ -890,41 +915,60 @@ class ScheduleBatch:
890
915
 
891
916
  def check_decode_mem(self, buf_multiplier=1):
892
917
  bs = len(self.reqs) * buf_multiplier
893
- if self.token_to_kv_pool.available_size() >= bs:
918
+ if self.token_to_kv_pool_allocator.available_size() >= bs:
894
919
  return True
895
920
 
896
- self.tree_cache.evict(bs, self.token_to_kv_pool.free)
921
+ self.tree_cache.evict(bs, self.token_to_kv_pool_allocator.free)
897
922
 
898
- if self.token_to_kv_pool.available_size() >= bs:
923
+ if self.token_to_kv_pool_allocator.available_size() >= bs:
899
924
  return True
900
925
 
901
926
  return False
902
927
 
903
- def retract_decode(self):
928
+ def retract_decode(self, server_args: ServerArgs):
904
929
  """Retract the decoding requests when there is not enough memory."""
905
930
  sorted_indices = [i for i in range(len(self.reqs))]
906
931
 
907
932
  # TODO(lsyin): improve retraction policy for radix cache
908
- sorted_indices.sort(
909
- key=lambda i: (
910
- len(self.reqs[i].output_ids),
911
- -len(self.reqs[i].origin_input_ids),
912
- ),
913
- reverse=True,
914
- )
933
+ # For spec decoding, filter_batch API can only filter
934
+ # requests from the back, so we can only retract from the back.
935
+ # TODO(sang): Clean up finish path and support better retract
936
+ # policy.
937
+ if not server_args.speculative_algorithm:
938
+ sorted_indices.sort(
939
+ key=lambda i: (
940
+ len(self.reqs[i].output_ids),
941
+ -len(self.reqs[i].origin_input_ids),
942
+ ),
943
+ reverse=True,
944
+ )
915
945
 
916
946
  retracted_reqs = []
917
947
  seq_lens_cpu = self.seq_lens.cpu().numpy()
918
948
  first_iter = True
949
+
950
+ def get_required_tokens(num_reqs: int):
951
+ headroom_for_spec_decode = 0
952
+ if server_args.speculative_algorithm:
953
+ headroom_for_spec_decode += (
954
+ num_reqs
955
+ * server_args.speculative_eagle_topk
956
+ * server_args.speculative_num_steps
957
+ + num_reqs * server_args.speculative_num_draft_tokens
958
+ )
959
+ return (
960
+ num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
961
+ )
962
+
919
963
  while (
920
- self.token_to_kv_pool.available_size()
921
- < len(sorted_indices) * global_config.retract_decode_steps
964
+ self.token_to_kv_pool_allocator.available_size()
965
+ < get_required_tokens(len(sorted_indices))
922
966
  or first_iter
923
967
  ):
924
968
  if len(sorted_indices) == 1:
925
969
  # Corner case: only one request left
926
970
  assert (
927
- self.token_to_kv_pool.available_size() > 0
971
+ self.token_to_kv_pool_allocator.available_size() > 0
928
972
  ), "No space left for only one request"
929
973
  break
930
974
 
@@ -938,7 +982,7 @@ class ScheduleBatch:
938
982
  token_indices = self.req_to_token_pool.req_to_token[
939
983
  req.req_pool_idx, : seq_lens_cpu[idx]
940
984
  ]
941
- self.token_to_kv_pool.free(token_indices)
985
+ self.token_to_kv_pool_allocator.free(token_indices)
942
986
  self.req_to_token_pool.free(req.req_pool_idx)
943
987
  del self.tree_cache.entries[req.rid]
944
988
  else:
@@ -947,7 +991,7 @@ class ScheduleBatch:
947
991
  token_indices = self.req_to_token_pool.req_to_token[
948
992
  req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
949
993
  ]
950
- self.token_to_kv_pool.free(token_indices)
994
+ self.token_to_kv_pool_allocator.free(token_indices)
951
995
  self.req_to_token_pool.free(req.req_pool_idx)
952
996
 
953
997
  # release the last node
@@ -956,10 +1000,13 @@ class ScheduleBatch:
956
1000
  # NOTE(lsyin): we should use the newly evictable memory instantly.
957
1001
  residual_size = (
958
1002
  len(sorted_indices) * global_config.retract_decode_steps
959
- - self.token_to_kv_pool.available_size()
1003
+ - self.token_to_kv_pool_allocator.available_size()
960
1004
  )
961
1005
  residual_size = max(0, residual_size)
962
- self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
1006
+ self.tree_cache.evict(
1007
+ residual_size, self.token_to_kv_pool_allocator.free
1008
+ )
1009
+
963
1010
  req.reset_for_retract()
964
1011
 
965
1012
  self.filter_batch(keep_indices=sorted_indices)
@@ -975,59 +1022,6 @@ class ScheduleBatch:
975
1022
 
976
1023
  return retracted_reqs, new_estimate_ratio
977
1024
 
978
- def check_for_jump_forward(self, pad_input_ids_func):
979
- jump_forward_reqs = []
980
- keep_indices = set(i for i in range(len(self.reqs)))
981
-
982
- for i, req in enumerate(self.reqs):
983
- if req.grammar is not None:
984
- jump_helper = req.grammar.try_jump_forward(req.tokenizer)
985
- if jump_helper:
986
- suffix_ids, _ = jump_helper
987
-
988
- # Current ids, for cache and revert
989
- cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
990
- cur_output_ids = req.output_ids
991
-
992
- req.output_ids.extend(suffix_ids)
993
- decode_res, new_text = req.get_next_inc_detokenization()
994
- if not decode_res:
995
- req.output_ids = cur_output_ids
996
- continue
997
-
998
- (
999
- jump_forward_str,
1000
- next_state,
1001
- ) = req.grammar.jump_forward_str_state(jump_helper)
1002
-
1003
- # Make the incrementally decoded text part of jump_forward_str
1004
- # so that the UTF-8 will not corrupt
1005
- jump_forward_str = new_text + jump_forward_str
1006
- if not req.jump_forward_and_retokenize(
1007
- jump_forward_str, next_state
1008
- ):
1009
- req.output_ids = cur_output_ids
1010
- continue
1011
-
1012
- # The decode status has diverged from detokenizer_manager
1013
- req.vid += 1
1014
-
1015
- # insert the old request into tree_cache
1016
- self.tree_cache.cache_finished_req(req, cur_all_ids)
1017
-
1018
- # re-applying image padding
1019
- if req.image_inputs is not None:
1020
- req.origin_input_ids = pad_input_ids_func(
1021
- req.origin_input_ids_unpadded, req.image_inputs
1022
- )
1023
-
1024
- jump_forward_reqs.append(req)
1025
- keep_indices.remove(i)
1026
-
1027
- self.filter_batch(keep_indices=list(keep_indices))
1028
-
1029
- return jump_forward_reqs
1030
-
1031
1025
  def prepare_encoder_info_decode(self):
1032
1026
  # Reset the encoder cached status
1033
1027
  self.encoder_cached = [True] * len(self.reqs)
@@ -1043,17 +1037,40 @@ class ScheduleBatch:
1043
1037
  self.sampling_info = SamplingBatchInfo.from_schedule_batch(
1044
1038
  self,
1045
1039
  self.model_config.vocab_size,
1046
- enable_overlap_schedule=self.enable_overlap,
1047
1040
  )
1048
1041
 
1049
1042
  def prepare_for_decode(self):
1050
1043
  self.forward_mode = ForwardMode.DECODE
1051
1044
  if self.spec_algorithm.is_eagle():
1045
+ # if spec decoding is used, the decode batch is prepared inside
1046
+ # `forward_batch_speculative_generation` after running draft models.
1052
1047
  return
1053
1048
 
1049
+ if self.sampling_info.penalizer_orchestrator.is_required:
1050
+ if self.enable_overlap:
1051
+ # TODO: this can be slow, optimize this.
1052
+ delayed_output_ids = torch.tensor(
1053
+ [
1054
+ (
1055
+ req.output_ids[-1]
1056
+ if len(req.output_ids)
1057
+ else req.origin_input_ids[-1]
1058
+ )
1059
+ for req in self.reqs
1060
+ ],
1061
+ dtype=torch.int64,
1062
+ device=self.device,
1063
+ )
1064
+ self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
1065
+ delayed_output_ids
1066
+ )
1067
+ else:
1068
+ self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
1069
+ self.output_ids.to(torch.int64)
1070
+ )
1071
+
1054
1072
  self.input_ids = self.output_ids
1055
1073
  self.output_ids = None
1056
- self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(self.input_ids)
1057
1074
 
1058
1075
  # Alloc mem
1059
1076
  bs = len(self.reqs)
@@ -1081,14 +1098,15 @@ class ScheduleBatch:
1081
1098
 
1082
1099
  def filter_batch(
1083
1100
  self,
1084
- being_chunked_req: Optional[Req] = None,
1101
+ chunked_req_to_exclude: Optional[Req] = None,
1085
1102
  keep_indices: Optional[List[int]] = None,
1086
1103
  ):
1087
1104
  if keep_indices is None:
1088
1105
  keep_indices = [
1089
1106
  i
1090
1107
  for i in range(len(self.reqs))
1091
- if not self.reqs[i].finished() and self.reqs[i] is not being_chunked_req
1108
+ if not self.reqs[i].finished()
1109
+ and self.reqs[i] is not chunked_req_to_exclude
1092
1110
  ]
1093
1111
 
1094
1112
  if keep_indices is None or len(keep_indices) == 0:
@@ -1100,31 +1118,34 @@ class ScheduleBatch:
1100
1118
  # No need to filter
1101
1119
  return
1102
1120
 
1121
+ keep_indices_device = torch.tensor(keep_indices, dtype=torch.int64).to(
1122
+ self.device, non_blocking=True
1123
+ )
1124
+
1103
1125
  if self.model_config.is_encoder_decoder:
1104
- self.encoder_lens = self.encoder_lens[keep_indices]
1126
+ self.encoder_lens = self.encoder_lens[keep_indices_device]
1105
1127
  self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]
1106
1128
 
1107
1129
  self.reqs = [self.reqs[i] for i in keep_indices]
1108
- new_indices = torch.tensor(keep_indices, dtype=torch.int64).to(
1109
- self.device, non_blocking=True
1110
- )
1111
- self.req_pool_indices = self.req_pool_indices[new_indices]
1112
- self.seq_lens = self.seq_lens[new_indices]
1130
+ self.req_pool_indices = self.req_pool_indices[keep_indices_device]
1131
+ self.seq_lens = self.seq_lens[keep_indices_device]
1113
1132
  self.out_cache_loc = None
1114
1133
  self.seq_lens_sum = self.seq_lens.sum().item()
1115
- self.output_ids = self.output_ids[new_indices]
1134
+ self.output_ids = self.output_ids[keep_indices_device]
1116
1135
  self.return_logprob = any(req.return_logprob for req in self.reqs)
1117
1136
  if self.return_logprob:
1118
1137
  self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
1138
+ self.token_ids_logprobs = [self.token_ids_logprobs[i] for i in keep_indices]
1119
1139
  else:
1120
1140
  self.top_logprobs_nums = None
1141
+ self.token_ids_logprobs = None
1121
1142
 
1122
1143
  self.has_stream = any(req.stream for req in self.reqs)
1123
1144
  self.has_grammar = any(req.grammar for req in self.reqs)
1124
1145
 
1125
- self.sampling_info.filter_batch(keep_indices, new_indices)
1146
+ self.sampling_info.filter_batch(keep_indices, keep_indices_device)
1126
1147
  if self.spec_info:
1127
- self.spec_info.filter_batch(new_indices)
1148
+ self.spec_info.filter_batch(keep_indices_device)
1128
1149
 
1129
1150
  def merge_batch(self, other: "ScheduleBatch"):
1130
1151
  # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
@@ -1147,23 +1168,32 @@ class ScheduleBatch:
1147
1168
  self.output_ids = torch.concat([self.output_ids, other.output_ids])
1148
1169
  if self.return_logprob and other.return_logprob:
1149
1170
  self.top_logprobs_nums.extend(other.top_logprobs_nums)
1171
+ self.token_ids_logprobs.extend(other.token_ids_logprobs)
1150
1172
  elif self.return_logprob:
1151
1173
  self.top_logprobs_nums.extend([0] * len(other.reqs))
1174
+ self.token_ids_logprobs.extend([None] * len(other.reqs))
1152
1175
  elif other.return_logprob:
1153
1176
  self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
1177
+ self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs
1154
1178
  self.reqs.extend(other.reqs)
1155
1179
 
1156
1180
  self.return_logprob |= other.return_logprob
1157
1181
  self.has_stream |= other.has_stream
1158
1182
  self.has_grammar |= other.has_grammar
1183
+ self.return_hidden_states |= other.return_hidden_states
1159
1184
 
1160
1185
  if self.spec_info:
1161
1186
  self.spec_info.merge_batch(other.spec_info)
1162
1187
 
1163
- def get_model_worker_batch(self):
1188
+ def get_model_worker_batch(self) -> ModelWorkerBatch:
1164
1189
  if self.forward_mode.is_decode_or_idle():
1190
+ if global_server_args_dict["enable_flashinfer_mla"]:
1191
+ decode_seq_lens = self.seq_lens.cpu()
1192
+ else:
1193
+ decode_seq_lens = None
1165
1194
  extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
1166
1195
  else:
1196
+ decode_seq_lens = None
1167
1197
  extend_seq_lens = self.extend_lens
1168
1198
  extend_prefix_lens = self.prefix_lens
1169
1199
  extend_logprob_start_lens = self.extend_logprob_start_lens
@@ -1186,8 +1216,11 @@ class ScheduleBatch:
1186
1216
  seq_lens_sum=self.seq_lens_sum,
1187
1217
  return_logprob=self.return_logprob,
1188
1218
  top_logprobs_nums=self.top_logprobs_nums,
1219
+ token_ids_logprobs=self.token_ids_logprobs,
1189
1220
  global_num_tokens=self.global_num_tokens,
1221
+ global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
1190
1222
  can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
1223
+ decode_seq_lens=decode_seq_lens,
1191
1224
  extend_num_tokens=self.extend_num_tokens,
1192
1225
  extend_seq_lens=extend_seq_lens,
1193
1226
  extend_prefix_lens=extend_prefix_lens,
@@ -1213,6 +1246,7 @@ class ScheduleBatch:
1213
1246
  else CaptureHiddenMode.NULL
1214
1247
  )
1215
1248
  ),
1249
+ extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
1216
1250
  )
1217
1251
 
1218
1252
  def copy(self):
@@ -1247,7 +1281,7 @@ class ModelWorkerBatch:
1247
1281
  req_pool_indices: torch.Tensor
1248
1282
  # The sequence length
1249
1283
  seq_lens: torch.Tensor
1250
- # The indices of output tokens in the token_to_kv_pool
1284
+ # The indices of output tokens in the token_to_kv_pool_allocator
1251
1285
  out_cache_loc: torch.Tensor
1252
1286
 
1253
1287
  # The sum of all sequence lengths
@@ -1256,16 +1290,22 @@ class ModelWorkerBatch:
1256
1290
  # For logprob
1257
1291
  return_logprob: bool
1258
1292
  top_logprobs_nums: Optional[List[int]]
1293
+ token_ids_logprobs: Optional[List[List[int]]]
1259
1294
 
1260
1295
  # For DP attention
1261
1296
  global_num_tokens: Optional[List[int]]
1297
+ global_num_tokens_for_logprob: Optional[List[int]]
1262
1298
  can_run_dp_cuda_graph: bool
1263
1299
 
1300
+ # For decode
1301
+ decode_seq_lens: Optional[torch.Tensor]
1302
+
1264
1303
  # For extend
1265
1304
  extend_num_tokens: Optional[int]
1266
1305
  extend_seq_lens: Optional[List[int]]
1267
1306
  extend_prefix_lens: Optional[List[int]]
1268
1307
  extend_logprob_start_lens: Optional[List[int]]
1308
+ extend_input_logprob_token_ids: Optional[torch.Tensor]
1269
1309
 
1270
1310
  # For multimodal
1271
1311
  image_inputs: Optional[List[ImageInputs]]
@@ -1287,7 +1327,8 @@ class ModelWorkerBatch:
1287
1327
 
1288
1328
  # Speculative decoding
1289
1329
  spec_algorithm: SpeculativeAlgorithm = None
1290
- spec_info: Optional[SpecInfo] = None
1330
+ spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
1331
+ # If set, the output of the batch contains the hidden states of the run.
1291
1332
  capture_hidden_mode: CaptureHiddenMode = None
1292
1333
 
1293
1334