sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import enum
4
+
3
5
  # Copyright 2023-2024 SGLang Team
4
6
  # Licensed under the Apache License, Version 2.0 (the "License");
5
7
  # you may not use this file except in compliance with the License.
@@ -34,83 +36,53 @@ TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing i
34
36
  import copy
35
37
  import dataclasses
36
38
  import logging
37
- import threading
39
+ import re
40
+ import time
38
41
  from enum import Enum, auto
39
42
  from http import HTTPStatus
40
43
  from itertools import chain
41
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
44
+ from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
42
45
 
43
46
  import numpy as np
44
47
  import torch
45
- import triton
46
- import triton.language as tl
47
48
 
48
- from sglang.global_config import global_config
49
49
  from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
50
50
  from sglang.srt.disaggregation.base import BaseKVSender
51
51
  from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
52
52
  ScheduleBatchDisaggregationDecodeMixin,
53
53
  )
54
+ from sglang.srt.disaggregation.utils import DisaggregationMode
54
55
  from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
56
+ from sglang.srt.environ import envs
55
57
  from sglang.srt.mem_cache.allocator import (
56
58
  BaseTokenToKVPoolAllocator,
57
59
  SWATokenToKVPoolAllocator,
58
60
  )
59
61
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
60
- from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
61
- from sglang.srt.mem_cache.lora_radix_cache import LoRAKey, LoRARadixCache
62
- from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
62
+ from sglang.srt.mem_cache.chunk_cache import SWAChunkCache
63
+ from sglang.srt.mem_cache.common import (
64
+ alloc_for_decode,
65
+ alloc_for_extend,
66
+ evict_from_tree_cache,
67
+ )
68
+ from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
69
+ from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
70
+ from sglang.srt.mem_cache.radix_cache import RadixKey
63
71
  from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
64
- from sglang.srt.metrics.collector import TimeStats
72
+ from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
65
73
  from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
66
74
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
67
75
  from sglang.srt.sampling.sampling_params import SamplingParams
68
- from sglang.srt.server_args import ServerArgs
69
- from sglang.srt.utils import flatten_nested_list, support_triton
76
+ from sglang.srt.server_args import ServerArgs, get_global_server_args
77
+ from sglang.srt.utils import flatten_nested_list
70
78
 
71
79
  if TYPE_CHECKING:
72
80
  from sglang.srt.configs.model_config import ModelConfig
73
- from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
74
- from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
81
+ from sglang.srt.speculative.eagle_info import EagleDraftInput
82
+ from sglang.srt.speculative.spec_info import SpecInput, SpeculativeAlgorithm
75
83
 
76
84
  INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
77
85
 
78
- GLOBAL_SERVER_ARGS_KEYS = [
79
- "attention_backend",
80
- "mm_attention_backend",
81
- "debug_tensor_dump_inject",
82
- "debug_tensor_dump_output_folder",
83
- "chunked_prefill_size",
84
- "device",
85
- "disable_chunked_prefix_cache",
86
- "disable_flashinfer_cutlass_moe_fp4_allgather",
87
- "disable_radix_cache",
88
- "enable_dp_lm_head",
89
- "flashinfer_mxfp4_moe_precision",
90
- "enable_flashinfer_allreduce_fusion",
91
- "moe_dense_tp_size",
92
- "ep_dispatch_algorithm",
93
- "ep_num_redundant_experts",
94
- "enable_nan_detection",
95
- "flashinfer_mla_disable_ragged",
96
- "max_micro_batch_size",
97
- "disable_shared_experts_fusion",
98
- "sampling_backend",
99
- "speculative_accept_threshold_single",
100
- "speculative_accept_threshold_acc",
101
- "speculative_attention_mode",
102
- "torchao_config",
103
- "triton_attention_reduce_in_fp32",
104
- "num_reserved_decode_tokens",
105
- "weight_loader_disable_mmap",
106
- "enable_multimodal",
107
- "enable_symm_mem",
108
- "enable_custom_logit_processor",
109
- "disaggregation_mode",
110
- ]
111
-
112
- # Put some global args for easy access
113
- global_server_args_dict = {k: getattr(ServerArgs, k) for k in GLOBAL_SERVER_ARGS_KEYS}
114
86
 
115
87
  logger = logging.getLogger(__name__)
116
88
 
@@ -147,6 +119,18 @@ class FINISH_MATCHED_STR(BaseFinishReason):
147
119
  }
148
120
 
149
121
 
122
+ class FINISHED_MATCHED_REGEX(BaseFinishReason):
123
+ def __init__(self, matched: str):
124
+ super().__init__()
125
+ self.matched = matched
126
+
127
+ def to_json(self):
128
+ return {
129
+ "type": "stop", # to match OpenAI API's return value
130
+ "matched": self.matched,
131
+ }
132
+
133
+
150
134
  class FINISH_LENGTH(BaseFinishReason):
151
135
  def __init__(self, length: int):
152
136
  super().__init__()
@@ -407,6 +391,23 @@ class MultimodalInputs:
407
391
  # other args would be kept intact
408
392
 
409
393
 
394
+ class RequestStage(str, enum.Enum):
395
+ # prefill
396
+ PREFILL_WAITING = "prefill_waiting"
397
+
398
+ # disaggregation prefill
399
+ PREFILL_PREPARE = "prefill_prepare"
400
+ PREFILL_BOOTSTRAP = "prefill_bootstrap"
401
+ PREFILL_FORWARD = "prefill_forward"
402
+ PREFILL_TRANSFER_KV_CACHE = "prefill_transfer_kv_cache"
403
+
404
+ # disaggregation decode
405
+ DECODE_PREPARE = "decode_prepare"
406
+ DECODE_BOOTSTRAP = "decode_bootstrap"
407
+ DECODE_WAITING = "decode_waiting"
408
+ DECODE_TRANSFERRED = "decode_transferred"
409
+
410
+
410
411
  class Req:
411
412
  """The input and output status of a request."""
412
413
 
@@ -431,8 +432,13 @@ class Req:
431
432
  bootstrap_host: Optional[str] = None,
432
433
  bootstrap_port: Optional[int] = None,
433
434
  bootstrap_room: Optional[int] = None,
435
+ disagg_mode: Optional[DisaggregationMode] = None,
434
436
  data_parallel_rank: Optional[int] = None,
435
437
  vocab_size: Optional[int] = None,
438
+ priority: Optional[int] = None,
439
+ metrics_collector: Optional[SchedulerMetricsCollector] = None,
440
+ extra_key: Optional[str] = None,
441
+ http_worker_ipc: Optional[str] = None,
436
442
  ):
437
443
  # Input and output info
438
444
  self.rid = rid
@@ -456,6 +462,9 @@ class Req:
456
462
  # The length of KV that have been removed in local attention chunked prefill
457
463
  self.evicted_seqlen_local = 0
458
464
 
465
+ # For multi-http worker
466
+ self.http_worker_ipc = http_worker_ipc
467
+
459
468
  # Sampling info
460
469
  if isinstance(sampling_params.custom_params, dict):
461
470
  sampling_params = copy.copy(sampling_params)
