sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (408) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +330 -156
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +8 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +134 -23
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +70 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +66 -66
  69. sglang/srt/entrypoints/grpc_server.py +431 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +120 -8
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +42 -4
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +3 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +18 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/utils.py +2 -2
  93. sglang/srt/grpc/compile_proto.py +3 -3
  94. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  95. sglang/srt/grpc/health_servicer.py +189 -0
  96. sglang/srt/grpc/scheduler_launcher.py +181 -0
  97. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  98. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  99. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  100. sglang/srt/layers/activation.py +4 -1
  101. sglang/srt/layers/attention/aiter_backend.py +3 -3
  102. sglang/srt/layers/attention/ascend_backend.py +17 -1
  103. sglang/srt/layers/attention/attention_registry.py +43 -23
  104. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  105. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  106. sglang/srt/layers/attention/fla/chunk.py +0 -1
  107. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  108. sglang/srt/layers/attention/fla/index.py +0 -2
  109. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  110. sglang/srt/layers/attention/fla/utils.py +0 -3
  111. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  112. sglang/srt/layers/attention/flashattention_backend.py +12 -8
  113. sglang/srt/layers/attention/flashinfer_backend.py +248 -21
  114. sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
  115. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  116. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  117. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  118. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  119. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  121. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  122. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  123. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  124. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  125. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  127. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  128. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  129. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  130. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  131. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  132. sglang/srt/layers/attention/nsa/utils.py +0 -1
  133. sglang/srt/layers/attention/nsa_backend.py +404 -90
  134. sglang/srt/layers/attention/triton_backend.py +208 -34
  135. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  136. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  137. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  138. sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
  139. sglang/srt/layers/attention/utils.py +11 -7
  140. sglang/srt/layers/attention/vision.py +3 -3
  141. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  142. sglang/srt/layers/communicator.py +11 -7
  143. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  146. sglang/srt/layers/dp_attention.py +17 -0
  147. sglang/srt/layers/layernorm.py +45 -15
  148. sglang/srt/layers/linear.py +9 -1
  149. sglang/srt/layers/logits_processor.py +147 -17
  150. sglang/srt/layers/modelopt_utils.py +11 -0
  151. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  152. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  153. sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
  154. sglang/srt/layers/moe/ep_moe/layer.py +119 -397
  155. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  159. sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
  160. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  161. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  162. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  163. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  164. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  165. sglang/srt/layers/moe/router.py +51 -15
  166. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  167. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  168. sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
  169. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  170. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  171. sglang/srt/layers/moe/topk.py +3 -2
  172. sglang/srt/layers/moe/utils.py +17 -1
  173. sglang/srt/layers/quantization/__init__.py +2 -53
  174. sglang/srt/layers/quantization/awq.py +183 -6
  175. sglang/srt/layers/quantization/awq_triton.py +29 -0
  176. sglang/srt/layers/quantization/base_config.py +20 -1
  177. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  178. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  179. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  180. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  181. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  183. sglang/srt/layers/quantization/fp8.py +84 -18
  184. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  185. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  186. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  187. sglang/srt/layers/quantization/gptq.py +0 -1
  188. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  189. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  190. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  191. sglang/srt/layers/quantization/mxfp4.py +5 -30
  192. sglang/srt/layers/quantization/petit.py +1 -1
  193. sglang/srt/layers/quantization/quark/quark.py +3 -1
  194. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  195. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  196. sglang/srt/layers/quantization/unquant.py +1 -4
  197. sglang/srt/layers/quantization/utils.py +0 -1
  198. sglang/srt/layers/quantization/w4afp8.py +51 -20
  199. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  200. sglang/srt/layers/radix_attention.py +59 -9
  201. sglang/srt/layers/rotary_embedding.py +673 -16
  202. sglang/srt/layers/sampler.py +36 -16
  203. sglang/srt/layers/sparse_pooler.py +98 -0
  204. sglang/srt/layers/utils.py +0 -1
  205. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  206. sglang/srt/lora/backend/triton_backend.py +0 -1
  207. sglang/srt/lora/eviction_policy.py +139 -0
  208. sglang/srt/lora/lora_manager.py +24 -9
  209. sglang/srt/lora/lora_registry.py +1 -1
  210. sglang/srt/lora/mem_pool.py +40 -16
  211. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  212. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  213. sglang/srt/managers/cache_controller.py +48 -17
  214. sglang/srt/managers/data_parallel_controller.py +146 -42
  215. sglang/srt/managers/detokenizer_manager.py +40 -13
  216. sglang/srt/managers/io_struct.py +66 -16
  217. sglang/srt/managers/mm_utils.py +20 -18
  218. sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
  219. sglang/srt/managers/overlap_utils.py +96 -19
  220. sglang/srt/managers/schedule_batch.py +241 -511
  221. sglang/srt/managers/schedule_policy.py +15 -2
  222. sglang/srt/managers/scheduler.py +399 -499
  223. sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
  224. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  225. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  226. sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
  227. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  228. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  229. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  230. sglang/srt/managers/tokenizer_manager.py +378 -90
  231. sglang/srt/managers/tp_worker.py +212 -161
  232. sglang/srt/managers/utils.py +78 -2
  233. sglang/srt/mem_cache/allocator.py +7 -2
  234. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  235. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  236. sglang/srt/mem_cache/chunk_cache.py +13 -2
  237. sglang/srt/mem_cache/common.py +480 -0
  238. sglang/srt/mem_cache/evict_policy.py +16 -1
  239. sglang/srt/mem_cache/hicache_storage.py +4 -1
  240. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  241. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  242. sglang/srt/mem_cache/memory_pool.py +435 -219
  243. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  244. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  245. sglang/srt/mem_cache/radix_cache.py +53 -19
  246. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  247. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  249. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  250. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  251. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  252. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  253. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  254. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  255. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  256. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  257. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  258. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  259. sglang/srt/metrics/collector.py +31 -0
  260. sglang/srt/metrics/func_timer.py +1 -1
  261. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  262. sglang/srt/model_executor/forward_batch_info.py +28 -23
  263. sglang/srt/model_executor/model_runner.py +379 -139
  264. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  265. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  266. sglang/srt/model_loader/__init__.py +1 -1
  267. sglang/srt/model_loader/loader.py +424 -27
  268. sglang/srt/model_loader/utils.py +0 -1
  269. sglang/srt/model_loader/weight_utils.py +47 -28
  270. sglang/srt/models/apertus.py +2 -3
  271. sglang/srt/models/arcee.py +2 -2
  272. sglang/srt/models/bailing_moe.py +13 -52
  273. sglang/srt/models/bailing_moe_nextn.py +3 -4
  274. sglang/srt/models/bert.py +1 -1
  275. sglang/srt/models/deepseek_nextn.py +19 -3
  276. sglang/srt/models/deepseek_ocr.py +1516 -0
  277. sglang/srt/models/deepseek_v2.py +273 -98
  278. sglang/srt/models/dots_ocr.py +0 -2
  279. sglang/srt/models/dots_vlm.py +0 -1
  280. sglang/srt/models/dots_vlm_vit.py +1 -1
  281. sglang/srt/models/falcon_h1.py +13 -19
  282. sglang/srt/models/gemma3_mm.py +16 -0
  283. sglang/srt/models/gemma3n_mm.py +1 -2
  284. sglang/srt/models/glm4_moe.py +14 -37
  285. sglang/srt/models/glm4_moe_nextn.py +2 -2
  286. sglang/srt/models/glm4v.py +2 -1
  287. sglang/srt/models/glm4v_moe.py +5 -5
  288. sglang/srt/models/gpt_oss.py +5 -5
  289. sglang/srt/models/grok.py +10 -23
  290. sglang/srt/models/hunyuan.py +2 -7
  291. sglang/srt/models/interns1.py +0 -1
  292. sglang/srt/models/kimi_vl.py +1 -7
  293. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  294. sglang/srt/models/llama.py +2 -2
  295. sglang/srt/models/llama_eagle3.py +1 -1
  296. sglang/srt/models/longcat_flash.py +5 -22
  297. sglang/srt/models/longcat_flash_nextn.py +3 -14
  298. sglang/srt/models/mimo.py +2 -13
  299. sglang/srt/models/mimo_mtp.py +1 -2
  300. sglang/srt/models/minicpmo.py +7 -5
  301. sglang/srt/models/mixtral.py +1 -4
  302. sglang/srt/models/mllama.py +1 -1
  303. sglang/srt/models/mllama4.py +13 -3
  304. sglang/srt/models/nemotron_h.py +511 -0
  305. sglang/srt/models/olmo2.py +31 -4
  306. sglang/srt/models/opt.py +5 -5
  307. sglang/srt/models/phi.py +1 -1
  308. sglang/srt/models/phi4mm.py +1 -1
  309. sglang/srt/models/phimoe.py +0 -1
  310. sglang/srt/models/pixtral.py +0 -3
  311. sglang/srt/models/points_v15_chat.py +186 -0
  312. sglang/srt/models/qwen.py +0 -1
  313. sglang/srt/models/qwen2_5_vl.py +3 -3
  314. sglang/srt/models/qwen2_audio.py +2 -15
  315. sglang/srt/models/qwen2_moe.py +15 -12
  316. sglang/srt/models/qwen2_vl.py +5 -2
  317. sglang/srt/models/qwen3_moe.py +19 -35
  318. sglang/srt/models/qwen3_next.py +7 -12
  319. sglang/srt/models/qwen3_next_mtp.py +3 -4
  320. sglang/srt/models/qwen3_omni_moe.py +661 -0
  321. sglang/srt/models/qwen3_vl.py +37 -33
  322. sglang/srt/models/qwen3_vl_moe.py +57 -185
  323. sglang/srt/models/roberta.py +55 -3
  324. sglang/srt/models/sarashina2_vision.py +0 -1
  325. sglang/srt/models/step3_vl.py +3 -5
  326. sglang/srt/models/utils.py +11 -1
  327. sglang/srt/multimodal/processors/base_processor.py +6 -2
  328. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  329. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  330. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  331. sglang/srt/multimodal/processors/glm4v.py +1 -5
  332. sglang/srt/multimodal/processors/internvl.py +0 -2
  333. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  334. sglang/srt/multimodal/processors/mllama4.py +0 -8
  335. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  336. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  337. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  338. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  339. sglang/srt/parser/conversation.py +41 -0
  340. sglang/srt/parser/reasoning_parser.py +0 -1
  341. sglang/srt/sampling/custom_logit_processor.py +77 -2
  342. sglang/srt/sampling/sampling_batch_info.py +17 -22
  343. sglang/srt/sampling/sampling_params.py +70 -2
  344. sglang/srt/server_args.py +577 -73
  345. sglang/srt/server_args_config_parser.py +1 -1
  346. sglang/srt/single_batch_overlap.py +38 -28
  347. sglang/srt/speculative/base_spec_worker.py +34 -0
  348. sglang/srt/speculative/draft_utils.py +226 -0
  349. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  350. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  351. sglang/srt/speculative/eagle_info.py +57 -18
  352. sglang/srt/speculative/eagle_info_v2.py +458 -0
  353. sglang/srt/speculative/eagle_utils.py +138 -0
  354. sglang/srt/speculative/eagle_worker.py +83 -280
  355. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  356. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  357. sglang/srt/speculative/ngram_worker.py +12 -11
  358. sglang/srt/speculative/spec_info.py +2 -0
  359. sglang/srt/speculative/spec_utils.py +38 -3
  360. sglang/srt/speculative/standalone_worker.py +4 -14
  361. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  362. sglang/srt/two_batch_overlap.py +28 -14
  363. sglang/srt/utils/__init__.py +1 -1
  364. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  365. sglang/srt/utils/common.py +192 -47
  366. sglang/srt/utils/hf_transformers_utils.py +40 -17
  367. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  368. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  369. sglang/srt/utils/profile_merger.py +199 -0
  370. sglang/test/attention/test_flashattn_backend.py +1 -1
  371. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  372. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  373. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  374. sglang/test/few_shot_gsm8k_engine.py +2 -4
  375. sglang/test/kit_matched_stop.py +157 -0
  376. sglang/test/longbench_v2/__init__.py +1 -0
  377. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  378. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  379. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  380. sglang/test/run_eval.py +41 -0
  381. sglang/test/runners.py +2 -0
  382. sglang/test/send_one.py +42 -7
  383. sglang/test/simple_eval_common.py +3 -0
  384. sglang/test/simple_eval_gpqa.py +0 -1
  385. sglang/test/simple_eval_humaneval.py +0 -3
  386. sglang/test/simple_eval_longbench_v2.py +344 -0
  387. sglang/test/test_block_fp8.py +1 -2
  388. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  389. sglang/test/test_cutlass_moe.py +1 -2
  390. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  391. sglang/test/test_deterministic.py +232 -99
  392. sglang/test/test_deterministic_utils.py +73 -0
  393. sglang/test/test_disaggregation_utils.py +81 -0
  394. sglang/test/test_marlin_moe.py +0 -1
  395. sglang/test/test_utils.py +85 -20
  396. sglang/version.py +1 -1
  397. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
  398. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
  399. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  400. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  401. sglang/srt/speculative/build_eagle_tree.py +0 -427
  402. sglang/test/test_block_fp8_ep.py +0 -358
  403. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  404. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  405. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  406. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,5 @@
