sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
- from typing import TYPE_CHECKING, Optional, Union
4
+ from typing import TYPE_CHECKING, List, Optional
5
5
 
6
6
  import torch
7
7
  import triton
@@ -12,12 +12,18 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito
12
12
  from sglang.srt.layers.dp_attention import get_attention_tp_size
13
13
  from sglang.srt.layers.radix_attention import AttentionType
14
14
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
15
- from sglang.srt.utils import get_bool_env_var, get_device_core_count, next_power_of_2
15
+ from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
16
+ from sglang.srt.utils import (
17
+ get_bool_env_var,
18
+ get_device_core_count,
19
+ get_int_env_var,
20
+ next_power_of_2,
21
+ )
16
22
 
17
23
  if TYPE_CHECKING:
18
24
  from sglang.srt.layers.radix_attention import RadixAttention
19
25
  from sglang.srt.model_executor.model_runner import ModelRunner
20
- from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
26
+ from sglang.srt.speculative.spec_info import SpecInput
21
27
 
22
28
 
23
29
  def logit_capping_mod(logit_capping_method, logit_cap):
@@ -58,13 +64,19 @@ class TritonAttnBackend(AttentionBackend):
58
64
  decode_attention_fwd,
59
65
  )
60
66
  from sglang.srt.layers.attention.triton_ops.extend_attention import (
67
+ build_unified_kv_indices,
61
68
  extend_attention_fwd,
69
+ extend_attention_fwd_unified,
62
70
  )
63
71
 
64
72
  super().__init__()
65
73
 
66
74
  self.decode_attention_fwd = torch.compiler.disable(decode_attention_fwd)
67
75
  self.extend_attention_fwd = torch.compiler.disable(extend_attention_fwd)
76
+ self.extend_attention_fwd_unified = torch.compiler.disable(
77
+ extend_attention_fwd_unified
78
+ )
79
+ self.build_unified_kv_indices = torch.compiler.disable(build_unified_kv_indices)
68
80
 
69
81
  # Parse args
70
82
  self.skip_prefill = skip_prefill
@@ -80,7 +92,7 @@ class TritonAttnBackend(AttentionBackend):
80
92
  self.num_kv_head = model_runner.model_config.get_num_kv_heads(
81
93
  get_attention_tp_size()
82
94
  )
83
- if model_runner.is_hybrid_gdn:
95
+ if model_runner.hybrid_gdn_config is not None:
84
96
  # For hybrid linear models, layer_id = 0 may not be full attention
85
97
  self.v_head_dim = model_runner.token_to_kv_pool.get_v_head_dim()
86
98
  else:
@@ -94,7 +106,25 @@ class TritonAttnBackend(AttentionBackend):
94
106
  "SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false"
95
107
  )
96
108
  self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
97
- self.split_tile_size = model_runner.server_args.triton_attention_split_tile_size
109
+
110
+ # Decide whether enable deterministic inference with batch-invariant operations
111
+ self.enable_deterministic = (
112
+ model_runner.server_args.enable_deterministic_inference
113
+ )
114
+
115
+ # Configure deterministic inference settings
116
+ if self.enable_deterministic:
117
+ # Use fixed split tile size for batch invariance
118
+ self.split_tile_size = get_int_env_var(
119
+ "SGLANG_TRITON_DECODE_SPLIT_TILE_SIZE", 256
120
+ )
121
+ # Set static_kv_splits to False to use deterministic logic instead
122
+ self.static_kv_splits = False
123
+ else:
124
+ self.split_tile_size = (
125
+ model_runner.server_args.triton_attention_split_tile_size
126
+ )
127
+
98
128
  if self.split_tile_size is not None:
