sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (408) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +330 -156
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +8 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +134 -23
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +70 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +66 -66
  69. sglang/srt/entrypoints/grpc_server.py +431 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +120 -8
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +42 -4
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +3 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +18 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/utils.py +2 -2
  93. sglang/srt/grpc/compile_proto.py +3 -3
  94. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  95. sglang/srt/grpc/health_servicer.py +189 -0
  96. sglang/srt/grpc/scheduler_launcher.py +181 -0
  97. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  98. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  99. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  100. sglang/srt/layers/activation.py +4 -1
  101. sglang/srt/layers/attention/aiter_backend.py +3 -3
  102. sglang/srt/layers/attention/ascend_backend.py +17 -1
  103. sglang/srt/layers/attention/attention_registry.py +43 -23
  104. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  105. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  106. sglang/srt/layers/attention/fla/chunk.py +0 -1
  107. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  108. sglang/srt/layers/attention/fla/index.py +0 -2
  109. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  110. sglang/srt/layers/attention/fla/utils.py +0 -3
  111. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  112. sglang/srt/layers/attention/flashattention_backend.py +12 -8
  113. sglang/srt/layers/attention/flashinfer_backend.py +248 -21
  114. sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
  115. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  116. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  117. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  118. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  119. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  121. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  122. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  123. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  124. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  125. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  127. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  128. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  129. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  130. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  131. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  132. sglang/srt/layers/attention/nsa/utils.py +0 -1
  133. sglang/srt/layers/attention/nsa_backend.py +404 -90
  134. sglang/srt/layers/attention/triton_backend.py +208 -34
  135. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  136. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  137. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  138. sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
  139. sglang/srt/layers/attention/utils.py +11 -7
  140. sglang/srt/layers/attention/vision.py +3 -3
  141. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  142. sglang/srt/layers/communicator.py +11 -7
  143. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  146. sglang/srt/layers/dp_attention.py +17 -0
  147. sglang/srt/layers/layernorm.py +45 -15
  148. sglang/srt/layers/linear.py +9 -1
  149. sglang/srt/layers/logits_processor.py +147 -17
  150. sglang/srt/layers/modelopt_utils.py +11 -0
  151. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  152. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  153. sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
  154. sglang/srt/layers/moe/ep_moe/layer.py +119 -397
  155. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  159. sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
  160. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  161. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  162. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  163. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  164. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  165. sglang/srt/layers/moe/router.py +51 -15
  166. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  167. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  168. sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
  169. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  170. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  171. sglang/srt/layers/moe/topk.py +3 -2
  172. sglang/srt/layers/moe/utils.py +17 -1
  173. sglang/srt/layers/quantization/__init__.py +2 -53
  174. sglang/srt/layers/quantization/awq.py +183 -6
  175. sglang/srt/layers/quantization/awq_triton.py +29 -0
  176. sglang/srt/layers/quantization/base_config.py +20 -1
  177. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  178. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  179. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  180. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  181. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  183. sglang/srt/layers/quantization/fp8.py +84 -18
  184. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  185. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  186. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  187. sglang/srt/layers/quantization/gptq.py +0 -1
  188. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  189. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  190. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  191. sglang/srt/layers/quantization/mxfp4.py +5 -30
  192. sglang/srt/layers/quantization/petit.py +1 -1
  193. sglang/srt/layers/quantization/quark/quark.py +3 -1
  194. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  195. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  196. sglang/srt/layers/quantization/unquant.py +1 -4
  197. sglang/srt/layers/quantization/utils.py +0 -1
  198. sglang/srt/layers/quantization/w4afp8.py +51 -20
  199. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  200. sglang/srt/layers/radix_attention.py +59 -9
  201. sglang/srt/layers/rotary_embedding.py +673 -16
  202. sglang/srt/layers/sampler.py +36 -16
  203. sglang/srt/layers/sparse_pooler.py +98 -0
  204. sglang/srt/layers/utils.py +0 -1
  205. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  206. sglang/srt/lora/backend/triton_backend.py +0 -1
  207. sglang/srt/lora/eviction_policy.py +139 -0
  208. sglang/srt/lora/lora_manager.py +24 -9
  209. sglang/srt/lora/lora_registry.py +1 -1
  210. sglang/srt/lora/mem_pool.py +40 -16
  211. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  212. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  213. sglang/srt/managers/cache_controller.py +48 -17
  214. sglang/srt/managers/data_parallel_controller.py +146 -42
  215. sglang/srt/managers/detokenizer_manager.py +40 -13
  216. sglang/srt/managers/io_struct.py +66 -16
  217. sglang/srt/managers/mm_utils.py +20 -18
  218. sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
  219. sglang/srt/managers/overlap_utils.py +96 -19
  220. sglang/srt/managers/schedule_batch.py +241 -511
  221. sglang/srt/managers/schedule_policy.py +15 -2
  222. sglang/srt/managers/scheduler.py +399 -499
  223. sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
  224. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  225. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  226. sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
  227. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  228. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  229. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  230. sglang/srt/managers/tokenizer_manager.py +378 -90
  231. sglang/srt/managers/tp_worker.py +212 -161
  232. sglang/srt/managers/utils.py +78 -2
  233. sglang/srt/mem_cache/allocator.py +7 -2
  234. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  235. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  236. sglang/srt/mem_cache/chunk_cache.py +13 -2
  237. sglang/srt/mem_cache/common.py +480 -0
  238. sglang/srt/mem_cache/evict_policy.py +16 -1
  239. sglang/srt/mem_cache/hicache_storage.py +4 -1
  240. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  241. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  242. sglang/srt/mem_cache/memory_pool.py +435 -219
  243. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  244. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  245. sglang/srt/mem_cache/radix_cache.py +53 -19
  246. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  247. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  249. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  250. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  251. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  252. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  253. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  254. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  255. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  256. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  257. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  258. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  259. sglang/srt/metrics/collector.py +31 -0
  260. sglang/srt/metrics/func_timer.py +1 -1
  261. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  262. sglang/srt/model_executor/forward_batch_info.py +28 -23
  263. sglang/srt/model_executor/model_runner.py +379 -139
  264. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  265. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  266. sglang/srt/model_loader/__init__.py +1 -1
  267. sglang/srt/model_loader/loader.py +424 -27
  268. sglang/srt/model_loader/utils.py +0 -1
  269. sglang/srt/model_loader/weight_utils.py +47 -28
  270. sglang/srt/models/apertus.py +2 -3
  271. sglang/srt/models/arcee.py +2 -2
  272. sglang/srt/models/bailing_moe.py +13 -52
  273. sglang/srt/models/bailing_moe_nextn.py +3 -4
  274. sglang/srt/models/bert.py +1 -1
  275. sglang/srt/models/deepseek_nextn.py +19 -3
  276. sglang/srt/models/deepseek_ocr.py +1516 -0
  277. sglang/srt/models/deepseek_v2.py +273 -98
  278. sglang/srt/models/dots_ocr.py +0 -2
  279. sglang/srt/models/dots_vlm.py +0 -1
  280. sglang/srt/models/dots_vlm_vit.py +1 -1
  281. sglang/srt/models/falcon_h1.py +13 -19
  282. sglang/srt/models/gemma3_mm.py +16 -0
  283. sglang/srt/models/gemma3n_mm.py +1 -2
  284. sglang/srt/models/glm4_moe.py +14 -37
  285. sglang/srt/models/glm4_moe_nextn.py +2 -2
  286. sglang/srt/models/glm4v.py +2 -1
  287. sglang/srt/models/glm4v_moe.py +5 -5
  288. sglang/srt/models/gpt_oss.py +5 -5
  289. sglang/srt/models/grok.py +10 -23
  290. sglang/srt/models/hunyuan.py +2 -7
  291. sglang/srt/models/interns1.py +0 -1
  292. sglang/srt/models/kimi_vl.py +1 -7
  293. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  294. sglang/srt/models/llama.py +2 -2
  295. sglang/srt/models/llama_eagle3.py +1 -1
  296. sglang/srt/models/longcat_flash.py +5 -22
  297. sglang/srt/models/longcat_flash_nextn.py +3 -14
  298. sglang/srt/models/mimo.py +2 -13
  299. sglang/srt/models/mimo_mtp.py +1 -2
  300. sglang/srt/models/minicpmo.py +7 -5
  301. sglang/srt/models/mixtral.py +1 -4
  302. sglang/srt/models/mllama.py +1 -1
  303. sglang/srt/models/mllama4.py +13 -3
  304. sglang/srt/models/nemotron_h.py +511 -0
  305. sglang/srt/models/olmo2.py +31 -4
  306. sglang/srt/models/opt.py +5 -5
  307. sglang/srt/models/phi.py +1 -1
  308. sglang/srt/models/phi4mm.py +1 -1
  309. sglang/srt/models/phimoe.py +0 -1
  310. sglang/srt/models/pixtral.py +0 -3
  311. sglang/srt/models/points_v15_chat.py +186 -0
  312. sglang/srt/models/qwen.py +0 -1
  313. sglang/srt/models/qwen2_5_vl.py +3 -3
  314. sglang/srt/models/qwen2_audio.py +2 -15
  315. sglang/srt/models/qwen2_moe.py +15 -12
  316. sglang/srt/models/qwen2_vl.py +5 -2
  317. sglang/srt/models/qwen3_moe.py +19 -35
  318. sglang/srt/models/qwen3_next.py +7 -12
  319. sglang/srt/models/qwen3_next_mtp.py +3 -4
  320. sglang/srt/models/qwen3_omni_moe.py +661 -0
  321. sglang/srt/models/qwen3_vl.py +37 -33
  322. sglang/srt/models/qwen3_vl_moe.py +57 -185
  323. sglang/srt/models/roberta.py +55 -3
  324. sglang/srt/models/sarashina2_vision.py +0 -1
  325. sglang/srt/models/step3_vl.py +3 -5
  326. sglang/srt/models/utils.py +11 -1
  327. sglang/srt/multimodal/processors/base_processor.py +6 -2
  328. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  329. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  330. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  331. sglang/srt/multimodal/processors/glm4v.py +1 -5
  332. sglang/srt/multimodal/processors/internvl.py +0 -2
  333. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  334. sglang/srt/multimodal/processors/mllama4.py +0 -8
  335. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  336. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  337. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  338. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  339. sglang/srt/parser/conversation.py +41 -0
  340. sglang/srt/parser/reasoning_parser.py +0 -1
  341. sglang/srt/sampling/custom_logit_processor.py +77 -2
  342. sglang/srt/sampling/sampling_batch_info.py +17 -22
  343. sglang/srt/sampling/sampling_params.py +70 -2
  344. sglang/srt/server_args.py +577 -73
  345. sglang/srt/server_args_config_parser.py +1 -1
  346. sglang/srt/single_batch_overlap.py +38 -28
  347. sglang/srt/speculative/base_spec_worker.py +34 -0
  348. sglang/srt/speculative/draft_utils.py +226 -0
  349. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  350. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  351. sglang/srt/speculative/eagle_info.py +57 -18
  352. sglang/srt/speculative/eagle_info_v2.py +458 -0
  353. sglang/srt/speculative/eagle_utils.py +138 -0
  354. sglang/srt/speculative/eagle_worker.py +83 -280
  355. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  356. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  357. sglang/srt/speculative/ngram_worker.py +12 -11
  358. sglang/srt/speculative/spec_info.py +2 -0
  359. sglang/srt/speculative/spec_utils.py +38 -3
  360. sglang/srt/speculative/standalone_worker.py +4 -14
  361. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  362. sglang/srt/two_batch_overlap.py +28 -14
  363. sglang/srt/utils/__init__.py +1 -1
  364. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  365. sglang/srt/utils/common.py +192 -47
  366. sglang/srt/utils/hf_transformers_utils.py +40 -17
  367. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  368. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  369. sglang/srt/utils/profile_merger.py +199 -0
  370. sglang/test/attention/test_flashattn_backend.py +1 -1
  371. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  372. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  373. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  374. sglang/test/few_shot_gsm8k_engine.py +2 -4
  375. sglang/test/kit_matched_stop.py +157 -0
  376. sglang/test/longbench_v2/__init__.py +1 -0
  377. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  378. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  379. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  380. sglang/test/run_eval.py +41 -0
  381. sglang/test/runners.py +2 -0
  382. sglang/test/send_one.py +42 -7
  383. sglang/test/simple_eval_common.py +3 -0
  384. sglang/test/simple_eval_gpqa.py +0 -1
  385. sglang/test/simple_eval_humaneval.py +0 -3
  386. sglang/test/simple_eval_longbench_v2.py +344 -0
  387. sglang/test/test_block_fp8.py +1 -2
  388. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  389. sglang/test/test_cutlass_moe.py +1 -2
  390. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  391. sglang/test/test_deterministic.py +232 -99
  392. sglang/test/test_deterministic_utils.py +73 -0
  393. sglang/test/test_disaggregation_utils.py +81 -0
  394. sglang/test/test_marlin_moe.py +0 -1
  395. sglang/test/test_utils.py +85 -20
  396. sglang/version.py +1 -1
  397. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
  398. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
  399. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  400. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  401. sglang/srt/speculative/build_eagle_tree.py +0 -427
  402. sglang/test/test_block_fp8_ep.py +0 -358
  403. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  404. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  405. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  406. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -7,6 +7,7 @@ FlashInfer is faster and Triton is easier to customize.