1
1
  from __future__ import annotations
2
2
 
3
- import sys
4
3
  from dataclasses import dataclass
5
4
  from typing import TYPE_CHECKING, Dict, List, Literal, Optional, TypeAlias
6
5
 
@@ -30,22 +29,23 @@ if TYPE_CHECKING:
30
29
  from sglang.srt.model_executor.model_runner import ModelRunner
31
30
  from sglang.srt.speculative.spec_info import SpecInput
32
31
 
32
+
33
33
  _is_hip = is_hip()
34
34
 
35
35
  if _is_hip:
36
36
  try:
37
- from aiter import (
37
+ from aiter import ( # noqa: F401
38
38
  flash_attn_varlen_func,
39
39
  mha_batch_prefill_func,
40
40
  paged_attention_ragged,
41
41
  )
42
- from aiter.mla import mla_decode_fwd, mla_prefill_fwd
42
+ from aiter.mla import mla_decode_fwd, mla_prefill_fwd # noqa: F401
43
43
  except ImportError:
44
44
  print(
45
45
  "aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
46
46
  )
47
47
  else:
48
- from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
48
+ from sgl_kernel.flash_attn import flash_attn_with_kvcache
49
49
 
50
50
 
51
51
  @dataclass(frozen=True)
@@ -140,16 +140,21 @@ def compute_cu_seqlens(seqlens: torch.Tensor) -> torch.Tensor:
140
140
  )