@@ -465,14 +474,25 @@ class Req:
465
474
  self.sampling_params = sampling_params
466
475
  self.custom_logit_processor = custom_logit_processor
467
476
  self.return_hidden_states = return_hidden_states
477
+
478
+ # extra key for classifying the request (e.g. cache_salt)
479
+ if lora_id is not None:
480
+ extra_key = (
481
+ extra_key or ""
482
+ ) + lora_id # lora_id is concatenated to the extra key
483
+
484
+ self.extra_key = extra_key
468
485
  self.lora_id = lora_id
469
486
 
470
487
  # Memory pool info
471
488
  self.req_pool_idx: Optional[int] = None
489
+ self.mamba_pool_idx: Optional[torch.Tensor] = None # shape (1)
472
490
 
473
491
  # Check finish
474
492
  self.tokenizer = None
475
493
  self.finished_reason = None
494
+ # finished position (in output_ids), used when checking stop conditions with speculative decoding
495
+ self.finished_len = None
476
496
  # Whether this request has finished output
477
497
  self.finished_output = None
478
498
  # If we want to abort the request in the middle of the event loop, set this to true
@@ -483,6 +503,7 @@ class Req:
483
503
  self.stream = stream
484
504
  self.eos_token_ids = eos_token_ids
485
505
  self.vocab_size = vocab_size
506
+ self.priority = priority
486
507
 
487
508
  # For incremental decoding
488
509
  # ----- | --------- read_ids -------|
@@ -502,7 +523,7 @@ class Req:
502
523
 
503
524
  # Prefix info
504
525
  # The indices to kv cache for the shared prefix.
505
- self.prefix_indices: torch.Tensor = []
526
+ self.prefix_indices: torch.Tensor = torch.empty((0,), dtype=torch.int64)
506
527
  # Number of tokens to run prefill.
507
528
  self.extend_input_len = 0
508
529
  # The relative logprob_start_len in an extend batch
@@ -512,6 +533,8 @@ class Req:
512
533
  self.host_hit_length = 0
513
534
  # The node to lock until for swa radix tree lock ref
514
535
  self.swa_uuid_for_lock: Optional[int] = None
536
+ # The prefix length of the last prefix matching
537
+ self.last_matched_prefix_len: int = 0
515
538
 
516
539
  # Whether or not if it is chunked. It increments whenever
517
540
  # it is chunked, and decrement whenever chunked request is
@@ -573,6 +596,8 @@ class Req:
573
596
  ) = None
574
597
  self.hidden_states: List[List[float]] = []
575
598
  self.hidden_states_tensor = None # Note: use tensor instead of list to transfer hidden_states when PD + MTP
599
+ self.output_topk_p = None
600
+ self.output_topk_index = None
576
601
 
577
602
  # Embedding (return values)
578
603
  self.embedding = None
@@ -589,11 +614,15 @@ class Req:
589
614
  # This is used to compute the average acceptance length per request.
590
615
  self.spec_verify_ct = 0
591
616
 
617
+ # The number of accepted tokens in speculative decoding for this request.
618
+ # This is used to compute the acceptance rate and average acceptance length per request.
619
+ self.spec_accepted_tokens = 0
620
+
592
621
  # For metrics
593
- self.time_stats: TimeStats = TimeStats()
622
+ self.metrics_collector = metrics_collector
623
+ self.time_stats: TimeStats = TimeStats(disagg_mode=disagg_mode)
594
624
  self.has_log_time_stats: bool = False
595
- self.queue_time_start = None
596
- self.queue_time_end = None
625
+ self.last_tic = time.monotonic()
597
626
 
598
627
  # For disaggregation
599
628
  self.bootstrap_host: str = bootstrap_host
@@ -624,7 +653,27 @@ class Req:
624
653
  @property
625
654
  def is_prefill_only(self) -> bool:
626
655
  """Check if this request is prefill-only (no token generation needed)."""
627
- return self.sampling_params.max_new_tokens == 0
656
+ # NOTE: when spec is enabled, prefill_only optimizations are disabled
657
+
658
+ spec_alg = get_global_server_args().speculative_algorithm
659
+ return self.sampling_params.max_new_tokens == 0 and spec_alg is None
660
+
661
+ @property
662
+ def output_ids_through_stop(self) -> List[int]:
663
+ """Get the output ids through the stop condition. Stop position is included."""
664
+ if self.finished_len is not None:
665
+ return self.output_ids[: self.finished_len]
666
+ return self.output_ids
667
+
668
+ def add_latency(self, stage: RequestStage):
669
+ if self.metrics_collector is None:
670
+ return
671
+
672
+ now = time.monotonic()
673
+ self.metrics_collector.observe_per_stage_req_latency(
674
+ stage.value, now - self.last_tic
675
+ )
676
+ self.last_tic = now
628
677
 
629
678
  def extend_image_inputs(self, image_inputs):
630
679
  if self.multimodal_inputs is None:
@@ -636,72 +685,163 @@ class Req:
636
685
  # Whether request reached finished condition
637
686
  return self.finished_reason is not None
638
687
 
639
- def init_next_round_input(
640
- self,
641
- tree_cache: Optional[BasePrefixCache] = None,
642
- ):
643
- self.fill_ids = self.origin_input_ids + self.output_ids
644
- if tree_cache is not None:
645
- if isinstance(tree_cache, LoRARadixCache):
646
- (
647
- self.prefix_indices,
648
- self.last_node,
649
- self.last_host_node,
650
- self.host_hit_length,
651
- ) = tree_cache.match_prefix_with_lora_id(
652
- key=LoRAKey(
653
- lora_id=self.lora_id, token_ids=self.adjust_max_prefix_ids()
654
- ),
655
- )
656
- else:
657
- (
658
- self.prefix_indices,
659
- self.last_node,
660
- self.last_host_node,
661
- self.host_hit_length,
662
- ) = tree_cache.match_prefix(
663
- key=self.adjust_max_prefix_ids(),
664
- )
665
- self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
666
-
667
- def adjust_max_prefix_ids(self):
688
+ def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
668
689
  self.fill_ids = self.origin_input_ids + self.output_ids
669
690
  input_len = len(self.fill_ids)
670
-
671
- # FIXME: To work around some bugs in logprob computation, we need to ensure each
672
- # request has at least one token. Later, we can relax this requirement and use `input_len`.
691
+ # NOTE: the matched length is at most 1 less than the input length to enable logprob computation
673
692
  max_prefix_len = input_len - 1
674
-
675
- if self.sampling_params.max_new_tokens > 0:
676
- # Need at least one token to compute logits
677
- max_prefix_len = min(max_prefix_len, input_len - 1)
678
-
679
693
  if self.return_logprob:
680
694
  max_prefix_len = min(max_prefix_len, self.logprob_start_len)
681
-
682
695
  max_prefix_len = max(max_prefix_len, 0)