99
129
  self.max_kv_splits = (
100
130
  self.max_context_len + self.split_tile_size - 1
@@ -139,6 +169,8 @@ class TritonAttnBackend(AttentionBackend):
139
169
  # Initialize forward metadata
140
170
  self.forward_metadata: ForwardMetadata = None
141
171
 
172
+ self.cuda_graph_custom_mask = None
173
+
142
174
  def get_num_kv_splits(
143
175
  self,
144
176
  num_kv_splits: torch.Tensor,
@@ -154,13 +186,23 @@ class TritonAttnBackend(AttentionBackend):
154
186
  num_group * num_seq == num_token
155
187
  ), f"num_seq({num_seq}), num_token({num_token}), something goes wrong!"
156
188
 
157
- if self.static_kv_splits or self.device_core_count <= 0:
189
+ # Legacy dynamic splitting logic (non-deterministic)
190
+ if (
191
+ self.static_kv_splits or self.device_core_count <= 0
192
+ ) and not self.enable_deterministic:
158
193
  num_kv_splits.fill_(self.max_kv_splits)
159
194
  return
160
195
 
161
- if self.split_tile_size is not None:
196
+ # deterministic
197
+ if self.split_tile_size is not None and self.enable_deterministic:
198
+ # expand seq_lens to match num_token
199
+ if num_group > 1:
200
+ expanded_seq_lens = seq_lens.repeat_interleave(num_group)
201
+ else:
202
+ expanded_seq_lens = seq_lens
203
+
162
204
  num_kv_splits[:] = (
163
- seq_lens + self.split_tile_size - 1
205
+ expanded_seq_lens + self.split_tile_size - 1
164
206
  ) // self.split_tile_size
165
207
  return
166
208
 
@@ -329,7 +371,7 @@ class TritonAttnBackend(AttentionBackend):
329
371
  )
330
372
  kv_indptr = kv_indptr[: bs + 1]
331
373
  kv_indices = torch.empty(
332
- forward_batch.extend_prefix_lens.sum().item(),
374
+ sum(forward_batch.extend_prefix_lens_cpu),
333
375
  dtype=torch.int64,
334
376
  device=self.device,
335
377
  )
@@ -388,6 +430,7 @@ class TritonAttnBackend(AttentionBackend):
388
430
  max_bs: int,
389
431
  max_num_tokens: int,
390
432
  kv_indices_buf: Optional[torch.Tensor] = None,
433
+ cuda_graph_num_kv_splits_buf: Optional[torch.Tensor] = None,
391
434
  ):
392
435
  self.cuda_graph_attn_logits = torch.zeros(
393
436
  (max_num_tokens, self.num_head, self.max_kv_splits, self.v_head_dim),
@@ -399,9 +442,17 @@ class TritonAttnBackend(AttentionBackend):
399
442
  dtype=torch.float32,
400
443
  device=self.device,
401
444
  )
402
- self.cuda_graph_num_kv_splits = torch.full(
403
- (max_num_tokens,), self.max_kv_splits, dtype=torch.int32, device=self.device
404
- )
445
+
446
+ if cuda_graph_num_kv_splits_buf is None:
447
+ self.cuda_graph_num_kv_splits = torch.full(
448
+ (max_num_tokens,),
449
+ self.max_kv_splits,
450
+ dtype=torch.int32,
451
+ device=self.device,
452
+ )
453
+ else:
454
+ self.cuda_graph_num_kv_splits = cuda_graph_num_kv_splits_buf
455
+
405
456
  if kv_indices_buf is None:
406
457
  self.cuda_graph_kv_indices = torch.zeros(
407
458
  (max_num_tokens * self.max_context_len),
@@ -449,7 +500,7 @@ class TritonAttnBackend(AttentionBackend):
449
500
  seq_lens: torch.Tensor,
450
501
  encoder_lens: Optional[torch.Tensor],
451
502
  forward_mode: ForwardMode,
452
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
503
+ spec_info: Optional[SpecInput],
453
504
  ):
454
505
  assert encoder_lens is None, "Not supported"
455
506
  window_kv_indptr = self.window_kv_indptr
@@ -605,7 +656,7 @@ class TritonAttnBackend(AttentionBackend):
605
656
  seq_lens_sum: int,
606
657
  encoder_lens: Optional[torch.Tensor],
607
658
  forward_mode: ForwardMode,
608
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
659
+ spec_info: Optional[SpecInput],
609
660
  seq_lens_cpu: Optional[torch.Tensor],
610
661
  ):
611
662
  # NOTE: encoder_lens expected to be zeros or None
@@ -648,9 +699,7 @@ class TritonAttnBackend(AttentionBackend):
648
699
  )
649
700
 
650
701
  else:
651
- kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr
652
- kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices
653
- num_token = spec_info.kv_indptr.shape[0] - 1
702
+ assert False, "Multi-step cuda graph init is not done here."
654
703
  self.get_num_kv_splits(num_kv_splits[:num_token], seq_lens[:bs])
655
704
 
656
705
  elif forward_mode.is_target_verify():
@@ -722,6 +771,19 @@ class TritonAttnBackend(AttentionBackend):
722
771
  def get_cuda_graph_seq_len_fill_value(self):
723
772
  return 1
724
773
 