141
141
 
142
142
 
143
- _NSA_IMPL_T: TypeAlias = Literal[
144
- "flashmla_prefill", "flashmla_decode", "fa3", "tilelang"
145
- ]
143
+ _NSA_IMPL_T: TypeAlias = Literal["flashmla_sparse", "flashmla_kv", "fa3", "tilelang"]
146
144
 
147
145
  NSA_PREFILL_IMPL: _NSA_IMPL_T
148
146
  NSA_DECODE_IMPL: _NSA_IMPL_T
149
147
 
150
148
 
151
149
  class NativeSparseAttnBackend(AttentionBackend):
152
- def __init__(self, model_runner: ModelRunner):
150
+ def __init__(
151
+ self,
152
+ model_runner: ModelRunner,
153
+ skip_prefill: bool = False,
154
+ speculative_step_id=0,
155
+ topk=0,
156
+ speculative_num_steps=0,
157
+ ):
153
158
  super().__init__()
154
159
  self.forward_metadata: NSAMetadata
155
160
  self.device = model_runner.device
@@ -174,8 +179,8 @@ class NativeSparseAttnBackend(AttentionBackend):
174
179
  self.req_to_token = model_runner.req_to_token_pool.req_to_token
175
180
 
176
181
  global NSA_PREFILL_IMPL, NSA_DECODE_IMPL
177
- NSA_PREFILL_IMPL = model_runner.server_args.nsa_prefill
178
- NSA_DECODE_IMPL = model_runner.server_args.nsa_decode
182
+ NSA_PREFILL_IMPL = model_runner.server_args.nsa_prefill_backend
183
+ NSA_DECODE_IMPL = model_runner.server_args.nsa_decode_backend
179
184
 
180
185
  self._arange_buf = torch.arange(16384, device=self.device, dtype=torch.int32)
181
186
 
@@ -186,6 +191,14 @@ class NativeSparseAttnBackend(AttentionBackend):
186
191
  (max_bs + 1,), dtype=torch.int32, device=model_runner.device
187
192
  )
188
193
 
194
+ # Speculative decoding
195
+ self.topk = model_runner.server_args.speculative_eagle_topk or 0
196
+ self.speculative_num_steps = speculative_num_steps
197
+ self.speculative_num_draft_tokens = (
198
+ model_runner.server_args.speculative_num_draft_tokens
199
+ )
200
+ self.speculative_step_id = speculative_step_id
201
+
189
202
  def get_device_int32_arange(self, l: int) -> torch.Tensor:
190
203
  if l > len(self._arange_buf):
191
204
  next_pow_of_2 = 1 << (l - 1).bit_length()
@@ -209,13 +222,15 @@ class NativeSparseAttnBackend(AttentionBackend):
209
222
  batch_size = forward_batch.batch_size
210
223
  device = forward_batch.seq_lens.device
211
224
 
212
- assert (
213
- forward_batch.spec_info is None
214
- ), "Spec decoding is not supported for NSA backend now"
215
- cache_seqlens_int32 = forward_batch.seq_lens.to(torch.int32)
225
+ if forward_batch.forward_mode.is_target_verify():
226
+ draft_token_num = self.speculative_num_draft_tokens
227
+ else:
228
+ draft_token_num = 0
229
+
230
+ cache_seqlens_int32 = (forward_batch.seq_lens + draft_token_num).to(torch.int32)
216
231
  cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)