683
- return self.fill_ids[:max_prefix_len]
696
+ token_ids = self.fill_ids[:max_prefix_len]
697
+
698
+ if tree_cache is not None:
699
+ (
700
+ self.prefix_indices,
701
+ self.last_node,
702
+ self.last_host_node,
703
+ self.host_hit_length,
704
+ ) = tree_cache.match_prefix(
705
+ key=RadixKey(token_ids=token_ids, extra_key=self.extra_key),
706
+ **(
707
+ {"req": self, "cow_mamba": True}
708
+ if isinstance(tree_cache, MambaRadixCache)
709
+ else {}
710
+ ),
711
+ )
712
+ self.last_matched_prefix_len = len(self.prefix_indices)
713
+ self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
684
714
 
685
715
  # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
686
716
  def init_incremental_detokenize(self):
687
717
  first_iter = self.surr_offset is None or self.read_offset is None
688
718
 
719
+ output_ids = self.output_ids_through_stop
720
+
689
721
  if first_iter:
690
722
  self.read_offset = len(self.origin_input_ids_unpadded)
691
723
  self.surr_offset = max(
692
724
  self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
693
725
  )
694
726
  self.surr_and_decode_ids = (
695
- self.origin_input_ids_unpadded[self.surr_offset :] + self.output_ids
727
+ self.origin_input_ids_unpadded[self.surr_offset :] + output_ids
696
728
  )
697
- self.cur_decode_ids_len = len(self.output_ids)
729
+ self.cur_decode_ids_len = len(output_ids)
698
730
  else:
699
- self.surr_and_decode_ids.extend(self.output_ids[self.cur_decode_ids_len :])
700
- self.cur_decode_ids_len = len(self.output_ids)
731
+ self.surr_and_decode_ids.extend(output_ids[self.cur_decode_ids_len :])
732
+ self.cur_decode_ids_len = len(output_ids)
701
733
 
702
734
  return self.surr_and_decode_ids, self.read_offset - self.surr_offset
703
735
 
704
- def check_finished(self):
736
+ def tail_str(self) -> str:
737
+ # Check stop strings and stop regex patterns together
738
+ if (
739
+ len(self.sampling_params.stop_strs) > 0
740
+ or len(self.sampling_params.stop_regex_strs) > 0
741
+ ):
742
+ max_len_tail_str = max(
743
+ self.sampling_params.stop_str_max_len + 1,
744
+ self.sampling_params.stop_regex_max_len + 1,
745
+ )
746
+
747
+ tail_len = min((max_len_tail_str + 1), len(self.output_ids))
748
+ return self.tokenizer.decode(self.output_ids[-tail_len:])
749
+
750
+ def check_match_stop_str_prefix(self) -> bool:
751
+ """
752
+ Check if the suffix of tail_str overlaps with any stop_str prefix
753
+ """
754
+ if not self.sampling_params.stop_strs:
755
+ return False
756
+
757
+ tail_str = self.tail_str()
758
+
759
+ # Early return if tail_str is empty
760
+ if not tail_str:
761
+ return False
762
+
763
+ for stop_str in self.sampling_params.stop_strs:
764
+ if not stop_str:
765
+ continue
766
+ # Check if stop_str is contained in tail_str (fastest check first)
767
+ if stop_str in tail_str:
768
+ return True
769
+
770
+ # Check if tail_str suffix matches stop_str prefix
771
+ # Only check if stop_str is not empty, it's for stream output
772
+ min_len = min(len(tail_str), len(stop_str))
773
+ for i in range(1, min_len + 1):
774
+ if tail_str[-i:] == stop_str[:i]:
775
+ return True
776
+
777
+ return False
778
+
779
+ def _check_token_based_finish(self, new_accepted_tokens: List[int]) -> bool:
780
+ if self.sampling_params.ignore_eos:
781
+ return False
782
+
783
+ # Check stop token ids
784
+ matched_eos = False
785
+
786
+ for i, token_id in enumerate(new_accepted_tokens):
787
+ if self.sampling_params.stop_token_ids:
788
+ matched_eos |= token_id in self.sampling_params.stop_token_ids
789
+ if self.eos_token_ids:
790
+ matched_eos |= token_id in self.eos_token_ids
791
+ if self.tokenizer is not None:
792
+ matched_eos |= token_id == self.tokenizer.eos_token_id
793
+ if self.tokenizer.additional_stop_token_ids:
794
+ matched_eos |= token_id in self.tokenizer.additional_stop_token_ids
795
+ if matched_eos:
796
+ self.finished_reason = FINISH_MATCHED_TOKEN(matched=token_id)
797
+ matched_pos = len(self.output_ids) - len(new_accepted_tokens) + i
798
+ self.finished_len = matched_pos + 1
799
+ return True
800
+
801
+ return False
802
+
803
+ def _check_str_based_finish(self):
804
+ if (
805
+ len(self.sampling_params.stop_strs) > 0
806
+ or len(self.sampling_params.stop_regex_strs) > 0
807
+ ):
808
+ tail_str = self.tail_str()
809
+
810
+ # Check stop strings
811
+ if len(self.sampling_params.stop_strs) > 0:
812
+ for stop_str in self.sampling_params.stop_strs:
813
+ if stop_str in tail_str or stop_str in self.decoded_text:
814
+ self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
815
+ return True
816
+
817
+ # Check stop regex
818
+ if len(self.sampling_params.stop_regex_strs) > 0:
819
+ for stop_regex_str in self.sampling_params.stop_regex_strs:
820
+ if re.search(stop_regex_str, tail_str):
821
+ self.finished_reason = FINISHED_MATCHED_REGEX(
822
+ matched=stop_regex_str
823
+ )
824
+ return True
825
+
826
+ return False
827
+
828
+ def _check_vocab_boundary_finish(self, new_accepted_tokens: List[int] = None):
829
+ for i, token_id in enumerate(new_accepted_tokens):
830
+ if token_id > self.vocab_size or token_id < 0:
831
+ offset = len(self.output_ids) - len(new_accepted_tokens) + i
832
+ if self.sampling_params.stop_token_ids:
833
+ self.output_ids[offset] = next(
834
+ iter(self.sampling_params.stop_token_ids)
835
+ )
836
+ if self.eos_token_ids:
837
+ self.output_ids[offset] = next(iter(self.eos_token_ids))
838
+ self.finished_reason = FINISH_MATCHED_STR(matched="NaN happened")
839
+ self.finished_len = offset + 1
840
+ return True
841
+
842
+ return False
843
+
844
+ def check_finished(self, new_accepted_len: int = 1):
705
845
  if self.finished():
706
846
  return
707
847
 
@@ -715,6 +855,7 @@ class Req:
715
855
  self.finished_reason = FINISH_LENGTH(
716
856
  length=self.sampling_params.max_new_tokens
717
857
  )
858
+ self.finished_len = self.sampling_params.max_new_tokens
718
859
  return
719
860
 
720
861
  if self.grammar is not None:
@@ -722,47 +863,19 @@ class Req:
722
863
  self.finished_reason = FINISH_MATCHED_TOKEN(matched=self.output_ids[-1])
723
864
  return
724
865
 
725
- last_token_id = self.output_ids[-1]
726
-
727
- if not self.sampling_params.ignore_eos:
728
- matched_eos = False
866
+ new_accepted_tokens = self.output_ids[-new_accepted_len:]
729
867
 
