sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -9,28 +9,21 @@ and uses BatchMLAPaged wrapper for decoding.
9
9
  More details can be found in https://docs.flashinfer.ai/api/mla.html
10
10
  """
11
11
 
12
- import os
13
12
  from dataclasses import dataclass
14
13
  from functools import partial
15
14
  from typing import TYPE_CHECKING, Callable, Optional, Union
16
15
 
17
16
  import torch
18
17
 
19
- if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
20
- import logging
21
-
22
- torch._logging.set_logs(dynamo=logging.ERROR)
23
- torch._dynamo.config.suppress_errors = True
24
-
25
- from sglang.global_config import global_config
18
+ from sglang.srt.environ import envs
26
19
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
27
20
  from sglang.srt.layers.attention.flashinfer_backend import (
28
21
  create_flashinfer_kv_indices_triton,
29
22
  )
30
23
  from sglang.srt.layers.dp_attention import get_attention_tp_size
31
- from sglang.srt.managers.schedule_batch import global_server_args_dict
32
24
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
33
- from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
25
+ from sglang.srt.server_args import get_global_server_args
26
+ from sglang.srt.speculative.spec_info import SpecInput
34
27
  from sglang.srt.utils import (
35
28
  is_flashinfer_available,
36
29
  is_sm100_supported,
@@ -38,9 +31,18 @@ from sglang.srt.utils import (
38
31
  )
39
32
 
40
33
  if TYPE_CHECKING:
34
+ from sglang.srt.layers.attention.flashinfer_mla_backend import (
35
+ FlashInferMlaAttnBackend,
36
+ )
41
37
  from sglang.srt.layers.radix_attention import RadixAttention
42
38
  from sglang.srt.model_executor.model_runner import ModelRunner
43
- from sglang.srt.speculative.spec_info import SpecInfo
39
+ from sglang.srt.speculative.spec_info import SpecInput
40
+
41
+ if envs.SGLANG_ENABLE_TORCH_COMPILE.get():
42
+ import logging
43
+
44
+ torch._logging.set_logs(dynamo=logging.ERROR)
45
+ torch._dynamo.config.suppress_errors = True
44
46
 
45
47
  if is_flashinfer_available():
46
48
  from flashinfer import (
@@ -66,7 +68,7 @@ global_workspace_buffer = None
66
68
 
67
69
  class FlashInferMhaChunkKVRunner:
68
70
  def __init__(
69
- self, model_runner: ModelRunner, attn_backend: "FlashInferMlaAttnBackend"
71
+ self, model_runner: ModelRunner, attn_backend: FlashInferMlaAttnBackend
70
72
  ):
71
73
  # Parse Constants
72
74
  self.num_local_heads = (
@@ -193,9 +195,9 @@ class FlashInferMLAAttnBackend(AttentionBackend):
193
195
  self.skip_prefill = skip_prefill
194
196
  self.enable_chunk_kv = (
195
197
  not skip_prefill
196
- and global_server_args_dict["disaggregation_mode"] != "decode"
197
- and not global_server_args_dict["disable_chunked_prefix_cache"]
198
- and not global_server_args_dict["flashinfer_mla_disable_ragged"]
198
+ and get_global_server_args().disaggregation_mode != "decode"
199
+ and not get_global_server_args().disable_chunked_prefix_cache
200
+ and not get_global_server_args().flashinfer_mla_disable_ragged
199
201
  )
200
202
  self.page_size = model_runner.page_size
201
203
 
@@ -204,7 +206,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
204
206
  if global_workspace_buffer is None:
205
207
  # different from flashinfer zero_init_global_workspace_buffer
206
208
  global_workspace_buffer = torch.empty(
207
- global_config.flashinfer_workspace_size,
209
+ envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.get(),
208
210
  dtype=torch.uint8,
209
211
  device=model_runner.device,
210
212
  )
@@ -306,7 +308,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
306
308
  prefix_lens = forward_batch.extend_prefix_lens
307
309
  extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
308
310
  use_ragged = (
309
- not global_server_args_dict["flashinfer_mla_disable_ragged"]
311
+ not get_global_server_args().flashinfer_mla_disable_ragged
310
312
  and extend_no_prefix
311
313
  )
312
314
 
@@ -361,7 +363,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
361
363
  seq_lens: torch.Tensor,
362
364
  encoder_lens: Optional[torch.Tensor],
363
365
  forward_mode: ForwardMode,
364
- spec_info: Optional[SpecInfo],
366
+ spec_info: Optional[SpecInput],
365
367
  ):
366
368
  if forward_mode.is_decode_or_idle():
367
369
  decode_wrapper = BatchMLAPagedAttentionWrapper(
@@ -441,7 +443,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
441
443
  seq_lens_sum: int,
442
444
  encoder_lens: Optional[torch.Tensor],
443
445
  forward_mode: ForwardMode,
444
- spec_info: Optional[SpecInfo],
446
+ spec_info: Optional[SpecInput],
445
447
  seq_lens_cpu: Optional[torch.Tensor],
446
448
  ):
447
449
  if forward_mode.is_decode_or_idle():
@@ -663,7 +665,7 @@ class FlashInferMLAIndicesUpdaterDecode:
663
665
  seq_lens_sum: int,
664
666
  decode_wrapper: BatchMLAPagedAttentionWrapper,
665
667
  init_metadata_replay: bool = False,
666
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
668
+ spec_info: Optional[SpecInput] = None,
667
669
  **fast_decode_kwargs,
668
670
  ):
669
671
  decode_wrapper = decode_wrapper or self.decode_wrapper
@@ -688,7 +690,7 @@ class FlashInferMLAIndicesUpdaterDecode:
688
690
  q_indptr: torch.Tensor,
689
691
  kv_indptr: torch.Tensor,
690
692
  init_metadata_replay: bool = False,
691
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
693
+ spec_info: Optional[SpecInput] = None,
692
694
  **fast_decode_kwargs,
693
695
  ):
694
696
  bs = len(req_pool_indices)
@@ -776,7 +778,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
776
778
  prefix_lens: torch.Tensor,
777
779
  prefill_wrapper_paged: BatchMLAPagedAttentionWrapper,
778
780
  use_ragged: bool,
779
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
781
+ spec_info: Optional[SpecInput] = None,
780
782
  ):
781
783
  if use_ragged:
782
784
  paged_kernel_lens = prefix_lens
@@ -811,7 +813,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
811
813
  kv_indptr: torch.Tensor,
812
814
  qo_indptr: torch.Tensor,
813
815
  use_ragged: bool,
814
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
816
+ spec_info: Optional[SpecInput] = None,
815
817
  ):
816
818
  bs = len(seq_lens)
817
819
  sm_scale = self.scaling
@@ -838,9 +840,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
838
840
  qo_indptr = qo_indptr[: bs + 1]
839
841
  custom_mask = None
840
842
  else:
841
- assert isinstance(spec_info, EagleDraftInput) or isinstance(
842
- spec_info, EagleVerifyInput
843
- )
843
+ assert isinstance(spec_info, SpecInput)
844
844
  # TODO: Support topk > 1 with custom mask
845
845
  kv_indices, kv_indptr, qo_indptr, custom_mask = (
846
846
  spec_info.generate_attn_arg_prefill(
@@ -894,7 +894,7 @@ class FlashInferMLAMultiStepDraftBackend:
894
894
  topk: int,
895
895
  speculative_num_steps: int,
896
896
  ):
897
- from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
897
+ from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
898
898
 
899
899
  if topk > 1:
900
900
  raise ValueError(
@@ -918,7 +918,7 @@ class FlashInferMLAMultiStepDraftBackend:
918
918
  )
919
919
 
920
920
  self.attn_backends = []
921
- for i in range(self.speculative_num_steps):
921
+ for i in range(self.speculative_num_steps - 1):
922
922
  self.attn_backends.append(
923
923
  FlashInferMLAAttnBackend(
924
924
  model_runner,
@@ -963,7 +963,7 @@ class FlashInferMLAMultiStepDraftBackend:
963
963
  )
964
964
 
965
965
  assert forward_batch.spec_info is not None
966
- assert isinstance(forward_batch.spec_info, EagleDraftInput)
966
+ assert forward_batch.spec_info.is_draft_input()
967
967
 
968
968
  for i in range(self.speculative_num_steps - 1):
969
969
  forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
@@ -983,8 +983,6 @@ class FlashInferMLAMultiStepDraftBackend:
983
983
  )
984
984
 
985
985
  def call_fn(i, forward_batch):
986
- assert forward_batch.spec_info is not None
987
- assert isinstance(forward_batch.spec_info, EagleDraftInput)
988
986
  forward_batch.spec_info.kv_indptr = (
989
987
  forward_batch.spec_info.kv_indptr.clone()
990
988
  )
@@ -1002,7 +1000,7 @@ class FlashInferMLAMultiStepDraftBackend:
1002
1000
  device="cuda",
1003
1001
  )
1004
1002
 
1005
- for i in range(self.speculative_num_steps):
1003
+ for i in range(self.speculative_num_steps - 1):
1006
1004
  self.attn_backends[i].init_cuda_graph_state(
1007
1005
  max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
1008
1006
  )
@@ -1064,7 +1062,7 @@ def fast_mla_decode_plan(
1064
1062
 
1065
1063
  try:
1066
1064
  # Standard version with just the required arguments (no use_profiler)
1067
- self._cached_module.plan.default(
1065
+ self._cached_module.plan(
1068
1066
  self._float_workspace_buffer,
1069
1067
  self._int_workspace_buffer,
1070
1068
  self._pin_memory_int_workspace_buffer,
@@ -19,7 +19,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo
19
19
  if TYPE_CHECKING:
20
20
  from sglang.srt.layers.radix_attention import RadixAttention
21
21
  from sglang.srt.model_executor.model_runner import ModelRunner
22
- from sglang.srt.speculative.spec_info import SpecInfo
22
+ from sglang.srt.speculative.spec_info import SpecInput
23
23
 
24
24
 
25
25
  # FlashMLA only supports pagesize=64
@@ -187,7 +187,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
187
187
  seq_lens: torch.Tensor,
188
188
  encoder_lens: Optional[torch.Tensor],
189
189
  forward_mode: ForwardMode,
190
- spec_info: Optional[SpecInfo],
190
+ spec_info: Optional[SpecInput],
191
191
  ):
192
192
  if forward_mode.is_decode_or_idle():
193
193
  max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
@@ -201,9 +201,10 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
201
201
  self.req_to_token.stride(0),
202
202
  self.cuda_graph_kv_indices.stride(0),
203
203
  )
204
+ num_q_heads = self.num_q_heads * (self.num_draft_tokens or 1)
204
205
  mla_metadata, num_splits = get_mla_metadata(
205
206
  seq_lens.to(torch.int32),
206
- self.num_q_heads,
207
+ num_q_heads,
207
208
  1,
208
209
  )
209
210
  self.cuda_graph_mla_metadata.copy_(mla_metadata)
@@ -257,7 +258,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
257
258
  seq_lens_sum: int,
258
259
  encoder_lens: Optional[torch.Tensor],
259
260
  forward_mode: ForwardMode,
260
- spec_info: Optional[SpecInfo],
261
+ spec_info: Optional[SpecInput],
261
262
  seq_lens_cpu: Optional[torch.Tensor],
262
263
  ):
263
264
 
@@ -275,9 +276,10 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
275
276
  self.req_to_token.stride(0),
276
277
  self.cuda_graph_kv_indices.stride(0),
277
278
  )
279
+ num_q_heads = self.num_q_heads * (self.num_draft_tokens or 1)
278
280
  mla_metadata, num_splits = get_mla_metadata(
279
281
  seq_lens.to(torch.int32),
280
- self.num_q_heads,
282
+ num_q_heads,
281
283
  1,
282
284
  )
283
285
  self.cuda_graph_mla_metadata.copy_(mla_metadata)
@@ -476,7 +478,7 @@ class FlashMLAMultiStepDraftBackend:
476
478
  )
477
479
 
478
480
  self.attn_backends = []
479
- for i in range(self.speculative_num_steps):
481
+ for i in range(self.speculative_num_steps - 1):
480
482
  self.attn_backends.append(
481
483
  FlashMLABackend(
482
484
  model_runner,
@@ -504,7 +506,7 @@ class FlashMLAMultiStepDraftBackend:
504
506
  self.common_template(forward_batch, call_fn)
505
507
 
506
508
  def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
507
- for i in range(self.speculative_num_steps):
509
+ for i in range(self.speculative_num_steps - 1):
508
510
  self.attn_backends[i].init_cuda_graph_state(
509
511
  max_bs, max_num_tokens, block_kv_indices=None
510
512
  )
@@ -1,12 +1,13 @@
1
- from typing import Optional, Union
1
+ from typing import Optional
2
2
 
3
3
  import torch
4
4
 
5
5
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
6
+ from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata
6
7
  from sglang.srt.layers.radix_attention import RadixAttention
7
8
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
8
9
  from sglang.srt.model_executor.model_runner import ModelRunner
9
- from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
10
+ from sglang.srt.speculative.spec_info import SpecInput
10
11
 
11
12
 
12
13
  class HybridAttnBackend(AttentionBackend):
@@ -21,6 +22,7 @@ class HybridAttnBackend(AttentionBackend):
21
22
  self.model_runner = model_runner
22
23
  self.prefill_backend = prefill_backend
23
24
  self.decode_backend = decode_backend
25
+ self.data_type = model_runner.kv_cache_dtype
24
26
 
25
27
  def _select_backend(self, forward_mode: ForwardMode) -> AttentionBackend:
26
28
  """
