sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -7,6 +7,7 @@ FlashInfer is faster and Triton is easier to customize.
7
7
  Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
8
8
  """
9
9
 
10
+ import logging
10
11
  import os
11
12
  from dataclasses import dataclass
12
13
  from enum import Enum, auto
@@ -15,21 +16,16 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Union
15
16
 
16
17
  import torch
17
18
 
18
- if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
19
- import logging
20
-
21
- torch._logging.set_logs(dynamo=logging.ERROR)
22
- torch._dynamo.config.suppress_errors = True
23
-
24
- from sglang.global_config import global_config
19
+ from sglang.srt.environ import envs
25
20
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
26
21
  from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
27
22
  from sglang.srt.layers.dp_attention import get_attention_tp_size
28
23
  from sglang.srt.layers.radix_attention import AttentionType
29
24
  from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
30
25
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
31
- from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
26
+ from sglang.srt.speculative.spec_info import SpecInput
32
27
  from sglang.srt.utils import (
28
+ get_int_env_var,
33
29
  is_flashinfer_available,
34
30
  is_sm100_supported,
35
31
  next_power_of_2,
@@ -39,14 +35,21 @@ if TYPE_CHECKING:
39
35
  from sglang.srt.layers.radix_attention import RadixAttention
40
36
  from sglang.srt.model_executor.model_runner import ModelRunner
41
37
 
38
+ logger = logging.getLogger(__name__)
39
+
40
+ if envs.SGLANG_ENABLE_TORCH_COMPILE.get():
41
+ torch._logging.set_logs(dynamo=logging.ERROR)
42
+ torch._dynamo.config.suppress_errors = True
43
+
44
+
42
45
  if is_flashinfer_available():
43
46
  from flashinfer import (
44
47
  BatchDecodeWithPagedKVCacheWrapper,
45
48
  BatchPrefillWithPagedKVCacheWrapper,
46
49
  BatchPrefillWithRaggedKVCacheWrapper,
50
+ fast_decode_plan,
47
51
  )
48
52
  from flashinfer.cascade import merge_state
49
- from flashinfer.decode import _get_range_buf, get_seq_lens
50
53
 
51
54
 
52
55
  class WrapperDispatch(Enum):
@@ -54,6 +57,36 @@ class WrapperDispatch(Enum):
54
57
  CROSS_ATTENTION = auto()
55
58
 
56
59
 
60
+ @dataclass
61
+ class MultiItemScoringParams:
62
+ """Parameters for multi-item scoring in attention computation.
63
+
64
+ Used when processing sequences with multiple items separated by delimiters,
65
+ where each item needs specific attention patterns that respect item boundaries.
66
+
67
+ Attributes:
68
+ prefix_len_ptr: A uint32 1D tensor indicating the prefix length of each prompt.
69
+ The tensor size is equal to the batch size.
70
+ token_pos_in_items_ptr: A uint16 1D tensor indicating the token position of each item
71
+ starting from 0 (delimiter) for each item. For batch size > 1,
72
+ sequences are concatenated with zero padding to ensure same length.
73
+ token_pos_in_items_len: Zero padding length for token_pos_in_items_ptr to handle
74
+ batch_size > 1 case. Defines the padded length for each sequence.
75
+ max_item_len_ptr: A uint16 tensor containing the max token length of all items
76
+ for each prompt in the batch.
77
+
78
+ """
79
+
80
+ prefix_len_ptr: Optional[torch.Tensor] = None
81
+ token_pos_in_items_ptr: Optional[torch.Tensor] = None
82
+ token_pos_in_items_len: int = 0
83
+ max_item_len_ptr: Optional[torch.Tensor] = None
84
+
85
+ def is_enabled(self) -> bool:
86
+ """Check if multi-item scoring is enabled."""
87
+ return self.prefix_len_ptr is not None
88
+
89
+
57
90
  @dataclass
58
91
  class DecodeMetadata:
59
92
  decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper]
@@ -64,6 +97,7 @@ class PrefillMetadata:
64
97
  prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper]
65
98
  use_ragged: bool
66
99
  extend_no_prefix: bool
100
+ multi_item_params: Optional[MultiItemScoringParams] = None
67
101
 
68
102
 
69
103
  # Reuse this workspace buffer across all flashinfer wrappers
@@ -83,9 +117,15 @@ class FlashInferAttnBackend(AttentionBackend):
83
117
  skip_prefill: bool = False,
84
118
  kv_indptr_buf: Optional[torch.Tensor] = None,
85
119
  kv_last_page_len_buf: Optional[torch.Tensor] = None,
120
+ init_new_workspace: bool = False,
86
121
  ):
87
122
  super().__init__()
88
123
 
124
+ # Store multi-item scoring delimiter for efficient access
125
+ self.multi_item_scoring_delimiter = (
126
+ model_runner.server_args.multi_item_scoring_delimiter
127
+ )
128
+
89
129
  # Parse constants
90
130
  self.decode_use_tensor_cores = should_use_tensor_core(
91
131
  kv_cache_dtype=model_runner.kv_cache_dtype,
@@ -120,18 +160,46 @@ class FlashInferAttnBackend(AttentionBackend):
120
160
  or "Qwen3ForCausalLM" in model_runner.model_config.hf_config.architectures
121
161
  or "MiMoForCausalLM" in model_runner.model_config.hf_config.architectures
122
162
  ):
123
- global_config.flashinfer_workspace_size = 512 * 1024 * 1024
163
+ envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.set(512 * 1024 * 1024)
164
+
165
+ # When deterministic inference is enabled, tensor cores should be used for decode
166
+ # Also set split tile sizes for prefill and decode from environment variables, and disable kv split for cuda graph
167
+ # More information can be found here: https://github.com/flashinfer-ai/flashinfer/pull/1675
168
+ self.enable_deterministic = (
169
+ model_runner.server_args.enable_deterministic_inference
170
+ )
171
+ self.prefill_split_tile_size = None
172
+ self.decode_split_tile_size = None
173
+ self.disable_cuda_graph_kv_split = False
174
+ if self.enable_deterministic:
175
+ self.decode_use_tensor_cores = True
176
+ self.prefill_split_tile_size = get_int_env_var(
177
+ "SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE", 4096
178
+ )
179
+ self.decode_split_tile_size = get_int_env_var(
180
+ "SGLANG_FLASHINFER_DECODE_SPLIT_TILE_SIZE", 2048
181
+ )
182
+ self.disable_cuda_graph_kv_split = True
183
+ envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.set(2048 * 1024 * 1024)
124
184
 
125
185
  # Allocate buffers
126
186
  global global_workspace_buffer
127
187
  if global_workspace_buffer is None:
128
188
  # different from flashinfer zero_init_global_workspace_buffer
189
+ global_workspace_size = envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.get()
129
190
  global_workspace_buffer = torch.empty(
130
- global_config.flashinfer_workspace_size,
191
+ global_workspace_size,
192
+ dtype=torch.uint8,
193
+ device=model_runner.device,
194
+ )
195
+ if init_new_workspace:
196
+ self.workspace_buffer = torch.empty(
197
+ envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.get(),
131
198
  dtype=torch.uint8,
132
199
  device=model_runner.device,
133
200
  )
134
- self.workspace_buffer = global_workspace_buffer
201
+ else:
202
+ self.workspace_buffer = global_workspace_buffer
135
203
  max_bs = model_runner.req_to_token_pool.size
136
204
  if kv_indptr_buf is None:
137
205
  self.kv_indptr = [
@@ -204,10 +272,133 @@ class FlashInferAttnBackend(AttentionBackend):
204
272
 
205
273
  # Other metadata
206
274
  self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
275
+
207
276
  self.decode_cuda_graph_metadata = {}
208
277
  self.prefill_cuda_graph_metadata = {} # For verify
209
278
  self.draft_extend_cuda_graph_metadata = {} # For draft extend
210
279
 
280
+ def _process_multi_item_scoring(
281
+ self, forward_batch: ForwardBatch
282
+ ) -> MultiItemScoringParams:
283
+ """Process multi-item scoring tensors for FlashInfer attention.
284
+
285
+ This method handles sequences containing multiple "items" separated by delimiter tokens,
286
+ where each item needs specific attention patterns that respect item boundaries.
287
+
288
+ The method produces four key tensors for FlashInfer:
289
+ - prefix_len_ptr: uint32 tensor with prefix length for each prompt in batch
290
+ - token_pos_in_items_ptr: uint16 tensor with token positions starting from 0 at delimiters
291
+ - token_pos_in_items_len: padding length for batch processing
292
+ - max_item_len_ptr: uint16 tensor with max item length for each prompt
293
+
294
+ Args:
295
+ forward_batch: The forward batch containing input sequences and delimiter info
296
+
297
+ Returns:
298
+ MultiItemScoringParams: The processed multi-item scoring parameters
299
+
300
+ Examples:
301
+ Following FlashInfer definition: for 3 items of length 3, 2, 4 respectively:
302
+ token_pos_in_items_ptr = [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0]
303
+
304
+ Case 1: Single sequence
305
+ Text: "What is the capital of France? <delim> London <delim> Paris <delim> Berlin <delim>"
306
+ Tokens: [What, is, the, capital, of, France, ?, <delim>, London, <delim>, Paris, <delim>, Berlin, <delim>]
307
+ Indices: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
308
+ - prefix_len_ptr: [7] (query length before first delimiter)
309
+ - token_pos_in_items_ptr: [0, 1, 0, 1, 0, 1, 0] (delim=0, London=1, delim=0, Paris=1, delim=0, Berlin=1, delim=0)
310
+ - token_pos_in_items_len: 7 (actual length)
311
+ - max_item_len_ptr: [1] (max item length is 1 token - all options are single tokens)
312
+
313
+ Case 2: Batch processing (batch_size=2)
314
+ Sequence 1: 2 items of length 2, 1 → [0, 1, 2, 0, 1, 0] (6 elements)
315
+ Sequence 2: 3 items of length 1, 3, 2 → [0, 1, 0, 1, 2, 3, 0, 1, 2, 0] (10 elements)
316
+ After padding both to length 10:
317
+ - token_pos_in_items_ptr: [0, 1, 2, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 2, 3, 0, 1, 2, 0]
318
+ - token_pos_in_items_len: 10 (padded length for batch processing)
319
+ - max_item_len_ptr: [2, 3] (max lengths per sequence)
320
+ """
321
+
322
+ delimiter = self.multi_item_scoring_delimiter
323
+ if delimiter is None or forward_batch.forward_mode == ForwardMode.DECODE:
324
+ return MultiItemScoringParams()
325
+
326
+ delimiter_mask = forward_batch.input_ids == delimiter
327
+ prefix_cache_lens = getattr(forward_batch, "extend_prefix_lens", None)
328
+ extend_seq_lens = getattr(forward_batch, "extend_seq_lens", None)
329
+ prefix_len_ptr, token_pos_in_items_ptr = [], []
330
+ token_pos_in_items_len = 0
331
+
332
+ # If no extend_seq_lens, treat whole batch as one sequence
333
+ if extend_seq_lens is None or len(extend_seq_lens) <= 1:
334
+ extend_seq_lens = [forward_batch.input_ids.size(0)]
335
+
336
+ seq_start = 0
337
+ for i, seq_len in enumerate(extend_seq_lens):
338
+ seq_end = seq_start + seq_len
339
+ mask = delimiter_mask[seq_start:seq_end]
340
+ pos = forward_batch.positions[seq_start:seq_end]
341
+ delimiter_indices = torch.nonzero(mask, as_tuple=True)[0]
342
+
343
+ if len(delimiter_indices) > 0:
344
+ first_delim = delimiter_indices[0]
345
+ # Prefix length: store as scalar
346
+ prefix_len = first_delim + (
347
+ prefix_cache_lens[i] if prefix_cache_lens is not None else 0
348
+ )
349
+ prefix_len_ptr.append(
350
+ prefix_len.item() if torch.is_tensor(prefix_len) else prefix_len
351
+ )
352
+
353
+ # Compute relative positions within items after delimiters
354
+ diff = pos[first_delim:] - torch.cummax(mask[first_delim:], 0)[1]
355
+ token_pos = (diff - pos[first_delim]).to(torch.uint16)
356
+ token_pos_in_items_ptr.append(token_pos)
357
+
358
+ # Update forward_batch positions in-place
359
+ pos[first_delim:] = diff - 1
360
+ forward_batch.positions[seq_start:seq_end] = pos
361
+
362
+ seq_start = seq_end
363
+
364
+ # Pad token_pos_in_items_ptr for batch processing
365
+ if token_pos_in_items_ptr:
366
+ token_pos_in_items_len = max(t.numel() for t in token_pos_in_items_ptr)
367
+ device = forward_batch.input_ids.device
368
+ token_pos_in_items_ptr = [
369
+ torch.cat(
370
+ [
371
+ t,
372
+ torch.zeros(
373
+ token_pos_in_items_len - t.numel(),
374
+ dtype=torch.uint16,
375
+ device=device,
376
+ ),
377
+ ]
378
+ )
379
+ for t in token_pos_in_items_ptr
380
+ ]
381
+
382
+ if not prefix_len_ptr or not token_pos_in_items_ptr:
383
+ return MultiItemScoringParams()
384
+
385
+ # Build final params
386
+ device = forward_batch.input_ids.device
387
+ return MultiItemScoringParams(
388
+ prefix_len_ptr=torch.tensor(
389
+ prefix_len_ptr, dtype=torch.uint32, device=device
390
+ ),
391
+ token_pos_in_items_ptr=torch.cat(token_pos_in_items_ptr, dim=0),
392
+ token_pos_in_items_len=token_pos_in_items_len & 0xFFFFFFFF,
393
+ max_item_len_ptr=torch.stack(
394
+ [
395
+ t.to(torch.int32).max().to(torch.uint16)
396
+ for t in token_pos_in_items_ptr
397
+ ],
398
+ dim=0,
399
+ ),
400
+ )
401
+
211
402
  def init_forward_metadata(self, forward_batch: ForwardBatch):
212
403
  if forward_batch.forward_mode.is_decode_or_idle():
213
404
  self.indices_updater_decode.update(
@@ -218,6 +409,8 @@ class FlashInferAttnBackend(AttentionBackend):
218
409
  decode_wrappers=self.decode_wrappers,
219
410
  encoder_lens=forward_batch.encoder_lens,
220
411
  spec_info=forward_batch.spec_info,
412
+ fixed_split_size=self.decode_split_tile_size,
413
+ disable_split_kv=False,
221
414
  )
222
415
  self.forward_metadata = DecodeMetadata(self.decode_wrappers)
223
416
  elif forward_batch.forward_mode.is_draft_extend():
@@ -253,13 +446,26 @@ class FlashInferAttnBackend(AttentionBackend):
253
446
  else:
254
447
  prefix_lens = forward_batch.extend_prefix_lens
255
448
 
256
- if self.is_multimodal:
449
+ # Disable ragged wrapper and ensure prefix handling for multimodal and multi-item scoring
450
+ if self.is_multimodal or self.multi_item_scoring_delimiter is not None:
451
+ # use_ragged = False: Multi-item scoring requires the paged wrapper because:
452
+ # 1. Ragged wrapper doesn't support the specialized multi-item parameters
453
+ # (prefix_len_ptr, token_pos_in_items_ptr, etc.)
454
+ # 2. Paged wrapper provides better control over attention masking needed
455
+ # for respecting item boundaries in multi-item sequences
456
+ # 3. Custom masking logic conflicts with ragged wrapper's assumptions
257
457
  use_ragged = False
258
458
  extend_no_prefix = False
259
459
  else:
260
- use_ragged = True
460
+ use_ragged = not self.enable_deterministic
261
461
  extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
262
462
 
463
+ # Process multi-item scoring in attention backend instead of ForwardBatch
464
+ multi_item_params = MultiItemScoringParams()
465
+ if self.multi_item_scoring_delimiter is not None:
466
+ # Use new backend-specific implementation
467
+ multi_item_params = self._process_multi_item_scoring(forward_batch)
468
+
263
469
  self.indices_updater_prefill.update(
264
470
  forward_batch.req_pool_indices,
265
471
  forward_batch.seq_lens,
@@ -270,9 +476,14 @@ class FlashInferAttnBackend(AttentionBackend):
270
476
  use_ragged=use_ragged,
271
477
  encoder_lens=forward_batch.encoder_lens,
272
478
  spec_info=None,
479
+ fixed_split_size=self.prefill_split_tile_size,
480
+ multi_item_params=multi_item_params,
273
481
  )
274
482
  self.forward_metadata = PrefillMetadata(
275
- self.prefill_wrappers_paged, use_ragged, extend_no_prefix
483
+ self.prefill_wrappers_paged,
484
+ use_ragged,
485
+ extend_no_prefix,
486
+ multi_item_params,
276
487
  )
277
488
 
278
489
  def init_cuda_graph_state(
@@ -317,7 +528,7 @@ class FlashInferAttnBackend(AttentionBackend):
317
528
  seq_lens: torch.Tensor,
318
529
  encoder_lens: Optional[torch.Tensor],
319
530
  forward_mode: ForwardMode,
320
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
531
+ spec_info: Optional[SpecInput],
321
532
  ):
322
533
  if forward_mode.is_decode_or_idle():
323
534
  decode_wrappers = []
@@ -344,6 +555,8 @@ class FlashInferAttnBackend(AttentionBackend):
344
555
  decode_wrappers=decode_wrappers,
345
556
  encoder_lens=encoder_lens,
346
557
  spec_info=spec_info,
558
+ fixed_split_size=None,
559
+ disable_split_kv=self.disable_cuda_graph_kv_split,
347
560
  )
348
561
  self.decode_cuda_graph_metadata[bs] = decode_wrappers
349
562
  self.forward_metadata = DecodeMetadata(decode_wrappers)
@@ -422,7 +635,7 @@ class FlashInferAttnBackend(AttentionBackend):
422
635
  seq_lens_sum: int,
423
636
  encoder_lens: Optional[torch.Tensor],
424
637
  forward_mode: ForwardMode,
425
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
638
+ spec_info: Optional[SpecInput],
426
639
  seq_lens_cpu: Optional[torch.Tensor],
427
640
  ):
428
641
  if forward_mode.is_decode_or_idle():
@@ -434,6 +647,8 @@ class FlashInferAttnBackend(AttentionBackend):
434
647
  decode_wrappers=self.decode_cuda_graph_metadata[bs],
435
648
  encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
436
649
  spec_info=spec_info,
650
+ fixed_split_size=None,
651
+ disable_split_kv=self.disable_cuda_graph_kv_split,
437
652
  )
438
653
  elif forward_mode.is_target_verify():
439
654
  self.indices_updater_prefill.update(
@@ -499,7 +714,20 @@ class FlashInferAttnBackend(AttentionBackend):
499
714
  forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
500
715
  causal=not layer.is_cross_attention,
501
716
  sm_scale=layer.scaling,
502
- window_left=layer.sliding_window_size,
717
+ # Disable sliding window attention for multi-item scoring:
718
+ # - Sliding window could cut across item boundaries, breaking semantic coherence
719
+ # - Multi-item sequences need full attention to properly handle delimiter tokens
720
+ # - Specialized multi-item parameters (prefix_len_ptr, token_pos_in_items_ptr)
721
+ # provide more precise attention control than simple sliding windows
722
+ # - Item-aware masking takes precedence over window-based masking
723
+ window_left=(
724
+ layer.sliding_window_size
725
+ if not (
726
+ self.forward_metadata.multi_item_params
727
+ and self.forward_metadata.multi_item_params.is_enabled()
728
+ )
729
+ else -1
730
+ ),
503
731
  logits_soft_cap=logits_soft_cap,
504
732
  # Must use _float to avoid device-to-host copy that breaks cuda graph capture.
505
733
  k_scale=layer.k_scale_float,
@@ -507,9 +735,13 @@ class FlashInferAttnBackend(AttentionBackend):
507
735
  )
508
736
  else:
509
737
  causal = True
510
- if layer.attn_type == AttentionType.ENCODER_ONLY:
511
- save_kv_cache = False
738
+ if (
739
+ layer.is_cross_attention
740
+ or layer.attn_type == AttentionType.ENCODER_ONLY
741
+ ):
512
742
  causal = False
743
+ if save_kv_cache and layer.attn_type == AttentionType.ENCODER_ONLY:
744
+ save_kv_cache = False
513
745
 
514
746
  if self.forward_metadata.extend_no_prefix:
515
747
  # NOTE: FlashInfer currently has limitations with head_dim = 32 or other dimensions
@@ -638,7 +870,9 @@ class FlashInferIndicesUpdaterDecode:
638
870
  seq_lens_sum: int,
639
871
  decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
640
872
  encoder_lens: Optional[torch.Tensor],
641
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
873
+ spec_info: Optional[SpecInput],
874
+ fixed_split_size: Optional[int] = None,
875
+ disable_split_kv: Optional[bool] = None,
642
876
  ):
643
877
  # Keep the signature for type checking. It will be assigned during runtime.
644
878
  raise NotImplementedError()
@@ -651,7 +885,9 @@ class FlashInferIndicesUpdaterDecode:
651
885
  seq_lens_sum: int,
652
886
  decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
653
887
  encoder_lens: Optional[torch.Tensor],
654
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
888
+ spec_info: Optional[SpecInput],
889
+ fixed_split_size: Optional[int] = None,
890
+ disable_split_kv: Optional[bool] = None,
655
891
  ):
656
892
  decode_wrappers = decode_wrappers or self.decode_wrappers
657
893
  self.call_begin_forward(
@@ -663,6 +899,8 @@ class FlashInferIndicesUpdaterDecode:
663
899
  None,
664
900
  spec_info,
665
901
  seq_lens_cpu,
902
+ fixed_split_size=fixed_split_size,
903
+ disable_split_kv=disable_split_kv,
666
904
  )
667
905
 
668
906
  def update_sliding_window(
@@ -673,7 +911,9 @@ class FlashInferIndicesUpdaterDecode:
673
911
  seq_lens_sum: int,
674
912
  decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
675
913
  encoder_lens: Optional[torch.Tensor],
676
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
914
+ spec_info: Optional[SpecInput],
915
+ fixed_split_size: Optional[int] = None,
916
+ disable_split_kv: Optional[bool] = None,
677
917
  ):
678
918
  assert self.sliding_window_size is not None
679
919
  for wrapper_id in range(2):
@@ -721,7 +961,9 @@ class FlashInferIndicesUpdaterDecode:
721
961
  seq_lens_sum: int,
722
962
  decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
723
963
  encoder_lens: Optional[torch.Tensor],
724
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
964
+ spec_info: Optional[SpecInput],
965
+ fixed_split_size: Optional[int] = None,
966
+ disable_split_kv: Optional[bool] = None,
725
967
  ):
726
968
  for wrapper_id in range(2):
727
969
  if wrapper_id == 0:
@@ -753,9 +995,11 @@ class FlashInferIndicesUpdaterDecode:
753
995
  paged_kernel_lens_sum: int,
754
996
  kv_indptr: torch.Tensor,
755
997
  kv_start_idx: torch.Tensor,
756
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
998
+ spec_info: Optional[SpecInput],
757
999
  seq_lens_cpu: Optional[torch.Tensor],
758
1000
  use_sliding_window_kv_pool: bool = False,
1001
+ fixed_split_size: Optional[int] = None,
1002
+ disable_split_kv: Optional[bool] = None,
759
1003
  ):
760
1004
  if spec_info is None:
761
1005
  bs = len(req_pool_indices)
@@ -799,19 +1043,51 @@ class FlashInferIndicesUpdaterDecode:
799
1043
  global_override_indptr_cpu[0] = 0
800
1044
  global_override_indptr_cpu[1 : bs + 1] = torch.cumsum(seq_lens_cpu, dim=0)
801
1045
 
802
- wrapper.begin_forward(
803
- kv_indptr,
804
- kv_indices,
805
- self.kv_last_page_len[:bs],
806
- self.num_qo_heads,
807
- self.num_kv_heads,
808
- self.head_dim,
809
- 1,
810
- data_type=self.data_type,
811
- q_data_type=self.q_data_type,
812
- non_blocking=True,
1046
+ # Check if this specific wrapper's begin_forward has been replaced with fast_decode_plan
1047
+ # by checking if it's a partial function with fast_decode_plan as the func
1048
+ wrapper_uses_fast_decode_plan = (
1049
+ hasattr(wrapper.begin_forward, "func")
1050
+ and wrapper.begin_forward.func == fast_decode_plan
813
1051
  )
814
1052
 
1053
+ if wrapper_uses_fast_decode_plan:
1054
+ # When begin_forward is replaced with fast_decode_plan, pass global_override_indptr_cpu
1055
+ wrapper.begin_forward(
1056
+ kv_indptr,
1057
+ kv_indices,
1058
+ self.kv_last_page_len[:bs],
1059
+ self.num_qo_heads,
1060
+ self.num_kv_heads,
1061
+ self.head_dim,
1062
+ 1,
1063
+ data_type=self.data_type,
1064
+ q_data_type=self.q_data_type,
1065
+ non_blocking=True,
1066
+ fixed_split_size=fixed_split_size,
1067
+ disable_split_kv=(
1068
+ disable_split_kv if disable_split_kv is not None else False
1069
+ ),
1070
+ global_override_indptr_cpu=global_override_indptr_cpu,
1071
+ )
1072
+ else:
1073
+ # When using original begin_forward, don't pass global_override_indptr_cpu
1074
+ wrapper.begin_forward(
1075
+ kv_indptr,
1076
+ kv_indices,
1077
+ self.kv_last_page_len[:bs],
1078
+ self.num_qo_heads,
1079
+ self.num_kv_heads,
1080
+ self.head_dim,
1081
+ 1,
1082
+ data_type=self.data_type,
1083
+ q_data_type=self.q_data_type,
1084
+ non_blocking=True,
1085
+ fixed_split_size=fixed_split_size,
1086
+ disable_split_kv=(
1087
+ disable_split_kv if disable_split_kv is not None else False
1088
+ ),
1089
+ )
1090
+
815
1091
  if locally_override:
816
1092
  global_override_indptr_cpu = None
817
1093
 
@@ -858,7 +1134,8 @@ class FlashInferIndicesUpdaterPrefill:
858
1134
  prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
859
1135
  use_ragged: bool,
860
1136
  encoder_lens: Optional[torch.Tensor],
861
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
1137
+ spec_info: Optional[SpecInput],
1138
+ fixed_split_size: Optional[int] = None,
862
1139
  ):
863
1140
  # Keep the signature for type checking. It will be assigned during runtime.
864
1141
  raise NotImplementedError()
@@ -873,7 +1150,9 @@ class FlashInferIndicesUpdaterPrefill:
873
1150
  prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
874
1151
  use_ragged: bool,
875
1152
  encoder_lens: Optional[torch.Tensor],
876
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
1153
+ spec_info: Optional[SpecInput],
1154
+ fixed_split_size: Optional[int] = None,
1155
+ multi_item_params: Optional[MultiItemScoringParams] = None,
877
1156
  ):
878
1157
  if use_ragged:
879
1158
  # TODO: remove this device sync, we can use forward_batch.extend_prefix_lens_cpu
@@ -897,6 +1176,8 @@ class FlashInferIndicesUpdaterPrefill:
897
1176
  self.qo_indptr[0],
898
1177
  use_ragged,
899
1178
  spec_info,
1179
+ fixed_split_size=fixed_split_size,
1180
+ multi_item_params=multi_item_params,
900
1181
  )
901
1182
 
902
1183
  def update_sliding_window(
@@ -909,7 +1190,9 @@ class FlashInferIndicesUpdaterPrefill:
909
1190
  prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
910
1191
  use_ragged: bool,
911
1192
  encoder_lens: Optional[torch.Tensor],
912
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
1193
+ spec_info: Optional[SpecInput],
1194
+ fixed_split_size: Optional[int] = None,
1195
+ multi_item_params: Optional[MultiItemScoringParams] = None,
913
1196
  ):
914
1197
  for wrapper_id in range(2):
915
1198
  if wrapper_id == 0:
@@ -943,6 +1226,7 @@ class FlashInferIndicesUpdaterPrefill:
943
1226
  use_ragged,
944
1227
  spec_info,
945
1228
  use_sliding_window_kv_pool=use_sliding_window_kv_pool,
1229
+ multi_item_params=multi_item_params,
946
1230
  )
947
1231
 
948
1232
  def update_cross_attention(
@@ -955,7 +1239,9 @@ class FlashInferIndicesUpdaterPrefill:
955
1239
  prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
956
1240
  use_ragged: bool,
957
1241
  encoder_lens: Optional[torch.Tensor],
958
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
1242
+ spec_info: Optional[SpecInput],
1243
+ fixed_split_size: Optional[int] = None,
1244
+ multi_item_params: Optional[MultiItemScoringParams] = None,
959
1245
  ):
960
1246
  for wrapper_id in range(2):
961
1247
  if wrapper_id == 0:
@@ -982,6 +1268,7 @@ class FlashInferIndicesUpdaterPrefill:
982
1268
  self.qo_indptr[wrapper_id],
983
1269
  use_ragged,
984
1270
  spec_info,
1271
+ multi_item_params=multi_item_params,
985
1272
  )
986
1273
 
987
1274
  def call_begin_forward(
@@ -997,8 +1284,10 @@ class FlashInferIndicesUpdaterPrefill:
997
1284
  kv_indptr: torch.Tensor,
998
1285
  qo_indptr: torch.Tensor,
999
1286
  use_ragged: bool,
1000
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
1287
+ spec_info: Optional[SpecInput],
1001
1288
  use_sliding_window_kv_pool: bool = False,
1289
+ fixed_split_size: Optional[int] = None,
1290
+ multi_item_params: Optional[MultiItemScoringParams] = None,
1002
1291
  ):
1003
1292
  bs = len(seq_lens)
1004
1293
  if spec_info is None:
@@ -1024,9 +1313,7 @@ class FlashInferIndicesUpdaterPrefill:
1024
1313
  qo_indptr = qo_indptr[: bs + 1]
1025
1314
  custom_mask = None
1026
1315
  else:
1027
- assert isinstance(spec_info, EagleDraftInput) or isinstance(
1028
- spec_info, EagleVerifyInput
1029
- )
1316
+ assert isinstance(spec_info, SpecInput)
1030
1317
  kv_indices, kv_indptr, qo_indptr, custom_mask = (
1031
1318
  spec_info.generate_attn_arg_prefill(
1032
1319
  req_pool_indices,
@@ -1056,6 +1343,22 @@ class FlashInferIndicesUpdaterPrefill:
1056
1343
  )
1057
1344
 
1058
1345
  # cached part
1346
+ # Conditionally set multi-item parameters
1347
+ if multi_item_params is not None and multi_item_params.is_enabled():
1348
+ # Multi-item scoring is active - use specialized parameters and disable generic custom_mask
1349
+ use_custom_mask = None
1350
+ prefix_len_ptr = multi_item_params.prefix_len_ptr
1351
+ token_pos_in_items_ptr = multi_item_params.token_pos_in_items_ptr
1352
+ token_pos_in_items_len = multi_item_params.token_pos_in_items_len
1353
+ max_item_len_ptr = multi_item_params.max_item_len_ptr
1354
+ else:
1355
+ # No multi-item scoring - use standard parameters
1356
+ use_custom_mask = custom_mask
1357
+ prefix_len_ptr = None
1358
+ token_pos_in_items_ptr = None
1359
+ token_pos_in_items_len = 0
1360
+ max_item_len_ptr = None
1361
+
1059
1362
  wrapper_paged.begin_forward(
1060
1363
  qo_indptr,
1061
1364
  kv_indptr,
@@ -1067,8 +1370,13 @@ class FlashInferIndicesUpdaterPrefill:
1067
1370
  1,
1068
1371
  q_data_type=self.q_data_type,
1069
1372
  kv_data_type=self.data_type,
1070
- custom_mask=custom_mask,
1373
+ custom_mask=use_custom_mask,
1071
1374
  non_blocking=True,
1375
+ fixed_split_size=fixed_split_size,
1376
+ prefix_len_ptr=prefix_len_ptr,
1377
+ token_pos_in_items_ptr=token_pos_in_items_ptr,
1378
+ token_pos_in_items_len=token_pos_in_items_len,
1379
+ max_item_len_ptr=max_item_len_ptr,
1072
1380
  )
1073
1381
 
1074
1382
 
@@ -1084,7 +1392,7 @@ class FlashInferMultiStepDraftBackend:
1084
1392
  topk: int,
1085
1393
  speculative_num_steps: int,
1086
1394
  ):
1087
- from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
1395
+ from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
1088
1396
 
1089
1397
  self.topk = topk
1090
1398
  self.speculative_num_steps = speculative_num_steps
@@ -1104,7 +1412,7 @@ class FlashInferMultiStepDraftBackend:
1104
1412
  (max_bs,), dtype=torch.int32, device=model_runner.device
1105
1413
  )
1106
1414
  self.attn_backends: List[FlashInferAttnBackend] = []
1107
- for i in range(self.speculative_num_steps):
1415
+ for i in range(self.speculative_num_steps - 1):
1108
1416
  self.attn_backends.append(
1109
1417
  FlashInferAttnBackend(
1110
1418
  model_runner,
@@ -1148,7 +1456,7 @@ class FlashInferMultiStepDraftBackend:
1148
1456
  )
1149
1457
 
1150
1458
  assert forward_batch.spec_info is not None
1151
- assert isinstance(forward_batch.spec_info, EagleDraftInput)
1459
+ assert forward_batch.spec_info.is_draft_input()
1152
1460
 
1153
1461
  # Copy the kv_indptr once to avoid multiple device-to-host copies in flashinfer's plan.
1154
1462
  indptr_cpu_whole = self.kv_indptr[:, : bs + 1].cpu()
@@ -1192,7 +1500,7 @@ class FlashInferMultiStepDraftBackend:
1192
1500
  device="cuda",
1193
1501
  )
1194
1502
 
1195
- for i in range(self.speculative_num_steps):
1503
+ for i in range(self.speculative_num_steps - 1):
1196
1504
  self.attn_backends[i].init_cuda_graph_state(
1197
1505
  max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
1198
1506
  )
@@ -1276,166 +1584,3 @@ def should_use_tensor_core(
1276
1584
  return gqa_group_size >= 4
1277
1585
  else:
1278
1586
  return False
1279
-
1280
-
1281
- # Use as a fast path to override the indptr in flashinfer's plan function
1282
- # This is used to remove some host-to-device copy overhead.
1283
- global_override_indptr_cpu = None
1284
-
1285
-
1286
- def fast_decode_plan(
1287
- self,
1288
- indptr: torch.Tensor,
1289
- indices: torch.Tensor,
1290
- last_page_len: torch.Tensor,
1291
- num_qo_heads: int,
1292
- num_kv_heads: int,
1293
- head_dim: int,
1294
- page_size: int,
1295
- pos_encoding_mode: str = "NONE",
1296
- window_left: int = -1,
1297
- logits_soft_cap: Optional[float] = None,
1298
- q_data_type: Optional[Union[str, torch.dtype]] = None,
1299
- kv_data_type: Optional[Union[str, torch.dtype]] = None,
1300
- data_type: Optional[Union[str, torch.dtype]] = None,
1301
- sm_scale: Optional[float] = None,
1302
- rope_scale: Optional[float] = None,
1303
- rope_theta: Optional[float] = None,
1304
- non_blocking: bool = True,
1305
- ) -> None:
1306
- """
1307
- A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend.
1308
- Modifications:
1309
- - Remove unnecessary device-to-device copy for the cuda graph buffers.
1310
- - Remove unnecessary host-to-device copy for the metadata buffers.
1311
- """
1312
- batch_size = len(last_page_len)
1313
- if logits_soft_cap is None:
1314
- logits_soft_cap = 0.0
1315
-
1316
- # Handle data types consistently
1317
- if data_type is not None:
1318
- if q_data_type is None:
1319
- q_data_type = data_type
1320
- if kv_data_type is None:
1321
- kv_data_type = data_type
1322
- elif q_data_type is None:
1323
- q_data_type = "float16"
1324
-
1325
- if kv_data_type is None:
1326
- kv_data_type = q_data_type
1327
-
1328
- if self.use_tensor_cores:
1329
- qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
1330
-
1331
- if self.is_cuda_graph_enabled:
1332
- if batch_size != self._fixed_batch_size:
1333
- raise ValueError(
1334
- "The batch size should be fixed in cudagraph mode, the runtime batch size {} "
1335
- " mismatches the batch size set during initialization {}".format(
1336
- batch_size, self._fixed_batch_size
1337
- )
1338
- )
1339
- if len(indices) > len(self._paged_kv_indices_buf):
1340
- raise ValueError(
1341
- "The size of indices should be less than or equal to the allocated buffer"
1342
- )
1343
- else:
1344
- self._paged_kv_indptr_buf = indptr
1345
- self._paged_kv_indices_buf = indices
1346
- self._paged_kv_last_page_len_buf = last_page_len
1347
- if self.use_tensor_cores:
1348
- self._qo_indptr_buf = qo_indptr_host.to(
1349
- self.device, non_blocking=non_blocking
1350
- )
1351
-
1352
- # Create empty tensors for dtype info if needed
1353
- empty_q_data = torch.empty(
1354
- 0,
1355
- dtype=(
1356
- getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type
1357
- ),
1358
- device=self.device,
1359
- )
1360
-
1361
- empty_kv_cache = torch.empty(
1362
- 0,
1363
- dtype=(
1364
- getattr(torch, kv_data_type)
1365
- if isinstance(kv_data_type, str)
1366
- else kv_data_type
1367
- ),
1368
- device=self.device,
1369
- )
1370
-
1371
- indptr_host = (
1372
- global_override_indptr_cpu
1373
- if global_override_indptr_cpu is not None
1374
- else indptr.cpu()
1375
- )
1376
-
1377
- with torch.cuda.device(self.device):
1378
-
1379
- if self.use_tensor_cores:
1380
- # ALSO convert last_page_len to CPU
1381
- if page_size == 1:
1382
- # When page size is 1, last_page_len is always 1.
1383
- # Directly construct the host tensor rather than executing a device-to-host copy.
1384
- last_page_len_host = torch.ones(
1385
- (batch_size,), dtype=torch.int32, device="cpu"
1386
- )
1387
- else:
1388
- last_page_len_host = last_page_len.cpu()
1389
-
1390
- kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size)
1391
-
1392
- try:
1393
- # Make sure we pass exactly 15 arguments for tensor core version
1394
- self._plan_info = self._cached_module.plan(
1395
- self._float_workspace_buffer,
1396
- self._int_workspace_buffer,
1397
- self._pin_memory_int_workspace_buffer,
1398
- qo_indptr_host,
1399
- indptr_host,
1400
- kv_lens_arr_host,
1401
- batch_size, # total_num_rows
1402
- batch_size,
1403
- num_qo_heads,
1404
- num_kv_heads,
1405
- page_size,
1406
- self.is_cuda_graph_enabled,
1407
- head_dim,
1408
- head_dim,
1409
- False, # causal
1410
- )
1411
- except Exception as e:
1412
- raise RuntimeError(f"Error in standard plan: {e}")
1413
- else:
1414
- try:
1415
- # Make sure we pass exactly 15 arguments for standard version
1416
- self._plan_info = self._cached_module.plan(
1417
- self._float_workspace_buffer,
1418
- self._int_workspace_buffer,
1419
- self._pin_memory_int_workspace_buffer,
1420
- indptr_host,
1421
- batch_size,
1422
- num_qo_heads,
1423
- num_kv_heads,
1424
- page_size,
1425
- self.is_cuda_graph_enabled,
1426
- window_left,
1427
- logits_soft_cap,
1428
- head_dim,
1429
- head_dim,
1430
- empty_q_data,
1431
- empty_kv_cache,
1432
- )
1433
- except Exception as e:
1434
- raise RuntimeError(f"Error in standard plan: {e}")
1435
-
1436
- self._pos_encoding_mode = pos_encoding_mode
1437
- self._window_left = window_left
1438
- self._logits_soft_cap = logits_soft_cap
1439
- self._sm_scale = sm_scale
1440
- self._rope_scale = rope_scale
1441
- self._rope_theta = rope_theta