730
- # Check stop token ids
731
- if self.sampling_params.stop_token_ids:
732
- matched_eos = last_token_id in self.sampling_params.stop_token_ids
733
- if self.eos_token_ids:
734
- matched_eos |= last_token_id in self.eos_token_ids
735
- if self.tokenizer is not None:
736
- matched_eos |= last_token_id == self.tokenizer.eos_token_id
737
- if self.tokenizer.additional_stop_token_ids:
738
- matched_eos |= (
739
- last_token_id in self.tokenizer.additional_stop_token_ids
740
- )
741
- if matched_eos:
742
- self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
743
- return
744
-
745
- if last_token_id > self.vocab_size or last_token_id < 0:
746
- if self.sampling_params.stop_token_ids:
747
- self.output_ids[-1] = next(iter(self.sampling_params.stop_token_ids))
748
- if self.eos_token_ids:
749
- self.output_ids[-1] = next(iter(self.eos_token_ids))
750
- self.finished_reason = FINISH_MATCHED_STR(matched="NaN happened")
868
+ if self._check_token_based_finish(new_accepted_tokens):
751
869
  return
752
870
 
753
- # Check stop strings
754
- if len(self.sampling_params.stop_strs) > 0:
755
- tail_str = self.tokenizer.decode(
756
- self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
757
- )
871
+ if self._check_vocab_boundary_finish(new_accepted_tokens):
872
+ return
758
873
 
759
- for stop_str in self.sampling_params.stop_strs:
760
- if stop_str in tail_str or stop_str in self.decoded_text:
761
- self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
762
- return
874
+ if self._check_str_based_finish():
875
+ return
763
876
 
764
877
  def reset_for_retract(self):
765
- self.prefix_indices = []
878
+ self.prefix_indices = torch.empty((0,), dtype=torch.int64)
766
879
  self.last_node = None
767
880
  self.swa_uuid_for_lock = None
768
881
  self.extend_input_len = 0
@@ -772,7 +885,7 @@ class Req:
772
885
  self.temp_input_top_logprobs_idx = None
773
886
  self.extend_logprob_start_len = 0
774
887
  self.is_chunked = 0
775
- self.req_pool_idx = None
888
+ self.mamba_pool_idx = None
776
889
  self.already_computed = 0
777
890
 
778
891
  def offload_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
@@ -794,10 +907,10 @@ class Req:
794
907
  return
795
908
 
796
909
  if self.bootstrap_room is not None:
797
- prefix = f"Req Time Stats(rid={self.rid}, bootstrap_room={self.bootstrap_room}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
910
+ prefix = f"Req Time Stats(rid={self.rid}, bootstrap_room={self.bootstrap_room}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.disagg_mode_str()})"
798
911
  else:
799
- prefix = f"Req Time Stats(rid={self.rid}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
800
- logger.info(f"{prefix}: {self.time_stats}")
912
+ prefix = f"Req Time Stats(rid={self.rid}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.disagg_mode_str()})"
913
+ logger.info(f"{prefix}: {self.time_stats.convert_to_duration()}")
801
914
  self.has_log_time_stats = True
802
915
 
803
916
  def set_finish_with_abort(self, error_msg: str):
@@ -820,10 +933,6 @@ class Req:
820
933
  )
821
934
 
822
935
 
823
- # Batch id
824
- bid = 0
825
-
826
-
827
936
  @dataclasses.dataclass
828
937
  class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
829
938
  """Store all information of a batch on the scheduler."""
@@ -844,15 +953,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
844
953
  # This is an optimization to reduce the overhead of the prefill check.
845
954
  batch_is_full: bool = False
846
955
 
847
- # Events
848
- launch_done: Optional[threading.Event] = None
849
-
850
956
  # For chunked prefill in PP
851
957
  chunked_req: Optional[Req] = None
852
958
 
853
959
  # Sampling info
854
960
  sampling_info: SamplingBatchInfo = None
855
- next_batch_sampling_info: SamplingBatchInfo = None
856
961
 
857
962
  # Batched arguments to model runner
858
963
  input_ids: torch.Tensor = None # shape: [b], int64
@@ -860,6 +965,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
860
965
  token_type_ids: torch.Tensor = None # shape: [b], int64
861
966
  req_pool_indices: torch.Tensor = None # shape: [b], int64
862
967
  seq_lens: torch.Tensor = None # shape: [b], int64
968
+ seq_lens_cpu: torch.Tensor = None # shape: [b], int64
863
969
  # The output locations of the KV cache
864
970
  out_cache_loc: torch.Tensor = None # shape: [b], int64
865
971
  output_ids: torch.Tensor = None # shape: [b], int64
@@ -915,7 +1021,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
915
1021
 
916
1022
  # Speculative decoding
917
1023
  spec_algorithm: SpeculativeAlgorithm = None
918
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None
1024
+ # spec_info: Optional[SpecInput] = None
1025
+ spec_info: Optional[SpecInput] = None
919
1026
 
920
1027
  # Whether to return hidden states
921
1028
  return_hidden_states: bool = False
@@ -973,107 +1080,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
973
1080
  def is_empty(self):
974
1081
  return len(self.reqs) == 0
975
1082
 
