sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (419) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +378 -160
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +10 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +136 -25
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +63 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +83 -80
  69. sglang/srt/entrypoints/grpc_server.py +430 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +195 -102
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +58 -6
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +33 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +20 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/minimax_m2.py +367 -0
  93. sglang/srt/function_call/utils.py +2 -2
  94. sglang/srt/grpc/compile_proto.py +3 -3
  95. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  96. sglang/srt/grpc/health_servicer.py +189 -0
  97. sglang/srt/grpc/scheduler_launcher.py +181 -0
  98. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  99. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  100. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  101. sglang/srt/layers/activation.py +10 -1
  102. sglang/srt/layers/attention/aiter_backend.py +3 -3
  103. sglang/srt/layers/attention/ascend_backend.py +17 -1
  104. sglang/srt/layers/attention/attention_registry.py +43 -23
  105. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  106. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  107. sglang/srt/layers/attention/fla/chunk.py +0 -1
  108. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  109. sglang/srt/layers/attention/fla/index.py +0 -2
  110. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  111. sglang/srt/layers/attention/fla/utils.py +0 -3
  112. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  113. sglang/srt/layers/attention/flashattention_backend.py +24 -10
  114. sglang/srt/layers/attention/flashinfer_backend.py +258 -22
  115. sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
  116. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  117. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  118. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  119. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  121. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  122. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  123. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  124. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  125. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  127. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  128. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  129. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  130. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  131. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  132. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  133. sglang/srt/layers/attention/nsa/utils.py +0 -1
  134. sglang/srt/layers/attention/nsa_backend.py +404 -90
  135. sglang/srt/layers/attention/triton_backend.py +208 -34
  136. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  137. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  138. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  139. sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
  140. sglang/srt/layers/attention/utils.py +89 -7
  141. sglang/srt/layers/attention/vision.py +3 -3
  142. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  143. sglang/srt/layers/communicator.py +12 -7
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  146. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  147. sglang/srt/layers/dp_attention.py +17 -0
  148. sglang/srt/layers/layernorm.py +64 -19
  149. sglang/srt/layers/linear.py +9 -1
  150. sglang/srt/layers/logits_processor.py +152 -17
  151. sglang/srt/layers/modelopt_utils.py +11 -0
  152. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  153. sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
  154. sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
  155. sglang/srt/layers/moe/ep_moe/layer.py +154 -625
  156. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  160. sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
  161. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
  162. sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
  163. sglang/srt/layers/moe/moe_runner/runner.py +6 -0
  164. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  165. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  166. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  167. sglang/srt/layers/moe/router.py +51 -15
  168. sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
  169. sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
  170. sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
  171. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  172. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  173. sglang/srt/layers/moe/topk.py +7 -6
  174. sglang/srt/layers/moe/utils.py +20 -5
  175. sglang/srt/layers/quantization/__init__.py +5 -58
  176. sglang/srt/layers/quantization/awq.py +183 -9
  177. sglang/srt/layers/quantization/awq_triton.py +29 -0
  178. sglang/srt/layers/quantization/base_config.py +27 -1
  179. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  180. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  185. sglang/srt/layers/quantization/fp8.py +152 -81
  186. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  187. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  188. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  189. sglang/srt/layers/quantization/gguf.py +566 -0
  190. sglang/srt/layers/quantization/gptq.py +0 -1
  191. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  192. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  193. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  194. sglang/srt/layers/quantization/mxfp4.py +35 -68
  195. sglang/srt/layers/quantization/petit.py +1 -1
  196. sglang/srt/layers/quantization/quark/quark.py +3 -1
  197. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  199. sglang/srt/layers/quantization/unquant.py +23 -48
  200. sglang/srt/layers/quantization/utils.py +0 -1
  201. sglang/srt/layers/quantization/w4afp8.py +87 -20
  202. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  203. sglang/srt/layers/radix_attention.py +62 -9
  204. sglang/srt/layers/rotary_embedding.py +686 -17
  205. sglang/srt/layers/sampler.py +47 -16
  206. sglang/srt/layers/sparse_pooler.py +98 -0
  207. sglang/srt/layers/utils.py +0 -1
  208. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  209. sglang/srt/lora/backend/triton_backend.py +0 -1
  210. sglang/srt/lora/eviction_policy.py +139 -0
  211. sglang/srt/lora/lora_manager.py +24 -9
  212. sglang/srt/lora/lora_registry.py +1 -1
  213. sglang/srt/lora/mem_pool.py +40 -16
  214. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  215. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  216. sglang/srt/managers/cache_controller.py +48 -17
  217. sglang/srt/managers/data_parallel_controller.py +146 -42
  218. sglang/srt/managers/detokenizer_manager.py +40 -13
  219. sglang/srt/managers/io_struct.py +69 -16
  220. sglang/srt/managers/mm_utils.py +20 -18
  221. sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
  222. sglang/srt/managers/overlap_utils.py +96 -19
  223. sglang/srt/managers/schedule_batch.py +241 -511
  224. sglang/srt/managers/schedule_policy.py +15 -2
  225. sglang/srt/managers/scheduler.py +420 -514
  226. sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
  227. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  228. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  229. sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
  230. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  231. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  232. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  233. sglang/srt/managers/tokenizer_manager.py +375 -95
  234. sglang/srt/managers/tp_worker.py +212 -161
  235. sglang/srt/managers/utils.py +78 -2
  236. sglang/srt/mem_cache/allocator.py +7 -2
  237. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  238. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  239. sglang/srt/mem_cache/chunk_cache.py +13 -2
  240. sglang/srt/mem_cache/common.py +480 -0
  241. sglang/srt/mem_cache/evict_policy.py +16 -1
  242. sglang/srt/mem_cache/hicache_storage.py +11 -2
  243. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  244. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  245. sglang/srt/mem_cache/memory_pool.py +517 -219
  246. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  247. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  248. sglang/srt/mem_cache/radix_cache.py +53 -19
  249. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  250. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  251. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  252. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  253. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  254. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  255. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  256. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  257. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  259. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  260. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  261. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  262. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  263. sglang/srt/metrics/collector.py +31 -0
  264. sglang/srt/metrics/func_timer.py +1 -1
  265. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  266. sglang/srt/model_executor/forward_batch_info.py +71 -25
  267. sglang/srt/model_executor/model_runner.py +362 -270
  268. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  269. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
  270. sglang/srt/model_loader/__init__.py +1 -1
  271. sglang/srt/model_loader/loader.py +424 -27
  272. sglang/srt/model_loader/utils.py +0 -1
  273. sglang/srt/model_loader/weight_utils.py +47 -28
  274. sglang/srt/models/apertus.py +2 -3
  275. sglang/srt/models/arcee.py +2 -2
  276. sglang/srt/models/bailing_moe.py +13 -52
  277. sglang/srt/models/bailing_moe_nextn.py +3 -4
  278. sglang/srt/models/bert.py +1 -1
  279. sglang/srt/models/deepseek_nextn.py +19 -3
  280. sglang/srt/models/deepseek_ocr.py +1516 -0
  281. sglang/srt/models/deepseek_v2.py +418 -140
  282. sglang/srt/models/dots_ocr.py +0 -2
  283. sglang/srt/models/dots_vlm.py +0 -1
  284. sglang/srt/models/dots_vlm_vit.py +1 -1
  285. sglang/srt/models/falcon_h1.py +13 -19
  286. sglang/srt/models/gemma3_mm.py +16 -0
  287. sglang/srt/models/gemma3n_mm.py +1 -2
  288. sglang/srt/models/glm4_moe.py +327 -382
  289. sglang/srt/models/glm4_moe_nextn.py +6 -16
  290. sglang/srt/models/glm4v.py +2 -1
  291. sglang/srt/models/glm4v_moe.py +32 -199
  292. sglang/srt/models/gpt_oss.py +5 -5
  293. sglang/srt/models/grok.py +10 -23
  294. sglang/srt/models/hunyuan.py +2 -7
  295. sglang/srt/models/interns1.py +0 -1
  296. sglang/srt/models/kimi_vl.py +1 -7
  297. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  298. sglang/srt/models/llama.py +2 -2
  299. sglang/srt/models/llama_eagle3.py +1 -1
  300. sglang/srt/models/longcat_flash.py +5 -22
  301. sglang/srt/models/longcat_flash_nextn.py +3 -14
  302. sglang/srt/models/mimo.py +2 -13
  303. sglang/srt/models/mimo_mtp.py +1 -2
  304. sglang/srt/models/minicpmo.py +7 -5
  305. sglang/srt/models/minimax_m2.py +922 -0
  306. sglang/srt/models/mixtral.py +1 -4
  307. sglang/srt/models/mllama.py +1 -1
  308. sglang/srt/models/mllama4.py +13 -3
  309. sglang/srt/models/nemotron_h.py +511 -0
  310. sglang/srt/models/nvila.py +355 -0
  311. sglang/srt/models/nvila_lite.py +184 -0
  312. sglang/srt/models/olmo2.py +31 -4
  313. sglang/srt/models/opt.py +5 -5
  314. sglang/srt/models/phi.py +1 -1
  315. sglang/srt/models/phi4mm.py +1 -1
  316. sglang/srt/models/phimoe.py +0 -1
  317. sglang/srt/models/pixtral.py +0 -3
  318. sglang/srt/models/points_v15_chat.py +186 -0
  319. sglang/srt/models/qwen.py +0 -1
  320. sglang/srt/models/qwen2.py +22 -1
  321. sglang/srt/models/qwen2_5_vl.py +3 -3
  322. sglang/srt/models/qwen2_audio.py +2 -15
  323. sglang/srt/models/qwen2_moe.py +15 -12
  324. sglang/srt/models/qwen2_vl.py +5 -2
  325. sglang/srt/models/qwen3.py +34 -4
  326. sglang/srt/models/qwen3_moe.py +19 -37
  327. sglang/srt/models/qwen3_next.py +7 -12
  328. sglang/srt/models/qwen3_next_mtp.py +3 -4
  329. sglang/srt/models/qwen3_omni_moe.py +661 -0
  330. sglang/srt/models/qwen3_vl.py +37 -33
  331. sglang/srt/models/qwen3_vl_moe.py +57 -185
  332. sglang/srt/models/roberta.py +55 -3
  333. sglang/srt/models/sarashina2_vision.py +0 -1
  334. sglang/srt/models/step3_vl.py +3 -5
  335. sglang/srt/models/utils.py +11 -1
  336. sglang/srt/multimodal/processors/base_processor.py +7 -2
  337. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  338. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  339. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  340. sglang/srt/multimodal/processors/glm4v.py +2 -6
  341. sglang/srt/multimodal/processors/internvl.py +0 -2
  342. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  343. sglang/srt/multimodal/processors/mllama4.py +0 -8
  344. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  345. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  346. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  347. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  348. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  349. sglang/srt/parser/conversation.py +41 -0
  350. sglang/srt/parser/reasoning_parser.py +28 -2
  351. sglang/srt/sampling/custom_logit_processor.py +77 -2
  352. sglang/srt/sampling/sampling_batch_info.py +17 -22
  353. sglang/srt/sampling/sampling_params.py +70 -2
  354. sglang/srt/server_args.py +846 -163
  355. sglang/srt/server_args_config_parser.py +1 -1
  356. sglang/srt/single_batch_overlap.py +36 -31
  357. sglang/srt/speculative/base_spec_worker.py +34 -0
  358. sglang/srt/speculative/draft_utils.py +226 -0
  359. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  360. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  361. sglang/srt/speculative/eagle_info.py +57 -18
  362. sglang/srt/speculative/eagle_info_v2.py +458 -0
  363. sglang/srt/speculative/eagle_utils.py +138 -0
  364. sglang/srt/speculative/eagle_worker.py +83 -280
  365. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  366. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  367. sglang/srt/speculative/ngram_worker.py +12 -11
  368. sglang/srt/speculative/spec_info.py +2 -0
  369. sglang/srt/speculative/spec_utils.py +38 -3
  370. sglang/srt/speculative/standalone_worker.py +4 -14
  371. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  372. sglang/srt/two_batch_overlap.py +28 -14
  373. sglang/srt/utils/__init__.py +1 -1
  374. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  375. sglang/srt/utils/common.py +272 -82
  376. sglang/srt/utils/hf_transformers_utils.py +44 -17
  377. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  378. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  379. sglang/srt/utils/profile_merger.py +199 -0
  380. sglang/test/attention/test_flashattn_backend.py +1 -1
  381. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  382. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  383. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  384. sglang/test/few_shot_gsm8k_engine.py +2 -4
  385. sglang/test/kit_matched_stop.py +157 -0
  386. sglang/test/longbench_v2/__init__.py +1 -0
  387. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  388. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  389. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  390. sglang/test/run_eval.py +41 -0
  391. sglang/test/runners.py +2 -0
  392. sglang/test/send_one.py +42 -7
  393. sglang/test/simple_eval_common.py +3 -0
  394. sglang/test/simple_eval_gpqa.py +0 -1
  395. sglang/test/simple_eval_humaneval.py +0 -3
  396. sglang/test/simple_eval_longbench_v2.py +344 -0
  397. sglang/test/test_block_fp8.py +1 -2
  398. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  399. sglang/test/test_cutlass_moe.py +1 -2
  400. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  401. sglang/test/test_deterministic.py +463 -107
  402. sglang/test/test_deterministic_utils.py +74 -0
  403. sglang/test/test_disaggregation_utils.py +81 -0
  404. sglang/test/test_marlin_moe.py +0 -1
  405. sglang/test/test_utils.py +85 -20
  406. sglang/version.py +1 -1
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
  409. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  410. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  411. sglang/srt/models/vila.py +0 -306
  412. sglang/srt/speculative/build_eagle_tree.py +0 -427
  413. sglang/test/test_block_fp8_ep.py +0 -358
  414. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  415. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  416. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  417. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  418. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  419. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -7,6 +7,7 @@ FlashInfer is faster and Triton is easier to customize.