217
232
  assert forward_batch.seq_lens_cpu is not None
218
- max_seqlen_k = int(forward_batch.seq_lens_cpu.max().item())
233
+ max_seqlen_k = int(forward_batch.seq_lens_cpu.max().item() + draft_token_num)
219
234
  page_table = forward_batch.req_to_token_pool.req_to_token[
220
235
  forward_batch.req_pool_indices, :max_seqlen_k
221
236
  ]
@@ -225,6 +240,41 @@ class NativeSparseAttnBackend(AttentionBackend):
225
240
  max_seqlen_q = 1
226
241
  cu_seqlens_q = self.get_device_int32_arange(batch_size + 1)
227
242
  seqlens_expanded = cache_seqlens_int32
243
+ elif forward_batch.forward_mode.is_target_verify():
244
+ max_seqlen_q = self.speculative_num_draft_tokens
245
+ nsa_max_seqlen_q = self.speculative_num_draft_tokens
246
+ cu_seqlens_q = torch.arange(
247
+ 0,
248
+ batch_size * self.speculative_num_draft_tokens + 1,
249
+ 1,
250
+ dtype=torch.int32,
251
+ device=device,
252
+ )
253
+ extend_seq_lens_cpu = [self.speculative_num_draft_tokens] * batch_size
254
+ forward_batch.extend_seq_lens_cpu = extend_seq_lens_cpu
255
+
256
+ seqlens_int32_cpu = [
257
+ self.speculative_num_draft_tokens + kv_len
258
+ for kv_len in forward_batch.seq_lens_cpu.tolist()
259
+ ]
260
+ seqlens_expanded = torch.cat(
261
+ [
262
+ torch.arange(
263
+ kv_len - qo_len + 1,
264
+ kv_len + 1,
265
+ dtype=torch.int32,
266
+ device=device,
267
+ )
268
+ for qo_len, kv_len in zip(
269
+ extend_seq_lens_cpu,
270
+ seqlens_int32_cpu,
271
+ strict=True,
272
+ )
273
+ ]
274
+ )
275
+ page_table = torch.repeat_interleave(
276
+ page_table, repeats=self.speculative_num_draft_tokens, dim=0
277
+ )
228
278
  elif forward_batch.forward_mode.is_extend():
229
279
  assert (
230
280
  forward_batch.extend_seq_lens_cpu is not None
@@ -233,7 +283,11 @@ class NativeSparseAttnBackend(AttentionBackend):
233
283
  ), "All of them must not be None"
234
284
  extend_seq_lens_cpu = forward_batch.extend_seq_lens_cpu
235
285
  assert forward_batch.extend_seq_lens is not None
236
- if any(forward_batch.extend_prefix_lens_cpu):
286
+
287
+ if (
288
+ any(forward_batch.extend_prefix_lens_cpu)
289
+ or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
290
+ ):
237
291
  max_seqlen_q = max(extend_seq_lens_cpu)
238
292
  cu_seqlens_q = compute_cu_seqlens(
239
293
  forward_batch.extend_seq_lens.to(torch.int32)
@@ -278,9 +332,9 @@ class NativeSparseAttnBackend(AttentionBackend):
278
332
  flashmla_metadata=(
279
333
  self._compute_flashmla_metadata(
280
334
  cache_seqlens=nsa_cache_seqlens_int32,
281
- seq_len_q=1, # TODO handle MTP which is not 1
335
+ seq_len_q=1,
282
336
  )
283
- if NSA_DECODE_IMPL == "flashmla_decode"
337
+ if NSA_DECODE_IMPL == "flashmla_kv"
284
338
  else None
285
339
  ),
286
340
  nsa_cache_seqlens_int32=nsa_cache_seqlens_int32,
@@ -289,6 +343,7 @@ class NativeSparseAttnBackend(AttentionBackend):
289
343
  nsa_seqlens_expanded=seqlens_expanded,
290
344
  nsa_extend_seq_lens_list=extend_seq_lens_cpu,
291
345
  real_page_table=self._transform_table_1_to_real(page_table),
346
+ nsa_max_seqlen_q=1,
292
347
  )
293
348
 
294
349
  self.forward_metadata = metadata
@@ -303,7 +358,9 @@ class NativeSparseAttnBackend(AttentionBackend):
303
358
  to avoid memory allocations.