7
7
  Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
8
8
  """
9
9
 
10
+ import logging
10
11
  import os
11
12
  from dataclasses import dataclass
12
13
  from enum import Enum, auto
@@ -15,20 +16,13 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Union
15
16
 
16
17
  import torch
17
18
 
18
- if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
19
- import logging
20
-
21
- torch._logging.set_logs(dynamo=logging.ERROR)
22
- torch._dynamo.config.suppress_errors = True
23
-
24
- from sglang.global_config import global_config
19
+ from sglang.srt.environ import envs
25
20
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
26
21
  from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
27
22
  from sglang.srt.layers.dp_attention import get_attention_tp_size
28
23
  from sglang.srt.layers.radix_attention import AttentionType
29
24
  from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
30
25
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
31
- from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
32
26
  from sglang.srt.speculative.spec_info import SpecInput
33
27
  from sglang.srt.utils import (
34
28
  get_int_env_var,
@@ -41,6 +35,12 @@ if TYPE_CHECKING:
41
35
  from sglang.srt.layers.radix_attention import RadixAttention
42
36
  from sglang.srt.model_executor.model_runner import ModelRunner
43
37
 
38
+ logger = logging.getLogger(__name__)
39
+
40
+ if envs.SGLANG_ENABLE_TORCH_COMPILE.get():
41
+ torch._logging.set_logs(dynamo=logging.ERROR)
42
+ torch._dynamo.config.suppress_errors = True
43
+
44
44
 
45
45
  if is_flashinfer_available():
46
46
  from flashinfer import (
@@ -50,7 +50,6 @@ if is_flashinfer_available():
50
50
  fast_decode_plan,
51
51
  )
52
52
  from flashinfer.cascade import merge_state
53
- from flashinfer.decode import _get_range_buf, get_seq_lens
54
53
 
55
54
 
56
55
  class WrapperDispatch(Enum):
@@ -58,6 +57,36 @@ class WrapperDispatch(Enum):
58
57
  CROSS_ATTENTION = auto()
59
58
 
60
59
 
60
+ @dataclass
61
+ class MultiItemScoringParams:
62
+ """Parameters for multi-item scoring in attention computation.
63
+
64
+ Used when processing sequences with multiple items separated by delimiters,
65
+ where each item needs specific attention patterns that respect item boundaries.
66
+
67
+ Attributes:
68
+ prefix_len_ptr: A uint32 1D tensor indicating the prefix length of each prompt.
69
+ The tensor size is equal to the batch size.
70
+ token_pos_in_items_ptr: A uint16 1D tensor indicating the token position of each item
71
+ starting from 0 (delimiter) for each item. For batch size > 1,
72
+ sequences are concatenated with zero padding to ensure same length.
73
+ token_pos_in_items_len: Zero padding length for token_pos_in_items_ptr to handle
74
+ batch_size > 1 case. Defines the padded length for each sequence.
75
+ max_item_len_ptr: A uint16 tensor containing the max token length of all items
76
+ for each prompt in the batch.
77
+
78
+ """
79
+
80
+ prefix_len_ptr: Optional[torch.Tensor] = None
81
+ token_pos_in_items_ptr: Optional[torch.Tensor] = None
82
+ token_pos_in_items_len: int = 0
83
+ max_item_len_ptr: Optional[torch.Tensor] = None
84
+
85
+ def is_enabled(self) -> bool:
86
+ """Check if multi-item scoring is enabled."""
87
+ return self.prefix_len_ptr is not None
88
+
89
+
61
90
  @dataclass
62
91
  class DecodeMetadata:
63
92
  decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper]
@@ -68,6 +97,7 @@ class PrefillMetadata:
68
97
  prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper]
69
98
  use_ragged: bool
70
99
  extend_no_prefix: bool
100
+ multi_item_params: Optional[MultiItemScoringParams] = None
71
101
 
72
102
 
73
103
  # Reuse this workspace buffer across all flashinfer wrappers
@@ -87,9 +117,15 @@ class FlashInferAttnBackend(AttentionBackend):
87
117
  skip_prefill: bool = False,
88
118
  kv_indptr_buf: Optional[torch.Tensor] = None,
89
119
  kv_last_page_len_buf: Optional[torch.Tensor] = None,
120
+ init_new_workspace: bool = False,
90
121
  ):
91
122
  super().__init__()
92
123
 
124
+ # Store multi-item scoring delimiter for efficient access
125
+ self.multi_item_scoring_delimiter = (
126
+ model_runner.server_args.multi_item_scoring_delimiter
127
+ )
128
+
93
129
  # Parse constants
94
130
  self.decode_use_tensor_cores = should_use_tensor_core(
95
131
  kv_cache_dtype=model_runner.kv_cache_dtype,
@@ -124,7 +160,7 @@ class FlashInferAttnBackend(AttentionBackend):
124
160
  or "Qwen3ForCausalLM" in model_runner.model_config.hf_config.architectures
125
161
  or "MiMoForCausalLM" in model_runner.model_config.hf_config.architectures
126
162
  ):
127
- global_config.flashinfer_workspace_size = 512 * 1024 * 1024
163
+ envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.set(512 * 1024 * 1024)
128
164
 
129
165
  # When deterministic inference is enabled, tensor cores should be used for decode
130
166
  # Also set split tile sizes for prefill and decode from environment variables, and disable kv split for cuda graph
@@ -144,19 +180,26 @@ class FlashInferAttnBackend(AttentionBackend):
144
180
  "SGLANG_FLASHINFER_DECODE_SPLIT_TILE_SIZE", 2048
145
181
  )
146
182
  self.disable_cuda_graph_kv_split = True
147
- global_config.flashinfer_workspace_size = 2048 * 1024 * 1024
183
+ envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.set(2048 * 1024 * 1024)
148
184
 
149
185
  # Allocate buffers
150
186
  global global_workspace_buffer
151
187
  if global_workspace_buffer is None:
152
188
  # different from flashinfer zero_init_global_workspace_buffer
153
- global_workspace_size = global_config.flashinfer_workspace_size
189
+ global_workspace_size = envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.get()
154
190
  global_workspace_buffer = torch.empty(
155
191
  global_workspace_size,
156
192
  dtype=torch.uint8,
157
193
  device=model_runner.device,
158
194
  )
159
- self.workspace_buffer = global_workspace_buffer
195
+ if init_new_workspace:
196
+ self.workspace_buffer = torch.empty(
197
+ envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.get(),
198
+ dtype=torch.uint8,
199
+ device=model_runner.device,
200
+ )
201
+ else:
202
+ self.workspace_buffer = global_workspace_buffer
160
203
  max_bs = model_runner.req_to_token_pool.size
161
204
  if kv_indptr_buf is None:
162
205
  self.kv_indptr = [
@@ -229,10 +272,133 @@ class FlashInferAttnBackend(AttentionBackend):
229
272
 
230
273
  # Other metadata
231
274
  self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
275
+
232
276
  self.decode_cuda_graph_metadata = {}
233
277
  self.prefill_cuda_graph_metadata = {} # For verify
234
278
  self.draft_extend_cuda_graph_metadata = {} # For draft extend
235
279
 
280
+ def _process_multi_item_scoring(
281
+ self, forward_batch: ForwardBatch
282
+ ) -> MultiItemScoringParams:
283
+ """Process multi-item scoring tensors for FlashInfer attention.
284
+
285
+ This method handles sequences containing multiple "items" separated by delimiter tokens,
286
+ where each item needs specific attention patterns that respect item boundaries.
287
+
288
+ The method produces four key tensors for FlashInfer:
289
+ - prefix_len_ptr: uint32 tensor with prefix length for each prompt in batch
290
+ - token_pos_in_items_ptr: uint16 tensor with token positions starting from 0 at delimiters
291
+ - token_pos_in_items_len: padding length for batch processing
292
+ - max_item_len_ptr: uint16 tensor with max item length for each prompt
293
+
294
+ Args:
295
+ forward_batch: The forward batch containing input sequences and delimiter info
296
+
297
+ Returns:
298
+ MultiItemScoringParams: The processed multi-item scoring parameters
299
+
300
+ Examples:
301
+ Following FlashInfer definition: for 3 items of length 3, 2, 4 respectively:
302
+ token_pos_in_items_ptr = [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0]
303
+
304
+ Case 1: Single sequence
305
+ Text: "What is the capital of France? <delim> London <delim> Paris <delim> Berlin <delim>"
306
+ Tokens: [What, is, the, capital, of, France, ?, <delim>, London, <delim>, Paris, <delim>, Berlin, <delim>]
307
+ Indices: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
308
+ - prefix_len_ptr: [7] (query length before first delimiter)
309
+ - token_pos_in_items_ptr: [0, 1, 0, 1, 0, 1, 0] (delim=0, London=1, delim=0, Paris=1, delim=0, Berlin=1, delim=0)
310
+ - token_pos_in_items_len: 7 (actual length)
311
+ - max_item_len_ptr: [1] (max item length is 1 token - all options are single tokens)
312
+
313
+ Case 2: Batch processing (batch_size=2)
314
+ Sequence 1: 2 items of length 2, 1 → [0, 1, 2, 0, 1, 0] (6 elements)
315
+ Sequence 2: 3 items of length 1, 3, 2 → [0, 1, 0, 1, 2, 3, 0, 1, 2, 0] (10 elements)
316
+ After padding both to length 10:
317
+ - token_pos_in_items_ptr: [0, 1, 2, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 2, 3, 0, 1, 2, 0]
318
+ - token_pos_in_items_len: 10 (padded length for batch processing)
319
+ - max_item_len_ptr: [2, 3] (max lengths per sequence)
320
+ """
321
+
322
+ delimiter = self.multi_item_scoring_delimiter
323
+ if delimiter is None or forward_batch.forward_mode == ForwardMode.DECODE:
324
+ return MultiItemScoringParams()
325
+
326
+ delimiter_mask = forward_batch.input_ids == delimiter
327
+ prefix_cache_lens = getattr(forward_batch, "extend_prefix_lens", None)
328
+ extend_seq_lens = getattr(forward_batch, "extend_seq_lens", None)
329
+ prefix_len_ptr, token_pos_in_items_ptr = [], []
330
+ token_pos_in_items_len = 0
331
+
332
+ # If no extend_seq_lens, treat whole batch as one sequence
333
+ if extend_seq_lens is None or len(extend_seq_lens) <= 1:
334
+ extend_seq_lens = [forward_batch.input_ids.size(0)]
335
+
336
+ seq_start = 0
337
+ for i, seq_len in enumerate(extend_seq_lens):
338
+ seq_end = seq_start + seq_len
339
+ mask = delimiter_mask[seq_start:seq_end]
340
+ pos = forward_batch.positions[seq_start:seq_end]
341
+ delimiter_indices = torch.nonzero(mask, as_tuple=True)[0]
342
+
343
+ if len(delimiter_indices) > 0:
344
+ first_delim = delimiter_indices[0]
345
+ # Prefix length: store as scalar
346
+ prefix_len = first_delim + (
347
+ prefix_cache_lens[i] if prefix_cache_lens is not None else 0
348
+ )
349
+ prefix_len_ptr.append(
350
+ prefix_len.item() if torch.is_tensor(prefix_len) else prefix_len
351
+ )
352
+
353
+ # Compute relative positions within items after delimiters
354
+ diff = pos[first_delim:] - torch.cummax(mask[first_delim:], 0)[1]
355
+ token_pos = (diff - pos[first_delim]).to(torch.uint16)
356
+ token_pos_in_items_ptr.append(token_pos)
357
+
358
+ # Update forward_batch positions in-place
359
+ pos[first_delim:] = diff - 1
360
+ forward_batch.positions[seq_start:seq_end] = pos
361
+
362
+ seq_start = seq_end
363
+
364
+ # Pad token_pos_in_items_ptr for batch processing
365
+ if token_pos_in_items_ptr:
366
+ token_pos_in_items_len = max(t.numel() for t in token_pos_in_items_ptr)
367
+ device = forward_batch.input_ids.device
368
+ token_pos_in_items_ptr = [
369
+ torch.cat(
370
+ [
371
+ t,
372
+ torch.zeros(
373
+ token_pos_in_items_len - t.numel(),
374
+ dtype=torch.uint16,
375
+ device=device,
376
+ ),
377
+ ]
378
+ )
379
+ for t in token_pos_in_items_ptr
380
+ ]
381
+
382
+ if not prefix_len_ptr or not token_pos_in_items_ptr:
383
+ return MultiItemScoringParams()
384
+
385
+ # Build final params
386
+ device = forward_batch.input_ids.device
387
+ return MultiItemScoringParams(
388
+ prefix_len_ptr=torch.tensor(
389
+ prefix_len_ptr, dtype=torch.uint32, device=device
390
+ ),
391
+ token_pos_in_items_ptr=torch.cat(token_pos_in_items_ptr, dim=0),
392
+ token_pos_in_items_len=token_pos_in_items_len & 0xFFFFFFFF,
393
+ max_item_len_ptr=torch.stack(
394
+ [
395
+ t.to(torch.int32).max().to(torch.uint16)
396
+ for t in token_pos_in_items_ptr
397
+ ],
398
+ dim=0,
399
+ ),
400
+ )
401
+
236
402
  def init_forward_metadata(self, forward_batch: ForwardBatch):
237
403
  if forward_batch.forward_mode.is_decode_or_idle():
238
404
  self.indices_updater_decode.update(
@@ -280,13 +446,26 @@ class FlashInferAttnBackend(AttentionBackend):
280
446
  else:
281
447
  prefix_lens = forward_batch.extend_prefix_lens
282
448
 
283
- if self.is_multimodal:
449
+ # Disable ragged wrapper and ensure prefix handling for multimodal and multi-item scoring
450
+ if self.is_multimodal or self.multi_item_scoring_delimiter is not None:
451
+ # use_ragged = False: Multi-item scoring requires the paged wrapper because:
452
+ # 1. Ragged wrapper doesn't support the specialized multi-item parameters
453
+ # (prefix_len_ptr, token_pos_in_items_ptr, etc.)
454
+ # 2. Paged wrapper provides better control over attention masking needed
455
+ # for respecting item boundaries in multi-item sequences
456
+ # 3. Custom masking logic conflicts with ragged wrapper's assumptions
284
457
  use_ragged = False
285
458
  extend_no_prefix = False
286
459
  else:
287
460
  use_ragged = not self.enable_deterministic
288
461
  extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
289
462
 
463
+ # Process multi-item scoring in attention backend instead of ForwardBatch
464
+ multi_item_params = MultiItemScoringParams()
465
+ if self.multi_item_scoring_delimiter is not None:
466
+ # Use new backend-specific implementation
467
+ multi_item_params = self._process_multi_item_scoring(forward_batch)
468
+
290
469
  self.indices_updater_prefill.update(
291
470
  forward_batch.req_pool_indices,
292
471
  forward_batch.seq_lens,
@@ -298,9 +477,13 @@ class FlashInferAttnBackend(AttentionBackend):
298
477
  encoder_lens=forward_batch.encoder_lens,
299
478
  spec_info=None,
300
479
  fixed_split_size=self.prefill_split_tile_size,
480
+ multi_item_params=multi_item_params,
301
481
  )
302
482
  self.forward_metadata = PrefillMetadata(
303
- self.prefill_wrappers_paged, use_ragged, extend_no_prefix
483
+ self.prefill_wrappers_paged,
484
+ use_ragged,
485
+ extend_no_prefix,
486
+ multi_item_params,
304
487
  )
305
488
 
306
489
  def init_cuda_graph_state(
@@ -531,7 +714,20 @@ class FlashInferAttnBackend(AttentionBackend):
531
714
  forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
532
715
  causal=not layer.is_cross_attention,
533
716
  sm_scale=layer.scaling,
534
- window_left=layer.sliding_window_size,
717
+ # Disable sliding window attention for multi-item scoring:
718
+ # - Sliding window could cut across item boundaries, breaking semantic coherence
719
+ # - Multi-item sequences need full attention to properly handle delimiter tokens
720
+ # - Specialized multi-item parameters (prefix_len_ptr, token_pos_in_items_ptr)
721
+ # provide more precise attention control than simple sliding windows
722
+ # - Item-aware masking takes precedence over window-based masking
723
+ window_left=(
724
+ layer.sliding_window_size
725
+ if not (
726
+ self.forward_metadata.multi_item_params
727
+ and self.forward_metadata.multi_item_params.is_enabled()
728
+ )
729
+ else -1
730
+ ),
535
731
  logits_soft_cap=logits_soft_cap,
536
732
  # Must use _float to avoid device-to-host copy that breaks cuda graph capture.
537
733
  k_scale=layer.k_scale_float,
@@ -539,9 +735,13 @@ class FlashInferAttnBackend(AttentionBackend):
539
735
  )
540
736
  else:
541
737
  causal = True
542
- if layer.attn_type == AttentionType.ENCODER_ONLY:
543
- save_kv_cache = False
738
+ if (
739
+ layer.is_cross_attention
740
+ or layer.attn_type == AttentionType.ENCODER_ONLY
741
+ ):
544
742
  causal = False
743
+ if save_kv_cache and layer.attn_type == AttentionType.ENCODER_ONLY:
744
+ save_kv_cache = False
545
745
 
546
746
  if self.forward_metadata.extend_no_prefix:
547
747
  # NOTE: FlashInfer currently has limitations with head_dim = 32 or other dimensions
@@ -952,6 +1152,7 @@ class FlashInferIndicesUpdaterPrefill:
952
1152
  encoder_lens: Optional[torch.Tensor],
953
1153
  spec_info: Optional[SpecInput],
954
1154
  fixed_split_size: Optional[int] = None,
1155
+ multi_item_params: Optional[MultiItemScoringParams] = None,
955
1156
  ):
956
1157
  if use_ragged:
957
1158
  # TODO: remove this device sync, we can use forward_batch.extend_prefix_lens_cpu
@@ -976,6 +1177,7 @@ class FlashInferIndicesUpdaterPrefill:
976
1177
  use_ragged,
977
1178
  spec_info,
978
1179
  fixed_split_size=fixed_split_size,
1180
+ multi_item_params=multi_item_params,
979
1181
  )
980
1182
 
981
1183
  def update_sliding_window(
@@ -990,6 +1192,7 @@ class FlashInferIndicesUpdaterPrefill:
990
1192
  encoder_lens: Optional[torch.Tensor],
991
1193
  spec_info: Optional[SpecInput],
992
1194
  fixed_split_size: Optional[int] = None,
1195
+ multi_item_params: Optional[MultiItemScoringParams] = None,
993
1196
  ):
994
1197
  for wrapper_id in range(2):
995
1198
  if wrapper_id == 0:
@@ -1023,6 +1226,7 @@ class FlashInferIndicesUpdaterPrefill:
1023
1226
  use_ragged,
1024
1227
  spec_info,
1025
1228
  use_sliding_window_kv_pool=use_sliding_window_kv_pool,
1229
+ multi_item_params=multi_item_params,
1026
1230
  )
1027
1231
 
1028
1232
  def update_cross_attention(
@@ -1037,6 +1241,7 @@ class FlashInferIndicesUpdaterPrefill:
1037
1241
  encoder_lens: Optional[torch.Tensor],
1038
1242
  spec_info: Optional[SpecInput],
1039
1243
  fixed_split_size: Optional[int] = None,
1244
+ multi_item_params: Optional[MultiItemScoringParams] = None,
1040
1245
  ):
1041
1246
  for wrapper_id in range(2):
1042
1247
  if wrapper_id == 0:
@@ -1063,6 +1268,7 @@ class FlashInferIndicesUpdaterPrefill:
1063
1268
  self.qo_indptr[wrapper_id],
1064
1269
  use_ragged,
1065
1270
  spec_info,
1271
+ multi_item_params=multi_item_params,
1066
1272
  )
1067
1273
 
1068
1274
  def call_begin_forward(
@@ -1081,6 +1287,7 @@ class FlashInferIndicesUpdaterPrefill:
1081
1287
  spec_info: Optional[SpecInput],
1082
1288
  use_sliding_window_kv_pool: bool = False,
1083
1289
  fixed_split_size: Optional[int] = None,
1290
+ multi_item_params: Optional[MultiItemScoringParams] = None,
1084
1291
  ):
1085
1292
  bs = len(seq_lens)
1086
1293
  if spec_info is None:
@@ -1136,6 +1343,22 @@ class FlashInferIndicesUpdaterPrefill:
1136
1343
  )
1137
1344
 
1138
1345
  # cached part
1346
+ # Conditionally set multi-item parameters
1347
+ if multi_item_params is not None and multi_item_params.is_enabled():
1348
+ # Multi-item scoring is active - use specialized parameters and disable generic custom_mask
1349
+ use_custom_mask = None
1350
+ prefix_len_ptr = multi_item_params.prefix_len_ptr
1351
+ token_pos_in_items_ptr = multi_item_params.token_pos_in_items_ptr
1352
+ token_pos_in_items_len = multi_item_params.token_pos_in_items_len
1353
+ max_item_len_ptr = multi_item_params.max_item_len_ptr
1354
+ else:
1355
+ # No multi-item scoring - use standard parameters
1356
+ use_custom_mask = custom_mask
1357
+ prefix_len_ptr = None
1358
+ token_pos_in_items_ptr = None
1359
+ token_pos_in_items_len = 0
1360
+ max_item_len_ptr = None
1361
+
1139
1362
  wrapper_paged.begin_forward(
1140
1363
  qo_indptr,
1141
1364
  kv_indptr,
@@ -1147,9 +1370,13 @@ class FlashInferIndicesUpdaterPrefill:
1147
1370
  1,
1148
1371
  q_data_type=self.q_data_type,
1149
1372
  kv_data_type=self.data_type,
1150
- custom_mask=custom_mask,
1373
+ custom_mask=use_custom_mask,
1151
1374
  non_blocking=True,
1152
1375
  fixed_split_size=fixed_split_size,
1376
+ prefix_len_ptr=prefix_len_ptr,
1377
+ token_pos_in_items_ptr=token_pos_in_items_ptr,
1378
+ token_pos_in_items_len=token_pos_in_items_len,
1379
+ max_item_len_ptr=max_item_len_ptr,
1153
1380
  )
1154
1381
 
1155
1382
 
@@ -1185,7 +1412,7 @@ class FlashInferMultiStepDraftBackend:
1185
1412
  (max_bs,), dtype=torch.int32, device=model_runner.device
1186
1413
  )
1187
1414
  self.attn_backends: List[FlashInferAttnBackend] = []
1188
- for i in range(self.speculative_num_steps):
1415
+ for i in range(self.speculative_num_steps - 1):
1189
1416
  self.attn_backends.append(
1190
1417
  FlashInferAttnBackend(
1191
1418
  model_runner,
@@ -1273,7 +1500,7 @@ class FlashInferMultiStepDraftBackend:
1273
1500
  device="cuda",
1274
1501
  )
1275
1502
 
1276
- for i in range(self.speculative_num_steps):
1503
+ for i in range(self.speculative_num_steps - 1):
1277
1504
  self.attn_backends[i].init_cuda_graph_state(
1278
1505
  max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
1279
1506
  )
@@ -9,27 +9,20 @@ and uses BatchMLAPaged wrapper for decoding.
9
9
  More details can be found in https://docs.flashinfer.ai/api/mla.html
10
10
  """