774
+ def get_verify_buffers_to_fill_after_draft(self):
775
+ """
776
+ Return buffers for verify attention kernels that needs to be filled after draft.
777
+
778
+ Typically, these are tree mask and position buffers.
779
+ """
780
+ return [self.cuda_graph_custom_mask, None]
781
+
782
+ def update_verify_buffers_to_fill_after_draft(
783
+ self, spec_info: SpecInput, cuda_graph_bs: Optional[int]
784
+ ):
785
+ pass
786
+
725
787
  def forward_extend(
726
788
  self,
727
789
  q: torch.Tensor,
@@ -738,6 +800,7 @@ class TritonAttnBackend(AttentionBackend):
738
800
  else:
739
801
  o = torch.empty_like(q)
740
802
 
803
+ # Save KV cache first (must do this before unified kernel)
741
804
  if save_kv_cache:
742
805
  forward_batch.token_to_kv_pool.set_kv_buffer(
743
806
  layer, forward_batch.out_cache_loc, k, v
@@ -746,9 +809,16 @@ class TritonAttnBackend(AttentionBackend):
746
809
  logits_soft_cap = logit_capping_mod(layer.logit_capping_method, layer.logit_cap)
747
810
 
748
811
  causal = True
749
- if layer.attn_type == AttentionType.ENCODER_ONLY:
812
+ if layer.is_cross_attention or layer.attn_type == AttentionType.ENCODER_ONLY:
750
813
  causal = False
751
814
 
815
+ # Deterministic mode: use unified 1-stage kernel
816
+ if self.enable_deterministic:
817
+ return self._forward_extend_unified(
818
+ q, o, layer, forward_batch, causal, logits_soft_cap, sinks
819
+ )
820
+
821
+ # Normal mode: use original 2-stage kernel
752
822
  if layer.sliding_window_size is not None and layer.sliding_window_size > -1:
753
823
  sliding_window_size = (
754
824
  layer.sliding_window_size
@@ -785,6 +855,127 @@ class TritonAttnBackend(AttentionBackend):
785
855
  )
786
856
  return o
787
857
 
858
+ def _forward_extend_unified(
859
+ self,
860
+ q: torch.Tensor,
861
+ o: torch.Tensor,
862
+ layer: RadixAttention,
863
+ forward_batch: ForwardBatch,
864
+ causal: bool,
865
+ logits_soft_cap: float,
866
+ sinks: Optional[torch.Tensor],
867
+ ):
868
+ """
869
+ Unified 1-stage extend attention for deterministic inference.
870
+ Both prefix and extend KV are accessed through unified kv_indices.
871
+ """
872
+ bs = forward_batch.batch_size
873
+
874
+ # Determine sliding window settings
875
+ if layer.sliding_window_size is not None and layer.sliding_window_size > -1:
876
+ sliding_window_size = layer.sliding_window_size
877
+ # Note: for unified kernel, we use full kv_indptr (not window)
878
+ prefix_kv_indptr = self.forward_metadata.window_kv_indptr
879
+ prefix_kv_indices = self.forward_metadata.window_kv_indices
880
+ # Compute window start positions (absolute position of first key in window)
881
+ # window_start_pos = seq_len - window_len
882
+ window_kv_lens = prefix_kv_indptr[1 : bs + 1] - prefix_kv_indptr[:bs]
883
+ # Handle TARGET_VERIFY mode where extend_prefix_lens might not be set
884
+ if forward_batch.extend_prefix_lens is not None:
885
+ window_start_pos = (
886
+ forward_batch.extend_prefix_lens[:bs] - window_kv_lens
887
+ )
888
+ else:
889
+ # Infer from spec_info: prefix_len = seq_len - draft_token_num
890
+ if forward_batch.spec_info is not None and hasattr(
891
+ forward_batch.spec_info, "draft_token_num"
892
+ ):
893
+ extend_prefix_lens = (
894
+ forward_batch.seq_lens[:bs]
895
+ - forward_batch.spec_info.draft_token_num
896
+ )
897
+ window_start_pos = extend_prefix_lens - window_kv_lens
898
+ else:
899
+ window_start_pos = None
900
+ else:
901
+ sliding_window_size = -1
902
+ prefix_kv_indptr = self.forward_metadata.kv_indptr
903
+ prefix_kv_indices = self.forward_metadata.kv_indices
904
+ window_start_pos = None
905
+
906
+ # Build unified kv_indices using fused Triton kernel
907
+ extend_kv_indices = forward_batch.out_cache_loc
908
+
909
+ # Handle cases where extend_seq_lens or extend_start_loc might not be set
910
+ # In speculative decoding, we can infer these from spec_info or compute them
911
+ if forward_batch.extend_seq_lens is None:
912
+ # TARGET_VERIFY mode: infer extend_seq_lens from spec_info
913
+ if forward_batch.spec_info is not None and hasattr(
914
+ forward_batch.spec_info, "draft_token_num"
915
+ ):
916
+ draft_token_num = forward_batch.spec_info.draft_token_num
917
+ extend_seq_lens = torch.full(
918
+ (bs,), draft_token_num, dtype=torch.int32, device=self.device
919
+ )
920
+ else:
921
+ raise RuntimeError(
922
+ "extend_seq_lens is None but cannot infer from spec_info. "
923
+ "This should not happen in TARGET_VERIFY mode."
924
+ )
925
+ else:
926
+ extend_seq_lens = forward_batch.extend_seq_lens
927
+
928
+ # Check extend_start_loc separately - it might be None even when extend_seq_lens is set
929
+ if forward_batch.extend_start_loc is None:
930
+ # Compute extend_start_loc from extend_seq_lens
931
+ # extend_start_loc[i] = sum(extend_seq_lens[0:i])
932
+ extend_start_loc = torch.cat(
933
+ [
934
+ torch.zeros(1, dtype=torch.int32, device=self.device),
935
+ torch.cumsum(extend_seq_lens[:-1], dim=0),
936
+ ]
937
+ )
938
+ else:
939
+ extend_start_loc = forward_batch.extend_start_loc
940
+
941
+ unified_kv_indptr, unified_kv_indices, prefix_lens = (
942
+ self.build_unified_kv_indices(
943
+ prefix_kv_indptr,
944
+ prefix_kv_indices,
945
+ extend_start_loc,
946
+ extend_seq_lens,
947
+ extend_kv_indices,
948
+ bs,
949
+ )
950
+ )
951
+
952
+ # Convert prefix_lens to int32 for the kernel
953
+ prefix_lens = prefix_lens.to(torch.int32)
954
+
955
+ # Call unified kernel
956
+ self.extend_attention_fwd_unified(
957
+ q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
958
+ o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
959
+ forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
960
+ forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
961
+ self.forward_metadata.qo_indptr,
962
+ unified_kv_indptr,
963
+ unified_kv_indices,
964
+ prefix_lens,
965
+ self.forward_metadata.max_extend_len,
966
+ custom_mask=self.forward_metadata.custom_mask,
967
+ mask_indptr=self.forward_metadata.mask_indptr,
968
+ sm_scale=layer.scaling,
969
+ logit_cap=logits_soft_cap,
970
+ is_causal=causal,
971
+ sliding_window_size=sliding_window_size,
972
+ sinks=sinks,
973
+ window_start_pos=window_start_pos,
974
+ xai_temperature_len=layer.xai_temperature_len,
975
+ )
976
+
977
+ return o
978
+
788
979
  def forward_decode(
789
980
  self,
790
981
  q: torch.Tensor,
@@ -850,11 +1041,8 @@ class TritonMultiStepDraftBackend:
850
1041
  topk: int,
851
1042
  speculative_num_steps: int,
852
1043
  ):
853
- from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
854
-
855
1044
  self.topk = topk
856
1045
  self.speculative_num_steps = speculative_num_steps
857
- self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
858
1046
  max_bs = model_runner.req_to_token_pool.size * self.topk
859
1047
  self.kv_indptr = torch.zeros(
860
1048
  (
@@ -864,8 +1052,8 @@ class TritonMultiStepDraftBackend:
864
1052
  dtype=torch.int32,
865
1053
  device=model_runner.device,
866
1054
  )
867
- self.attn_backends = []
868
- for i in range(self.speculative_num_steps):
1055
+ self.attn_backends: List[TritonAttnBackend] = []
1056
+ for i in range(self.speculative_num_steps - 1):
869
1057
  self.attn_backends.append(
870
1058
  TritonAttnBackend(
871
1059
  model_runner,
@@ -883,13 +1071,19 @@ class TritonMultiStepDraftBackend:
883
1071
  self.page_size = model_runner.server_args.page_size
884
1072
 
885
1073
  def common_template(
886
- self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int
1074
+ self,
1075
+ forward_batch: ForwardBatch,
1076
+ kv_indices_buffer: Optional[torch.Tensor],
1077
+ call_fn: int,
887
1078
  ):
1079
+ if kv_indices_buffer is None:
1080
+ kv_indices_buffer = self.cuda_graph_kv_indices
1081
+
888
1082
  num_seqs = forward_batch.batch_size
889
1083
  bs = self.topk * num_seqs
890
1084
  seq_lens_sum = forward_batch.seq_lens_sum
891
1085
 
892
- self.generate_draft_decode_kv_indices[
1086
+ generate_draft_decode_kv_indices[
893
1087
  (self.speculative_num_steps, num_seqs, self.topk)
894
1088
  ](
895
1089
  forward_batch.req_pool_indices,
@@ -907,7 +1101,10 @@ class TritonMultiStepDraftBackend:
907
1101
  self.page_size,
908
1102
  )
909
1103
 
910
- for i in range(self.speculative_num_steps):
1104
+ if call_fn is None:
1105
+ return
1106
+
1107
+ for i in range(self.speculative_num_steps - 1):
911
1108
  forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
912
1109
  forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
913
1110
  : seq_lens_sum * self.topk + bs * (i + 1)
@@ -941,9 +1138,19 @@ class TritonMultiStepDraftBackend:
941
1138
  dtype=torch.int64,
942
1139
  device=self.device,
943
1140
  )
944
- for i in range(self.speculative_num_steps):
1141
+ self.cuda_graph_num_kv_splits = torch.full(
1142
+ (max_num_tokens,),
1143
+ self.attn_backends[0].max_kv_splits,
1144
+ dtype=torch.int32,
1145
+ device=self.device,
1146
+ )
1147
+
1148
+ for i in range(self.speculative_num_steps - 1):
945
1149
  self.attn_backends[i].init_cuda_graph_state(
946
- max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
1150
+ max_bs,
1151
+ max_num_tokens,
1152
+ kv_indices_buf=self.cuda_graph_kv_indices[i],
1153
+ cuda_graph_num_kv_splits_buf=self.cuda_graph_num_kv_splits,
947
1154
  )
948
1155
 
949
1156
  def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
@@ -958,24 +1165,24 @@ class TritonMultiStepDraftBackend:
958
1165
  spec_info=forward_batch.spec_info,
959
1166
  )
960
1167
 
961
- self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
1168
+ self.common_template(forward_batch, None, call_fn)
962
1169
 
963
1170
  def init_forward_metadata_replay_cuda_graph(
964
1171
  self, forward_batch: ForwardBatch, bs: int
965
1172
  ):
966
- def call_fn(i, forward_batch):
967
- self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
968
- bs,
969
- forward_batch.req_pool_indices,
970
- forward_batch.seq_lens,
971
- seq_lens_sum=-1,
972
- encoder_lens=None,
973
- forward_mode=ForwardMode.DECODE,
974
- spec_info=forward_batch.spec_info,
975
- seq_lens_cpu=None,
976
- )
977
-
978
- self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
1173
+ self.common_template(forward_batch, None, None)
1174
+
1175
+ # NOTE: Multi-step's attention backends use the slice of
1176
+ # - kv_indptr buffer (cuda graph and non-cuda graph)
1177
+ # - kv_indices buffer (cuda graph only)
1178
+ # So we don't need to assign the KV indices inside the attention backend.
1179
+
1180
+ # Compute num_kv_splits only once
1181
+ num_token = forward_batch.batch_size * self.topk
1182
+ self.attn_backends[-1].get_num_kv_splits(
1183
+ self.attn_backends[-1].cuda_graph_num_kv_splits[:num_token],
1184
+ forward_batch.seq_lens[:bs],
1185
+ )
979
1186
 
980
1187
 
981
1188
  @triton.jit
@@ -2,7 +2,7 @@ import torch
2
2
  import triton
3
3
  import triton.language as tl
4
4
 
5
- from sglang.srt.managers.schedule_batch import global_server_args_dict
5
+ from sglang.srt.server_args import get_global_server_args
6
6
  from sglang.srt.utils import is_cuda, is_hip
7
7
 
8
8
  _is_cuda = is_cuda()
@@ -11,7 +11,7 @@ if _is_cuda:
11
11
 
12
12
  _is_hip = is_hip()
13
13
 
14
- if global_server_args_dict.get("attention_reduce_in_fp32", False):
14
+ if get_global_server_args().triton_attention_reduce_in_fp32:
15
15
  REDUCE_TRITON_TYPE = tl.float32
16
16
  REDUCE_TORCH_TYPE = torch.float32
17
17
  else: