sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (396) hide show
  1. sglang/bench_one_batch.py +7 -11
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +474 -142
  4. sglang/compile_deep_gemm.py +3 -0
  5. sglang/global_config.py +2 -2
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +10 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +314 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/mamba_utils.py +117 -0
  17. sglang/srt/configs/model_config.py +228 -92
  18. sglang/srt/configs/nemotron_h.py +286 -0
  19. sglang/srt/configs/qwen3_next.py +294 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +78 -37
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +373 -68
  44. sglang/srt/disaggregation/prefill.py +53 -49
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +156 -80
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +842 -0
  52. sglang/srt/entrypoints/grpc_server.py +950 -0
  53. sglang/srt/entrypoints/http_server.py +179 -60
  54. sglang/srt/entrypoints/openai/protocol.py +265 -29
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +213 -122
  57. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  58. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  63. sglang/srt/environ.py +289 -0
  64. sglang/srt/eplb/eplb_manager.py +2 -2
  65. sglang/srt/eplb/expert_distribution.py +26 -13
  66. sglang/srt/eplb/expert_location.py +38 -8
  67. sglang/srt/eplb/expert_location_updater.py +1 -1
  68. sglang/srt/function_call/base_format_detector.py +3 -6
  69. sglang/srt/function_call/ebnf_composer.py +11 -9
  70. sglang/srt/function_call/function_call_parser.py +17 -8
  71. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  72. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  73. sglang/srt/function_call/json_array_parser.py +63 -0
  74. sglang/srt/function_call/kimik2_detector.py +17 -4
  75. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  76. sglang/srt/function_call/utils.py +96 -5
  77. sglang/srt/grpc/__init__.py +1 -0
  78. sglang/srt/grpc/compile_proto.py +245 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
  81. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
  82. sglang/srt/layers/activation.py +143 -9
  83. sglang/srt/layers/attention/aiter_backend.py +14 -15
  84. sglang/srt/layers/attention/ascend_backend.py +115 -9
  85. sglang/srt/layers/attention/attention_registry.py +215 -0
  86. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  87. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  88. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  89. sglang/srt/layers/attention/fla/chunk.py +242 -0
  90. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  91. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  92. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  93. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  94. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  95. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  96. sglang/srt/layers/attention/fla/index.py +37 -0
  97. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  98. sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
  99. sglang/srt/layers/attention/fla/op.py +66 -0
  100. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  101. sglang/srt/layers/attention/fla/utils.py +331 -0
  102. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  103. sglang/srt/layers/attention/flashattention_backend.py +40 -8
  104. sglang/srt/layers/attention/flashinfer_backend.py +341 -204
  105. sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
  106. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  107. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  108. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
  109. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  111. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
  112. sglang/srt/layers/attention/mamba/mamba.py +577 -0
  113. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  114. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  115. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  116. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  117. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  121. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  122. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  123. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  124. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  125. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  126. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  127. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  128. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  129. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  130. sglang/srt/layers/attention/nsa/utils.py +24 -0
  131. sglang/srt/layers/attention/nsa_backend.py +887 -0
  132. sglang/srt/layers/attention/tbo_backend.py +6 -6
  133. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  134. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  135. sglang/srt/layers/attention/triton_backend.py +57 -7
  136. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  137. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  138. sglang/srt/layers/attention/vision.py +58 -0
  139. sglang/srt/layers/attention/wave_backend.py +4 -4
  140. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  141. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  142. sglang/srt/layers/communicator.py +8 -0
  143. sglang/srt/layers/dp_attention.py +41 -2
  144. sglang/srt/layers/elementwise.py +3 -1
  145. sglang/srt/layers/layernorm.py +34 -15
  146. sglang/srt/layers/linear.py +55 -7
  147. sglang/srt/layers/logits_processor.py +180 -18
  148. sglang/srt/layers/modelopt_utils.py +11 -0
  149. sglang/srt/layers/moe/__init__.py +2 -1
  150. sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
  151. sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
  152. sglang/srt/layers/moe/ep_moe/layer.py +248 -333
  153. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  154. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  155. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  167. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  169. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  170. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  171. sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
  172. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  173. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  174. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  175. sglang/srt/layers/moe/moe_runner/runner.py +83 -0
  176. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  177. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  178. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  179. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  180. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  181. sglang/srt/layers/moe/topk.py +30 -9
  182. sglang/srt/layers/moe/utils.py +29 -7
  183. sglang/srt/layers/parameter.py +23 -6
  184. sglang/srt/layers/quantization/__init__.py +1 -1
  185. sglang/srt/layers/quantization/awq.py +19 -7
  186. sglang/srt/layers/quantization/base_config.py +11 -6
  187. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  188. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  189. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  190. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  191. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  192. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  193. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  194. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  195. sglang/srt/layers/quantization/fp8.py +155 -60
  196. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  197. sglang/srt/layers/quantization/gptq.py +25 -17
  198. sglang/srt/layers/quantization/modelopt_quant.py +191 -56
  199. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  200. sglang/srt/layers/quantization/mxfp4.py +74 -42
  201. sglang/srt/layers/quantization/quark/quark.py +3 -1
  202. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  203. sglang/srt/layers/quantization/unquant.py +135 -47
  204. sglang/srt/layers/quantization/w4afp8.py +28 -33
  205. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  206. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  207. sglang/srt/layers/rotary_embedding.py +78 -31
  208. sglang/srt/layers/sampler.py +213 -21
  209. sglang/srt/layers/utils.py +23 -0
  210. sglang/srt/lora/backend/base_backend.py +50 -8
  211. sglang/srt/lora/backend/chunked_backend.py +348 -0
  212. sglang/srt/lora/backend/triton_backend.py +99 -5
  213. sglang/srt/lora/layers.py +32 -0
  214. sglang/srt/lora/lora.py +8 -3
  215. sglang/srt/lora/lora_manager.py +44 -118
  216. sglang/srt/lora/mem_pool.py +25 -11
  217. sglang/srt/lora/triton_ops/__init__.py +4 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  219. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  220. sglang/srt/lora/utils.py +22 -11
  221. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  222. sglang/srt/managers/cache_controller.py +199 -301
  223. sglang/srt/managers/data_parallel_controller.py +115 -80
  224. sglang/srt/managers/detokenizer_manager.py +19 -15
  225. sglang/srt/managers/disagg_service.py +46 -0
  226. sglang/srt/managers/io_struct.py +340 -109
  227. sglang/srt/managers/mm_utils.py +44 -6
  228. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  229. sglang/srt/managers/multimodal_processor.py +1 -2
  230. sglang/srt/managers/overlap_utils.py +55 -0
  231. sglang/srt/managers/schedule_batch.py +343 -212
  232. sglang/srt/managers/schedule_policy.py +145 -18
  233. sglang/srt/managers/scheduler.py +653 -273
  234. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  235. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  236. sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
  237. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  238. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +579 -674
  241. sglang/srt/managers/tp_worker.py +96 -26
  242. sglang/srt/managers/utils.py +1 -45
  243. sglang/srt/mem_cache/allocator.py +21 -22
  244. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  245. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  246. sglang/srt/mem_cache/chunk_cache.py +9 -2
  247. sglang/srt/mem_cache/evict_policy.py +23 -0
  248. sglang/srt/mem_cache/hicache_storage.py +43 -24
  249. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  250. sglang/srt/mem_cache/memory_pool.py +651 -80
  251. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  252. sglang/srt/mem_cache/radix_cache.py +227 -73
  253. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  254. sglang/srt/mem_cache/storage/__init__.py +10 -0
  255. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  257. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  258. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  259. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  260. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  261. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  262. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  263. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  264. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  265. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  266. sglang/srt/mem_cache/swa_radix_cache.py +93 -48
  267. sglang/srt/metrics/collector.py +511 -132
  268. sglang/srt/metrics/func_timer.py +2 -7
  269. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  270. sglang/srt/metrics/utils.py +8 -1
  271. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  272. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  273. sglang/srt/model_executor/forward_batch_info.py +74 -46
  274. sglang/srt/model_executor/model_runner.py +455 -176
  275. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  276. sglang/srt/model_loader/__init__.py +10 -4
  277. sglang/srt/model_loader/loader.py +319 -10
  278. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  279. sglang/srt/model_loader/weight_utils.py +161 -3
  280. sglang/srt/models/apertus.py +686 -0
  281. sglang/srt/models/bailing_moe.py +820 -217
  282. sglang/srt/models/bailing_moe_nextn.py +168 -0
  283. sglang/srt/models/deepseek_nextn.py +6 -1
  284. sglang/srt/models/deepseek_v2.py +607 -130
  285. sglang/srt/models/dots_ocr.py +173 -0
  286. sglang/srt/models/dots_vlm.py +174 -0
  287. sglang/srt/models/dots_vlm_vit.py +337 -0
  288. sglang/srt/models/ernie4.py +1 -1
  289. sglang/srt/models/falcon_h1.py +578 -0
  290. sglang/srt/models/gemma3_causal.py +0 -2
  291. sglang/srt/models/gemma3_mm.py +17 -1
  292. sglang/srt/models/gemma3n_mm.py +2 -2
  293. sglang/srt/models/glm4_moe.py +4 -4
  294. sglang/srt/models/glm4_moe_nextn.py +2 -2
  295. sglang/srt/models/glm4v.py +5 -3
  296. sglang/srt/models/glm4v_moe.py +4 -1
  297. sglang/srt/models/gpt_oss.py +8 -31
  298. sglang/srt/models/grok.py +5 -13
  299. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  300. sglang/srt/models/llama.py +4 -0
  301. sglang/srt/models/llama4.py +9 -0
  302. sglang/srt/models/llama_eagle3.py +13 -0
  303. sglang/srt/models/longcat_flash.py +3 -3
  304. sglang/srt/models/longcat_flash_nextn.py +1 -1
  305. sglang/srt/models/mixtral.py +1 -3
  306. sglang/srt/models/mllama4.py +50 -4
  307. sglang/srt/models/nemotron_h.py +514 -0
  308. sglang/srt/models/opt.py +637 -0
  309. sglang/srt/models/qwen2_5_vl.py +29 -5
  310. sglang/srt/models/qwen2_audio.py +1 -1
  311. sglang/srt/models/qwen2_moe.py +120 -13
  312. sglang/srt/models/qwen2_vl.py +1 -1
  313. sglang/srt/models/qwen3.py +18 -3
  314. sglang/srt/models/qwen3_moe.py +32 -4
  315. sglang/srt/models/qwen3_next.py +1069 -0
  316. sglang/srt/models/qwen3_next_mtp.py +112 -0
  317. sglang/srt/models/qwen3_vl.py +787 -0
  318. sglang/srt/models/qwen3_vl_moe.py +471 -0
  319. sglang/srt/models/registry.py +15 -3
  320. sglang/srt/models/sarashina2_vision.py +269 -0
  321. sglang/srt/models/solar.py +505 -0
  322. sglang/srt/models/starcoder2.py +357 -0
  323. sglang/srt/models/step3_vl.py +1 -1
  324. sglang/srt/models/torch_native_llama.py +9 -2
  325. sglang/srt/models/utils.py +55 -0
  326. sglang/srt/multimodal/processors/base_processor.py +15 -7
  327. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  328. sglang/srt/multimodal/processors/glm4v.py +9 -9
  329. sglang/srt/multimodal/processors/internvl.py +153 -129
  330. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  331. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  332. sglang/srt/offloader.py +27 -3
  333. sglang/srt/parser/jinja_template_utils.py +6 -0
  334. sglang/srt/sampling/sampling_batch_info.py +49 -26
  335. sglang/srt/sampling/sampling_params.py +7 -0
  336. sglang/srt/server_args.py +1051 -285
  337. sglang/srt/server_args_config_parser.py +146 -0
  338. sglang/srt/single_batch_overlap.py +151 -0
  339. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  340. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  341. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  342. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  343. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  344. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  345. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  346. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  347. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  348. sglang/srt/speculative/eagle_worker.py +98 -29
  349. sglang/srt/speculative/ngram_info.py +428 -0
  350. sglang/srt/speculative/ngram_worker.py +246 -0
  351. sglang/srt/speculative/spec_info.py +52 -0
  352. sglang/srt/speculative/spec_utils.py +605 -0
  353. sglang/srt/speculative/standalone_worker.py +109 -0
  354. sglang/srt/torch_memory_saver_adapter.py +5 -7
  355. sglang/srt/tracing/trace.py +578 -0
  356. sglang/srt/two_batch_overlap.py +9 -5
  357. sglang/srt/utils/__init__.py +2 -0
  358. sglang/srt/{utils.py → utils/common.py} +451 -77
  359. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
  360. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  361. sglang/srt/utils/rpd_utils.py +452 -0
  362. sglang/srt/utils/slow_rank_detector.py +71 -0
  363. sglang/srt/warmup.py +8 -4
  364. sglang/srt/weight_sync/utils.py +2 -2
  365. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  366. sglang/test/get_logits_ut.py +57 -0
  367. sglang/test/longbench_v2/__init__.py +1 -0
  368. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  369. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  370. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  371. sglang/test/run_eval.py +119 -11
  372. sglang/test/runners.py +5 -1
  373. sglang/test/simple_eval_common.py +5 -2
  374. sglang/test/simple_eval_longbench_v2.py +332 -0
  375. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  376. sglang/test/test_block_fp8.py +2 -2
  377. sglang/test/test_cutlass_moe.py +24 -6
  378. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  379. sglang/test/test_deterministic.py +313 -0
  380. sglang/test/test_deterministic_utils.py +81 -0
  381. sglang/test/test_disaggregation_utils.py +140 -0
  382. sglang/test/test_fp4_moe.py +370 -1
  383. sglang/test/test_programs.py +1 -1
  384. sglang/test/test_utils.py +407 -8
  385. sglang/utils.py +21 -1
  386. sglang/version.py +1 -1
  387. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
  388. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
  389. sglang/srt/disaggregation/launch_lb.py +0 -118
  390. sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
  391. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  392. sglang/test/test_block_fp8_ep.py +0 -358
  393. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  394. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  395. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  396. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -7,6 +7,7 @@ FlashInfer is faster and Triton is easier to customize.
7
7
  Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
8
8
  """
9
9
 
10
+ import logging
10
11
  import os
11
12
  from dataclasses import dataclass
12
13
  from enum import Enum, auto
@@ -16,11 +17,11 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Union
16
17
  import torch
17
18
 
18
19
  if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
19
- import logging
20
-
21
20
  torch._logging.set_logs(dynamo=logging.ERROR)
22
21
  torch._dynamo.config.suppress_errors = True
23
22
 
23
+ logger = logging.getLogger(__name__)
24
+
24
25
  from sglang.global_config import global_config
25
26
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
26
27
  from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
@@ -28,8 +29,10 @@ from sglang.srt.layers.dp_attention import get_attention_tp_size
28
29
  from sglang.srt.layers.radix_attention import AttentionType
29
30
  from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
30
31
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
31
- from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
32
+ from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
33
+ from sglang.srt.speculative.spec_info import SpecInput
32
34
  from sglang.srt.utils import (
35
+ get_int_env_var,
33
36
  is_flashinfer_available,
34
37
  is_sm100_supported,
35
38
  next_power_of_2,
@@ -39,11 +42,13 @@ if TYPE_CHECKING:
39
42
  from sglang.srt.layers.radix_attention import RadixAttention
40
43
  from sglang.srt.model_executor.model_runner import ModelRunner
41
44
 
45
+
42
46
  if is_flashinfer_available():
43
47
  from flashinfer import (
44
48
  BatchDecodeWithPagedKVCacheWrapper,
45
49
  BatchPrefillWithPagedKVCacheWrapper,
46
50
  BatchPrefillWithRaggedKVCacheWrapper,
51
+ fast_decode_plan,
47
52
  )
48
53
  from flashinfer.cascade import merge_state
49
54
  from flashinfer.decode import _get_range_buf, get_seq_lens
@@ -54,6 +59,36 @@ class WrapperDispatch(Enum):
54
59
  CROSS_ATTENTION = auto()
55
60
 
56
61
 
62
+ @dataclass
63
+ class MultiItemScoringParams:
64
+ """Parameters for multi-item scoring in attention computation.
65
+
66
+ Used when processing sequences with multiple items separated by delimiters,
67
+ where each item needs specific attention patterns that respect item boundaries.
68
+
69
+ Attributes:
70
+ prefix_len_ptr: A uint32 1D tensor indicating the prefix length of each prompt.
71
+ The tensor size is equal to the batch size.
72
+ token_pos_in_items_ptr: A uint16 1D tensor indicating the token position of each item
73
+ starting from 0 (delimiter) for each item. For batch size > 1,
74
+ sequences are concatenated with zero padding to ensure same length.
75
+ token_pos_in_items_len: Zero padding length for token_pos_in_items_ptr to handle
76
+ batch_size > 1 case. Defines the padded length for each sequence.
77
+ max_item_len_ptr: A uint16 tensor containing the max token length of all items
78
+ for each prompt in the batch.
79
+
80
+ """
81
+
82
+ prefix_len_ptr: Optional[torch.Tensor] = None
83
+ token_pos_in_items_ptr: Optional[torch.Tensor] = None
84
+ token_pos_in_items_len: int = 0
85
+ max_item_len_ptr: Optional[torch.Tensor] = None
86
+
87
+ def is_enabled(self) -> bool:
88
+ """Check if multi-item scoring is enabled."""
89
+ return self.prefix_len_ptr is not None
90
+
91
+
57
92
  @dataclass
58
93
  class DecodeMetadata:
59
94
  decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper]
@@ -64,6 +99,7 @@ class PrefillMetadata:
64
99
  prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper]
65
100
  use_ragged: bool
66
101
  extend_no_prefix: bool
102
+ multi_item_params: Optional[MultiItemScoringParams] = None
67
103
 
68
104
 
69
105
  # Reuse this workspace buffer across all flashinfer wrappers
@@ -86,6 +122,11 @@ class FlashInferAttnBackend(AttentionBackend):
86
122
  ):
87
123
  super().__init__()
88
124
 
125
+ # Store multi-item scoring delimiter for efficient access
126
+ self.multi_item_scoring_delimiter = (
127
+ model_runner.server_args.multi_item_scoring_delimiter
128
+ )
129
+
89
130
  # Parse constants
90
131
  self.decode_use_tensor_cores = should_use_tensor_core(
91
132
  kv_cache_dtype=model_runner.kv_cache_dtype,
@@ -122,12 +163,33 @@ class FlashInferAttnBackend(AttentionBackend):
122
163
  ):
123
164
  global_config.flashinfer_workspace_size = 512 * 1024 * 1024
124
165
 
166
+ # When deterministic inference is enabled, tensor cores should be used for decode
167
+ # Also set split tile sizes for prefill and decode from environment variables, and disable kv split for cuda graph
168
+ # More information can be found here: https://github.com/flashinfer-ai/flashinfer/pull/1675
169
+ self.enable_deterministic = (
170
+ model_runner.server_args.enable_deterministic_inference
171
+ )
172
+ self.prefill_split_tile_size = None
173
+ self.decode_split_tile_size = None
174
+ self.disable_cuda_graph_kv_split = False
175
+ if self.enable_deterministic:
176
+ self.decode_use_tensor_cores = True
177
+ self.prefill_split_tile_size = get_int_env_var(
178
+ "SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE", 4096
179
+ )
180
+ self.decode_split_tile_size = get_int_env_var(
181
+ "SGLANG_FLASHINFER_DECODE_SPLIT_TILE_SIZE", 2048
182
+ )
183
+ self.disable_cuda_graph_kv_split = True
184
+ global_config.flashinfer_workspace_size = 2048 * 1024 * 1024
185
+
125
186
  # Allocate buffers
126
187
  global global_workspace_buffer
127
188
  if global_workspace_buffer is None:
128
189
  # different from flashinfer zero_init_global_workspace_buffer
190
+ global_workspace_size = global_config.flashinfer_workspace_size
129
191
  global_workspace_buffer = torch.empty(
130
- global_config.flashinfer_workspace_size,
192
+ global_workspace_size,
131
193
  dtype=torch.uint8,
132
194
  device=model_runner.device,
133
195
  )
@@ -204,10 +266,133 @@ class FlashInferAttnBackend(AttentionBackend):
204
266
 
205
267
  # Other metadata
206
268
  self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
269
+
207
270
  self.decode_cuda_graph_metadata = {}
208
271
  self.prefill_cuda_graph_metadata = {} # For verify
209
272
  self.draft_extend_cuda_graph_metadata = {} # For draft extend
210
273
 
274
+ def _process_multi_item_scoring(
275
+ self, forward_batch: ForwardBatch
276
+ ) -> MultiItemScoringParams:
277
+ """Process multi-item scoring tensors for FlashInfer attention.
278
+
279
+ This method handles sequences containing multiple "items" separated by delimiter tokens,
280
+ where each item needs specific attention patterns that respect item boundaries.
281
+
282
+ The method produces four key tensors for FlashInfer:
283
+ - prefix_len_ptr: uint32 tensor with prefix length for each prompt in batch
284
+ - token_pos_in_items_ptr: uint16 tensor with token positions starting from 0 at delimiters
285
+ - token_pos_in_items_len: padding length for batch processing
286
+ - max_item_len_ptr: uint16 tensor with max item length for each prompt
287
+
288
+ Args:
289
+ forward_batch: The forward batch containing input sequences and delimiter info
290
+
291
+ Returns:
292
+ MultiItemScoringParams: The processed multi-item scoring parameters
293
+
294
+ Examples:
295
+ Following FlashInfer definition: for 3 items of length 3, 2, 4 respectively:
296
+ token_pos_in_items_ptr = [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0]
297
+
298
+ Case 1: Single sequence
299
+ Text: "What is the capital of France? <delim> London <delim> Paris <delim> Berlin <delim>"
300
+ Tokens: [What, is, the, capital, of, France, ?, <delim>, London, <delim>, Paris, <delim>, Berlin, <delim>]
301
+ Indices: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
302
+ - prefix_len_ptr: [7] (query length before first delimiter)
303
+ - token_pos_in_items_ptr: [0, 1, 0, 1, 0, 1, 0] (delim=0, London=1, delim=0, Paris=1, delim=0, Berlin=1, delim=0)
304
+ - token_pos_in_items_len: 7 (actual length)
305
+ - max_item_len_ptr: [1] (max item length is 1 token - all options are single tokens)
306
+
307
+ Case 2: Batch processing (batch_size=2)
308
+ Sequence 1: 2 items of length 2, 1 → [0, 1, 2, 0, 1, 0] (6 elements)
309
+ Sequence 2: 3 items of length 1, 3, 2 → [0, 1, 0, 1, 2, 3, 0, 1, 2, 0] (10 elements)
310
+ After padding both to length 10:
311
+ - token_pos_in_items_ptr: [0, 1, 2, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 2, 3, 0, 1, 2, 0]
312
+ - token_pos_in_items_len: 10 (padded length for batch processing)
313
+ - max_item_len_ptr: [2, 3] (max lengths per sequence)
314
+ """
315
+
316
+ delimiter = self.multi_item_scoring_delimiter
317
+ if delimiter is None or forward_batch.forward_mode == ForwardMode.DECODE:
318
+ return MultiItemScoringParams()
319
+
320
+ delimiter_mask = forward_batch.input_ids == delimiter
321
+ prefix_cache_lens = getattr(forward_batch, "extend_prefix_lens", None)
322
+ extend_seq_lens = getattr(forward_batch, "extend_seq_lens", None)
323
+ prefix_len_ptr, token_pos_in_items_ptr = [], []
324
+ token_pos_in_items_len = 0
325
+
326
+ # If no extend_seq_lens, treat whole batch as one sequence
327
+ if extend_seq_lens is None or len(extend_seq_lens) <= 1:
328
+ extend_seq_lens = [forward_batch.input_ids.size(0)]
329
+
330
+ seq_start = 0
331
+ for i, seq_len in enumerate(extend_seq_lens):
332
+ seq_end = seq_start + seq_len
333
+ mask = delimiter_mask[seq_start:seq_end]
334
+ pos = forward_batch.positions[seq_start:seq_end]
335
+ delimiter_indices = torch.nonzero(mask, as_tuple=True)[0]
336
+
337
+ if len(delimiter_indices) > 0:
338
+ first_delim = delimiter_indices[0]
339
+ # Prefix length: store as scalar
340
+ prefix_len = first_delim + (
341
+ prefix_cache_lens[i] if prefix_cache_lens is not None else 0
342
+ )
343
+ prefix_len_ptr.append(
344
+ prefix_len.item() if torch.is_tensor(prefix_len) else prefix_len
345
+ )
346
+
347
+ # Compute relative positions within items after delimiters
348
+ diff = pos[first_delim:] - torch.cummax(mask[first_delim:], 0)[1]
349
+ token_pos = (diff - pos[first_delim]).to(torch.uint16)
350
+ token_pos_in_items_ptr.append(token_pos)
351
+
352
+ # Update forward_batch positions in-place
353
+ pos[first_delim:] = diff - 1
354
+ forward_batch.positions[seq_start:seq_end] = pos
355
+
356
+ seq_start = seq_end
357
+
358
+ # Pad token_pos_in_items_ptr for batch processing
359
+ if token_pos_in_items_ptr:
360
+ token_pos_in_items_len = max(t.numel() for t in token_pos_in_items_ptr)
361
+ device = forward_batch.input_ids.device
362
+ token_pos_in_items_ptr = [
363
+ torch.cat(
364
+ [
365
+ t,
366
+ torch.zeros(
367
+ token_pos_in_items_len - t.numel(),
368
+ dtype=torch.uint16,
369
+ device=device,
370
+ ),
371
+ ]
372
+ )
373
+ for t in token_pos_in_items_ptr
374
+ ]
375
+
376
+ if not prefix_len_ptr or not token_pos_in_items_ptr:
377
+ return MultiItemScoringParams()
378
+
379
+ # Build final params
380
+ device = forward_batch.input_ids.device
381
+ return MultiItemScoringParams(
382
+ prefix_len_ptr=torch.tensor(
383
+ prefix_len_ptr, dtype=torch.uint32, device=device
384
+ ),
385
+ token_pos_in_items_ptr=torch.cat(token_pos_in_items_ptr, dim=0),
386
+ token_pos_in_items_len=token_pos_in_items_len & 0xFFFFFFFF,
387
+ max_item_len_ptr=torch.stack(
388
+ [
389
+ t.to(torch.int32).max().to(torch.uint16)
390
+ for t in token_pos_in_items_ptr
391
+ ],
392
+ dim=0,
393
+ ),
394
+ )
395
+
211
396
  def init_forward_metadata(self, forward_batch: ForwardBatch):
212
397
  if forward_batch.forward_mode.is_decode_or_idle():
213
398
  self.indices_updater_decode.update(
@@ -218,6 +403,8 @@ class FlashInferAttnBackend(AttentionBackend):
218
403
  decode_wrappers=self.decode_wrappers,
219
404
  encoder_lens=forward_batch.encoder_lens,
220
405
  spec_info=forward_batch.spec_info,
406
+ fixed_split_size=self.decode_split_tile_size,
407
+ disable_split_kv=False,
221
408
  )
222
409
  self.forward_metadata = DecodeMetadata(self.decode_wrappers)
223
410
  elif forward_batch.forward_mode.is_draft_extend():
@@ -253,13 +440,26 @@ class FlashInferAttnBackend(AttentionBackend):
253
440
  else:
254
441
  prefix_lens = forward_batch.extend_prefix_lens
255
442
 
256
- if self.is_multimodal:
443
+ # Disable ragged wrapper and ensure prefix handling for multimodal and multi-item scoring
444
+ if self.is_multimodal or self.multi_item_scoring_delimiter is not None:
445
+ # use_ragged = False: Multi-item scoring requires the paged wrapper because:
446
+ # 1. Ragged wrapper doesn't support the specialized multi-item parameters
447
+ # (prefix_len_ptr, token_pos_in_items_ptr, etc.)
448
+ # 2. Paged wrapper provides better control over attention masking needed
449
+ # for respecting item boundaries in multi-item sequences
450
+ # 3. Custom masking logic conflicts with ragged wrapper's assumptions
257
451
  use_ragged = False
258
452
  extend_no_prefix = False
259
453
  else:
260
- use_ragged = True
454
+ use_ragged = not self.enable_deterministic
261
455
  extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
262
456
 
457
+ # Process multi-item scoring in attention backend instead of ForwardBatch
458
+ multi_item_params = MultiItemScoringParams()
459
+ if self.multi_item_scoring_delimiter is not None:
460
+ # Use new backend-specific implementation
461
+ multi_item_params = self._process_multi_item_scoring(forward_batch)
462
+
263
463
  self.indices_updater_prefill.update(
264
464
  forward_batch.req_pool_indices,
265
465
  forward_batch.seq_lens,
@@ -270,9 +470,14 @@ class FlashInferAttnBackend(AttentionBackend):
270
470
  use_ragged=use_ragged,
271
471
  encoder_lens=forward_batch.encoder_lens,
272
472
  spec_info=None,
473
+ fixed_split_size=self.prefill_split_tile_size,
474
+ multi_item_params=multi_item_params,
273
475
  )
274
476
  self.forward_metadata = PrefillMetadata(
275
- self.prefill_wrappers_paged, use_ragged, extend_no_prefix
477
+ self.prefill_wrappers_paged,
478
+ use_ragged,
479
+ extend_no_prefix,
480
+ multi_item_params,
276
481
  )
277
482
 
278
483
  def init_cuda_graph_state(
@@ -317,7 +522,7 @@ class FlashInferAttnBackend(AttentionBackend):
317
522
  seq_lens: torch.Tensor,
318
523
  encoder_lens: Optional[torch.Tensor],
319
524
  forward_mode: ForwardMode,
320
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
525
+ spec_info: Optional[SpecInput],
321
526
  ):
322
527
  if forward_mode.is_decode_or_idle():
323
528
  decode_wrappers = []
@@ -344,6 +549,8 @@ class FlashInferAttnBackend(AttentionBackend):
344
549
  decode_wrappers=decode_wrappers,
345
550
  encoder_lens=encoder_lens,
346
551
  spec_info=spec_info,
552
+ fixed_split_size=None,
553
+ disable_split_kv=self.disable_cuda_graph_kv_split,
347
554
  )
348
555
  self.decode_cuda_graph_metadata[bs] = decode_wrappers
349
556
  self.forward_metadata = DecodeMetadata(decode_wrappers)
@@ -422,7 +629,7 @@ class FlashInferAttnBackend(AttentionBackend):
422
629
  seq_lens_sum: int,
423
630
  encoder_lens: Optional[torch.Tensor],
424
631
  forward_mode: ForwardMode,
425
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
632
+ spec_info: Optional[SpecInput],
426
633
  seq_lens_cpu: Optional[torch.Tensor],
427
634
  ):
428
635
  if forward_mode.is_decode_or_idle():
@@ -434,6 +641,8 @@ class FlashInferAttnBackend(AttentionBackend):
434
641
  decode_wrappers=self.decode_cuda_graph_metadata[bs],
435
642
  encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
436
643
  spec_info=spec_info,
644
+ fixed_split_size=None,
645
+ disable_split_kv=self.disable_cuda_graph_kv_split,
437
646
  )
438
647
  elif forward_mode.is_target_verify():
439
648
  self.indices_updater_prefill.update(
@@ -499,10 +708,24 @@ class FlashInferAttnBackend(AttentionBackend):
499
708
  forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
500
709
  causal=not layer.is_cross_attention,
501
710
  sm_scale=layer.scaling,
502
- window_left=layer.sliding_window_size,
711
+ # Disable sliding window attention for multi-item scoring:
712
+ # - Sliding window could cut across item boundaries, breaking semantic coherence
713
+ # - Multi-item sequences need full attention to properly handle delimiter tokens
714
+ # - Specialized multi-item parameters (prefix_len_ptr, token_pos_in_items_ptr)
715
+ # provide more precise attention control than simple sliding windows
716
+ # - Item-aware masking takes precedence over window-based masking
717
+ window_left=(
718
+ layer.sliding_window_size
719
+ if not (
720
+ self.forward_metadata.multi_item_params
721
+ and self.forward_metadata.multi_item_params.is_enabled()
722
+ )
723
+ else -1
724
+ ),
503
725
  logits_soft_cap=logits_soft_cap,
504
- k_scale=layer.k_scale,
505
- v_scale=layer.v_scale,
726
+ # Must use _float to avoid device-to-host copy that breaks cuda graph capture.
727
+ k_scale=layer.k_scale_float,
728
+ v_scale=layer.v_scale_float,
506
729
  )
507
730
  else:
508
731
  causal = True
@@ -580,8 +803,9 @@ class FlashInferAttnBackend(AttentionBackend):
580
803
  forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
581
804
  sm_scale=layer.scaling,
582
805
  logits_soft_cap=layer.logit_cap,
583
- k_scale=layer.k_scale,
584
- v_scale=layer.v_scale,
806
+ # Must use _float to avoid device-to-host copy that breaks cuda graph capture.
807
+ k_scale=layer.k_scale_float,
808
+ v_scale=layer.v_scale_float,
585
809
  )
586
810
 
587
811
  return o.view(-1, layer.tp_q_head_num * layer.head_dim)
@@ -636,7 +860,9 @@ class FlashInferIndicesUpdaterDecode:
636
860
  seq_lens_sum: int,
637
861
  decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
638
862
  encoder_lens: Optional[torch.Tensor],
639
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
863
+ spec_info: Optional[SpecInput],
864
+ fixed_split_size: Optional[int] = None,
865
+ disable_split_kv: Optional[bool] = None,
640
866
  ):
641
867
  # Keep the signature for type checking. It will be assigned during runtime.
642
868
  raise NotImplementedError()
@@ -649,7 +875,9 @@ class FlashInferIndicesUpdaterDecode:
649
875
  seq_lens_sum: int,
650
876
  decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
651
877
  encoder_lens: Optional[torch.Tensor],
652
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
878
+ spec_info: Optional[SpecInput],
879
+ fixed_split_size: Optional[int] = None,
880
+ disable_split_kv: Optional[bool] = None,
653
881
  ):
654
882
  decode_wrappers = decode_wrappers or self.decode_wrappers
655
883
  self.call_begin_forward(
@@ -661,6 +889,8 @@ class FlashInferIndicesUpdaterDecode:
661
889
  None,
662
890
  spec_info,
663
891
  seq_lens_cpu,
892
+ fixed_split_size=fixed_split_size,
893
+ disable_split_kv=disable_split_kv,
664
894
  )
665
895
 
666
896
  def update_sliding_window(
@@ -671,7 +901,9 @@ class FlashInferIndicesUpdaterDecode:
671
901
  seq_lens_sum: int,
672
902
  decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
673
903
  encoder_lens: Optional[torch.Tensor],
674
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
904
+ spec_info: Optional[SpecInput],
905
+ fixed_split_size: Optional[int] = None,
906
+ disable_split_kv: Optional[bool] = None,
675
907
  ):
676
908
  assert self.sliding_window_size is not None
677
909
  for wrapper_id in range(2):
@@ -719,7 +951,9 @@ class FlashInferIndicesUpdaterDecode:
719
951
  seq_lens_sum: int,
720
952
  decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
721
953
  encoder_lens: Optional[torch.Tensor],
722
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
954
+ spec_info: Optional[SpecInput],
955
+ fixed_split_size: Optional[int] = None,
956
+ disable_split_kv: Optional[bool] = None,
723
957
  ):
724
958
  for wrapper_id in range(2):
725
959
  if wrapper_id == 0:
@@ -751,9 +985,11 @@ class FlashInferIndicesUpdaterDecode:
751
985
  paged_kernel_lens_sum: int,
752
986
  kv_indptr: torch.Tensor,
753
987
  kv_start_idx: torch.Tensor,
754
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
988
+ spec_info: Optional[SpecInput],
755
989
  seq_lens_cpu: Optional[torch.Tensor],
756
990
  use_sliding_window_kv_pool: bool = False,
991
+ fixed_split_size: Optional[int] = None,
992
+ disable_split_kv: Optional[bool] = None,
757
993
  ):
758
994
  if spec_info is None:
759
995
  bs = len(req_pool_indices)
@@ -797,19 +1033,51 @@ class FlashInferIndicesUpdaterDecode:
797
1033
  global_override_indptr_cpu[0] = 0
798
1034
  global_override_indptr_cpu[1 : bs + 1] = torch.cumsum(seq_lens_cpu, dim=0)
799
1035
 
800
- wrapper.begin_forward(
801
- kv_indptr,
802
- kv_indices,
803
- self.kv_last_page_len[:bs],
804
- self.num_qo_heads,
805
- self.num_kv_heads,
806
- self.head_dim,
807
- 1,
808
- data_type=self.data_type,
809
- q_data_type=self.q_data_type,
810
- non_blocking=True,
1036
+ # Check if this specific wrapper's begin_forward has been replaced with fast_decode_plan
1037
+ # by checking if it's a partial function with fast_decode_plan as the func
1038
+ wrapper_uses_fast_decode_plan = (
1039
+ hasattr(wrapper.begin_forward, "func")
1040
+ and wrapper.begin_forward.func == fast_decode_plan
811
1041
  )
812
1042
 
1043
+ if wrapper_uses_fast_decode_plan:
1044
+ # When begin_forward is replaced with fast_decode_plan, pass global_override_indptr_cpu
1045
+ wrapper.begin_forward(
1046
+ kv_indptr,
1047
+ kv_indices,
1048
+ self.kv_last_page_len[:bs],
1049
+ self.num_qo_heads,
1050
+ self.num_kv_heads,
1051
+ self.head_dim,
1052
+ 1,
1053
+ data_type=self.data_type,
1054
+ q_data_type=self.q_data_type,
1055
+ non_blocking=True,
1056
+ fixed_split_size=fixed_split_size,
1057
+ disable_split_kv=(
1058
+ disable_split_kv if disable_split_kv is not None else False
1059
+ ),
1060
+ global_override_indptr_cpu=global_override_indptr_cpu,
1061
+ )
1062
+ else:
1063
+ # When using original begin_forward, don't pass global_override_indptr_cpu
1064
+ wrapper.begin_forward(
1065
+ kv_indptr,
1066
+ kv_indices,
1067
+ self.kv_last_page_len[:bs],
1068
+ self.num_qo_heads,
1069
+ self.num_kv_heads,
1070
+ self.head_dim,
1071
+ 1,
1072
+ data_type=self.data_type,
1073
+ q_data_type=self.q_data_type,
1074
+ non_blocking=True,
1075
+ fixed_split_size=fixed_split_size,
1076
+ disable_split_kv=(
1077
+ disable_split_kv if disable_split_kv is not None else False
1078
+ ),
1079
+ )
1080
+
813
1081
  if locally_override:
814
1082
  global_override_indptr_cpu = None
815
1083
 
@@ -856,7 +1124,8 @@ class FlashInferIndicesUpdaterPrefill:
856
1124
  prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
857
1125
  use_ragged: bool,
858
1126
  encoder_lens: Optional[torch.Tensor],
859
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
1127
+ spec_info: Optional[SpecInput],
1128
+ fixed_split_size: Optional[int] = None,
860
1129
  ):
861
1130
  # Keep the signature for type checking. It will be assigned during runtime.
862
1131
  raise NotImplementedError()
@@ -871,7 +1140,9 @@ class FlashInferIndicesUpdaterPrefill:
871
1140
  prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
872
1141
  use_ragged: bool,
873
1142
  encoder_lens: Optional[torch.Tensor],
874
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
1143
+ spec_info: Optional[SpecInput],
1144
+ fixed_split_size: Optional[int] = None,
1145
+ multi_item_params: Optional[MultiItemScoringParams] = None,
875
1146
  ):
876
1147
  if use_ragged:
877
1148
  # TODO: remove this device sync, we can use forward_batch.extend_prefix_lens_cpu
@@ -895,6 +1166,8 @@ class FlashInferIndicesUpdaterPrefill:
895
1166
  self.qo_indptr[0],
896
1167
  use_ragged,
897
1168
  spec_info,
1169
+ fixed_split_size=fixed_split_size,
1170
+ multi_item_params=multi_item_params,
898
1171
  )
899
1172
 
900
1173
  def update_sliding_window(
@@ -907,7 +1180,9 @@ class FlashInferIndicesUpdaterPrefill:
907
1180
  prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
908
1181
  use_ragged: bool,
909
1182
  encoder_lens: Optional[torch.Tensor],
910
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
1183
+ spec_info: Optional[SpecInput],
1184
+ fixed_split_size: Optional[int] = None,
1185
+ multi_item_params: Optional[MultiItemScoringParams] = None,
911
1186
  ):
912
1187
  for wrapper_id in range(2):
913
1188
  if wrapper_id == 0:
@@ -941,6 +1216,7 @@ class FlashInferIndicesUpdaterPrefill:
941
1216
  use_ragged,
942
1217
  spec_info,
943
1218
  use_sliding_window_kv_pool=use_sliding_window_kv_pool,
1219
+ multi_item_params=multi_item_params,
944
1220
  )
945
1221
 
946
1222
  def update_cross_attention(
@@ -953,7 +1229,9 @@ class FlashInferIndicesUpdaterPrefill:
953
1229
  prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
954
1230
  use_ragged: bool,
955
1231
  encoder_lens: Optional[torch.Tensor],
956
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
1232
+ spec_info: Optional[SpecInput],
1233
+ fixed_split_size: Optional[int] = None,
1234
+ multi_item_params: Optional[MultiItemScoringParams] = None,
957
1235
  ):
958
1236
  for wrapper_id in range(2):
959
1237
  if wrapper_id == 0:
@@ -980,6 +1258,7 @@ class FlashInferIndicesUpdaterPrefill:
980
1258
  self.qo_indptr[wrapper_id],
981
1259
  use_ragged,
982
1260
  spec_info,
1261
+ multi_item_params=multi_item_params,
983
1262
  )
984
1263
 
985
1264
  def call_begin_forward(
@@ -995,8 +1274,10 @@ class FlashInferIndicesUpdaterPrefill:
995
1274
  kv_indptr: torch.Tensor,
996
1275
  qo_indptr: torch.Tensor,
997
1276
  use_ragged: bool,
998
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
1277
+ spec_info: Optional[SpecInput],
999
1278
  use_sliding_window_kv_pool: bool = False,
1279
+ fixed_split_size: Optional[int] = None,
1280
+ multi_item_params: Optional[MultiItemScoringParams] = None,
1000
1281
  ):
1001
1282
  bs = len(seq_lens)
1002
1283
  if spec_info is None:
@@ -1022,9 +1303,7 @@ class FlashInferIndicesUpdaterPrefill:
1022
1303
  qo_indptr = qo_indptr[: bs + 1]
1023
1304
  custom_mask = None
1024
1305
  else:
1025
- assert isinstance(spec_info, EagleDraftInput) or isinstance(
1026
- spec_info, EagleVerifyInput
1027
- )
1306
+ assert isinstance(spec_info, SpecInput)
1028
1307
  kv_indices, kv_indptr, qo_indptr, custom_mask = (
1029
1308
  spec_info.generate_attn_arg_prefill(
1030
1309
  req_pool_indices,
@@ -1054,6 +1333,22 @@ class FlashInferIndicesUpdaterPrefill:
1054
1333
  )
1055
1334
 
1056
1335
  # cached part
1336
+ # Conditionally set multi-item parameters
1337
+ if multi_item_params is not None and multi_item_params.is_enabled():
1338
+ # Multi-item scoring is active - use specialized parameters and disable generic custom_mask
1339
+ use_custom_mask = None
1340
+ prefix_len_ptr = multi_item_params.prefix_len_ptr
1341
+ token_pos_in_items_ptr = multi_item_params.token_pos_in_items_ptr
1342
+ token_pos_in_items_len = multi_item_params.token_pos_in_items_len
1343
+ max_item_len_ptr = multi_item_params.max_item_len_ptr
1344
+ else:
1345
+ # No multi-item scoring - use standard parameters
1346
+ use_custom_mask = custom_mask
1347
+ prefix_len_ptr = None
1348
+ token_pos_in_items_ptr = None
1349
+ token_pos_in_items_len = 0
1350
+ max_item_len_ptr = None
1351
+
1057
1352
  wrapper_paged.begin_forward(
1058
1353
  qo_indptr,
1059
1354
  kv_indptr,
@@ -1065,8 +1360,13 @@ class FlashInferIndicesUpdaterPrefill:
1065
1360
  1,
1066
1361
  q_data_type=self.q_data_type,
1067
1362
  kv_data_type=self.data_type,
1068
- custom_mask=custom_mask,
1363
+ custom_mask=use_custom_mask,
1069
1364
  non_blocking=True,
1365
+ fixed_split_size=fixed_split_size,
1366
+ prefix_len_ptr=prefix_len_ptr,
1367
+ token_pos_in_items_ptr=token_pos_in_items_ptr,
1368
+ token_pos_in_items_len=token_pos_in_items_len,
1369
+ max_item_len_ptr=max_item_len_ptr,
1070
1370
  )
1071
1371
 
1072
1372
 
@@ -1082,7 +1382,7 @@ class FlashInferMultiStepDraftBackend:
1082
1382
  topk: int,
1083
1383
  speculative_num_steps: int,
1084
1384
  ):
1085
- from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
1385
+ from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
1086
1386
 
1087
1387
  self.topk = topk
1088
1388
  self.speculative_num_steps = speculative_num_steps
@@ -1146,7 +1446,7 @@ class FlashInferMultiStepDraftBackend:
1146
1446
  )
1147
1447
 
1148
1448
  assert forward_batch.spec_info is not None
1149
- assert isinstance(forward_batch.spec_info, EagleDraftInput)
1449
+ assert forward_batch.spec_info.is_draft_input()
1150
1450
 
1151
1451
  # Copy the kv_indptr once to avoid multiple device-to-host copies in flashinfer's plan.
1152
1452
  indptr_cpu_whole = self.kv_indptr[:, : bs + 1].cpu()
@@ -1274,166 +1574,3 @@ def should_use_tensor_core(
1274
1574
  return gqa_group_size >= 4
1275
1575
  else:
1276
1576
  return False
1277
-
1278
-
1279
- # Use as a fast path to override the indptr in flashinfer's plan function
1280
- # This is used to remove some host-to-device copy overhead.
1281
- global_override_indptr_cpu = None
1282
-
1283
-
1284
- def fast_decode_plan(
1285
- self,
1286
- indptr: torch.Tensor,
1287
- indices: torch.Tensor,
1288
- last_page_len: torch.Tensor,
1289
- num_qo_heads: int,
1290
- num_kv_heads: int,
1291
- head_dim: int,
1292
- page_size: int,
1293
- pos_encoding_mode: str = "NONE",
1294
- window_left: int = -1,
1295
- logits_soft_cap: Optional[float] = None,
1296
- q_data_type: Optional[Union[str, torch.dtype]] = None,
1297
- kv_data_type: Optional[Union[str, torch.dtype]] = None,
1298
- data_type: Optional[Union[str, torch.dtype]] = None,
1299
- sm_scale: Optional[float] = None,
1300
- rope_scale: Optional[float] = None,
1301
- rope_theta: Optional[float] = None,
1302
- non_blocking: bool = True,
1303
- ) -> None:
1304
- """
1305
- A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend.
1306
- Modifications:
1307
- - Remove unnecessary device-to-device copy for the cuda graph buffers.
1308
- - Remove unnecessary host-to-device copy for the metadata buffers.
1309
- """
1310
- batch_size = len(last_page_len)
1311
- if logits_soft_cap is None:
1312
- logits_soft_cap = 0.0
1313
-
1314
- # Handle data types consistently
1315
- if data_type is not None:
1316
- if q_data_type is None:
1317
- q_data_type = data_type
1318
- if kv_data_type is None:
1319
- kv_data_type = data_type
1320
- elif q_data_type is None:
1321
- q_data_type = "float16"
1322
-
1323
- if kv_data_type is None:
1324
- kv_data_type = q_data_type
1325
-
1326
- if self.use_tensor_cores:
1327
- qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
1328
-
1329
- if self.is_cuda_graph_enabled:
1330
- if batch_size != self._fixed_batch_size:
1331
- raise ValueError(
1332
- "The batch size should be fixed in cudagraph mode, the runtime batch size {} "
1333
- " mismatches the batch size set during initialization {}".format(
1334
- batch_size, self._fixed_batch_size
1335
- )
1336
- )
1337
- if len(indices) > len(self._paged_kv_indices_buf):
1338
- raise ValueError(
1339
- "The size of indices should be less than or equal to the allocated buffer"
1340
- )
1341
- else:
1342
- self._paged_kv_indptr_buf = indptr
1343
- self._paged_kv_indices_buf = indices
1344
- self._paged_kv_last_page_len_buf = last_page_len
1345
- if self.use_tensor_cores:
1346
- self._qo_indptr_buf = qo_indptr_host.to(
1347
- self.device, non_blocking=non_blocking
1348
- )
1349
-
1350
- # Create empty tensors for dtype info if needed
1351
- empty_q_data = torch.empty(
1352
- 0,
1353
- dtype=(
1354
- getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type
1355
- ),
1356
- device=self.device,
1357
- )
1358
-
1359
- empty_kv_cache = torch.empty(
1360
- 0,
1361
- dtype=(
1362
- getattr(torch, kv_data_type)
1363
- if isinstance(kv_data_type, str)
1364
- else kv_data_type
1365
- ),
1366
- device=self.device,
1367
- )
1368
-
1369
- indptr_host = (
1370
- global_override_indptr_cpu
1371
- if global_override_indptr_cpu is not None
1372
- else indptr.cpu()
1373
- )
1374
-
1375
- with torch.cuda.device(self.device):
1376
-
1377
- if self.use_tensor_cores:
1378
- # ALSO convert last_page_len to CPU
1379
- if page_size == 1:
1380
- # When page size is 1, last_page_len is always 1.
1381
- # Directly construct the host tensor rather than executing a device-to-host copy.
1382
- last_page_len_host = torch.ones(
1383
- (batch_size,), dtype=torch.int32, device="cpu"
1384
- )
1385
- else:
1386
- last_page_len_host = last_page_len.cpu()
1387
-
1388
- kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size)
1389
-
1390
- try:
1391
- # Make sure we pass exactly 15 arguments for tensor core version
1392
- self._plan_info = self._cached_module.plan(
1393
- self._float_workspace_buffer,
1394
- self._int_workspace_buffer,
1395
- self._pin_memory_int_workspace_buffer,
1396
- qo_indptr_host,
1397
- indptr_host,
1398
- kv_lens_arr_host,
1399
- batch_size, # total_num_rows
1400
- batch_size,
1401
- num_qo_heads,
1402
- num_kv_heads,
1403
- page_size,
1404
- self.is_cuda_graph_enabled,
1405
- head_dim,
1406
- head_dim,
1407
- False, # causal
1408
- )
1409
- except Exception as e:
1410
- raise RuntimeError(f"Error in standard plan: {e}")
1411
- else:
1412
- try:
1413
- # Make sure we pass exactly 15 arguments for standard version
1414
- self._plan_info = self._cached_module.plan(
1415
- self._float_workspace_buffer,
1416
- self._int_workspace_buffer,
1417
- self._pin_memory_int_workspace_buffer,
1418
- indptr_host,
1419
- batch_size,
1420
- num_qo_heads,
1421
- num_kv_heads,
1422
- page_size,
1423
- self.is_cuda_graph_enabled,
1424
- window_left,
1425
- logits_soft_cap,
1426
- head_dim,
1427
- head_dim,
1428
- empty_q_data,
1429
- empty_kv_cache,
1430
- )
1431
- except Exception as e:
1432
- raise RuntimeError(f"Error in standard plan: {e}")
1433
-
1434
- self._pos_encoding_mode = pos_encoding_mode
1435
- self._window_left = window_left
1436
- self._logits_soft_cap = logits_soft_cap
1437
- self._sm_scale = sm_scale
1438
- self._rope_scale = rope_scale
1439
- self._rope_theta = rope_theta