304
359
  """
305
360
  self.decode_cuda_graph_metadata: Dict = {
306
- "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
361
+ "cache_seqlens": torch.ones(
362
+ max_num_tokens, dtype=torch.int32, device=self.device
363
+ ),
307
364
  "cu_seqlens_q": torch.arange(
308
365
  0, max_bs + 1, dtype=torch.int32, device=self.device
309
366
  ),
@@ -312,7 +369,7 @@ class NativeSparseAttnBackend(AttentionBackend):
312
369
  ),
313
370
  # fake page_table for sparse_prefill
314
371
  "page_table": torch.zeros(
315
- max_bs,
372
+ max_num_tokens,
316
373
  self.max_context_len,
317
374
  dtype=torch.int32,
318
375
  device=self.device,
@@ -320,11 +377,11 @@ class NativeSparseAttnBackend(AttentionBackend):
320
377
  "flashmla_metadata": (
321
378
  self._compute_flashmla_metadata(
322
379
  cache_seqlens=torch.ones(
323
- max_bs, dtype=torch.int32, device=self.device
380
+ max_num_tokens, dtype=torch.int32, device=self.device
324
381
  ),
325
- seq_len_q=1, # TODO handle MTP which is not 1
382
+ seq_len_q=1,
326
383
  )
327
- if NSA_DECODE_IMPL == "flashmla_decode"
384
+ if NSA_DECODE_IMPL == "flashmla_kv"
328
385
  else None
329
386
  ),
330
387
  }
@@ -340,50 +397,166 @@ class NativeSparseAttnBackend(AttentionBackend):
340
397
  spec_info: Optional[SpecInput],
341
398
  ):
342
399
  """Initialize forward metadata for capturing CUDA graph."""
343
- assert forward_mode.is_decode_or_idle(), "Only support decode for now"
344
- assert (
345
- spec_info is None
346
- ), "Speculative decoding is not supported for NSA backend now"
400
+ if forward_mode.is_decode_or_idle():
401
+ # Normal Decode
402
+ # Get sequence information
403
+ cache_seqlens_int32 = seq_lens.to(torch.int32)
404
+ cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)
405
+
406
+ # Use max context length for seq_len_k
407
+ page_table_1 = self.decode_cuda_graph_metadata["page_table"][:bs, :]
408
+ max_seqlen_q = 1
409
+ max_seqlen_k = page_table_1.shape[1]
347
410
 
348
- # Normal Decode
349
- # Get sequence information
350
- cache_seqlens_int32 = seq_lens.to(torch.int32)
351
- cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)
411
+ # Precompute page table
412
+ # Precompute cumulative sequence lengths
352
413
 
353
- # Use max context length for seq_len_k
354
- page_table_1 = self.decode_cuda_graph_metadata["page_table"][:bs, :]
355
- max_seq_len_k = page_table_1.shape[1]
414
+ # NOTE(dark): this is always arange, since we are decoding
415
+ cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][: bs + 1]
416
+ nsa_cache_seqlens_int32 = compute_nsa_seqlens(
417
+ cache_seqlens_int32, nsa_index_topk=self.nsa_index_topk
418
+ )
356
419
 
357
- # Precompute page table
358
- # Precompute cumulative sequence lengths
420
+ seqlens_expanded = cache_seqlens_int32
421
+ nsa_extend_seq_lens_list = [1] * num_tokens
422
+ if NSA_DECODE_IMPL == "flashmla_kv":
423
+ flashmla_metadata = self.decode_cuda_graph_metadata[
424
+ "flashmla_metadata"
425
+ ].slice(slice(0, num_tokens + 1))
426
+ flashmla_metadata.copy_(
427
+ self._compute_flashmla_metadata(
428
+ cache_seqlens=nsa_cache_seqlens_int32,
429
+ seq_len_q=1,
430
+ )
431
+ )
432
+ else:
433
+ flashmla_metadata = None
434
+ elif forward_mode.is_target_verify():
435
+ cache_seqlens_int32 = (seq_lens + self.speculative_num_draft_tokens).to(
436
+ torch.int32
437
+ )
438
+ cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)
439
+ max_seqlen_q = 1
440
+ page_table_1 = self.decode_cuda_graph_metadata["page_table"][
441
+ : bs * self.speculative_num_draft_tokens, :
442
+ ]
443
+ max_seqlen_k = page_table_1.shape[1]
444
+
445
+ cu_seqlens_q = torch.arange(
446
+ 0,
447
+ bs * self.speculative_num_draft_tokens + 1,
448
+ 1,
449
+ dtype=torch.int32,
450
+ device=self.device,
451
+ )
359
452
 
360
- # NOTE(dark): this is always arange, since we are decoding
361
- cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][: bs + 1]
362
- nsa_cache_seqlens_int32 = compute_nsa_seqlens(
363
- cache_seqlens_int32, nsa_index_topk=self.nsa_index_topk
364
- )
365
- nsa_cu_seqlens_k = compute_cu_seqlens(nsa_cache_seqlens_int32)
366
- nsa_cu_seqlens_q = self.get_device_int32_arange(len(nsa_cu_seqlens_k))
367
- real_page_table = self._transform_table_1_to_real(page_table_1)
453
+ extend_seq_lens_cpu = [self.speculative_num_draft_tokens] * bs
368
454
 
369
- if NSA_DECODE_IMPL == "flashmla_decode":
370
- flashmla_metadata = self.decode_cuda_graph_metadata[
371
- "flashmla_metadata"
372
- ].slice(slice(0, bs + 1))
373
- flashmla_metadata.copy_(
374
- self._compute_flashmla_metadata(
375
- cache_seqlens=nsa_cache_seqlens_int32,
376
- seq_len_q=1, # TODO handle MTP which is not 1
455
+ seqlens_int32_cpu = [
456
+ self.speculative_num_draft_tokens + kv_len
457
+ for kv_len in seq_lens.tolist()
458
+ ]
459
+ seqlens_expanded = torch.cat(
460
+ [
461
+ torch.arange(
462
+ kv_len - qo_len + 1,
463
+ kv_len + 1,
464
+ dtype=torch.int32,
465
+ device=self.device,
466
+ )
467
+ for qo_len, kv_len in zip(
468
+ extend_seq_lens_cpu,
469
+ seqlens_int32_cpu,
470
+ strict=True,
471
+ )
472
+ ]
473
+ )
474
+ nsa_cache_seqlens_int32 = compute_nsa_seqlens(
475
+ seqlens_expanded, nsa_index_topk=self.nsa_index_topk
476
+ )
477
+ nsa_extend_seq_lens_list = [1] * bs * self.speculative_num_draft_tokens
478
+
479
+ if NSA_DECODE_IMPL == "flashmla_kv":
480
+ flashmla_metadata = self.decode_cuda_graph_metadata[
481
+ "flashmla_metadata"
482
+ ].slice(slice(0, bs * self.speculative_num_draft_tokens + 1))
483
+
484
+ flashmla_metadata.copy_(
485
+ self._compute_flashmla_metadata(
486
+ cache_seqlens=nsa_cache_seqlens_int32,
487
+ seq_len_q=1,
488
+ )
377
489
  )
490
+ else:
491
+ flashmla_metadata = None
492
+ elif forward_mode.is_draft_extend():
493
+ cache_seqlens_int32 = (seq_lens + self.speculative_num_draft_tokens).to(
494
+ torch.int32
378
495
  )
379
- else:
380
- flashmla_metadata = None
496
+ cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)
497
+ page_table_1 = self.decode_cuda_graph_metadata["page_table"][:bs, :]
498
+ max_seqlen_k = page_table_1.shape[1]
499
+
500
+ extend_seq_lens_cpu = [self.speculative_num_draft_tokens] * bs
501
+ extend_seq_lens = torch.full(
502
+ (bs,),
503
+ self.speculative_num_draft_tokens,
504
+ device=self.device,
505
+ dtype=torch.int32,
506
+ )
507
+
508
+ max_seqlen_q = max(extend_seq_lens_cpu)
509
+ cu_seqlens_q = compute_cu_seqlens(extend_seq_lens.to(torch.int32))
510
+
511
+ seqlens_int32_cpu = [
512
+ self.speculative_num_draft_tokens + kv_len
513
+ for kv_len in seq_lens.tolist()
514
+ ]
515
+ seqlens_expanded = torch.cat(
516
+ [
517
+ torch.arange(
518
+ kv_len - qo_len + 1,
519
+ kv_len + 1,
520
+ dtype=torch.int32,
521
+ device=self.device,
522
+ )
523
+ for qo_len, kv_len in zip(
524
+ extend_seq_lens_cpu,
525
+ seqlens_int32_cpu,
526
+ strict=True,
527
+ )
528
+ ]
529
+ )
530
+ nsa_cache_seqlens_int32 = compute_nsa_seqlens(
531
+ seqlens_expanded, nsa_index_topk=self.nsa_index_topk
532
+ )
533
+ nsa_extend_seq_lens_list = [1] * bs
534
+
535
+ if NSA_DECODE_IMPL == "flashmla_kv":
536
+ flashmla_metadata = self.decode_cuda_graph_metadata[
537
+ "flashmla_metadata"
538
+ ].slice(slice(0, bs * self.speculative_num_draft_tokens + 1))
539
+ # As the DeepGemm is not support for q_len = 3/4 in Indexer and every token has independent topk_indices,
540
+ # we made the Q shape [bs * speculative_num_draft_tokens, 1, head_nums, dim].
541
+ # So seq_len_q is 1 for flashmla_metadata in target_verify and draft_extend mode.
542
+ flashmla_metadata.copy_(
543
+ self._compute_flashmla_metadata(
544
+ cache_seqlens=nsa_cache_seqlens_int32,
545
+ seq_len_q=1,
546
+ )
547
+ )
548
+ else:
549
+ flashmla_metadata = None
550
+
551
+ nsa_cu_seqlens_k = compute_cu_seqlens(nsa_cache_seqlens_int32)
552
+ nsa_cu_seqlens_q = self.get_device_int32_arange(len(nsa_cu_seqlens_k))
553
+ real_page_table = self._transform_table_1_to_real(page_table_1)
381
554
 
382
555
  metadata = NSAMetadata(
383
556
  page_size=self.real_page_size,
384
557
  cache_seqlens_int32=cache_seqlens_int32,
385
- max_seq_len_q=1,
386
- max_seq_len_k=max_seq_len_k,
558
+ max_seq_len_q=max_seqlen_q,
559
+ max_seq_len_k=max_seqlen_k,
387
560
  cu_seqlens_q=cu_seqlens_q,
388
561
  cu_seqlens_k=cu_seqlens_k,
389
562
  page_table_1=page_table_1,
@@ -391,9 +564,9 @@ class NativeSparseAttnBackend(AttentionBackend):
391
564
  nsa_cache_seqlens_int32=nsa_cache_seqlens_int32,
392
565
  nsa_cu_seqlens_q=nsa_cu_seqlens_q,
393
566
  nsa_cu_seqlens_k=nsa_cu_seqlens_k,
394
- nsa_seqlens_expanded=cache_seqlens_int32,
567
+ nsa_seqlens_expanded=seqlens_expanded,
395
568
  real_page_table=real_page_table,
396
- nsa_extend_seq_lens_list=[1] * bs,
569
+ nsa_extend_seq_lens_list=nsa_extend_seq_lens_list,
397
570
  )
398
571
  self.decode_cuda_graph_metadata[bs] = metadata
399
572
  self.forward_metadata = metadata
@@ -412,33 +585,119 @@ class NativeSparseAttnBackend(AttentionBackend):
412
585
  ):
413
586
  """Initialize forward metadata for replaying CUDA graph."""
414
587
  assert seq_lens_cpu is not None
415
- assert forward_mode.is_decode_or_idle(), "Only support decode for now"
416
- assert (
417
- spec_info is None
418
- ), "Speculative decoding is not supported for NSA backend now"
588
+
419
589
  seq_lens = seq_lens[:bs]
420
590
  seq_lens_cpu = seq_lens_cpu[:bs]
421
591
  req_pool_indices = req_pool_indices[:bs]
422
592
 
423
593
  # Normal Decode
424
594
  metadata: NSAMetadata = self.decode_cuda_graph_metadata[bs]
425
- max_len = int(seq_lens_cpu.max().item())
595
+ if forward_mode.is_decode_or_idle():
596
+ # Normal Decode
597
+ max_len = int(seq_lens_cpu.max().item())
598
+
599
+ cache_seqlens = seq_lens.to(torch.int32)
600
+ metadata.cache_seqlens_int32.copy_(cache_seqlens)
601
+ metadata.cu_seqlens_k[1:].copy_(
602
+ torch.cumsum(cache_seqlens, dim=0, dtype=torch.int32)
603
+ )
604
+ page_indices = self.req_to_token[req_pool_indices, :max_len]
605
+ metadata.page_table_1[:, :max_len].copy_(page_indices)
606
+ nsa_cache_seqlens = compute_nsa_seqlens(
607
+ cache_seqlens, nsa_index_topk=self.nsa_index_topk
608
+ )
609
+ metadata.nsa_cache_seqlens_int32.copy_(nsa_cache_seqlens)
610
+ seqlens_expanded = cache_seqlens
611
+ elif forward_mode.is_target_verify():
612
+ max_seqlen_k = int(
613
+ seq_lens_cpu.max().item() + self.speculative_num_draft_tokens
614
+ )
426
615
 
427
- cache_seqlens = seq_lens.to(torch.int32)
428
- metadata.cache_seqlens_int32.copy_(cache_seqlens)
429
- metadata.cu_seqlens_k[1:].copy_(
430
- torch.cumsum(cache_seqlens, dim=0, dtype=torch.int32)
431
- )
432
- page_indices = self.req_to_token[req_pool_indices, :max_len]
433
- metadata.page_table_1[:, :max_len].copy_(page_indices)
616
+ cache_seqlens = (seq_lens + self.speculative_num_draft_tokens).to(
617
+ torch.int32
618
+ )
619
+ metadata.cache_seqlens_int32.copy_(cache_seqlens)
620
+ metadata.cu_seqlens_k[1:].copy_(
621
+ torch.cumsum(cache_seqlens, dim=0, dtype=torch.int32)
622
+ )
623
+ page_indices = self.req_to_token[req_pool_indices, :max_seqlen_k]
624
+ page_indices = torch.repeat_interleave(
625
+ page_indices, repeats=self.speculative_num_draft_tokens, dim=0
626
+ )
627
+ metadata.page_table_1[:, :max_seqlen_k].copy_(page_indices)
628
+ extend_seq_lens_cpu = [self.speculative_num_draft_tokens] * bs
629
+
630
+ seqlens_int32_cpu = [
631
+ self.speculative_num_draft_tokens + kv_len
632
+ for kv_len in seq_lens_cpu.tolist()
633
+ ]
634
+ seqlens_expanded = torch.cat(
635
+ [
636
+ torch.arange(
637
+ kv_len - qo_len + 1,
638
+ kv_len + 1,
639
+ dtype=torch.int32,
640
+ device=self.device,
641
+ )
642
+ for qo_len, kv_len in zip(
643
+ extend_seq_lens_cpu,
644
+ seqlens_int32_cpu,
645
+ strict=True,
646
+ )
647
+ ]
648
+ )
649
+ metadata.nsa_seqlens_expanded.copy_(seqlens_expanded)
650
+ nsa_cache_seqlens = compute_nsa_seqlens(
651
+ seqlens_expanded, self.nsa_index_topk
652
+ )
653
+ metadata.nsa_cache_seqlens_int32.copy_(nsa_cache_seqlens)
654
+ elif forward_mode.is_draft_extend():
655
+ max_seqlen_k = int(seq_lens_cpu.max().item())
656
+ cache_seqlens = seq_lens.to(torch.int32)
657
+ metadata.cache_seqlens_int32.copy_(cache_seqlens)
658
+ metadata.cu_seqlens_k[1:].copy_(
659
+ torch.cumsum(cache_seqlens, dim=0, dtype=torch.int32)
660
+ )
661
+ page_indices = self.req_to_token[req_pool_indices, :max_seqlen_k]
662
+ metadata.page_table_1[:, :max_seqlen_k].copy_(page_indices)
663
+ extend_seq_lens_cpu = spec_info.accept_length[:bs].tolist()
664
+
665
+ seqlens_int32_cpu = [
666
+ self.speculative_num_draft_tokens + kv_len
667
+ for kv_len in seq_lens_cpu.tolist()
668
+ ]
669
+ seqlens_expanded = torch.cat(
670
+ [
671
+ torch.arange(
672
+ kv_len - qo_len + 1,
673
+ kv_len + 1,
674
+ dtype=torch.int32,
675
+ device=self.device,
676
+ )
677
+ for qo_len, kv_len in zip(
678
+ extend_seq_lens_cpu,
679
+ seqlens_int32_cpu,
680
+ strict=True,
681
+ )
682
+ ]
683
+ )
684
+ metadata.nsa_seqlens_expanded[: seqlens_expanded.size(0)].copy_(
685
+ seqlens_expanded
686
+ )
687
+ nsa_cache_seqlens = compute_nsa_seqlens(
688
+ seqlens_expanded, self.nsa_index_topk
689
+ )
690
+ metadata.nsa_cache_seqlens_int32[: seqlens_expanded.size(0)].copy_(
691
+ nsa_cache_seqlens
692
+ )
693
+ seqlens_expanded_size = seqlens_expanded.size(0)
434
694
  assert (
435
695
  metadata.nsa_cache_seqlens_int32 is not None
436
696
  and metadata.nsa_cu_seqlens_k is not None
437
697
  and self.nsa_index_topk is not None
438
698
  )
439
- nsa_cache_seqlens = compute_nsa_seqlens(cache_seqlens, self.nsa_index_topk)
440
- metadata.nsa_cache_seqlens_int32.copy_(nsa_cache_seqlens)
441
- metadata.nsa_cu_seqlens_k[1:].copy_(
699
+
700
+ metadata.nsa_cu_seqlens_k[1 : 1 + seqlens_expanded_size].copy_(
442
701
  torch.cumsum(nsa_cache_seqlens, dim=0, dtype=torch.int32)
443
702
  )
444
703
  # NOTE(dark): (nsa-) cu_seqlens_q is always arange, no need to copy
@@ -451,11 +710,14 @@ class NativeSparseAttnBackend(AttentionBackend):
451
710
  else:
452
711
  assert metadata.real_page_table is metadata.page_table_1
453
712
 
454
- if NSA_DECODE_IMPL == "flashmla_decode":
455
- metadata.flashmla_metadata.copy_(
713
+ if NSA_DECODE_IMPL == "flashmla_kv":
714
+ flashmla_metadata = metadata.flashmla_metadata.slice(
715
+ slice(0, seqlens_expanded_size + 1)
716
+ )
717
+ flashmla_metadata.copy_(
456
718
  self._compute_flashmla_metadata(
457
719
  cache_seqlens=nsa_cache_seqlens,
458
- seq_len_q=1, # TODO handle MTP which is not 1
720
+ seq_len_q=1,
459
721
  )
460
722
  )
461
723
 
@@ -474,10 +736,7 @@ class NativeSparseAttnBackend(AttentionBackend):
474
736
  k_rope: Optional[torch.Tensor] = None,
475
737
  topk_indices: Optional[torch.Tensor] = None,
476
738
  ) -> torch.Tensor:
477
- assert (
478
- not forward_batch.forward_mode.is_target_verify()
479
- and not forward_batch.forward_mode.is_draft_extend()
480
- ), "NSA backend doesn't support speculative decoding"
739
+
481
740
  if k is not None:
482
741
  assert v is not None
483
742
  if save_kv_cache:
@@ -542,20 +801,20 @@ class NativeSparseAttnBackend(AttentionBackend):
542
801
  sm_scale=layer.scaling,
543
802
  v_head_dim=layer.v_head_dim,
544
803
  )
545
- elif NSA_PREFILL_IMPL == "flashmla_prefill":
804
+ elif NSA_PREFILL_IMPL == "flashmla_sparse":
546
805
  if q_rope is not None:
547
806
  q_all = torch.cat([q_nope, q_rope], dim=-1)
548
- return self._forward_flashmla_prefill(
807
+ return self._forward_flashmla_sparse(
549
808
  q_all=q_all,
550
809
  kv_cache=kv_cache,
551
810
  page_table_1=page_table_1,
552
811
  sm_scale=layer.scaling,
553
812
  v_head_dim=layer.v_head_dim,
554
813
  )
555
- elif NSA_PREFILL_IMPL == "flashmla_decode":
814
+ elif NSA_PREFILL_IMPL == "flashmla_kv":
556
815
  if q_rope is not None:
557
816
  q_all = torch.cat([q_nope, q_rope], dim=-1)
558
- return self._forward_flashmla_decode(
817
+ return self._forward_flashmla_kv(
559
818
  q_all=q_all,
560
819
  kv_cache=kv_cache,
561
820
  sm_scale=layer.scaling,
@@ -636,20 +895,20 @@ class NativeSparseAttnBackend(AttentionBackend):
636
895
  page_size=1,
637
896
  )
638
897
 
639
- if NSA_DECODE_IMPL == "flashmla_prefill":
898
+ if NSA_DECODE_IMPL == "flashmla_sparse":
640
899
  if q_rope is not None:
641
900
  q_all = torch.cat([q_nope, q_rope], dim=-1)
642
- return self._forward_flashmla_prefill(
901
+ return self._forward_flashmla_sparse(
643
902
  q_all=q_all,
644
903
  kv_cache=kv_cache,
645
904
  page_table_1=page_table_1,
646
905
  sm_scale=layer.scaling,
647
906
  v_head_dim=layer.v_head_dim,
648
907
  )
649
- elif NSA_DECODE_IMPL == "flashmla_decode":
908
+ elif NSA_DECODE_IMPL == "flashmla_kv":
650
909
  if q_rope is not None:
651
910
  q_all = torch.cat([q_nope, q_rope], dim=-1)
652
- return self._forward_flashmla_decode(
911
+ return self._forward_flashmla_kv(
653
912
  q_all=q_all,
654
913
  kv_cache=kv_cache,
655
914
  sm_scale=layer.scaling,
@@ -737,7 +996,7 @@ class NativeSparseAttnBackend(AttentionBackend):
737
996
  )
738
997
  return o # type: ignore
739
998
 
740
- def _forward_flashmla_prefill(
999
+ def _forward_flashmla_sparse(
741
1000
  self,
742
1001
  q_all: torch.Tensor,
743
1002
  kv_cache: torch.Tensor,
@@ -756,7 +1015,7 @@ class NativeSparseAttnBackend(AttentionBackend):
756
1015
  )
757
1016
  return o
758
1017
 
759
- def _forward_flashmla_decode(
1018
+ def _forward_flashmla_kv(
760
1019
  self,
761
1020
  q_all: torch.Tensor,
762
1021
  kv_cache: torch.Tensor,
@@ -885,3 +1144,58 @@ class NativeSparseAttnBackend(AttentionBackend):
885
1144
  flashmla_metadata=flashmla_metadata,
886
1145
  num_splits=num_splits,
887
1146
  )
1147
+
1148
+
1149
+ class NativeSparseAttnMultiStepBackend:
1150
+
1151
+ def __init__(
1152
+ self, model_runner: ModelRunner, topk: int, speculative_num_steps: int
1153
+ ):
1154
+ self.model_runner = model_runner
1155
+ self.topk = topk
1156
+ self.speculative_num_steps = speculative_num_steps
1157
+ self.attn_backends = []
1158
+ for i in range(self.speculative_num_steps):
1159
+ self.attn_backends.append(
1160
+ NativeSparseAttnBackend(
1161
+ model_runner,
1162
+ speculative_step_id=i,
1163
+ topk=self.topk,
1164
+ speculative_num_steps=self.speculative_num_steps,
1165
+ )
1166
+ )
1167
+
1168
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
1169
+ for i in range(self.speculative_num_steps - 1):
1170
+ self.attn_backends[i].init_forward_metadata(forward_batch)
1171
+
1172
+ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
1173
+ for i in range(self.speculative_num_steps):
1174
+ self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens)
1175
+
1176
+ def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
1177
+ for i in range(self.speculative_num_steps):
1178
+ self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
1179
+ forward_batch.batch_size,
1180
+ forward_batch.batch_size * self.topk,
1181
+ forward_batch.req_pool_indices,
1182
+ forward_batch.seq_lens,
1183
+ encoder_lens=None,
1184
+ forward_mode=ForwardMode.DECODE,
1185
+ spec_info=forward_batch.spec_info,
1186
+ )
1187
+
1188
+ def init_forward_metadata_replay_cuda_graph(
1189
+ self, forward_batch: ForwardBatch, bs: int
1190
+ ):
1191
+ for i in range(self.speculative_num_steps):
1192
+ self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
1193
+ bs,
1194
+ forward_batch.req_pool_indices,
1195
+ forward_batch.seq_lens,
1196
+ seq_lens_sum=-1,
1197
+ encoder_lens=None,
1198
+ forward_mode=ForwardMode.DECODE,
1199
+ spec_info=forward_batch.spec_info,
1200
+ seq_lens_cpu=forward_batch.seq_lens_cpu,
1201
+ )