976
- def alloc_req_slots(self, num_reqs: int, reqs: Optional[List[Req]] = None):
977
- if isinstance(self.req_to_token_pool, HybridReqToTokenPool):
978
- req_pool_indices = self.req_to_token_pool.alloc(num_reqs, reqs)
979
- else:
980
- req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
981
- if req_pool_indices is None:
982
- raise RuntimeError(
983
- "alloc_req_slots runs out of memory. "
984
- "Please set a smaller number for `--max-running-requests`. "
985
- f"{self.req_to_token_pool.available_size()=}, "
986
- f"{num_reqs=}, "
987
- )
988
- return req_pool_indices
989
-
990
- def alloc_token_slots(self, num_tokens: int, backup_state: bool = False):
991
- self._evict_tree_cache_if_needed(num_tokens)
992
-
993
- if backup_state:
994
- state = self.token_to_kv_pool_allocator.backup_state()
995
-
996
- out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
997
- if out_cache_loc is None:
998
- phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
999
- error_msg = (
1000
- f"{phase_str} out of memory. Try to lower your batch size.\n"
1001
- f"Try to allocate {num_tokens} tokens.\n"
1002
- f"{self._available_and_evictable_str()}"
1003
- )
1004
- logger.error(error_msg)
1005
- if self.tree_cache is not None:
1006
- self.tree_cache.pretty_print()
1007
- raise RuntimeError(error_msg)
1008
-
1009
- if backup_state:
1010
- return out_cache_loc, state
1011
- else:
1012
- return out_cache_loc
1013
-
1014
- def alloc_paged_token_slots_extend(
1015
- self,
1016
- prefix_lens: torch.Tensor,
1017
- seq_lens: torch.Tensor,
1018
- last_loc: torch.Tensor,
1019
- extend_num_tokens: int,
1020
- backup_state: bool = False,
1021
- ):
1022
- # Over estimate the number of tokens: assume each request needs a new page.
1023
- num_tokens = (
1024
- extend_num_tokens
1025
- + len(seq_lens) * self.token_to_kv_pool_allocator.page_size
1026
- )
1027
- self._evict_tree_cache_if_needed(num_tokens)
1028
-
1029
- if backup_state:
1030
- state = self.token_to_kv_pool_allocator.backup_state()
1031
-
1032
- out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend(
1033
- prefix_lens, seq_lens, last_loc, extend_num_tokens
1034
- )
1035
- if out_cache_loc is None:
1036
- error_msg = (
1037
- f"Prefill out of memory. Try to lower your batch size.\n"
1038
- f"Try to allocate {extend_num_tokens} tokens.\n"
1039
- f"{self._available_and_evictable_str()}"
1040
- )
1041
- logger.error(error_msg)
1042
- raise RuntimeError(error_msg)
1043
-
1044
- if backup_state:
1045
- return out_cache_loc, state
1046
- else:
1047
- return out_cache_loc
1048
-
1049
- def alloc_paged_token_slots_decode(
1050
- self,
1051
- seq_lens: torch.Tensor,
1052
- last_loc: torch.Tensor,
1053
- backup_state: bool = False,
1054
- ):
1055
- # Over estimate the number of tokens: assume each request needs a new page.
1056
- num_tokens = len(seq_lens) * self.token_to_kv_pool_allocator.page_size
1057
- self._evict_tree_cache_if_needed(num_tokens)
1058
-
1059
- if backup_state:
1060
- state = self.token_to_kv_pool_allocator.backup_state()
1061
-
1062
- out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc)
1063
- if out_cache_loc is None:
1064
- error_msg = (
1065
- f"Decode out of memory. Try to lower your batch size.\n"
1066
- f"Try to allocate {len(seq_lens)} tokens.\n"
1067
- f"{self._available_and_evictable_str()}"
1068
- )
1069
- logger.error(error_msg)
1070
- raise RuntimeError(error_msg)
1071
-
1072
- if backup_state:
1073
- return out_cache_loc, state
1074
- else:
1075
- return out_cache_loc
1076
-
1077
1083
  def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
1078
1084
  self.encoder_lens_cpu = []
1079
1085
  self.encoder_cached = []
@@ -1128,6 +1134,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1128
1134
  self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
1129
1135
  self.device, non_blocking=True
1130
1136
  )
1137
+ self.seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64)
1131
1138
 
1132
1139
  if not decoder_out_cache_loc:
1133
1140
  self.out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
@@ -1150,10 +1157,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1150
1157
  def prepare_for_extend(self):
1151
1158
  self.forward_mode = ForwardMode.EXTEND
1152
1159
 
1153
- # Allocate req slots
1154
- bs = len(self.reqs)
1155
- req_pool_indices = self.alloc_req_slots(bs, self.reqs)
1156
-
1157
1160
  # Init tensors
1158
1161
  reqs = self.reqs
1159
1162
  input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
@@ -1167,21 +1170,16 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1167
1170
  r.token_type_ids for r in reqs if r.token_type_ids is not None
1168
1171
  ]
1169
1172
 
1170
- req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
1171
- self.device, non_blocking=True
1172
- )
1173
1173
  input_ids_tensor = torch.tensor(
1174
1174
  list(chain.from_iterable(input_ids)), dtype=torch.int64
1175
1175
  ).to(self.device, non_blocking=True)
1176
1176
  seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
1177
1177
  self.device, non_blocking=True
1178
1178
  )
1179
+ seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64)
1179
1180
  orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to(
1180
1181
  self.device, non_blocking=True
1181
1182
  )
1182
- prefix_lens_tensor = torch.tensor(
1183
- prefix_lens, dtype=torch.int64, device=self.device
1184
- )
1185
1183
 
1186
1184
  token_type_ids_tensor = None
1187
1185
  if len(token_type_ids) > 0:
@@ -1189,9 +1187,19 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1189
1187
  sum(token_type_ids, []), dtype=torch.int64
1190
1188
  ).to(self.device, non_blocking=True)
1191
1189
 
1192
- extend_lens_tensor = seq_lens_tensor - prefix_lens_tensor
1190
+ # Set batch fields needed by alloc_for_extend
1191
+ self.prefix_lens = prefix_lens
1192
+ self.extend_lens = extend_lens
1193
+ self.seq_lens = seq_lens_tensor
1194
+ self.seq_lens_cpu = seq_lens_cpu
1195
+ self.extend_num_tokens = extend_num_tokens
1193
1196
 
1194
- # Copy prefix and do some basic check
1197
+ # Allocate memory
1198
+ out_cache_loc, req_pool_indices_tensor, req_pool_indices = alloc_for_extend(
1199
+ self
1200
+ )
1201
+
1202
+ # Set fields
1195
1203
  input_embeds = []
1196
1204
  extend_input_logprob_token_ids = []
1197
1205
  multimodal_inputs = []
@@ -1200,15 +1208,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1200
1208
  req.req_pool_idx = req_pool_indices[i]
1201
1209
  assert seq_len - pre_len == req.extend_input_len
1202
1210
 
1203
- if pre_len > 0:
1204
- self.req_to_token_pool.write(
1205
- (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
1206
- )
1207
- if isinstance(self.tree_cache, SWAChunkCache):
1208
- self.tree_cache.evict_swa(
1209
- req, pre_len, self.model_config.attention_chunk_size
1210
- )
1211
-
1212
1211
  # If input_embeds are available, store them
1213
1212
  if req.input_embeds is not None:
1214
1213
  # If req.input_embeds is already a list, append its content directly
@@ -1298,23 +1297,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1298
1297
  else:
1299
1298
  extend_input_logprob_token_ids = None
1300
1299
 
1301
- # Allocate memory
1302
- if self.token_to_kv_pool_allocator.page_size == 1:
1303
- out_cache_loc = self.alloc_token_slots(extend_num_tokens)
1304
- else:
1305
- last_loc = get_last_loc(
1306
- self.req_to_token_pool.req_to_token,
1307
- req_pool_indices_tensor,
1308
- prefix_lens_tensor,
1309
- )
1310
- out_cache_loc = self.alloc_paged_token_slots_extend(
1311
- prefix_lens_tensor, seq_lens_tensor, last_loc, extend_num_tokens
1312
- )
1313
-
1314
- # Set fields
1315
1300
  self.input_ids = input_ids_tensor
1316
1301
  self.req_pool_indices = req_pool_indices_tensor
1317
- self.seq_lens = seq_lens_tensor
1318
1302
  self.orig_seq_lens = orig_seq_lens_tensor
1319
1303
  self.out_cache_loc = out_cache_loc
1320
1304
  self.input_embeds = (
@@ -1338,33 +1322,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1338
1322
  self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
1339
1323
 
1340
1324
  self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
1341
- self.extend_num_tokens = extend_num_tokens
1342
- self.prefix_lens = prefix_lens
1343
- self.extend_lens = extend_lens
1344
1325
  self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
1345
1326
 
1346
- # Write to req_to_token_pool
1347
- if support_triton(global_server_args_dict.get("attention_backend")):
1348
- # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
1349
-
1350
- write_req_to_token_pool_triton[(bs,)](
1351
- self.req_to_token_pool.req_to_token,
1352
- req_pool_indices_tensor,
1353
- prefix_lens_tensor,
1354
- seq_lens_tensor,
1355
- extend_lens_tensor,
1356
- out_cache_loc,
1357
- self.req_to_token_pool.req_to_token.shape[1],
1358
- )
1359
- else:
1360
- pt = 0
1361
- for i in range(bs):
1362
- self.req_to_token_pool.write(
1363
- (req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])),
1364
- out_cache_loc[pt : pt + extend_lens[i]],
1365
- )
1366
- pt += extend_lens[i]
1367
-
1368
1327
  if self.model_config.is_encoder_decoder:
1369
1328
  self.prepare_encoder_info_extend(input_ids, seq_lens)
1370
1329
 
@@ -1435,7 +1394,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1435
1394
  * self.token_to_kv_pool_allocator.page_size
1436
1395
  )
1437
1396
 
1438
- self._evict_tree_cache_if_needed(num_tokens)
1397
+ evict_from_tree_cache(self.tree_cache, num_tokens)
1439
1398
  return self._is_available_size_sufficient(num_tokens)
1440
1399
 
1441
1400
  def retract_decode(self, server_args: ServerArgs):
@@ -1457,7 +1416,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1457
1416
  )
1458
1417
 
1459
1418
  retracted_reqs = []
1460
- seq_lens_cpu = self.seq_lens.cpu().numpy()
1461
1419
  first_iter = True
1462
1420
  while first_iter or (
1463
1421
  not self.check_decode_mem(selected_indices=sorted_indices)
@@ -1484,37 +1442,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1484
1442
  idx = sorted_indices.pop()
1485
1443
  req = self.reqs[idx]
1486
1444
  retracted_reqs.append(req)
1487
-
1488
- if server_args.disaggregation_mode == "decode":
1489
- req.offload_kv_cache(
1490
- self.req_to_token_pool, self.token_to_kv_pool_allocator
1491
- )
1492
-
1493
- if isinstance(self.tree_cache, ChunkCache):
1494
- # ChunkCache does not have eviction
1495
- token_indices = self.req_to_token_pool.req_to_token[
1496
- req.req_pool_idx, : seq_lens_cpu[idx]
1497
- ]
1498
- self.token_to_kv_pool_allocator.free(token_indices)
1499
- self.req_to_token_pool.free(req.req_pool_idx)
1500
- else:
1501
- # TODO: apply more fine-grained retraction
1502
- last_uncached_pos = (
1503
- len(req.prefix_indices) // server_args.page_size
1504
- ) * server_args.page_size
1505
- token_indices = self.req_to_token_pool.req_to_token[
1506
- req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
1507
- ]
1508
- self.token_to_kv_pool_allocator.free(token_indices)
1509
- self.req_to_token_pool.free(req.req_pool_idx)
1510
-
1511
- # release the last node
1512
- if self.is_hybrid:
1513
- self.tree_cache.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
1514
- else:
1515
- self.tree_cache.dec_lock_ref(req.last_node)
1516
-
1517
- req.reset_for_retract()
1445
+ # release memory and don't insert into the tree because we need the space instantly
1446
+ self.release_req(idx, len(sorted_indices), server_args)
1518
1447
 
1519
1448
  if len(retracted_reqs) == 0:
1520
1449
  # Corner case: only one request left
@@ -1529,11 +1458,29 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1529
1458
  total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)
1530
1459
 
1531
1460
  new_estimate_ratio = (
1532
- total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
1533
- ) / total_max_new_tokens
1461
+ total_decoded_tokens
1462
+ + envs.SGLANG_RETRACT_DECODE_STEPS.get() * len(self.reqs)
1463
+ ) / (
1464
+ total_max_new_tokens + 1
1465
+ ) # avoid zero division
1534
1466
  new_estimate_ratio = min(1.0, new_estimate_ratio)
1535
1467
 
1536
- return retracted_reqs, new_estimate_ratio
1468
+ return retracted_reqs, new_estimate_ratio, []
1469
+
1470
+ def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs):
1471
+ req = self.reqs[idx]
1472
+
1473
+ if server_args.disaggregation_mode == "decode":
1474
+ req.offload_kv_cache(
1475
+ self.req_to_token_pool, self.token_to_kv_pool_allocator
1476
+ )
1477
+ # TODO (csy): for preempted requests, we may want to insert into the tree
1478
+ self.tree_cache.cache_finished_req(req, is_insert=False)
1479
+ # NOTE(lsyin): we should use the newly evictable memory instantly.
1480
+ num_tokens = remaing_req_count * envs.SGLANG_RETRACT_DECODE_STEPS.get()
1481
+ evict_from_tree_cache(self.tree_cache, num_tokens)
1482
+
1483
+ req.reset_for_retract()
1537
1484
 
1538
1485
  def prepare_encoder_info_decode(self):
1539
1486
  # Reset the encoder cached status
@@ -1543,6 +1490,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1543
1490
  self.forward_mode = ForwardMode.IDLE
1544
1491
  self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device)
1545
1492
  self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
1493
+ self.seq_lens_cpu = torch.empty(0, dtype=torch.int64)
1546
1494
  self.orig_seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
1547
1495
  self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device)
1548
1496
  self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
@@ -1553,11 +1501,21 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1553
1501
  self.model_config.vocab_size,
1554
1502
  )
1555
1503
 
1504
+ @property
1505
+ def is_v2_eagle(self):
1506
+ # FIXME: finally deprecate is_v2_eagle
1507
+ return self.enable_overlap and self.spec_algorithm.is_eagle()
1508
+
1556
1509
  def prepare_for_decode(self):
1557
1510
  self.forward_mode = ForwardMode.DECODE
1558
1511
  bs = len(self.reqs)
1559
1512
 
1560
- if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone():
1513
+ if self.is_v2_eagle:
1514
+ # TODO(spec-v2): all v2 spec should go through this path
1515
+ draft_input: EagleDraftInput = self.spec_info
1516
+ draft_input.prepare_for_decode(self)
1517
+
1518
+ if not self.spec_algorithm.is_none():
1561
1519
  # if spec decoding is used, the decode batch is prepared inside
1562
1520
  # `forward_batch_speculative_generation` after running draft models.
1563
1521
  return
@@ -1590,48 +1548,39 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1590
1548
  self.output_ids = None
1591
1549
 
1592
1550
  if self.model_config.is_encoder_decoder:
1593
- locs = self.encoder_lens + self.seq_lens
1594
1551
  self.prepare_encoder_info_decode()
1595
- else:
1596
- locs = self.seq_lens.clone()
1597
1552
 
1553
+ # Allocate memory
1554
+ self.out_cache_loc = alloc_for_decode(self, token_per_req=1)
1555
+
1556
+ # Update seq_lens after allocation
1598
1557
  if self.enable_overlap:
1599
1558
  # Do not use in-place operations in the overlap mode
1600
1559
  self.seq_lens = self.seq_lens + 1
1560
+ self.seq_lens_cpu = self.seq_lens_cpu + 1
1601
1561
  self.orig_seq_lens = self.orig_seq_lens + 1