@@ -70,7 +72,7 @@ class HybridAttnBackend(AttentionBackend):
70
72
  seq_lens: torch.Tensor,
71
73
  encoder_lens: Optional[torch.Tensor],
72
74
  forward_mode: ForwardMode,
73
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
75
+ spec_info: Optional[SpecInput],
74
76
  ):
75
77
  backend = self._select_backend(forward_mode)
76
78
  backend.init_forward_metadata_capture_cuda_graph(
@@ -91,7 +93,7 @@ class HybridAttnBackend(AttentionBackend):
91
93
  seq_lens_sum: int,
92
94
  encoder_lens: Optional[torch.Tensor],
93
95
  forward_mode: ForwardMode,
94
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
96
+ spec_info: Optional[SpecInput],
95
97
  seq_lens_cpu: Optional[torch.Tensor],
96
98
  ):
97
99
  backend = self._select_backend(forward_mode)
@@ -137,3 +139,9 @@ class HybridAttnBackend(AttentionBackend):
137
139
  return backend.forward_extend(
138
140
  q, k, v, layer, forward_batch, save_kv_cache, **kwargs
139
141
  )
142
+
143
+ def get_indexer_metadata(
144
+ self, layer_id: int, forward_batch: ForwardBatch
145
+ ) -> Optional[BaseIndexerMetadata]:
146
+ backend = self._select_backend(forward_batch.forward_mode)
147
+ return backend.get_indexer_metadata(layer_id, forward_batch)