7
7
  Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
8
8
  """
9
9
 
10
+ import logging
10
11
  import os
11
12
  from dataclasses import dataclass
12
13
  from enum import Enum, auto
@@ -15,20 +16,13 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Union
15
16
 
16
17
  import torch
17
18
 
18
- if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
19
- import logging
20
-
21
- torch._logging.set_logs(dynamo=logging.ERROR)
22
- torch._dynamo.config.suppress_errors = True
23
-
24
- from sglang.global_config import global_config
19
+ from sglang.srt.environ import envs
25
20
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
26
21
  from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
27
22
  from sglang.srt.layers.dp_attention import get_attention_tp_size
28
23
  from sglang.srt.layers.radix_attention import AttentionType
29
24
  from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
30
25
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
31
- from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
32
26
  from sglang.srt.speculative.spec_info import SpecInput
33
27
  from sglang.srt.utils import (
34
28
  get_int_env_var,
@@ -41,6 +35,12 @@ if TYPE_CHECKING:
41
35
  from sglang.srt.layers.radix_attention import RadixAttention
42
36
  from sglang.srt.model_executor.model_runner import ModelRunner
43
37
 
38
+ logger = logging.getLogger(__name__)
39
+
40
+ if envs.SGLANG_ENABLE_TORCH_COMPILE.get():
41
+ torch._logging.set_logs(dynamo=logging.ERROR)
42
+ torch._dynamo.config.suppress_errors = True
43
+
44
44
 
45
45
  if is_flashinfer_available():
46
46
  from flashinfer import (
@@ -50,7 +50,6 @@ if is_flashinfer_available():
50
50
  fast_decode_plan,
51
51
  )
52
52
  from flashinfer.cascade import merge_state
53
- from flashinfer.decode import _get_range_buf, get_seq_lens
54
53
 
55
54
 
56
55
  class WrapperDispatch(Enum):
@@ -58,6 +57,36 @@ class WrapperDispatch(Enum):
58
57
  CROSS_ATTENTION = auto()
59
58
 
60
59
 
60
+ @dataclass
61
+ class MultiItemScoringParams:
62
+ """Parameters for multi-item scoring in attention computation.
63
+
64
+ Used when processing sequences with multiple items separated by delimiters,
65
+ where each item needs specific attention patterns that respect item boundaries.
66
+
67
+ Attributes:
68
+ prefix_len_ptr: A uint32 1D tensor indicating the prefix length of each prompt.
69
+ The tensor size is equal to the batch size.
70
+ token_pos_in_items_ptr: A uint16 1D tensor indicating the token position of each item
71
+ starting from 0 (delimiter) for each item. For batch size > 1,
72
+ sequences are concatenated with zero padding to ensure same length.
73
+ token_pos_in_items_len: Zero padding length for token_pos_in_items_ptr to handle
74
+ batch_size > 1 case. Defines the padded length for each sequence.
75
+ max_item_len_ptr: A uint16 tensor containing the max token length of all items
76
+ for each prompt in the batch.
77
+
78
+ """
79
+
80
+ prefix_len_ptr: Optional[torch.Tensor] = None
81
+ token_pos_in_items_ptr: Optional[torch.Tensor] = None
82
+ token_pos_in_items_len: int = 0
83
+ max_item_len_ptr: Optional[torch.Tensor] = None
84
+
85
+ def is_enabled(self) -> bool:
86
+ """Check if multi-item scoring is enabled."""
87
+ return self.prefix_len_ptr is not None
88
+
89
+
61
90
  @dataclass
62
91
  class DecodeMetadata:
63
92
  decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper]
@@ -68,6 +97,7 @@ class PrefillMetadata:
68
97
  prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper]
69
98
  use_ragged: bool
70
99
  extend_no_prefix: bool
100
+ multi_item_params: Optional[MultiItemScoringParams] = None
71
101
 
72
102
 
73
103
  # Reuse this workspace buffer across all flashinfer wrappers
@@ -87,9 +117,15 @@ class FlashInferAttnBackend(AttentionBackend):
87
117
  skip_prefill: bool = False,
88
118
  kv_indptr_buf: Optional[torch.Tensor] = None,
89
119
  kv_last_page_len_buf: Optional[torch.Tensor] = None,
120
+ init_new_workspace: bool = False,
90
121
  ):
91
122
  super().__init__()
92
123
 
124
+ # Store multi-item scoring delimiter for efficient access
125
+ self.multi_item_scoring_delimiter = (
126
+ model_runner.server_args.multi_item_scoring_delimiter
127
+ )
128
+
93
129
  # Parse constants
94
130
  self.decode_use_tensor_cores = should_use_tensor_core(
95
131
  kv_cache_dtype=model_runner.kv_cache_dtype,
@@ -124,7 +160,7 @@ class FlashInferAttnBackend(AttentionBackend):
124
160
  or "Qwen3ForCausalLM" in model_runner.model_config.hf_config.architectures
125
161
  or "MiMoForCausalLM" in model_runner.model_config.hf_config.architectures
126
162
  ):
127
- global_config.flashinfer_workspace_size = 512 * 1024 * 1024
163
+ envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.set(512 * 1024 * 1024)
128
164
 
129
165
  # When deterministic inference is enabled, tensor cores should be used for decode
130
166
  # Also set split tile sizes for prefill and decode from environment variables, and disable kv split for cuda graph
@@ -144,19 +180,26 @@ class FlashInferAttnBackend(AttentionBackend):
144
180
  "SGLANG_FLASHINFER_DECODE_SPLIT_TILE_SIZE", 2048
145
181
  )
146
182
  self.disable_cuda_graph_kv_split = True
147
- global_config.flashinfer_workspace_size = 2048 * 1024 * 1024
183
+ envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.set(2048 * 1024 * 1024)
148
184
 
149
185
  # Allocate buffers
150
186
  global global_workspace_buffer
151
187
  if global_workspace_buffer is None:
152
188
  # different from flashinfer zero_init_global_workspace_buffer
153
- global_workspace_size = global_config.flashinfer_workspace_size
189
+ global_workspace_size = envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.get()
154
190
  global_workspace_buffer = torch.empty(
155
191
  global_workspace_size,
156
192
  dtype=torch.uint8,
157
193
  device=model_runner.device,
158
194
  )
159
- self.workspace_buffer = global_workspace_buffer
195
+ if init_new_workspace:
196
+ self.workspace_buffer = torch.empty(
197
+ envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.get(),
198
+ dtype=torch.uint8,
199
+ device=model_runner.device,
200
+ )
201
+ else:
202
+ self.workspace_buffer = global_workspace_buffer
160
203
  max_bs = model_runner.req_to_token_pool.size
161
204
  if kv_indptr_buf is None:
162
205
  self.kv_indptr = [
@@ -187,7 +230,16 @@ class FlashInferAttnBackend(AttentionBackend):
187
230
 
188
231
  fmha_backend = "auto"
189
232
  if is_sm100_supported():
190
- fmha_backend = "cutlass"
233
+ # Disable CUTLASS backend when piecewise cuda graph is enabled
234
+ # due to TMA descriptor initialization issues on B200
235
+ if model_runner.server_args.enable_piecewise_cuda_graph:
236
+ logger.warning(
237
+ "CUTLASS backend is disabled when piecewise cuda graph is enabled "
238
+ "due to TMA descriptor initialization issues on B200. "
239
+ "Using auto backend instead for stability."
240
+ )
241
+ else:
242
+ fmha_backend = "cutlass"
191
243
  self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
192
244
  self.workspace_buffer, "NHD", backend=fmha_backend
193
245
  )
@@ -229,10 +281,133 @@ class FlashInferAttnBackend(AttentionBackend):
229
281
 
230
282
  # Other metadata
231
283
  self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
284
+
232
285
  self.decode_cuda_graph_metadata = {}
233
286
  self.prefill_cuda_graph_metadata = {} # For verify
234
287
  self.draft_extend_cuda_graph_metadata = {} # For draft extend
235
288
 
289
+ def _process_multi_item_scoring(
290
+ self, forward_batch: ForwardBatch
291
+ ) -> MultiItemScoringParams:
292
+ """Process multi-item scoring tensors for FlashInfer attention.
293
+
294
+ This method handles sequences containing multiple "items" separated by delimiter tokens,
295
+ where each item needs specific attention patterns that respect item boundaries.
296
+
297
+ The method produces four key tensors for FlashInfer:
298
+ - prefix_len_ptr: uint32 tensor with prefix length for each prompt in batch
299
+ - token_pos_in_items_ptr: uint16 tensor with token positions starting from 0 at delimiters
300
+ - token_pos_in_items_len: padding length for batch processing
301
+ - max_item_len_ptr: uint16 tensor with max item length for each prompt
302
+
303
+ Args:
304
+ forward_batch: The forward batch containing input sequences and delimiter info
305
+
306
+ Returns:
307
+ MultiItemScoringParams: The processed multi-item scoring parameters
308
+
309
+ Examples:
310
+ Following FlashInfer definition: for 3 items of length 3, 2, 4 respectively:
311
+ token_pos_in_items_ptr = [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0]
312
+
313
+ Case 1: Single sequence
314
+ Text: "What is the capital of France? <delim> London <delim> Paris <delim> Berlin <delim>"
315
+ Tokens: [What, is, the, capital, of, France, ?, <delim>, London, <delim>, Paris, <delim>, Berlin, <delim>]
316
+ Indices: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
317
+ - prefix_len_ptr: [7] (query length before first delimiter)
318
+ - token_pos_in_items_ptr: [0, 1, 0, 1, 0, 1, 0] (delim=0, London=1, delim=0, Paris=1, delim=0, Berlin=1, delim=0)
319
+ - token_pos_in_items_len: 7 (actual length)
320
+ - max_item_len_ptr: [1] (max item length is 1 token - all options are single tokens)
321
+
322
+ Case 2: Batch processing (batch_size=2)
323
+ Sequence 1: 2 items of length 2, 1 → [0, 1, 2, 0, 1, 0] (6 elements)
324
+ Sequence 2: 3 items of length 1, 3, 2 → [0, 1, 0, 1, 2, 3, 0, 1, 2, 0] (10 elements)
325
+ After padding both to length 10:
326
+ - token_pos_in_items_ptr: [0, 1, 2, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 2, 3, 0, 1, 2, 0]
327
+ - token_pos_in_items_len: 10 (padded length for batch processing)
328
+ - max_item_len_ptr: [2, 3] (max lengths per sequence)
329
+ """
330
+
331
+ delimiter = self.multi_item_scoring_delimiter
332
+ if delimiter is None or forward_batch.forward_mode == ForwardMode.DECODE:
333
+ return MultiItemScoringParams()
334
+
335
+ delimiter_mask = forward_batch.input_ids == delimiter
336
+ prefix_cache_lens = getattr(forward_batch, "extend_prefix_lens", None)
337
+ extend_seq_lens = getattr(forward_batch, "extend_seq_lens", None)
338
+ prefix_len_ptr, token_pos_in_items_ptr = [], []
339
+ token_pos_in_items_len = 0
340
+
341
+ # If no extend_seq_lens, treat whole batch as one sequence
342
+ if extend_seq_lens is None or len(extend_seq_lens) <= 1:
343
+ extend_seq_lens = [forward_batch.input_ids.size(0)]
344
+
345
+ seq_start = 0
346
+ for i, seq_len in enumerate(extend_seq_lens):
347
+ seq_end = seq_start + seq_len
348
+ mask = delimiter_mask[seq_start:seq_end]
349
+ pos = forward_batch.positions[seq_start:seq_end]
350
+ delimiter_indices = torch.nonzero(mask, as_tuple=True)[0]
351
+
352
+ if len(delimiter_indices) > 0:
353
+ first_delim = delimiter_indices[0]
354
+ # Prefix length: store as scalar
355
+ prefix_len = first_delim + (
356
+ prefix_cache_lens[i] if prefix_cache_lens is not None else 0
357
+ )
358
+ prefix_len_ptr.append(
359
+ prefix_len.item() if torch.is_tensor(prefix_len) else prefix_len
360
+ )
361
+
362
+ # Compute relative positions within items after delimiters
363
+ diff = pos[first_delim:] - torch.cummax(mask[first_delim:], 0)[1]
364
+ token_pos = (diff - pos[first_delim]).to(torch.uint16)
365
+ token_pos_in_items_ptr.append(token_pos)
366
+
367
+ # Update forward_batch positions in-place
368
+ pos[first_delim:] = diff - 1
369
+ forward_batch.positions[seq_start:seq_end] = pos
370
+
371
+ seq_start = seq_end
372
+
373
+ # Pad token_pos_in_items_ptr for batch processing
374
+ if token_pos_in_items_ptr:
375
+ token_pos_in_items_len = max(t.numel() for t in token_pos_in_items_ptr)
376
+ device = forward_batch.input_ids.device
377
+ token_pos_in_items_ptr = [
378
+ torch.cat(
379
+ [
380
+ t,
381
+ torch.zeros(
382
+ token_pos_in_items_len - t.numel(),
383
+ dtype=torch.uint16,
384
+ device=device,
385
+ ),
386
+ ]
387
+ )
388
+ for t in token_pos_in_items_ptr
389
+ ]
390
+
391
+ if not prefix_len_ptr or not token_pos_in_items_ptr:
392
+ return MultiItemScoringParams()
393
+
394
+ # Build final params
395
+ device = forward_batch.input_ids.device
396
+ return MultiItemScoringParams(
397
+ prefix_len_ptr=torch.tensor(
398
+ prefix_len_ptr, dtype=torch.uint32, device=device
399
+ ),
400
+ token_pos_in_items_ptr=torch.cat(token_pos_in_items_ptr, dim=0),
401
+ token_pos_in_items_len=token_pos_in_items_len & 0xFFFFFFFF,
402
+ max_item_len_ptr=torch.stack(
403
+ [
404
+ t.to(torch.int32).max().to(torch.uint16)
405
+ for t in token_pos_in_items_ptr
406
+ ],
407
+ dim=0,
408
+ ),
409
+ )
410
+
236
411
  def init_forward_metadata(self, forward_batch: ForwardBatch):
237
412
  if forward_batch.forward_mode.is_decode_or_idle():
238
413
  self.indices_updater_decode.update(
@@ -280,13 +455,26 @@ class FlashInferAttnBackend(AttentionBackend):
280
455
  else:
281
456
  prefix_lens = forward_batch.extend_prefix_lens
282
457
 
283
- if self.is_multimodal:
458
+ # Disable ragged wrapper and ensure prefix handling for multimodal and multi-item scoring
459
+ if self.is_multimodal or self.multi_item_scoring_delimiter is not None:
460
+ # use_ragged = False: Multi-item scoring requires the paged wrapper because:
461
+ # 1. Ragged wrapper doesn't support the specialized multi-item parameters
462
+ # (prefix_len_ptr, token_pos_in_items_ptr, etc.)
463
+ # 2. Paged wrapper provides better control over attention masking needed
464
+ # for respecting item boundaries in multi-item sequences
465
+ # 3. Custom masking logic conflicts with ragged wrapper's assumptions
284
466
  use_ragged = False
285
467
  extend_no_prefix = False
286
468
  else:
287
469
  use_ragged = not self.enable_deterministic
288
470
  extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
289
471
 
472
+ # Process multi-item scoring in attention backend instead of ForwardBatch
473
+ multi_item_params = MultiItemScoringParams()
474
+ if self.multi_item_scoring_delimiter is not None:
475
+ # Use new backend-specific implementation
476
+ multi_item_params = self._process_multi_item_scoring(forward_batch)
477
+
290
478
  self.indices_updater_prefill.update(
291
479
  forward_batch.req_pool_indices,
292
480
  forward_batch.seq_lens,
@@ -298,9 +486,13 @@ class FlashInferAttnBackend(AttentionBackend):
298
486
  encoder_lens=forward_batch.encoder_lens,
299
487
  spec_info=None,
300
488
  fixed_split_size=self.prefill_split_tile_size,
489
+ multi_item_params=multi_item_params,
301
490
  )
302
491
  self.forward_metadata = PrefillMetadata(
303
- self.prefill_wrappers_paged, use_ragged, extend_no_prefix
492
+ self.prefill_wrappers_paged,
493
+ use_ragged,
494
+ extend_no_prefix,
495
+ multi_item_params,
304
496
  )
305
497
 
306
498
  def init_cuda_graph_state(
@@ -531,7 +723,20 @@ class FlashInferAttnBackend(AttentionBackend):
531
723
  forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
532
724
  causal=not layer.is_cross_attention,
533
725
  sm_scale=layer.scaling,
534
- window_left=layer.sliding_window_size,
726
+ # Disable sliding window attention for multi-item scoring:
727
+ # - Sliding window could cut across item boundaries, breaking semantic coherence
728
+ # - Multi-item sequences need full attention to properly handle delimiter tokens
729
+ # - Specialized multi-item parameters (prefix_len_ptr, token_pos_in_items_ptr)
730
+ # provide more precise attention control than simple sliding windows
731
+ # - Item-aware masking takes precedence over window-based masking
732
+ window_left=(
733
+ layer.sliding_window_size
734
+ if not (
735
+ self.forward_metadata.multi_item_params
736
+ and self.forward_metadata.multi_item_params.is_enabled()
737
+ )
738
+ else -1
739
+ ),
535
740
  logits_soft_cap=logits_soft_cap,
536
741
  # Must use _float to avoid device-to-host copy that breaks cuda graph capture.
537
742
  k_scale=layer.k_scale_float,
@@ -539,9 +744,13 @@ class FlashInferAttnBackend(AttentionBackend):
539
744
  )
540
745
  else:
541
746
  causal = True
542
- if layer.attn_type == AttentionType.ENCODER_ONLY:
543
- save_kv_cache = False
747
+ if (
748
+ layer.is_cross_attention
749
+ or layer.attn_type == AttentionType.ENCODER_ONLY
750
+ ):
544
751
  causal = False
752
+ if save_kv_cache and layer.attn_type == AttentionType.ENCODER_ONLY:
753
+ save_kv_cache = False
545
754
 
546
755
  if self.forward_metadata.extend_no_prefix:
547
756
  # NOTE: FlashInfer currently has limitations with head_dim = 32 or other dimensions
@@ -952,6 +1161,7 @@ class FlashInferIndicesUpdaterPrefill:
952
1161
  encoder_lens: Optional[torch.Tensor],
953
1162
  spec_info: Optional[SpecInput],
954
1163
  fixed_split_size: Optional[int] = None,
1164
+ multi_item_params: Optional[MultiItemScoringParams] = None,
955
1165
  ):
956
1166
  if use_ragged:
957
1167
  # TODO: remove this device sync, we can use forward_batch.extend_prefix_lens_cpu
@@ -976,6 +1186,7 @@ class FlashInferIndicesUpdaterPrefill:
976
1186
  use_ragged,
977
1187
  spec_info,
978
1188
  fixed_split_size=fixed_split_size,
1189
+ multi_item_params=multi_item_params,
979
1190
  )
980
1191
 
981
1192
  def update_sliding_window(
@@ -990,6 +1201,7 @@ class FlashInferIndicesUpdaterPrefill:
990
1201
  encoder_lens: Optional[torch.Tensor],
991
1202
  spec_info: Optional[SpecInput],
992
1203
  fixed_split_size: Optional[int] = None,
1204
+ multi_item_params: Optional[MultiItemScoringParams] = None,
993
1205
  ):
994
1206
  for wrapper_id in range(2):
995
1207
  if wrapper_id == 0:
@@ -1023,6 +1235,7 @@ class FlashInferIndicesUpdaterPrefill:
1023
1235
  use_ragged,
1024
1236
  spec_info,
1025
1237
  use_sliding_window_kv_pool=use_sliding_window_kv_pool,
1238
+ multi_item_params=multi_item_params,
1026
1239
  )
1027
1240
 
1028
1241
  def update_cross_attention(
@@ -1037,6 +1250,7 @@ class FlashInferIndicesUpdaterPrefill:
1037
1250
  encoder_lens: Optional[torch.Tensor],
1038
1251
  spec_info: Optional[SpecInput],
1039
1252
  fixed_split_size: Optional[int] = None,
1253
+ multi_item_params: Optional[MultiItemScoringParams] = None,
1040
1254
  ):
1041
1255
  for wrapper_id in range(2):
1042
1256
  if wrapper_id == 0:
@@ -1063,6 +1277,7 @@ class FlashInferIndicesUpdaterPrefill:
1063
1277
  self.qo_indptr[wrapper_id],
1064
1278
  use_ragged,
1065
1279
  spec_info,
1280
+ multi_item_params=multi_item_params,
1066
1281
  )
1067
1282
 
1068
1283
  def call_begin_forward(
@@ -1081,6 +1296,7 @@ class FlashInferIndicesUpdaterPrefill:
1081
1296
  spec_info: Optional[SpecInput],
1082
1297
  use_sliding_window_kv_pool: bool = False,
1083
1298
  fixed_split_size: Optional[int] = None,
1299
+ multi_item_params: Optional[MultiItemScoringParams] = None,
1084
1300
  ):
1085
1301
  bs = len(seq_lens)
1086
1302
  if spec_info is None:
@@ -1136,6 +1352,22 @@ class FlashInferIndicesUpdaterPrefill:
1136
1352
  )
1137
1353
 
1138
1354
  # cached part
1355
+ # Conditionally set multi-item parameters
1356
+ if multi_item_params is not None and multi_item_params.is_enabled():
1357
+ # Multi-item scoring is active - use specialized parameters and disable generic custom_mask
1358
+ use_custom_mask = None
1359
+ prefix_len_ptr = multi_item_params.prefix_len_ptr
1360
+ token_pos_in_items_ptr = multi_item_params.token_pos_in_items_ptr
1361
+ token_pos_in_items_len = multi_item_params.token_pos_in_items_len
1362
+ max_item_len_ptr = multi_item_params.max_item_len_ptr
1363
+ else:
1364
+ # No multi-item scoring - use standard parameters
1365
+ use_custom_mask = custom_mask
1366
+ prefix_len_ptr = None
1367
+ token_pos_in_items_ptr = None
1368
+ token_pos_in_items_len = 0
1369
+ max_item_len_ptr = None
1370
+
1139
1371
  wrapper_paged.begin_forward(
1140
1372
  qo_indptr,
1141
1373
  kv_indptr,
@@ -1147,9 +1379,13 @@ class FlashInferIndicesUpdaterPrefill:
1147
1379
  1,
1148
1380
  q_data_type=self.q_data_type,
1149
1381
  kv_data_type=self.data_type,
1150
- custom_mask=custom_mask,
1382
+ custom_mask=use_custom_mask,
1151
1383
  non_blocking=True,
1152
1384
  fixed_split_size=fixed_split_size,
1385
+ prefix_len_ptr=prefix_len_ptr,
1386
+ token_pos_in_items_ptr=token_pos_in_items_ptr,
1387
+ token_pos_in_items_len=token_pos_in_items_len,
1388
+ max_item_len_ptr=max_item_len_ptr,
1153
1389
  )
1154
1390
 
1155
1391
 
@@ -1185,7 +1421,7 @@ class FlashInferMultiStepDraftBackend:
1185
1421
  (max_bs,), dtype=torch.int32, device=model_runner.device
1186
1422
  )
1187
1423
  self.attn_backends: List[FlashInferAttnBackend] = []
1188
- for i in range(self.speculative_num_steps):
1424
+ for i in range(self.speculative_num_steps - 1):
1189
1425
  self.attn_backends.append(
1190
1426
  FlashInferAttnBackend(
1191
1427
  model_runner,
@@ -1273,7 +1509,7 @@ class FlashInferMultiStepDraftBackend:
1273
1509
  device="cuda",
1274
1510
  )
1275
1511
 
1276
- for i in range(self.speculative_num_steps):
1512
+ for i in range(self.speculative_num_steps - 1):
1277
1513
  self.attn_backends[i].init_cuda_graph_state(
1278
1514
  max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
1279
1515
  )
@@ -9,27 +9,20 @@ and uses BatchMLAPaged wrapper for decoding.
9
9
  More details can be found in https://docs.flashinfer.ai/api/mla.html
10
10
  """
11
11
 
12
- import os
13
12
  from dataclasses import dataclass
14
13
  from functools import partial
15
14
  from typing import TYPE_CHECKING, Callable, Optional, Union
16
15
 
17
16
  import torch
18
17
 
19
- if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
20
- import logging
21
-
22
- torch._logging.set_logs(dynamo=logging.ERROR)
23
- torch._dynamo.config.suppress_errors = True
24
-
25
- from sglang.global_config import global_config
18
+ from sglang.srt.environ import envs
26
19
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
27
20
  from sglang.srt.layers.attention.flashinfer_backend import (
28
21
  create_flashinfer_kv_indices_triton,
29
22
  )
30
23
  from sglang.srt.layers.dp_attention import get_attention_tp_size
31
- from sglang.srt.managers.schedule_batch import global_server_args_dict
32
24
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
25
+ from sglang.srt.server_args import get_global_server_args
33
26
  from sglang.srt.speculative.spec_info import SpecInput
34
27
  from sglang.srt.utils import (
35
28
  is_flashinfer_available,
@@ -38,10 +31,19 @@ from sglang.srt.utils import (
38
31
  )
39
32
 
40
33
  if TYPE_CHECKING:
34
+ from sglang.srt.layers.attention.flashinfer_mla_backend import (
35
+ FlashInferMlaAttnBackend,
36
+ )
41
37
  from sglang.srt.layers.radix_attention import RadixAttention
42
38
  from sglang.srt.model_executor.model_runner import ModelRunner
43
39
  from sglang.srt.speculative.spec_info import SpecInput
44
40
 
41
+ if envs.SGLANG_ENABLE_TORCH_COMPILE.get():
42
+ import logging
43
+
44
+ torch._logging.set_logs(dynamo=logging.ERROR)
45
+ torch._dynamo.config.suppress_errors = True
46
+
45
47
  if is_flashinfer_available():
46
48
  from flashinfer import (
47
49
  BatchMLAPagedAttentionWrapper,
@@ -66,7 +68,7 @@ global_workspace_buffer = None
66
68
 
67
69
  class FlashInferMhaChunkKVRunner:
68
70
  def __init__(
69
- self, model_runner: ModelRunner, attn_backend: "FlashInferMlaAttnBackend"
71
+ self, model_runner: ModelRunner, attn_backend: FlashInferMlaAttnBackend
70
72
  ):
71
73
  # Parse Constants
72
74
  self.num_local_heads = (
@@ -80,6 +82,7 @@ class FlashInferMhaChunkKVRunner:
80
82
 
81
83
  # Buffers and wrappers
82
84
  self.qo_indptr = attn_backend.qo_indptr
85
+ self.kv_indptr = attn_backend.kv_indptr
83
86
  self.workspace_buffer = attn_backend.workspace_buffer
84
87
  self.fmha_backend = attn_backend.fmha_backend
85
88
 
@@ -130,9 +133,14 @@ class FlashInferMhaChunkKVRunner:
130
133
  )
131
134
  # ragged prefill
132
135
  if not disable_flashinfer_ragged:
136
+ kv_indptr = (
137
+ qo_indptr
138
+ if not forward_batch.mha_one_shot
139
+ else self.kv_indptr[: bs + 1]
140
+ )
133
141
  self.ragged_wrapper.begin_forward(
134
142
  qo_indptr=qo_indptr,
135
- kv_indptr=qo_indptr,
143
+ kv_indptr=kv_indptr,
136
144
  num_qo_heads=self.num_local_heads,
137
145
  num_kv_heads=self.num_local_heads,
138
146
  head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
@@ -154,7 +162,7 @@ class FlashInferMhaChunkKVRunner:
154
162
  chunk_idx = forward_batch.prefix_chunk_idx
155
163
  assert chunk_idx >= 0
156
164
  wrapper = self.chunk_ragged_wrappers[chunk_idx]
157
- o1, s1 = wrapper.forward_return_lse(
165
+ o = wrapper.forward_return_lse(
158
166
  q.view(-1, layer.tp_q_head_num, layer.head_dim),
159
167
  k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
160
168
  v.view(-1, layer.tp_v_head_num, layer.v_head_dim).to(q.dtype),
@@ -163,7 +171,12 @@ class FlashInferMhaChunkKVRunner:
163
171
  logits_soft_cap=logits_soft_cap,
164
172
  )
165
173
  else:
166
- o1, s1 = self.ragged_wrapper.forward_return_lse(
174
+ forward = (
175
+ self.ragged_wrapper.forward_return_lse
176
+ if forward_batch.mha_return_lse
177
+ else self.ragged_wrapper.forward
178
+ )
179
+ o = forward(
167
180
  q.view(-1, layer.tp_q_head_num, layer.head_dim),
168
181
  k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
169
182
  v.view(-1, layer.tp_v_head_num, layer.v_head_dim).to(q.dtype),
@@ -171,8 +184,7 @@ class FlashInferMhaChunkKVRunner:
171
184
  sm_scale=layer.scaling,
172
185
  logits_soft_cap=logits_soft_cap,
173
186
  )
174
-
175
- return o1, s1
187
+ return o
176
188
 
177
189
 
178
190
  class FlashInferMLAAttnBackend(AttentionBackend):
@@ -193,9 +205,9 @@ class FlashInferMLAAttnBackend(AttentionBackend):
193
205
  self.skip_prefill = skip_prefill
194
206
  self.enable_chunk_kv = (
195
207
  not skip_prefill
196
- and global_server_args_dict["disaggregation_mode"] != "decode"
197
- and not global_server_args_dict["disable_chunked_prefix_cache"]
198
- and not global_server_args_dict["flashinfer_mla_disable_ragged"]
208
+ and get_global_server_args().disaggregation_mode != "decode"
209
+ and not get_global_server_args().disable_chunked_prefix_cache
210
+ and not get_global_server_args().flashinfer_mla_disable_ragged
199
211
  )
200
212
  self.page_size = model_runner.page_size
201
213
 
@@ -204,7 +216,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
204
216
  if global_workspace_buffer is None:
205
217
  # different from flashinfer zero_init_global_workspace_buffer
206
218
  global_workspace_buffer = torch.empty(
207
- global_config.flashinfer_workspace_size,
219
+ envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.get(),
208
220
  dtype=torch.uint8,
209
221
  device=model_runner.device,
210
222
  )
@@ -306,7 +318,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
306
318
  prefix_lens = forward_batch.extend_prefix_lens
307
319
  extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
308
320
  use_ragged = (
309
- not global_server_args_dict["flashinfer_mla_disable_ragged"]
321
+ not get_global_server_args().flashinfer_mla_disable_ragged
310
322
  and extend_no_prefix
311
323
  )
312
324
 
@@ -510,15 +522,13 @@ class FlashInferMLAAttnBackend(AttentionBackend):
510
522
  q_rope: Optional[torch.Tensor] = None,
511
523
  k_rope: Optional[torch.Tensor] = None,
512
524
  ):
513
- if (
514
- forward_batch.attn_attend_prefix_cache is not None
515
- and forward_batch.mha_return_lse
525
+ if forward_batch.attn_attend_prefix_cache is not None and any(
526
+ forward_batch.extend_prefix_lens_cpu
516
527
  ): # MHA Chunk
517
528
  assert self.enable_chunk_kv
518
529
  assert q_rope is None
519
530
  assert k_rope is None
520
- o1, s1 = self.mha_chunk_kv_cache.forward(q, k, v, layer, forward_batch)
521
- return o1, s1
531
+ return self.mha_chunk_kv_cache.forward(q, k, v, layer, forward_batch)
522
532
 
523
533
  cache_loc = forward_batch.out_cache_loc
524
534
  logits_soft_cap = layer.logit_cap
@@ -916,7 +926,7 @@ class FlashInferMLAMultiStepDraftBackend:
916
926
  )
917
927
 
918
928
  self.attn_backends = []
919
- for i in range(self.speculative_num_steps):
929
+ for i in range(self.speculative_num_steps - 1):
920
930
  self.attn_backends.append(
921
931
  FlashInferMLAAttnBackend(
922
932
  model_runner,
@@ -998,7 +1008,7 @@ class FlashInferMLAMultiStepDraftBackend:
998
1008
  device="cuda",
999
1009
  )
1000
1010
 
1001
- for i in range(self.speculative_num_steps):
1011
+ for i in range(self.speculative_num_steps - 1):
1002
1012
  self.attn_backends[i].init_cuda_graph_state(
1003
1013
  max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
1004
1014
  )
@@ -1060,7 +1070,7 @@ def fast_mla_decode_plan(
1060
1070
 
1061
1071
  try:
1062
1072
  # Standard version with just the required arguments (no use_profiler)
1063
- self._cached_module.plan.default(
1073
+ self._cached_module.plan(
1064
1074
  self._float_workspace_buffer,
1065
1075
  self._int_workspace_buffer,
1066
1076
  self._pin_memory_int_workspace_buffer,