1602
1562
  else:
1603
1563
  # A faster in-place version
1604
1564
  self.seq_lens.add_(1)
1565
+ self.seq_lens_cpu.add_(1)
1605
1566
  self.orig_seq_lens.add_(1)
1606
1567
  self.seq_lens_sum += bs
1607
1568
 
1608
- # free memory
1609
- if isinstance(self.tree_cache, SWAChunkCache):
1610
- for req in self.reqs:
1611
- self.tree_cache.evict_swa(
1612
- req, req.seqlen - 1, self.model_config.attention_chunk_size
1613
- )
1614
-
1615
- # Allocate memory
1616
- if self.token_to_kv_pool_allocator.page_size == 1:
1617
- self.out_cache_loc = self.alloc_token_slots(bs)
1618
- else:
1619
- last_loc = self.req_to_token_pool.req_to_token[
1620
- self.req_pool_indices, self.seq_lens - 2
1621
- ]
1622
- self.out_cache_loc = self.alloc_paged_token_slots_decode(
1623
- self.seq_lens, last_loc
1624
- )
1625
-
1626
- self.req_to_token_pool.write(
1627
- (self.req_pool_indices, locs), self.out_cache_loc.to(torch.int32)
1628
- )
1569
+ def maybe_wait_verify_done(self):
1570
+ if self.is_v2_eagle:
1571
+ draft_input: EagleDraftInput = self.spec_info
1572
+ if draft_input.verify_done is not None:
1573
+ draft_input.verify_done.synchronize()
1629
1574
 
1630
1575
  def filter_batch(
1631
1576
  self,
1632
1577
  chunked_req_to_exclude: Optional[Union[Req, List[Req]]] = None,
1633
1578
  keep_indices: Optional[List[int]] = None,
1634
1579
  ):
1580
+ # FIXME(lsyin): used here to get the correct seq_lens
1581
+ # The batch has been launched but we need it verified to get correct next batch info
1582
+ self.maybe_wait_verify_done()
1583
+
1635
1584
  if keep_indices is None:
1636
1585
  if isinstance(chunked_req_to_exclude, Req):
1637
1586
  chunked_req_to_exclude = [chunked_req_to_exclude]
@@ -1666,6 +1615,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1666
1615
  self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
1667
1616
  self.req_pool_indices = self.req_pool_indices[keep_indices_device]
1668
1617
  self.seq_lens = self.seq_lens[keep_indices_device]
1618
+ self.seq_lens_cpu = self.seq_lens_cpu[keep_indices]
1669
1619
  self.orig_seq_lens = self.orig_seq_lens[keep_indices_device]
1670
1620
  self.out_cache_loc = None
1671
1621
  self.seq_lens_sum = self.seq_lens.sum().item()
@@ -1683,9 +1633,20 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1683
1633
 
1684
1634
  self.sampling_info.filter_batch(keep_indices, keep_indices_device)
1685
1635
  if self.spec_info:
1686
- self.spec_info.filter_batch(keep_indices_device)
1636
+ if chunked_req_to_exclude is not None and len(chunked_req_to_exclude) > 0:
1637
+ has_been_filtered = False
1638
+ else:
1639
+ has_been_filtered = True
1640
+ self.spec_info.filter_batch(
1641
+ new_indices=keep_indices_device,
1642
+ has_been_filtered=has_been_filtered,
1643
+ )
1687
1644
 
1688
1645
  def merge_batch(self, other: "ScheduleBatch"):
1646
+ # NOTE: in v2 eagle mode, we do not need wait verify here because
1647
+ # 1) current batch is always prefill, whose seq_lens and allocate_lens are not a future
1648
+ # 2) other batch is always decode, which is finished in previous step
1649
+
1689
1650
  # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
1690
1651
  # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
1691
1652
  # needs to be called with pre-merged Batch.reqs.
@@ -1699,6 +1660,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1699
1660
  [self.req_pool_indices, other.req_pool_indices]
1700
1661
  )
1701
1662
  self.seq_lens = torch.cat([self.seq_lens, other.seq_lens])
1663
+ self.seq_lens_cpu = torch.cat([self.seq_lens_cpu, other.seq_lens_cpu])
1702
1664
  self.orig_seq_lens = torch.cat([self.orig_seq_lens, other.orig_seq_lens])
1703
1665
  self.out_cache_loc = None
1704
1666
  self.seq_lens_sum += other.seq_lens_sum
@@ -1742,15 +1704,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1742
1704
  self.sampling_info.grammars = None
1743
1705
 
1744
1706
  seq_lens_cpu = (
1745
- seq_lens_cpu_cache
1746
- if seq_lens_cpu_cache is not None
1747
- else self.seq_lens.cpu()
1707
+ seq_lens_cpu_cache if seq_lens_cpu_cache is not None else self.seq_lens_cpu
1748
1708
  )
1749
1709
 
1750
- global bid
1751
- bid += 1
1752
1710
  return ModelWorkerBatch(
1753
- bid=bid,
1754
1711
  forward_mode=self.forward_mode,
1755
1712
  input_ids=self.input_ids,
1756
1713
  req_pool_indices=self.req_pool_indices,
@@ -1796,7 +1753,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1796
1753
  )
1797
1754
  ),
1798
1755
  extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
1799
- launch_done=self.launch_done,
1800
1756
  is_prefill_only=self.is_prefill_only,
1801
1757
  )
1802
1758
 
@@ -1804,6 +1760,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1804
1760
  # Only contain fields that will be used by process_batch_result
1805
1761
  return ScheduleBatch(
1806
1762
  reqs=self.reqs,
1763
+ req_to_token_pool=self.req_to_token_pool,
1764
+ req_pool_indices=self.req_pool_indices,
1807
1765
  model_config=self.model_config,
1808
1766
  forward_mode=self.forward_mode,
1809
1767
  out_cache_loc=self.out_cache_loc,
@@ -1815,26 +1773,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1815
1773
  can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
1816
1774
  is_extend_in_batch=self.is_extend_in_batch,
1817
1775
  is_prefill_only=self.is_prefill_only,
1776
+ seq_lens_cpu=self.seq_lens_cpu,
1777
+ enable_overlap=self.enable_overlap,
1818
1778
  )
1819
1779
 
1820
- def _evict_tree_cache_if_needed(self, num_tokens: int):
1821
- if isinstance(self.tree_cache, (SWAChunkCache, ChunkCache)):
1822
- return
1823
-
1824
- if self.is_hybrid:
1825
- full_available_size = self.token_to_kv_pool_allocator.full_available_size()
1826
- swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()
1827
-
1828
- if full_available_size < num_tokens or swa_available_size < num_tokens:
1829
- if self.tree_cache is not None:
1830
- full_num_tokens = max(0, num_tokens - full_available_size)
1831
- swa_num_tokens = max(0, num_tokens - swa_available_size)
1832
- self.tree_cache.evict(full_num_tokens, swa_num_tokens)
1833
- else:
1834
- if self.token_to_kv_pool_allocator.available_size() < num_tokens:
1835
- if self.tree_cache is not None:
1836
- self.tree_cache.evict(num_tokens)
1837
-
1838
1780
  def _is_available_size_sufficient(self, num_tokens: int) -> bool:
