sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1201 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import TYPE_CHECKING, Dict, List, Literal, Optional, TypeAlias
5
+
6
+ import torch
7
+
8
+ from sglang.srt.configs.model_config import get_nsa_index_topk, is_deepseek_nsa
9
+ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
10
+ from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata
11
+ from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
12
+ from sglang.srt.layers.attention.nsa.transform_index import (
13
+ transform_index_page_table_decode,
14
+ transform_index_page_table_prefill,
15
+ )
16
+ from sglang.srt.layers.attention.nsa.utils import (
17
+ NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8,
18
+ NSA_FUSE_TOPK,
19
+ compute_nsa_seqlens,
20
+ )
21
+ from sglang.srt.layers.dp_attention import get_attention_tp_size
22
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
23
+ from sglang.srt.utils import is_hip
24
+
25
+ # from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
26
+
27
+ if TYPE_CHECKING:
28
+ from sglang.srt.layers.radix_attention import RadixAttention
29
+ from sglang.srt.model_executor.model_runner import ModelRunner
30
+ from sglang.srt.speculative.spec_info import SpecInput
31
+
32
+
33
+ _is_hip = is_hip()
34
+
35
+ if _is_hip:
36
+ try:
37
+ from aiter import ( # noqa: F401
38
+ flash_attn_varlen_func,
39
+ mha_batch_prefill_func,
40
+ paged_attention_ragged,
41
+ )
42
+ from aiter.mla import mla_decode_fwd, mla_prefill_fwd # noqa: F401
43
+ except ImportError:
44
+ print(
45
+ "aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
46
+ )
47
+ else:
48
+ from sgl_kernel.flash_attn import flash_attn_with_kvcache
49
+
50
+
51
+ @dataclass(frozen=True)
52
+ class NSAFlashMLAMetadata:
53
+ """Metadata only needed by FlashMLA"""
54
+
55
+ flashmla_metadata: torch.Tensor
56
+ num_splits: torch.Tensor
57
+
58
+ def slice(self, sli):
59
+ return NSAFlashMLAMetadata(
60
+ flashmla_metadata=self.flashmla_metadata,
61
+ num_splits=self.num_splits[sli],
62
+ )
63
+
64
+ def copy_(self, other: "NSAFlashMLAMetadata"):
65
+ self.flashmla_metadata.copy_(other.flashmla_metadata)
66
+ self.num_splits.copy_(other.num_splits)
67
+
68
+
69
+ @dataclass(frozen=True)
70
+ class NSAMetadata:
71
+ page_size: int
72
+
73
+ # Sequence lengths for the forward batch
74
+ cache_seqlens_int32: torch.Tensor
75
+ # Maximum sequence length for query
76
+ max_seq_len_q: int
77
+ # Maximum sequence length for key
78
+ max_seq_len_k: int
79
+ # Cumulative sequence lengths for query
80
+ cu_seqlens_q: torch.Tensor
81
+ # Cumulative sequence lengths for key
82
+ cu_seqlens_k: torch.Tensor
83
+ # Page table, the index of KV Cache Tables/Blocks
84
+ # this table is always with page_size = 1
85
+ page_table_1: torch.Tensor
86
+
87
+ # NOTE(dark): This will property be used in:
88
+ # 1. dense decode/prefill, we use paged flash attention, need real_page_table
89
+ # 2. sparse decode/prefill, indexer need real_page_table to compute the score
90
+ real_page_table: torch.Tensor
91
+
92
+ # NSA metadata (nsa prefill are expanded)
93
+ nsa_cache_seqlens_int32: torch.Tensor # this seqlens is clipped to `topk`
94
+ nsa_cu_seqlens_q: torch.Tensor # must be arange(0, len(nsa_cu_seqlens_k))
95
+ nsa_cu_seqlens_k: torch.Tensor # cumsum of `nsa_cache_seqlens_int32`
96
+ nsa_extend_seq_lens_list: List[int]
97
+ nsa_seqlens_expanded: torch.Tensor # expanded, unclipped `seqlens`
98
+ nsa_max_seqlen_q: Literal[1] = 1 # always 1 for decode, variable for extend
99
+
100
+ flashmla_metadata: Optional[NSAFlashMLAMetadata] = None
101
+
102
+
103
+ @dataclass(frozen=True)
104
+ class NSAIndexerMetadata(BaseIndexerMetadata):
105
+ attn_metadata: NSAMetadata
106
+
107
+ def get_seqlens_int32(self) -> torch.Tensor:
108
+ return self.attn_metadata.cache_seqlens_int32
109
+
110
+ def get_page_table_64(self) -> torch.Tensor:
111
+ return self.attn_metadata.real_page_table
112
+
113
+ def get_seqlens_expanded(self) -> torch.Tensor:
114
+ return self.attn_metadata.nsa_seqlens_expanded
115
+
116
+ def topk_transform(
117
+ self,
118
+ logits: torch.Tensor,
119
+ topk: int,
120
+ ) -> torch.Tensor:
121
+ from sgl_kernel import fast_topk_transform_fused, fast_topk_v2
122
+
123
+ if not NSA_FUSE_TOPK:
124
+ return fast_topk_v2(logits, self.get_seqlens_expanded(), topk)
125
+
126
+ # NOTE(dark): if fused, we return a transformed page table directly
127
+ return fast_topk_transform_fused(
128
+ score=logits,
129
+ lengths=self.get_seqlens_expanded(),
130
+ page_table_size_1=self.attn_metadata.page_table_1,
131
+ cu_seqlens_q=self.attn_metadata.cu_seqlens_q,
132
+ topk=topk,
133
+ )
134
+
135
+
136
+ def compute_cu_seqlens(seqlens: torch.Tensor) -> torch.Tensor:
137
+ assert seqlens.dtype == torch.int32 and seqlens.is_cuda
138
+ return torch.nn.functional.pad(
139
+ torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)
140
+ )
141
+
142
+
143
+ _NSA_IMPL_T: TypeAlias = Literal["flashmla_sparse", "flashmla_kv", "fa3", "tilelang"]
144
+
145
+ NSA_PREFILL_IMPL: _NSA_IMPL_T
146
+ NSA_DECODE_IMPL: _NSA_IMPL_T
147
+
148
+
149
+ class NativeSparseAttnBackend(AttentionBackend):
150
+ def __init__(
151
+ self,
152
+ model_runner: ModelRunner,
153
+ skip_prefill: bool = False,
154
+ speculative_step_id=0,
155
+ topk=0,
156
+ speculative_num_steps=0,
157
+ ):
158
+ super().__init__()
159
+ self.forward_metadata: NSAMetadata
160
+ self.device = model_runner.device
161
+ assert isinstance(model_runner.page_size, int)
162
+ self.real_page_size = model_runner.page_size
163
+ self.num_splits = (
164
+ 1 if model_runner.server_args.enable_deterministic_inference else 0
165
+ )
166
+ self.use_nsa = is_deepseek_nsa(model_runner.model_config.hf_config)
167
+ assert self.use_nsa, "NSA backend only supports DeepSeek NSA"
168
+ self.nsa_kv_cache_store_fp8 = (
169
+ model_runner.token_to_kv_pool.nsa_kv_cache_store_fp8
170
+ )
171
+ self.nsa_index_topk = get_nsa_index_topk(model_runner.model_config.hf_config)
172
+ self.max_context_len = model_runner.model_config.context_len
173
+ self.num_q_heads = (
174
+ model_runner.model_config.num_attention_heads // get_attention_tp_size()
175
+ )
176
+ self.kv_cache_dim = model_runner.token_to_kv_pool.kv_cache_dim
177
+
178
+ assert model_runner.req_to_token_pool is not None
179
+ self.req_to_token = model_runner.req_to_token_pool.req_to_token
180
+
181
+ global NSA_PREFILL_IMPL, NSA_DECODE_IMPL
182
+ NSA_PREFILL_IMPL = model_runner.server_args.nsa_prefill_backend
183
+ NSA_DECODE_IMPL = model_runner.server_args.nsa_decode_backend
184
+
185
+ self._arange_buf = torch.arange(16384, device=self.device, dtype=torch.int32)
186
+
187
+ if _is_hip:
188
+ max_bs = model_runner.req_to_token_pool.size
189
+
190
+ self.kv_indptr = torch.zeros(
191
+ (max_bs + 1,), dtype=torch.int32, device=model_runner.device
192
+ )
193
+
194
+ # Speculative decoding
195
+ self.topk = model_runner.server_args.speculative_eagle_topk or 0
196
+ self.speculative_num_steps = speculative_num_steps
197
+ self.speculative_num_draft_tokens = (
198
+ model_runner.server_args.speculative_num_draft_tokens
199
+ )
200
+ self.speculative_step_id = speculative_step_id
201
+
202
+ def get_device_int32_arange(self, l: int) -> torch.Tensor:
203
+ if l > len(self._arange_buf):
204
+ next_pow_of_2 = 1 << (l - 1).bit_length()
205
+ self._arange_buf = torch.arange(
206
+ next_pow_of_2, device=self.device, dtype=torch.int32
207
+ )
208
+ return self._arange_buf[:l]
209
+
210
+ def _transform_table_1_to_real(self, page_table: torch.Tensor) -> torch.Tensor:
211
+ page_size = self.real_page_size
212
+ if page_size == 1:
213
+ return page_table
214
+ max_seqlen_k = page_table.shape[1]
215
+ strided_indices = torch.arange(
216
+ 0, max_seqlen_k, page_size, device=page_table.device, dtype=torch.int32
217
+ )
218
+ return page_table[:, strided_indices] // page_size
219
+
220
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
221
+ """Init the metadata for a forward pass."""
222
+ batch_size = forward_batch.batch_size
223
+ device = forward_batch.seq_lens.device
224
+
225
+ if forward_batch.forward_mode.is_target_verify():
226
+ draft_token_num = self.speculative_num_draft_tokens
227
+ else:
228
+ draft_token_num = 0
229
+
230
+ cache_seqlens_int32 = (forward_batch.seq_lens + draft_token_num).to(torch.int32)
231
+ cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)
232
+ assert forward_batch.seq_lens_cpu is not None
233
+ max_seqlen_k = int(forward_batch.seq_lens_cpu.max().item() + draft_token_num)
234
+ page_table = forward_batch.req_to_token_pool.req_to_token[
235
+ forward_batch.req_pool_indices, :max_seqlen_k
236
+ ]
237
+
238
+ if forward_batch.forward_mode.is_decode_or_idle():
239
+ extend_seq_lens_cpu = [1] * batch_size
240
+ max_seqlen_q = 1
241
+ cu_seqlens_q = self.get_device_int32_arange(batch_size + 1)
242
+ seqlens_expanded = cache_seqlens_int32
243
+ elif forward_batch.forward_mode.is_target_verify():
244
+ max_seqlen_q = self.speculative_num_draft_tokens
245
+ nsa_max_seqlen_q = self.speculative_num_draft_tokens
246
+ cu_seqlens_q = torch.arange(
247
+ 0,
248
+ batch_size * self.speculative_num_draft_tokens + 1,
249
+ 1,
250
+ dtype=torch.int32,
251
+ device=device,
252
+ )
253
+ extend_seq_lens_cpu = [self.speculative_num_draft_tokens] * batch_size
254
+ forward_batch.extend_seq_lens_cpu = extend_seq_lens_cpu
255
+
256
+ seqlens_int32_cpu = [
257
+ self.speculative_num_draft_tokens + kv_len
258
+ for kv_len in forward_batch.seq_lens_cpu.tolist()
259
+ ]
260
+ seqlens_expanded = torch.cat(
261
+ [
262
+ torch.arange(
263
+ kv_len - qo_len + 1,
264
+ kv_len + 1,
265
+ dtype=torch.int32,
266
+ device=device,
267
+ )
268
+ for qo_len, kv_len in zip(
269
+ extend_seq_lens_cpu,
270
+ seqlens_int32_cpu,
271
+ strict=True,
272
+ )
273
+ ]
274
+ )
275
+ page_table = torch.repeat_interleave(
276
+ page_table, repeats=self.speculative_num_draft_tokens, dim=0
277
+ )
278
+ elif forward_batch.forward_mode.is_extend():
279
+ assert (
280
+ forward_batch.extend_seq_lens_cpu is not None
281
+ and forward_batch.extend_seq_lens is not None
282
+ and forward_batch.extend_prefix_lens_cpu is not None
283
+ ), "All of them must not be None"
284
+ extend_seq_lens_cpu = forward_batch.extend_seq_lens_cpu
285
+ assert forward_batch.extend_seq_lens is not None
286
+
287
+ if (
288
+ any(forward_batch.extend_prefix_lens_cpu)
289
+ or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
290
+ ):
291
+ max_seqlen_q = max(extend_seq_lens_cpu)
292
+ cu_seqlens_q = compute_cu_seqlens(
293
+ forward_batch.extend_seq_lens.to(torch.int32)
294
+ )
295
+ else:
296
+ max_seqlen_q = max_seqlen_k
297
+ cu_seqlens_q = cu_seqlens_k
298
+ seqlens_expanded = torch.cat(
299
+ [
300
+ torch.arange(
301
+ kv_len - qo_len + 1,
302
+ kv_len + 1,
303
+ dtype=torch.int32,
304
+ device=device,
305
+ )
306
+ for qo_len, kv_len in zip(
307
+ forward_batch.extend_seq_lens_cpu,
308
+ forward_batch.seq_lens_cpu.tolist(),
309
+ strict=True,
310
+ )
311
+ ]
312
+ )
313
+ else:
314
+ assert False, f"Unsupported {forward_batch.forward_mode = }"
315
+
316
+ # 1D, expanded seqlens (1D means cheap to compute, so always compute it)
317
+ nsa_cache_seqlens_int32 = compute_nsa_seqlens(
318
+ original_seq_lens=seqlens_expanded,
319
+ nsa_index_topk=self.nsa_index_topk,
320
+ )
321
+ nsa_cu_seqlens_k = compute_cu_seqlens(nsa_cache_seqlens_int32)
322
+ nsa_cu_seqlens_q = self.get_device_int32_arange(len(nsa_cu_seqlens_k))
323
+
324
+ metadata = NSAMetadata(
325
+ page_size=self.real_page_size,
326
+ cache_seqlens_int32=cache_seqlens_int32,
327
+ max_seq_len_q=max_seqlen_q,
328
+ max_seq_len_k=max_seqlen_k,
329
+ cu_seqlens_q=cu_seqlens_q,
330
+ cu_seqlens_k=cu_seqlens_k,
331
+ page_table_1=page_table,
332
+ flashmla_metadata=(
333
+ self._compute_flashmla_metadata(
334
+ cache_seqlens=nsa_cache_seqlens_int32,
335
+ seq_len_q=1,
336
+ )
337
+ if NSA_DECODE_IMPL == "flashmla_kv"
338
+ else None
339
+ ),
340
+ nsa_cache_seqlens_int32=nsa_cache_seqlens_int32,
341
+ nsa_cu_seqlens_q=nsa_cu_seqlens_q,
342
+ nsa_cu_seqlens_k=nsa_cu_seqlens_k,
343
+ nsa_seqlens_expanded=seqlens_expanded,
344
+ nsa_extend_seq_lens_list=extend_seq_lens_cpu,
345
+ real_page_table=self._transform_table_1_to_real(page_table),
346
+ nsa_max_seqlen_q=1,
347
+ )
348
+
349
+ self.forward_metadata = metadata
350
+
351
+ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
352
+ """Initialize CUDA graph state for the attention backend.
353
+
354
+ Args:
355
+ max_bs (int): Maximum batch size to support in CUDA graphs
356
+
357
+ This creates fixed-size tensors that will be reused during CUDA graph replay
358
+ to avoid memory allocations.
359
+ """
360
+ self.decode_cuda_graph_metadata: Dict = {
361
+ "cache_seqlens": torch.ones(
362
+ max_num_tokens, dtype=torch.int32, device=self.device
363
+ ),
364
+ "cu_seqlens_q": torch.arange(
365
+ 0, max_bs + 1, dtype=torch.int32, device=self.device
366
+ ),
367
+ "cu_seqlens_k": torch.zeros(
368
+ max_bs + 1, dtype=torch.int32, device=self.device
369
+ ),
370
+ # fake page_table for sparse_prefill
371
+ "page_table": torch.zeros(
372
+ max_num_tokens,
373
+ self.max_context_len,
374
+ dtype=torch.int32,
375
+ device=self.device,
376
+ ),
377
+ "flashmla_metadata": (
378
+ self._compute_flashmla_metadata(
379
+ cache_seqlens=torch.ones(
380
+ max_num_tokens, dtype=torch.int32, device=self.device
381
+ ),
382
+ seq_len_q=1,
383
+ )
384
+ if NSA_DECODE_IMPL == "flashmla_kv"
385
+ else None
386
+ ),
387
+ }
388
+
389
+ def init_forward_metadata_capture_cuda_graph(
390
+ self,
391
+ bs: int,
392
+ num_tokens: int,
393
+ req_pool_indices: torch.Tensor,
394
+ seq_lens: torch.Tensor,
395
+ encoder_lens: Optional[torch.Tensor],
396
+ forward_mode: ForwardMode,
397
+ spec_info: Optional[SpecInput],
398
+ ):
399
+ """Initialize forward metadata for capturing CUDA graph."""
400
+ if forward_mode.is_decode_or_idle():
401
+ # Normal Decode
402
+ # Get sequence information
403
+ cache_seqlens_int32 = seq_lens.to(torch.int32)
404
+ cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)
405
+
406
+ # Use max context length for seq_len_k
407
+ page_table_1 = self.decode_cuda_graph_metadata["page_table"][:bs, :]
408
+ max_seqlen_q = 1
409
+ max_seqlen_k = page_table_1.shape[1]
410
+
411
+ # Precompute page table
412
+ # Precompute cumulative sequence lengths
413
+
414
+ # NOTE(dark): this is always arange, since we are decoding
415
+ cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][: bs + 1]
416
+ nsa_cache_seqlens_int32 = compute_nsa_seqlens(
417
+ cache_seqlens_int32, nsa_index_topk=self.nsa_index_topk
418
+ )
419
+
420
+ seqlens_expanded = cache_seqlens_int32
421
+ nsa_extend_seq_lens_list = [1] * num_tokens
422
+ if NSA_DECODE_IMPL == "flashmla_kv":
423
+ flashmla_metadata = self.decode_cuda_graph_metadata[
424
+ "flashmla_metadata"
425
+ ].slice(slice(0, num_tokens + 1))
426
+ flashmla_metadata.copy_(
427
+ self._compute_flashmla_metadata(
428
+ cache_seqlens=nsa_cache_seqlens_int32,
429
+ seq_len_q=1,
430
+ )
431
+ )
432
+ else:
433
+ flashmla_metadata = None
434
+ elif forward_mode.is_target_verify():
435
+ cache_seqlens_int32 = (seq_lens + self.speculative_num_draft_tokens).to(
436
+ torch.int32
437
+ )
438
+ cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)
439
+ max_seqlen_q = 1
440
+ page_table_1 = self.decode_cuda_graph_metadata["page_table"][
441
+ : bs * self.speculative_num_draft_tokens, :
442
+ ]
443
+ max_seqlen_k = page_table_1.shape[1]
444
+
445
+ cu_seqlens_q = torch.arange(
446
+ 0,
447
+ bs * self.speculative_num_draft_tokens + 1,
448
+ 1,
449
+ dtype=torch.int32,
450
+ device=self.device,
451
+ )
452
+
453
+ extend_seq_lens_cpu = [self.speculative_num_draft_tokens] * bs
454
+
455
+ seqlens_int32_cpu = [
456
+ self.speculative_num_draft_tokens + kv_len
457
+ for kv_len in seq_lens.tolist()
458
+ ]
459
+ seqlens_expanded = torch.cat(
460
+ [
461
+ torch.arange(
462
+ kv_len - qo_len + 1,
463
+ kv_len + 1,
464
+ dtype=torch.int32,
465
+ device=self.device,
466
+ )
467
+ for qo_len, kv_len in zip(
468
+ extend_seq_lens_cpu,
469
+ seqlens_int32_cpu,
470
+ strict=True,
471
+ )
472
+ ]
473
+ )
474
+ nsa_cache_seqlens_int32 = compute_nsa_seqlens(
475
+ seqlens_expanded, nsa_index_topk=self.nsa_index_topk
476
+ )
477
+ nsa_extend_seq_lens_list = [1] * bs * self.speculative_num_draft_tokens
478
+
479
+ if NSA_DECODE_IMPL == "flashmla_kv":
480
+ flashmla_metadata = self.decode_cuda_graph_metadata[
481
+ "flashmla_metadata"
482
+ ].slice(slice(0, bs * self.speculative_num_draft_tokens + 1))
483
+
484
+ flashmla_metadata.copy_(
485
+ self._compute_flashmla_metadata(
486
+ cache_seqlens=nsa_cache_seqlens_int32,
487
+ seq_len_q=1,
488
+ )
489
+ )
490
+ else:
491
+ flashmla_metadata = None
492
+ elif forward_mode.is_draft_extend():
493
+ cache_seqlens_int32 = (seq_lens + self.speculative_num_draft_tokens).to(
494
+ torch.int32
495
+ )
496
+ cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)
497
+ page_table_1 = self.decode_cuda_graph_metadata["page_table"][:bs, :]
498
+ max_seqlen_k = page_table_1.shape[1]
499
+
500
+ extend_seq_lens_cpu = [self.speculative_num_draft_tokens] * bs
501
+ extend_seq_lens = torch.full(
502
+ (bs,),
503
+ self.speculative_num_draft_tokens,
504
+ device=self.device,
505
+ dtype=torch.int32,
506
+ )
507
+
508
+ max_seqlen_q = max(extend_seq_lens_cpu)
509
+ cu_seqlens_q = compute_cu_seqlens(extend_seq_lens.to(torch.int32))
510
+
511
+ seqlens_int32_cpu = [
512
+ self.speculative_num_draft_tokens + kv_len
513
+ for kv_len in seq_lens.tolist()
514
+ ]
515
+ seqlens_expanded = torch.cat(
516
+ [
517
+ torch.arange(
518
+ kv_len - qo_len + 1,
519
+ kv_len + 1,
520
+ dtype=torch.int32,
521
+ device=self.device,
522
+ )
523
+ for qo_len, kv_len in zip(
524
+ extend_seq_lens_cpu,
525
+ seqlens_int32_cpu,
526
+ strict=True,
527
+ )
528
+ ]
529
+ )
530
+ nsa_cache_seqlens_int32 = compute_nsa_seqlens(
531
+ seqlens_expanded, nsa_index_topk=self.nsa_index_topk
532
+ )
533
+ nsa_extend_seq_lens_list = [1] * bs
534
+
535
+ if NSA_DECODE_IMPL == "flashmla_kv":
536
+ flashmla_metadata = self.decode_cuda_graph_metadata[
537
+ "flashmla_metadata"
538
+ ].slice(slice(0, bs * self.speculative_num_draft_tokens + 1))
539
+ # As the DeepGemm is not support for q_len = 3/4 in Indexer and every token has independent topk_indices,
540
+ # we made the Q shape [bs * speculative_num_draft_tokens, 1, head_nums, dim].
541
+ # So seq_len_q is 1 for flashmla_metadata in target_verify and draft_extend mode.
542
+ flashmla_metadata.copy_(
543
+ self._compute_flashmla_metadata(
544
+ cache_seqlens=nsa_cache_seqlens_int32,
545
+ seq_len_q=1,
546
+ )
547
+ )
548
+ else:
549
+ flashmla_metadata = None
550
+
551
+ nsa_cu_seqlens_k = compute_cu_seqlens(nsa_cache_seqlens_int32)
552
+ nsa_cu_seqlens_q = self.get_device_int32_arange(len(nsa_cu_seqlens_k))
553
+ real_page_table = self._transform_table_1_to_real(page_table_1)
554
+
555
+ metadata = NSAMetadata(
556
+ page_size=self.real_page_size,
557
+ cache_seqlens_int32=cache_seqlens_int32,
558
+ max_seq_len_q=max_seqlen_q,
559
+ max_seq_len_k=max_seqlen_k,
560
+ cu_seqlens_q=cu_seqlens_q,
561
+ cu_seqlens_k=cu_seqlens_k,
562
+ page_table_1=page_table_1,
563
+ flashmla_metadata=flashmla_metadata,
564
+ nsa_cache_seqlens_int32=nsa_cache_seqlens_int32,
565
+ nsa_cu_seqlens_q=nsa_cu_seqlens_q,
566
+ nsa_cu_seqlens_k=nsa_cu_seqlens_k,
567
+ nsa_seqlens_expanded=seqlens_expanded,
568
+ real_page_table=real_page_table,
569
+ nsa_extend_seq_lens_list=nsa_extend_seq_lens_list,
570
+ )
571
+ self.decode_cuda_graph_metadata[bs] = metadata
572
+ self.forward_metadata = metadata
573
+
574
+ def init_forward_metadata_replay_cuda_graph(
575
+ self,
576
+ bs: int,
577
+ req_pool_indices: torch.Tensor,
578
+ seq_lens: torch.Tensor,
579
+ seq_lens_sum: int,
580
+ encoder_lens: Optional[torch.Tensor],
581
+ forward_mode: ForwardMode,
582
+ spec_info: Optional[SpecInput],
583
+ seq_lens_cpu: Optional[torch.Tensor],
584
+ out_cache_loc: Optional[torch.Tensor] = None,
585
+ ):
586
+ """Initialize forward metadata for replaying CUDA graph."""
587
+ assert seq_lens_cpu is not None
588
+
589
+ seq_lens = seq_lens[:bs]
590
+ seq_lens_cpu = seq_lens_cpu[:bs]
591
+ req_pool_indices = req_pool_indices[:bs]
592
+
593
+ # Normal Decode
594
+ metadata: NSAMetadata = self.decode_cuda_graph_metadata[bs]
595
+ if forward_mode.is_decode_or_idle():
596
+ # Normal Decode
597
+ max_len = int(seq_lens_cpu.max().item())
598
+
599
+ cache_seqlens = seq_lens.to(torch.int32)
600
+ metadata.cache_seqlens_int32.copy_(cache_seqlens)
601
+ metadata.cu_seqlens_k[1:].copy_(
602
+ torch.cumsum(cache_seqlens, dim=0, dtype=torch.int32)
603
+ )
604
+ page_indices = self.req_to_token[req_pool_indices, :max_len]
605
+ metadata.page_table_1[:, :max_len].copy_(page_indices)
606
+ nsa_cache_seqlens = compute_nsa_seqlens(
607
+ cache_seqlens, nsa_index_topk=self.nsa_index_topk
608
+ )
609
+ metadata.nsa_cache_seqlens_int32.copy_(nsa_cache_seqlens)
610
+ seqlens_expanded = cache_seqlens
611
+ elif forward_mode.is_target_verify():
612
+ max_seqlen_k = int(
613
+ seq_lens_cpu.max().item() + self.speculative_num_draft_tokens
614
+ )
615
+
616
+ cache_seqlens = (seq_lens + self.speculative_num_draft_tokens).to(
617
+ torch.int32
618
+ )
619
+ metadata.cache_seqlens_int32.copy_(cache_seqlens)
620
+ metadata.cu_seqlens_k[1:].copy_(
621
+ torch.cumsum(cache_seqlens, dim=0, dtype=torch.int32)
622
+ )
623
+ page_indices = self.req_to_token[req_pool_indices, :max_seqlen_k]
624
+ page_indices = torch.repeat_interleave(
625
+ page_indices, repeats=self.speculative_num_draft_tokens, dim=0
626
+ )
627
+ metadata.page_table_1[:, :max_seqlen_k].copy_(page_indices)
628
+ extend_seq_lens_cpu = [self.speculative_num_draft_tokens] * bs
629
+
630
+ seqlens_int32_cpu = [
631
+ self.speculative_num_draft_tokens + kv_len
632
+ for kv_len in seq_lens_cpu.tolist()
633
+ ]
634
+ seqlens_expanded = torch.cat(
635
+ [
636
+ torch.arange(
637
+ kv_len - qo_len + 1,
638
+ kv_len + 1,
639
+ dtype=torch.int32,
640
+ device=self.device,
641
+ )
642
+ for qo_len, kv_len in zip(
643
+ extend_seq_lens_cpu,
644
+ seqlens_int32_cpu,
645
+ strict=True,
646
+ )
647
+ ]
648
+ )
649
+ metadata.nsa_seqlens_expanded.copy_(seqlens_expanded)
650
+ nsa_cache_seqlens = compute_nsa_seqlens(
651
+ seqlens_expanded, self.nsa_index_topk
652
+ )
653
+ metadata.nsa_cache_seqlens_int32.copy_(nsa_cache_seqlens)
654
+ elif forward_mode.is_draft_extend():
655
+ max_seqlen_k = int(seq_lens_cpu.max().item())
656
+ cache_seqlens = seq_lens.to(torch.int32)
657
+ metadata.cache_seqlens_int32.copy_(cache_seqlens)
658
+ metadata.cu_seqlens_k[1:].copy_(
659
+ torch.cumsum(cache_seqlens, dim=0, dtype=torch.int32)
660
+ )
661
+ page_indices = self.req_to_token[req_pool_indices, :max_seqlen_k]
662
+ metadata.page_table_1[:, :max_seqlen_k].copy_(page_indices)
663
+ extend_seq_lens_cpu = spec_info.accept_length[:bs].tolist()
664
+
665
+ seqlens_int32_cpu = [
666
+ self.speculative_num_draft_tokens + kv_len
667
+ for kv_len in seq_lens_cpu.tolist()
668
+ ]
669
+ seqlens_expanded = torch.cat(
670
+ [
671
+ torch.arange(
672
+ kv_len - qo_len + 1,
673
+ kv_len + 1,
674
+ dtype=torch.int32,
675
+ device=self.device,
676
+ )
677
+ for qo_len, kv_len in zip(
678
+ extend_seq_lens_cpu,
679
+ seqlens_int32_cpu,
680
+ strict=True,
681
+ )
682
+ ]
683
+ )
684
+ metadata.nsa_seqlens_expanded[: seqlens_expanded.size(0)].copy_(
685
+ seqlens_expanded
686
+ )
687
+ nsa_cache_seqlens = compute_nsa_seqlens(
688
+ seqlens_expanded, self.nsa_index_topk
689
+ )
690
+ metadata.nsa_cache_seqlens_int32[: seqlens_expanded.size(0)].copy_(
691
+ nsa_cache_seqlens
692
+ )
693
+ seqlens_expanded_size = seqlens_expanded.size(0)
694
+ assert (
695
+ metadata.nsa_cache_seqlens_int32 is not None
696
+ and metadata.nsa_cu_seqlens_k is not None
697
+ and self.nsa_index_topk is not None
698
+ )
699
+
700
+ metadata.nsa_cu_seqlens_k[1 : 1 + seqlens_expanded_size].copy_(
701
+ torch.cumsum(nsa_cache_seqlens, dim=0, dtype=torch.int32)
702
+ )
703
+ # NOTE(dark): (nsa-) cu_seqlens_q is always arange, no need to copy
704
+
705
+ assert self.real_page_size == metadata.page_size
706
+ if self.real_page_size > 1:
707
+ real_table = self._transform_table_1_to_real(page_indices)
708
+ new_len = real_table.shape[1]
709
+ metadata.real_page_table[:, :new_len].copy_(real_table)
710
+ else:
711
+ assert metadata.real_page_table is metadata.page_table_1
712
+
713
+ if NSA_DECODE_IMPL == "flashmla_kv":
714
+ flashmla_metadata = metadata.flashmla_metadata.slice(
715
+ slice(0, seqlens_expanded_size + 1)
716
+ )
717
+ flashmla_metadata.copy_(
718
+ self._compute_flashmla_metadata(
719
+ cache_seqlens=nsa_cache_seqlens,
720
+ seq_len_q=1,
721
+ )
722
+ )
723
+
724
+ self.forward_metadata = metadata
725
+
726
+ def forward_extend(
727
+ self,
728
+ q: torch.Tensor,
729
+ k: torch.Tensor,
730
+ v: torch.Tensor,
731
+ layer: RadixAttention,
732
+ forward_batch: ForwardBatch,
733
+ save_kv_cache=True,
734
+ # For multi-head latent attention
735
+ q_rope: Optional[torch.Tensor] = None,
736
+ k_rope: Optional[torch.Tensor] = None,
737
+ topk_indices: Optional[torch.Tensor] = None,
738
+ ) -> torch.Tensor:
739
+
740
+ if k is not None:
741
+ assert v is not None
742
+ if save_kv_cache:
743
+ cache_loc = (
744
+ forward_batch.out_cache_loc
745
+ if not layer.is_cross_attention
746
+ else forward_batch.encoder_out_cache_loc
747
+ )
748
+ forward_batch.token_to_kv_pool.set_mla_kv_buffer( # type: ignore
749
+ layer,
750
+ cache_loc,
751
+ k,
752
+ k_rope,
753
+ )
754
+
755
+ metadata = self.forward_metadata
756
+ causal = not layer.is_cross_attention
757
+ assert causal, "NSA is causal only"
758
+
759
+ # For fa3 interface version compatibility, we put new fields into conditional keyword args
760
+ kwargs = {}
761
+
762
+ # Do absorbed multi-latent attention
763
+ assert q_rope is not None
764
+ kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
765
+
766
+ # when store in fp8 and compute in fp8, no need to convert dtype
767
+ if not (
768
+ NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8 and self.nsa_kv_cache_store_fp8
769
+ ):
770
+ kv_cache = kv_cache.to(q.dtype)
771
+
772
+ if q_rope is not None:
773
+ q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
774
+ q_rope = q_rope.view(
775
+ -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
776
+ )
777
+ else:
778
+ q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
779
+ q_nope = q_all[:, :, : layer.v_head_dim]
780
+ q_rope = q_all[:, :, layer.v_head_dim :]
781
+
782
+ # NOTE(dark): here, we use page size = 1
783
+
784
+ if NSA_FUSE_TOPK:
785
+ page_table_1 = topk_indices
786
+ else:
787
+ assert metadata.nsa_extend_seq_lens_list is not None
788
+ page_table_1 = transform_index_page_table_prefill(
789
+ page_table=metadata.page_table_1,
790
+ topk_indices=topk_indices,
791
+ extend_lens_cpu=metadata.nsa_extend_seq_lens_list,
792
+ page_size=1,
793
+ )
794
+ if NSA_PREFILL_IMPL == "tilelang":
795
+ if q_rope is not None:
796
+ q_all = torch.cat([q_nope, q_rope], dim=-1)
797
+ return self._forward_tilelang(
798
+ q_all=q_all,
799
+ kv_cache=kv_cache,
800
+ page_table_1=page_table_1,
801
+ sm_scale=layer.scaling,
802
+ v_head_dim=layer.v_head_dim,
803
+ )
804
+ elif NSA_PREFILL_IMPL == "flashmla_sparse":
805
+ if q_rope is not None:
806
+ q_all = torch.cat([q_nope, q_rope], dim=-1)
807
+ return self._forward_flashmla_sparse(
808
+ q_all=q_all,
809
+ kv_cache=kv_cache,
810
+ page_table_1=page_table_1,
811
+ sm_scale=layer.scaling,
812
+ v_head_dim=layer.v_head_dim,
813
+ )
814
+ elif NSA_PREFILL_IMPL == "flashmla_kv":
815
+ if q_rope is not None:
816
+ q_all = torch.cat([q_nope, q_rope], dim=-1)
817
+ return self._forward_flashmla_kv(
818
+ q_all=q_all,
819
+ kv_cache=kv_cache,
820
+ sm_scale=layer.scaling,
821
+ v_head_dim=layer.v_head_dim,
822
+ # TODO optimize args
823
+ layer=layer,
824
+ metadata=metadata,
825
+ page_table_1=page_table_1,
826
+ )
827
+ elif NSA_PREFILL_IMPL == "fa3":
828
+ return self._forward_fa3(
829
+ q_rope=q_rope,
830
+ kv_cache=kv_cache,
831
+ v_head_dim=layer.v_head_dim,
832
+ q_nope=q_nope,
833
+ page_table=page_table_1,
834
+ cache_seqlens=metadata.nsa_cache_seqlens_int32,
835
+ cu_seqlens_q=metadata.nsa_cu_seqlens_q,
836
+ cu_seqlens_k=metadata.nsa_cu_seqlens_k,
837
+ max_seqlen_q=metadata.nsa_max_seqlen_q,
838
+ sm_scale=layer.scaling,
839
+ logit_cap=layer.logit_cap,
840
+ page_size=1,
841
+ )
842
+ else:
843
+ raise ValueError(f"Unsupported {NSA_PREFILL_IMPL = }")
844
+
845
+ def forward_decode(
846
+ self,
847
+ q: torch.Tensor,
848
+ k: torch.Tensor,
849
+ v: torch.Tensor,
850
+ layer: RadixAttention,
851
+ forward_batch: ForwardBatch,
852
+ save_kv_cache=True,
853
+ # For multi-head latent attention
854
+ q_rope: Optional[torch.Tensor] = None,
855
+ k_rope: Optional[torch.Tensor] = None,
856
+ topk_indices: Optional[torch.Tensor] = None,
857
+ ) -> torch.Tensor:
858
+ if k is not None:
859
+ assert v is not None
860
+ if save_kv_cache:
861
+ cache_loc = (
862
+ forward_batch.out_cache_loc
863
+ if not layer.is_cross_attention
864
+ else forward_batch.encoder_out_cache_loc
865
+ )
866
+ forward_batch.token_to_kv_pool.set_mla_kv_buffer( # type: ignore
867
+ layer,
868
+ cache_loc,
869
+ k,
870
+ k_rope,
871
+ )
872
+
873
+ metadata = self.forward_metadata
874
+ causal = not layer.is_cross_attention
875
+ assert causal, "NSA is causal only"
876
+
877
+ # Do absorbed multi-latent attention
878
+ kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
879
+ if q_rope is not None:
880
+ q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
881
+ q_rope = q_rope.view(
882
+ -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
883
+ )
884
+ else:
885
+ q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
886
+ q_nope = q_all[:, :, : layer.v_head_dim]
887
+ q_rope = q_all[:, :, layer.v_head_dim :]
888
+
889
+ if NSA_FUSE_TOPK:
890
+ page_table_1 = topk_indices
891
+ else:
892
+ page_table_1 = transform_index_page_table_decode(
893
+ page_table=metadata.page_table_1,
894
+ topk_indices=topk_indices,
895
+ page_size=1,
896
+ )
897
+
898
+ if NSA_DECODE_IMPL == "flashmla_sparse":
899
+ if q_rope is not None:
900
+ q_all = torch.cat([q_nope, q_rope], dim=-1)
901
+ return self._forward_flashmla_sparse(
902
+ q_all=q_all,
903
+ kv_cache=kv_cache,
904
+ page_table_1=page_table_1,
905
+ sm_scale=layer.scaling,
906
+ v_head_dim=layer.v_head_dim,
907
+ )
908
+ elif NSA_DECODE_IMPL == "flashmla_kv":
909
+ if q_rope is not None:
910
+ q_all = torch.cat([q_nope, q_rope], dim=-1)
911
+ return self._forward_flashmla_kv(
912
+ q_all=q_all,
913
+ kv_cache=kv_cache,
914
+ sm_scale=layer.scaling,
915
+ v_head_dim=layer.v_head_dim,
916
+ # TODO optimize args
917
+ layer=layer,
918
+ metadata=metadata,
919
+ page_table_1=page_table_1,
920
+ )
921
+ elif NSA_DECODE_IMPL == "tilelang":
922
+ if q_rope is not None:
923
+ q_all = torch.cat([q_nope, q_rope], dim=-1)
924
+ return self._forward_tilelang(
925
+ q_all=q_all,
926
+ kv_cache=kv_cache,
927
+ page_table_1=page_table_1,
928
+ sm_scale=layer.scaling,
929
+ v_head_dim=layer.v_head_dim,
930
+ )
931
+ elif NSA_DECODE_IMPL == "fa3":
932
+ return self._forward_fa3(
933
+ q_rope=q_rope,
934
+ kv_cache=kv_cache,
935
+ v_head_dim=layer.v_head_dim,
936
+ q_nope=q_nope,
937
+ page_table=page_table_1,
938
+ cache_seqlens=metadata.nsa_cache_seqlens_int32,
939
+ cu_seqlens_q=metadata.nsa_cu_seqlens_q,
940
+ cu_seqlens_k=metadata.nsa_cu_seqlens_k,
941
+ max_seqlen_q=metadata.nsa_max_seqlen_q,
942
+ sm_scale=layer.scaling,
943
+ logit_cap=layer.logit_cap,
944
+ page_size=1,
945
+ )
946
+ elif NSA_DECODE_IMPL == "aiter":
947
+ if q_rope is not None:
948
+ q_all = torch.cat([q_nope, q_rope], dim=-1)
949
+ return self._forward_aiter(
950
+ q_all=q_all,
951
+ kv_cache=kv_cache,
952
+ page_table_1=page_table_1,
953
+ layer=layer,
954
+ metadata=metadata,
955
+ bs=forward_batch.batch_size,
956
+ )
957
+
958
+ else:
959
+ assert False, f"Unsupported {NSA_DECODE_IMPL = }"
960
+
961
+ def _forward_fa3(
962
+ self,
963
+ q_rope: torch.Tensor,
964
+ kv_cache: torch.Tensor,
965
+ v_head_dim: int,
966
+ q_nope: torch.Tensor,
967
+ page_table: torch.Tensor,
968
+ cache_seqlens: torch.Tensor,
969
+ cu_seqlens_q: torch.Tensor,
970
+ cu_seqlens_k: torch.Tensor,
971
+ max_seqlen_q: int,
972
+ sm_scale: float,
973
+ logit_cap: float,
974
+ page_size: int,
975
+ ) -> torch.Tensor:
976
+ k_rope_cache = kv_cache[:, :, v_head_dim:]
977
+ c_kv_cache = kv_cache[:, :, :v_head_dim]
978
+ qk_rope_dim = k_rope_cache.shape[-1]
979
+ k_rope_cache = k_rope_cache.view(-1, page_size, 1, qk_rope_dim)
980
+ c_kv_cache = c_kv_cache.view(-1, page_size, 1, v_head_dim)
981
+ o = flash_attn_with_kvcache(
982
+ q=q_rope,
983
+ k_cache=k_rope_cache,
984
+ v_cache=c_kv_cache,
985
+ qv=q_nope,
986
+ page_table=page_table,
987
+ cache_seqlens=cache_seqlens,
988
+ cu_seqlens_q=cu_seqlens_q,
989
+ cu_seqlens_k_new=cu_seqlens_k,
990
+ max_seqlen_q=max_seqlen_q,
991
+ softmax_scale=sm_scale,
992
+ causal=True,
993
+ softcap=logit_cap,
994
+ return_softmax_lse=False,
995
+ num_splits=self.num_splits,
996
+ )
997
+ return o # type: ignore
998
+
999
+ def _forward_flashmla_sparse(
1000
+ self,
1001
+ q_all: torch.Tensor,
1002
+ kv_cache: torch.Tensor,
1003
+ v_head_dim: int,
1004
+ page_table_1: torch.Tensor,
1005
+ sm_scale: float,
1006
+ ) -> torch.Tensor:
1007
+ from flash_mla import flash_mla_sparse_fwd
1008
+
1009
+ o, _, _ = flash_mla_sparse_fwd(
1010
+ q=q_all,
1011
+ kv=kv_cache,
1012
+ indices=page_table_1.unsqueeze(1),
1013
+ sm_scale=sm_scale,
1014
+ d_v=v_head_dim,
1015
+ )
1016
+ return o
1017
+
1018
+ def _forward_flashmla_kv(
1019
+ self,
1020
+ q_all: torch.Tensor,
1021
+ kv_cache: torch.Tensor,
1022
+ v_head_dim: int,
1023
+ sm_scale: float,
1024
+ layer,
1025
+ metadata: NSAMetadata,
1026
+ page_table_1,
1027
+ ) -> torch.Tensor:
1028
+ from flash_mla import flash_mla_with_kvcache
1029
+
1030
+ cache_seqlens = metadata.nsa_cache_seqlens_int32
1031
+
1032
+ # TODO the 2nd dim is seq_len_q, need to be >1 when MTP
1033
+ q_all = q_all.view(-1, 1, layer.tp_q_head_num, layer.head_dim)
1034
+ kv_cache = kv_cache.view(-1, self.real_page_size, 1, self.kv_cache_dim)
1035
+ assert self.real_page_size == 64, "only page size 64 is supported"
1036
+
1037
+ if NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8 and not self.nsa_kv_cache_store_fp8:
1038
+ # inefficiently quantize the whole cache
1039
+ kv_cache = quantize_k_cache(kv_cache)
1040
+
1041
+ indices = page_table_1.unsqueeze(1)
1042
+ assert (
1043
+ indices.shape[-1] == self.nsa_index_topk
1044
+ ) # requirement of FlashMLA decode kernel
1045
+
1046
+ o, _ = flash_mla_with_kvcache(
1047
+ q=q_all,
1048
+ k_cache=kv_cache,
1049
+ cache_seqlens=cache_seqlens,
1050
+ head_dim_v=v_head_dim,
1051
+ tile_scheduler_metadata=metadata.flashmla_metadata.flashmla_metadata,
1052
+ num_splits=metadata.flashmla_metadata.num_splits,
1053
+ softmax_scale=sm_scale,
1054
+ indices=indices,
1055
+ # doc says it is not used, but if pass in None then error
1056
+ block_table=torch.empty(
1057
+ (q_all.shape[0], 0), dtype=torch.int32, device=q_all.device
1058
+ ),
1059
+ is_fp8_kvcache=NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8,
1060
+ )
1061
+ return o
1062
+
1063
+ def _forward_tilelang(
1064
+ self,
1065
+ q_all: torch.Tensor,
1066
+ kv_cache: torch.Tensor,
1067
+ v_head_dim: int,
1068
+ page_table_1: torch.Tensor,
1069
+ sm_scale: float,
1070
+ ) -> torch.Tensor:
1071
+ from sglang.srt.layers.attention.nsa.tilelang_kernel import tilelang_sparse_fwd
1072
+
1073
+ return tilelang_sparse_fwd(
1074
+ q=q_all,
1075
+ kv=kv_cache,
1076
+ indices=page_table_1.unsqueeze(1),
1077
+ sm_scale=sm_scale,
1078
+ d_v=v_head_dim,
1079
+ )
1080
+
1081
+ def _forward_aiter(
1082
+ self,
1083
+ q_all: torch.Tensor,
1084
+ kv_cache: torch.Tensor,
1085
+ page_table_1: torch.Tensor,
1086
+ layer: RadixAttention,
1087
+ metadata: NSAMetadata,
1088
+ bs: int,
1089
+ ) -> torch.Tensor:
1090
+ q = q_all.reshape(-1, layer.tp_q_head_num * layer.head_dim)
1091
+
1092
+ if layer.head_dim != layer.v_head_dim:
1093
+ o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
1094
+ else:
1095
+ o = torch.empty_like(q)
1096
+
1097
+ kv_indptr = self.kv_indptr
1098
+
1099
+ non_minus1_mask = page_table_1 != -1
1100
+ non_minus1_counts = non_minus1_mask.sum(dim=1)
1101
+ kv_indptr[1 : bs + 1] = torch.cumsum(non_minus1_counts, dim=0)
1102
+
1103
+ kv_indices = page_table_1[page_table_1 != -1]
1104
+
1105
+ mla_decode_fwd(
1106
+ q.view(-1, layer.tp_q_head_num, layer.head_dim),
1107
+ kv_cache.view(-1, 1, 1, layer.head_dim),
1108
+ o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
1109
+ metadata.cu_seqlens_q,
1110
+ kv_indptr,
1111
+ kv_indices,
1112
+ metadata.cu_seqlens_q,
1113
+ metadata.max_seq_len_q,
1114
+ layer.scaling,
1115
+ layer.logit_cap,
1116
+ )
1117
+ # kv_cache = kv_cache.view(-1, 1, layer.head_dim)
1118
+ return o
1119
+
1120
+ def get_cuda_graph_seq_len_fill_value(self):
1121
+ """Get the fill value for sequence length in CUDA graph."""
1122
+ return 1
1123
+
1124
+ def get_indexer_metadata(
1125
+ self, layer_id: int, forward_batch: ForwardBatch
1126
+ ) -> NSAIndexerMetadata:
1127
+ return NSAIndexerMetadata(attn_metadata=self.forward_metadata)
1128
+
1129
+ def _compute_flashmla_metadata(self, cache_seqlens: torch.Tensor, seq_len_q: int):
1130
+ from flash_mla import get_mla_metadata
1131
+
1132
+ flashmla_metadata, num_splits = get_mla_metadata(
1133
+ cache_seqlens=cache_seqlens,
1134
+ # TODO doc says `num_q_tokens_per_q_seq * num_heads_q // num_heads_k`
1135
+ # but the name looks like need seq_len_q?
1136
+ num_q_tokens_per_head_k=seq_len_q * self.num_q_heads // 1,
1137
+ num_heads_k=1,
1138
+ num_heads_q=self.num_q_heads,
1139
+ is_fp8_kvcache=NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8,
1140
+ topk=self.nsa_index_topk,
1141
+ )
1142
+
1143
+ return NSAFlashMLAMetadata(
1144
+ flashmla_metadata=flashmla_metadata,
1145
+ num_splits=num_splits,
1146
+ )
1147
+
1148
+
1149
+ class NativeSparseAttnMultiStepBackend:
1150
+
1151
+ def __init__(
1152
+ self, model_runner: ModelRunner, topk: int, speculative_num_steps: int
1153
+ ):
1154
+ self.model_runner = model_runner
1155
+ self.topk = topk
1156
+ self.speculative_num_steps = speculative_num_steps
1157
+ self.attn_backends = []
1158
+ for i in range(self.speculative_num_steps):
1159
+ self.attn_backends.append(
1160
+ NativeSparseAttnBackend(
1161
+ model_runner,
1162
+ speculative_step_id=i,
1163
+ topk=self.topk,
1164
+ speculative_num_steps=self.speculative_num_steps,
1165
+ )
1166
+ )
1167
+
1168
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
1169
+ for i in range(self.speculative_num_steps - 1):
1170
+ self.attn_backends[i].init_forward_metadata(forward_batch)
1171
+
1172
+ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
1173
+ for i in range(self.speculative_num_steps):
1174
+ self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens)
1175
+
1176
+ def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
1177
+ for i in range(self.speculative_num_steps):
1178
+ self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
1179
+ forward_batch.batch_size,
1180
+ forward_batch.batch_size * self.topk,
1181
+ forward_batch.req_pool_indices,
1182
+ forward_batch.seq_lens,
1183
+ encoder_lens=None,
1184
+ forward_mode=ForwardMode.DECODE,
1185
+ spec_info=forward_batch.spec_info,
1186
+ )
1187
+
1188
+ def init_forward_metadata_replay_cuda_graph(
1189
+ self, forward_batch: ForwardBatch, bs: int
1190
+ ):
1191
+ for i in range(self.speculative_num_steps):
1192
+ self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
1193
+ bs,
1194
+ forward_batch.req_pool_indices,
1195
+ forward_batch.seq_lens,
1196
+ seq_lens_sum=-1,
1197
+ encoder_lens=None,
1198
+ forward_mode=ForwardMode.DECODE,
1199
+ spec_info=forward_batch.spec_info,
1200
+ seq_lens_cpu=forward_batch.seq_lens_cpu,
1201
+ )