11
11
 
12
- import os
13
12
  from dataclasses import dataclass
14
13
  from functools import partial
15
14
  from typing import TYPE_CHECKING, Callable, Optional, Union
16
15
 
17
16
  import torch
18
17
 
19
- if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
20
- import logging
21
-
22
- torch._logging.set_logs(dynamo=logging.ERROR)
23
- torch._dynamo.config.suppress_errors = True
24
-
25
- from sglang.global_config import global_config
18
+ from sglang.srt.environ import envs
26
19
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
27
20
  from sglang.srt.layers.attention.flashinfer_backend import (
28
21
  create_flashinfer_kv_indices_triton,
29
22
  )
30
23
  from sglang.srt.layers.dp_attention import get_attention_tp_size
31
- from sglang.srt.managers.schedule_batch import global_server_args_dict
32
24
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
25
+ from sglang.srt.server_args import get_global_server_args
33
26
  from sglang.srt.speculative.spec_info import SpecInput
34
27
  from sglang.srt.utils import (
35
28
  is_flashinfer_available,
@@ -38,10 +31,19 @@ from sglang.srt.utils import (
38
31
  )
39
32
 
40
33
  if TYPE_CHECKING:
34
+ from sglang.srt.layers.attention.flashinfer_mla_backend import (
35
+ FlashInferMlaAttnBackend,
36
+ )
41
37
  from sglang.srt.layers.radix_attention import RadixAttention
42
38
  from sglang.srt.model_executor.model_runner import ModelRunner
43
39
  from sglang.srt.speculative.spec_info import SpecInput
44
40
 
41
+ if envs.SGLANG_ENABLE_TORCH_COMPILE.get():
42
+ import logging
43
+
44
+ torch._logging.set_logs(dynamo=logging.ERROR)
45
+ torch._dynamo.config.suppress_errors = True
46
+
45
47
  if is_flashinfer_available():
46
48
  from flashinfer import (
47
49
  BatchMLAPagedAttentionWrapper,
@@ -66,7 +68,7 @@ global_workspace_buffer = None
66
68
 
67
69
  class FlashInferMhaChunkKVRunner:
68
70
  def __init__(
69
- self, model_runner: ModelRunner, attn_backend: "FlashInferMlaAttnBackend"
71
+ self, model_runner: ModelRunner, attn_backend: FlashInferMlaAttnBackend
70
72
  ):
71
73
  # Parse Constants
72
74
  self.num_local_heads = (
@@ -193,9 +195,9 @@ class FlashInferMLAAttnBackend(AttentionBackend):
193
195
  self.skip_prefill = skip_prefill
194
196
  self.enable_chunk_kv = (
195
197
  not skip_prefill
196
- and global_server_args_dict["disaggregation_mode"] != "decode"
197
- and not global_server_args_dict["disable_chunked_prefix_cache"]
198
- and not global_server_args_dict["flashinfer_mla_disable_ragged"]
198
+ and get_global_server_args().disaggregation_mode != "decode"
199
+ and not get_global_server_args().disable_chunked_prefix_cache
200
+ and not get_global_server_args().flashinfer_mla_disable_ragged
199
201
  )
200
202
  self.page_size = model_runner.page_size
201
203
 
@@ -204,7 +206,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
204
206
  if global_workspace_buffer is None:
205
207
  # different from flashinfer zero_init_global_workspace_buffer
206
208
  global_workspace_buffer = torch.empty(
207
- global_config.flashinfer_workspace_size,
209
+ envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.get(),
208
210
  dtype=torch.uint8,
209
211
  device=model_runner.device,
210
212
  )
@@ -306,7 +308,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
306
308
  prefix_lens = forward_batch.extend_prefix_lens
307
309
  extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
308
310
  use_ragged = (
309
- not global_server_args_dict["flashinfer_mla_disable_ragged"]
311
+ not get_global_server_args().flashinfer_mla_disable_ragged
310
312
  and extend_no_prefix
311
313
  )
312
314
 
@@ -916,7 +918,7 @@ class FlashInferMLAMultiStepDraftBackend:
916
918
  )
917
919
 
918
920
  self.attn_backends = []
919
- for i in range(self.speculative_num_steps):
921
+ for i in range(self.speculative_num_steps - 1):
920
922
  self.attn_backends.append(
921
923
  FlashInferMLAAttnBackend(
922
924
  model_runner,
@@ -998,7 +1000,7 @@ class FlashInferMLAMultiStepDraftBackend:
998
1000
  device="cuda",
999
1001
  )
1000
1002
 
1001
- for i in range(self.speculative_num_steps):
1003
+ for i in range(self.speculative_num_steps - 1):
1002
1004
  self.attn_backends[i].init_cuda_graph_state(
1003
1005
  max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
1004
1006
  )
@@ -1060,7 +1062,7 @@ def fast_mla_decode_plan(
1060
1062
 
1061
1063
  try:
1062
1064
  # Standard version with just the required arguments (no use_profiler)
1063
- self._cached_module.plan.default(
1065
+ self._cached_module.plan(
1064
1066
  self._float_workspace_buffer,
1065
1067
  self._int_workspace_buffer,
1066
1068
  self._pin_memory_int_workspace_buffer,
@@ -478,7 +478,7 @@ class FlashMLAMultiStepDraftBackend:
478
478
  )
479
479
 
480
480
  self.attn_backends = []
481
- for i in range(self.speculative_num_steps):
481
+ for i in range(self.speculative_num_steps - 1):
482
482
  self.attn_backends.append(
483
483
  FlashMLABackend(
484
484
  model_runner,
@@ -506,7 +506,7 @@ class FlashMLAMultiStepDraftBackend:
506
506
  self.common_template(forward_batch, call_fn)
507
507
 
508
508
  def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
509
- for i in range(self.speculative_num_steps):
509
+ for i in range(self.speculative_num_steps - 1):
510
510
  self.attn_backends[i].init_cuda_graph_state(
511
511
  max_bs, max_num_tokens, block_kv_indices=None
512
512
  )
@@ -1,4 +1,4 @@
1
- from typing import Optional, Union
1
+ from typing import Optional
2
2
 
3
3
  import torch
4
4