1839
1781
  if self.is_hybrid:
1840
1782
  return (
@@ -1844,23 +1786,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1844
1786
  else:
1845
1787
  return self.token_to_kv_pool_allocator.available_size() >= num_tokens
1846
1788
 
1847
- def _available_and_evictable_str(self) -> str:
1848
- if self.is_hybrid:
1849
- full_available_size = self.token_to_kv_pool_allocator.full_available_size()
1850
- swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()
1851
- full_evictable_size = self.tree_cache.full_evictable_size()
1852
- swa_evictable_size = self.tree_cache.swa_evictable_size()
1853
- return (
1854
- f"Available full tokens: {full_available_size + full_evictable_size} ({full_available_size=} + {full_evictable_size=})\n"
1855
- f"Available swa tokens: {swa_available_size + swa_evictable_size} ({swa_available_size=} + {swa_evictable_size=})\n"
1856
- f"Full LRU list evictable size: {self.tree_cache.full_lru_list_evictable_size()}\n"
1857
- f"SWA LRU list evictable size: {self.tree_cache.swa_lru_list_evictable_size()}\n"
1858
- )
1859
- else:
1860
- available_size = self.token_to_kv_pool_allocator.available_size()
1861
- evictable_size = self.tree_cache.evictable_size()
1862
- return f"Available tokens: {available_size + evictable_size} ({available_size=} + {evictable_size=})\n"
1863
-
1864
1789
  def __str__(self):
1865
1790
  return (
1866
1791
  f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, "
@@ -1870,8 +1795,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1870
1795
 
1871
1796
  @dataclasses.dataclass
1872
1797
  class ModelWorkerBatch:
1873
- # The batch id
1874
- bid: int
1875
1798
  # The forward mode
1876
1799
  forward_mode: ForwardMode
1877
1800
  # The input ids
@@ -1932,124 +1855,12 @@ class ModelWorkerBatch:
1932
1855
 
1933
1856
  # Speculative decoding
1934
1857
  spec_algorithm: SpeculativeAlgorithm = None
1935
- spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
1858
+
1859
+ spec_info: Optional[SpecInput] = None
1860
+
1936
1861
  # If set, the output of the batch contains the hidden states of the run.
1937
1862
  capture_hidden_mode: CaptureHiddenMode = None
1938
1863
  hicache_consumer_index: int = -1
1939
1864
 
1940
- # Overlap event
1941
- launch_done: Optional[threading.Event] = None
1942
-
1943
1865
  # Whether this batch is prefill-only (no token generation needed)
1944
1866
  is_prefill_only: bool = False
1945
-
1946
-
1947
- @triton.jit
1948
- def write_req_to_token_pool_triton(
1949
- req_to_token_ptr, # [max_batch, max_context_len]
1950
- req_pool_indices,
1951
- pre_lens,
1952
- seq_lens,
1953
- extend_lens,
1954
- out_cache_loc,
1955
- req_to_token_ptr_stride: tl.constexpr,
1956
- ):
1957
- BLOCK_SIZE: tl.constexpr = 512
1958
- pid = tl.program_id(0)
1959
-
1960
- req_pool_index = tl.load(req_pool_indices + pid)
1961
- pre_len = tl.load(pre_lens + pid)
1962
- seq_len = tl.load(seq_lens + pid)
1963
-
1964
- # NOTE: This can be slow for large bs
1965
- cumsum_start = tl.cast(0, tl.int64)
1966
- for i in range(pid):
1967
- cumsum_start += tl.load(extend_lens + i)
1968
-
1969
- num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
1970
- for i in range(num_loop):
1971
- offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
1972
- mask = offset < (seq_len - pre_len)
1973
- value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
1974
- tl.store(
1975
- req_to_token_ptr
1976
- + req_pool_index * req_to_token_ptr_stride
1977
- + offset
1978
- + pre_len,
1979
- value,
1980
- mask=mask,
1981
- )
1982
-
1983
-
1984
- def get_last_loc(
1985
- req_to_token: torch.Tensor,
1986
- req_pool_indices_tensor: torch.Tensor,
1987
- prefix_lens_tensor: torch.Tensor,
1988
- ) -> torch.Tensor:
1989
- if (
1990
- global_server_args_dict["attention_backend"] != "ascend"
1991
- and global_server_args_dict["attention_backend"] != "torch_native"
1992
- ):
1993
- impl = get_last_loc_triton
1994
- else:
1995
- impl = get_last_loc_torch
1996
-
1997
- return impl(req_to_token, req_pool_indices_tensor, prefix_lens_tensor)
1998
-
1999
-
2000
- def get_last_loc_torch(
2001
- req_to_token: torch.Tensor,
2002
- req_pool_indices_tensor: torch.Tensor,
2003
- prefix_lens_tensor: torch.Tensor,
2004
- ) -> torch.Tensor:
2005
- return torch.where(
2006
- prefix_lens_tensor > 0,
2007
- req_to_token[req_pool_indices_tensor, prefix_lens_tensor - 1],
2008
- torch.full_like(prefix_lens_tensor, -1),
2009
- )
2010
-
2011
-
2012
- @triton.jit
2013
- def get_last_loc_kernel(
2014
- req_to_token,
2015
- req_pool_indices_tensor,
2016
- prefix_lens_tensor,
2017
- result,
2018
- num_tokens,
2019
- req_to_token_stride,
2020
- BLOCK_SIZE: tl.constexpr,
2021
- ):
2022
- pid = tl.program_id(0)
2023
- offset = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
2024
- mask = offset < num_tokens
2025
-
2026
- prefix_lens = tl.load(prefix_lens_tensor + offset, mask=mask, other=0)
2027
- req_pool_indices = tl.load(req_pool_indices_tensor + offset, mask=mask, other=0)
2028
-
2029
- token_mask = prefix_lens > 0
2030
- token_index = req_pool_indices * req_to_token_stride + (prefix_lens - 1)
2031
- tokens = tl.load(req_to_token + token_index, mask=token_mask, other=-1)
2032
-
2033
- tl.store(result + offset, tokens, mask=mask)
2034
-
2035
-
2036
- def get_last_loc_triton(
2037
- req_to_token: torch.Tensor,
2038
- req_pool_indices_tensor: torch.Tensor,
2039
- prefix_lens_tensor: torch.Tensor,
2040
- ) -> torch.Tensor:
2041
- BLOCK_SIZE = 256
2042
- num_tokens = prefix_lens_tensor.shape[0]
2043
- result = torch.empty_like(prefix_lens_tensor)
2044
- grid = (triton.cdiv(num_tokens, BLOCK_SIZE),)
2045
-
2046
- get_last_loc_kernel[grid](
2047
- req_to_token,
2048
- req_pool_indices_tensor,
2049
- prefix_lens_tensor,
2050
- result,
2051
- num_tokens,
2052
- req_to_token.stride(0),
2053
- BLOCK_SIZE,
2054
- )
2055
- return result