sglang 0.5.2rc1__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/lang/interpreter.py +1 -1
  7. sglang/launch_server.py +14 -0
  8. sglang/profiler.py +2 -2
  9. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  10. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  11. sglang/srt/configs/__init__.py +8 -0
  12. sglang/srt/configs/device_config.py +3 -1
  13. sglang/srt/configs/dots_ocr.py +64 -0
  14. sglang/srt/configs/dots_vlm.py +139 -0
  15. sglang/srt/configs/falcon_h1.py +360 -0
  16. sglang/srt/configs/internvl.py +6 -0
  17. sglang/srt/configs/load_config.py +9 -0
  18. sglang/srt/configs/model_config.py +181 -82
  19. sglang/srt/configs/qwen3_next.py +326 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +71 -19
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +326 -53
  44. sglang/srt/disaggregation/prefill.py +36 -17
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +192 -113
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  52. sglang/srt/entrypoints/grpc_server.py +810 -0
  53. sglang/srt/entrypoints/http_server.py +132 -57
  54. sglang/srt/entrypoints/openai/protocol.py +115 -7
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +207 -58
  57. sglang/srt/entrypoints/openai/serving_completions.py +17 -4
  58. sglang/srt/entrypoints/openai/serving_embedding.py +10 -4
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +49 -4
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/environ.py +285 -0
  63. sglang/srt/eplb/eplb_manager.py +2 -2
  64. sglang/srt/eplb/expert_distribution.py +26 -13
  65. sglang/srt/eplb/expert_location.py +38 -8
  66. sglang/srt/eplb/expert_location_updater.py +1 -1
  67. sglang/srt/function_call/base_format_detector.py +3 -6
  68. sglang/srt/function_call/ebnf_composer.py +11 -9
  69. sglang/srt/function_call/function_call_parser.py +9 -2
  70. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  71. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  72. sglang/srt/function_call/json_array_parser.py +63 -0
  73. sglang/srt/function_call/kimik2_detector.py +17 -4
  74. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  75. sglang/srt/function_call/utils.py +96 -5
  76. sglang/srt/grpc/__init__.py +1 -0
  77. sglang/srt/grpc/compile_proto.py +245 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  81. sglang/srt/layers/activation.py +143 -9
  82. sglang/srt/layers/attention/aiter_backend.py +106 -82
  83. sglang/srt/layers/attention/ascend_backend.py +115 -9
  84. sglang/srt/layers/attention/attention_registry.py +206 -0
  85. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  86. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  87. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  88. sglang/srt/layers/attention/fla/chunk.py +242 -0
  89. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  90. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  91. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  92. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  93. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  94. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  95. sglang/srt/layers/attention/fla/index.py +37 -0
  96. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  97. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  98. sglang/srt/layers/attention/fla/op.py +66 -0
  99. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  100. sglang/srt/layers/attention/fla/utils.py +331 -0
  101. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  102. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  103. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  104. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  105. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  106. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  107. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  108. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  109. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  111. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  112. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  113. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  114. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  115. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  121. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  122. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  123. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  124. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  125. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  126. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  127. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  128. sglang/srt/layers/attention/nsa/utils.py +24 -0
  129. sglang/srt/layers/attention/nsa_backend.py +887 -0
  130. sglang/srt/layers/attention/tbo_backend.py +6 -6
  131. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  132. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  133. sglang/srt/layers/attention/triton_backend.py +57 -7
  134. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  135. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  136. sglang/srt/layers/attention/vision.py +58 -0
  137. sglang/srt/layers/attention/wave_backend.py +4 -4
  138. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  139. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  140. sglang/srt/layers/communicator.py +53 -7
  141. sglang/srt/layers/dp_attention.py +41 -2
  142. sglang/srt/layers/elementwise.py +3 -1
  143. sglang/srt/layers/layernorm.py +34 -15
  144. sglang/srt/layers/linear.py +55 -7
  145. sglang/srt/layers/logits_processor.py +44 -12
  146. sglang/srt/layers/moe/__init__.py +2 -1
  147. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  148. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  149. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  150. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  151. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  167. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  169. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  170. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  171. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  172. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  173. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  174. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  175. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  176. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  177. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  178. sglang/srt/layers/moe/topk.py +30 -9
  179. sglang/srt/layers/moe/utils.py +22 -7
  180. sglang/srt/layers/parameter.py +23 -6
  181. sglang/srt/layers/quantization/awq.py +19 -7
  182. sglang/srt/layers/quantization/base_config.py +11 -6
  183. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  184. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  185. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  186. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  187. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  188. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  189. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  190. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  191. sglang/srt/layers/quantization/fp8.py +78 -49
  192. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  193. sglang/srt/layers/quantization/gptq.py +25 -17
  194. sglang/srt/layers/quantization/modelopt_quant.py +225 -57
  195. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  196. sglang/srt/layers/quantization/mxfp4.py +77 -42
  197. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  199. sglang/srt/layers/quantization/quark/utils.py +97 -0
  200. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  201. sglang/srt/layers/quantization/unquant.py +135 -47
  202. sglang/srt/layers/quantization/w4afp8.py +26 -17
  203. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  204. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  205. sglang/srt/layers/rocm_linear_utils.py +44 -0
  206. sglang/srt/layers/rotary_embedding.py +78 -49
  207. sglang/srt/layers/sampler.py +213 -21
  208. sglang/srt/layers/utils.py +23 -0
  209. sglang/srt/lora/backend/base_backend.py +50 -8
  210. sglang/srt/lora/backend/chunked_backend.py +348 -0
  211. sglang/srt/lora/backend/triton_backend.py +99 -5
  212. sglang/srt/lora/layers.py +32 -0
  213. sglang/srt/lora/lora.py +8 -3
  214. sglang/srt/lora/lora_manager.py +52 -118
  215. sglang/srt/lora/mem_pool.py +25 -11
  216. sglang/srt/lora/triton_ops/__init__.py +4 -0
  217. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  219. sglang/srt/lora/utils.py +22 -11
  220. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  221. sglang/srt/managers/cache_controller.py +215 -314
  222. sglang/srt/managers/data_parallel_controller.py +115 -80
  223. sglang/srt/managers/detokenizer_manager.py +19 -15
  224. sglang/srt/managers/disagg_service.py +46 -0
  225. sglang/srt/managers/io_struct.py +340 -109
  226. sglang/srt/managers/mm_utils.py +44 -6
  227. sglang/srt/managers/multi_tokenizer_mixin.py +358 -404
  228. sglang/srt/managers/multimodal_processor.py +1 -2
  229. sglang/srt/managers/overlap_utils.py +53 -0
  230. sglang/srt/managers/schedule_batch.py +240 -138
  231. sglang/srt/managers/schedule_policy.py +147 -19
  232. sglang/srt/managers/scheduler.py +501 -304
  233. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  234. sglang/srt/managers/scheduler_metrics_mixin.py +119 -40
  235. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  236. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  237. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  238. sglang/srt/managers/template_manager.py +3 -3
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +321 -632
  241. sglang/srt/managers/tp_worker.py +81 -22
  242. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  243. sglang/srt/managers/utils.py +1 -45
  244. sglang/srt/mem_cache/allocator.py +15 -21
  245. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  246. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  247. sglang/srt/mem_cache/chunk_cache.py +8 -1
  248. sglang/srt/mem_cache/evict_policy.py +23 -0
  249. sglang/srt/mem_cache/hicache_storage.py +58 -34
  250. sglang/srt/mem_cache/hiradix_cache.py +227 -80
  251. sglang/srt/mem_cache/memory_pool.py +535 -58
  252. sglang/srt/mem_cache/memory_pool_host.py +239 -223
  253. sglang/srt/mem_cache/radix_cache.py +222 -73
  254. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  255. sglang/srt/mem_cache/storage/__init__.py +10 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  257. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  258. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  259. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  260. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  261. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  262. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  263. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +268 -63
  264. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  265. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  266. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +198 -30
  267. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  268. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  269. sglang/srt/metrics/collector.py +519 -132
  270. sglang/srt/metrics/func_timer.py +2 -7
  271. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  272. sglang/srt/metrics/utils.py +55 -0
  273. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  274. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  275. sglang/srt/model_executor/forward_batch_info.py +98 -57
  276. sglang/srt/model_executor/model_runner.py +433 -158
  277. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  278. sglang/srt/model_loader/__init__.py +9 -3
  279. sglang/srt/model_loader/loader.py +133 -5
  280. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  281. sglang/srt/model_loader/weight_utils.py +158 -3
  282. sglang/srt/models/apertus.py +686 -0
  283. sglang/srt/models/bailing_moe.py +820 -217
  284. sglang/srt/models/bailing_moe_nextn.py +168 -0
  285. sglang/srt/models/deepseek_nextn.py +6 -1
  286. sglang/srt/models/deepseek_v2.py +833 -152
  287. sglang/srt/models/dots_ocr.py +173 -0
  288. sglang/srt/models/dots_vlm.py +174 -0
  289. sglang/srt/models/dots_vlm_vit.py +337 -0
  290. sglang/srt/models/ernie4.py +1 -1
  291. sglang/srt/models/falcon_h1.py +576 -0
  292. sglang/srt/models/gemma3_causal.py +0 -2
  293. sglang/srt/models/gemma3_mm.py +1 -1
  294. sglang/srt/models/gemma3n_mm.py +2 -2
  295. sglang/srt/models/glm4_moe.py +14 -5
  296. sglang/srt/models/glm4_moe_nextn.py +2 -2
  297. sglang/srt/models/glm4v.py +5 -3
  298. sglang/srt/models/glm4v_moe.py +4 -1
  299. sglang/srt/models/gpt_oss.py +8 -31
  300. sglang/srt/models/internvl.py +28 -0
  301. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  302. sglang/srt/models/llama.py +4 -0
  303. sglang/srt/models/llama4.py +9 -0
  304. sglang/srt/models/llama_eagle3.py +13 -0
  305. sglang/srt/models/longcat_flash.py +3 -3
  306. sglang/srt/models/longcat_flash_nextn.py +1 -1
  307. sglang/srt/models/minicpmv.py +165 -3
  308. sglang/srt/models/mllama4.py +40 -4
  309. sglang/srt/models/opt.py +637 -0
  310. sglang/srt/models/qwen2_5_vl.py +29 -5
  311. sglang/srt/models/qwen2_audio.py +1 -1
  312. sglang/srt/models/qwen2_moe.py +124 -14
  313. sglang/srt/models/qwen2_vl.py +1 -1
  314. sglang/srt/models/qwen3.py +26 -5
  315. sglang/srt/models/qwen3_moe.py +71 -12
  316. sglang/srt/models/qwen3_next.py +1069 -0
  317. sglang/srt/models/qwen3_next_mtp.py +112 -0
  318. sglang/srt/models/qwen3_vl.py +787 -0
  319. sglang/srt/models/qwen3_vl_moe.py +471 -0
  320. sglang/srt/models/registry.py +15 -3
  321. sglang/srt/models/sarashina2_vision.py +269 -0
  322. sglang/srt/models/solar.py +505 -0
  323. sglang/srt/models/starcoder2.py +357 -0
  324. sglang/srt/models/step3_vl.py +1 -1
  325. sglang/srt/models/torch_native_llama.py +10 -3
  326. sglang/srt/models/utils.py +51 -0
  327. sglang/srt/multimodal/processors/base_processor.py +15 -7
  328. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  329. sglang/srt/multimodal/processors/glm4v.py +9 -9
  330. sglang/srt/multimodal/processors/internvl.py +153 -129
  331. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  332. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  333. sglang/srt/offloader.py +27 -3
  334. sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +6 -0
  335. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  336. sglang/srt/sampling/sampling_batch_info.py +38 -17
  337. sglang/srt/sampling/sampling_params.py +7 -0
  338. sglang/srt/server_args.py +1030 -254
  339. sglang/srt/server_args_config_parser.py +146 -0
  340. sglang/srt/single_batch_overlap.py +151 -0
  341. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  342. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  343. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  344. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  345. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  346. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  347. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  348. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  349. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  350. sglang/srt/speculative/eagle_worker.py +253 -136
  351. sglang/srt/speculative/ngram_utils.py +428 -0
  352. sglang/srt/speculative/ngram_worker.py +245 -0
  353. sglang/srt/speculative/spec_info.py +52 -0
  354. sglang/srt/speculative/spec_utils.py +606 -0
  355. sglang/srt/speculative/standalone_worker.py +109 -0
  356. sglang/srt/torch_memory_saver_adapter.py +5 -7
  357. sglang/srt/tracing/trace.py +578 -0
  358. sglang/srt/two_batch_overlap.py +8 -5
  359. sglang/srt/utils/__init__.py +2 -0
  360. sglang/srt/{utils.py → utils/common.py} +445 -77
  361. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  362. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  363. sglang/srt/utils/rpd_utils.py +452 -0
  364. sglang/srt/utils/slow_rank_detector.py +71 -0
  365. sglang/srt/warmup.py +8 -4
  366. sglang/srt/weight_sync/utils.py +2 -2
  367. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  368. sglang/test/few_shot_gsm8k.py +1 -0
  369. sglang/test/get_logits_ut.py +57 -0
  370. sglang/test/run_eval.py +79 -11
  371. sglang/test/runners.py +5 -1
  372. sglang/test/simple_eval_common.py +5 -2
  373. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  374. sglang/test/test_block_fp8.py +2 -2
  375. sglang/test/test_cutlass_moe.py +24 -6
  376. sglang/test/test_deterministic.py +297 -0
  377. sglang/test/test_disaggregation_utils.py +77 -0
  378. sglang/test/test_fp4_moe.py +370 -1
  379. sglang/test/test_programs.py +1 -1
  380. sglang/test/test_utils.py +383 -5
  381. sglang/utils.py +22 -1
  382. sglang/version.py +1 -1
  383. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
  384. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/RECORD +392 -258
  385. sglang/srt/disaggregation/launch_lb.py +0 -118
  386. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  387. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  388. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  389. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  390. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  391. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  392. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  393. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
  394. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
  395. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -4,27 +4,25 @@ from __future__ import annotations
4
4
  end to end attention solution with aiter kernels
5
5
  """
6
6
 
7
- import math
8
- import os
9
7
  from dataclasses import dataclass
10
8
  from enum import Enum, auto
11
- from functools import partial
12
- from typing import TYPE_CHECKING, List, Optional, Union
9
+ from typing import TYPE_CHECKING, Optional
13
10
 
14
11
  import torch
15
12
  import triton
16
- import triton.language as tl
17
13
 
18
- from sglang.global_config import global_config
19
14
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
20
15
  from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
21
- from sglang.srt.layers.dp_attention import get_attention_tp_size
16
+ from sglang.srt.layers.dp_attention import (
17
+ get_attention_tp_size,
18
+ is_dp_attention_enabled,
19
+ )
22
20
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
23
21
 
24
22
  if TYPE_CHECKING:
25
23
  from sglang.srt.layers.radix_attention import RadixAttention
26
24
  from sglang.srt.model_executor.model_runner import ModelRunner
27
- from sglang.srt.speculative.spec_info import SpecInfo
25
+ from sglang.srt.speculative.spec_info import SpecInput
28
26
 
29
27
  try:
30
28
  from aiter import (
@@ -154,6 +152,8 @@ class AiterAttnBackend(AttentionBackend):
154
152
  (max_bs + 1,), dtype=torch.int32, device=model_runner.device
155
153
  )
156
154
 
155
+ self.enable_dp_attention = is_dp_attention_enabled()
156
+
157
157
  def init_forward_metadata(self, forward_batch: ForwardBatch):
158
158
  """Init auxiliary variables for triton attention backend."""
159
159
 
@@ -302,19 +302,19 @@ class AiterAttnBackend(AttentionBackend):
302
302
  if self.use_mla:
303
303
  self.mla_indices_updater_prefill.update(
304
304
  forward_batch.req_pool_indices,
305
- forward_batch.extend_prefix_lens,
306
- sum(forward_batch.extend_prefix_lens_cpu),
305
+ forward_batch.seq_lens,
306
+ forward_batch.seq_lens_sum,
307
307
  forward_batch.extend_seq_lens,
308
- max(forward_batch.extend_seq_lens_cpu),
309
- forward_batch.seq_lens_cpu.max().item(),
308
+ forward_batch.extend_seq_lens.max().item(),
309
+ forward_batch.seq_lens.max().item(),
310
310
  spec_info=None,
311
311
  )
312
- self.mla_indices_updater_prefill.kv_indptr += (
313
- self.mla_indices_updater_prefill.qo_indptr
314
- )
312
+
313
+ kv_indices = self.mla_indices_updater_prefill.kv_indices
314
+
315
315
  self.forward_metadata = ForwardMetadata(
316
316
  self.mla_indices_updater_prefill.kv_indptr,
317
- self.mla_indices_updater_prefill.kv_indices,
317
+ kv_indices,
318
318
  self.mla_indices_updater_prefill.qo_indptr,
319
319
  self.kv_last_page_len[:bs],
320
320
  self.mla_indices_updater_prefill.max_q_len,
@@ -369,7 +369,7 @@ class AiterAttnBackend(AttentionBackend):
369
369
  seq_lens: torch.Tensor,
370
370
  encoder_lens: Optional[torch.Tensor],
371
371
  forward_mode: ForwardMode,
372
- spec_info: Optional[SpecInfo],
372
+ spec_info: Optional[SpecInput],
373
373
  ):
374
374
  if forward_mode.is_decode_or_idle():
375
375
  qo_indptr = None
@@ -504,7 +504,7 @@ class AiterAttnBackend(AttentionBackend):
504
504
  seq_lens_sum: int,
505
505
  encoder_lens: Optional[torch.Tensor],
506
506
  forward_mode: ForwardMode,
507
- spec_info: Optional[SpecInfo],
507
+ spec_info: Optional[SpecInput],
508
508
  seq_lens_cpu: Optional[torch.Tensor],
509
509
  ):
510
510
  if forward_mode.is_decode_or_idle():
@@ -614,66 +614,90 @@ class AiterAttnBackend(AttentionBackend):
614
614
  assert len(k.shape) == 3
615
615
  assert len(v.shape) == 3
616
616
 
617
- if kv_indices.shape[0] == 0:
618
- o = flash_attn_varlen_func(
619
- q,
620
- k,
621
- v,
622
- qo_indptr,
623
- qo_indptr,
624
- max_q_len,
625
- max_q_len,
626
- softmax_scale=layer.scaling,
627
- causal=True,
628
- )
629
- return o
630
- elif layer.qk_head_dim != (kv_lora_rank + qk_rope_head_dim):
631
- K_Buffer = torch.index_select(K_Buffer, 0, kv_indices)
632
- kvc, k_pe = torch.split(
633
- K_Buffer, [kv_lora_rank, qk_rope_head_dim], dim=-1
634
- )
635
- kvprefix = layer.kv_b_proj(kvc.contiguous())[0]
617
+ if (
618
+ forward_batch.forward_mode.is_extend()
619
+ and not forward_batch.forward_mode.is_target_verify()
620
+ and not forward_batch.forward_mode.is_draft_extend()
621
+ ):
622
+ if kv_indices.shape[0] == 0:
623
+ o = flash_attn_varlen_func(
624
+ q,
625
+ k,
626
+ v,
627
+ qo_indptr,
628
+ qo_indptr,
629
+ max_q_len,
630
+ max_q_len,
631
+ softmax_scale=layer.scaling,
632
+ causal=True,
633
+ )
634
+ return o
635
+ elif layer.qk_head_dim != (kv_lora_rank + qk_rope_head_dim):
636
+ K_Buffer = torch.index_select(K_Buffer, 0, kv_indices)
637
+ kvc, k_pe = torch.split(
638
+ K_Buffer, [kv_lora_rank, qk_rope_head_dim], dim=-1
639
+ )
640
+ kvprefix = layer.kv_b_proj(kvc.contiguous())[0]
636
641
 
637
- kvprefix = kvprefix.view(
638
- -1, layer.tp_k_head_num, qk_nope_head_dim + layer.v_head_dim
639
- )
640
- k_prefix, v_prefix = torch.split(
641
- kvprefix, [qk_nope_head_dim, layer.v_head_dim], dim=-1
642
- )
643
- k_prefix = torch.cat(
644
- [
645
- k_prefix,
646
- torch.broadcast_to(
647
- k_pe,
648
- (k_pe.shape[0], layer.tp_k_head_num, k_pe.shape[2]),
649
- ),
650
- ],
651
- dim=-1,
652
- )
653
- assert (
654
- forward_batch.extend_prefix_lens.shape
655
- == forward_batch.extend_seq_lens.shape
656
- )
657
- k_prefix = torch.split(k_prefix, forward_batch.extend_prefix_lens_cpu)
658
- k_extend = torch.split(k, forward_batch.extend_seq_lens_cpu)
659
- assert len(k_prefix) == len(forward_batch.extend_prefix_lens_cpu)
660
- k = torch.cat([x for el in zip(k_prefix, k_extend) for x in el])
661
- v_prefix = torch.split(v_prefix, forward_batch.extend_prefix_lens_cpu)
662
- v_extend = torch.split(v, forward_batch.extend_seq_lens_cpu)
663
- v = torch.cat([x for el in zip(v_prefix, v_extend) for x in el])
664
-
665
- o = flash_attn_varlen_func(
666
- q,
667
- k,
668
- v,
669
- qo_indptr,
670
- kv_indptr,
671
- max_q_len,
672
- max_kv_len,
673
- softmax_scale=layer.scaling,
674
- causal=True,
675
- )
676
- return o
642
+ kvprefix = kvprefix.view(
643
+ -1, layer.tp_k_head_num, qk_nope_head_dim + layer.v_head_dim
644
+ )
645
+ k_prefix, v_prefix = torch.split(
646
+ kvprefix, [qk_nope_head_dim, layer.v_head_dim], dim=-1
647
+ )
648
+ k_prefix = torch.cat(
649
+ [
650
+ k_prefix,
651
+ torch.broadcast_to(
652
+ k_pe,
653
+ (k_pe.shape[0], layer.tp_k_head_num, k_pe.shape[2]),
654
+ ),
655
+ ],
656
+ dim=-1,
657
+ )
658
+ assert (
659
+ forward_batch.extend_prefix_lens.shape
660
+ == forward_batch.extend_seq_lens.shape
661
+ )
662
+
663
+ k = k_prefix
664
+ v = v_prefix
665
+
666
+ o = flash_attn_varlen_func(
667
+ q,
668
+ k,
669
+ v,
670
+ qo_indptr,
671
+ kv_indptr,
672
+ max_q_len,
673
+ max_kv_len,
674
+ softmax_scale=layer.scaling,
675
+ causal=True,
676
+ )
677
+ return o
678
+
679
+ else:
680
+ if layer.qk_head_dim != layer.v_head_dim:
681
+ o = q.new_empty(
682
+ (q.shape[0], layer.tp_q_head_num * layer.v_head_dim)
683
+ )
684
+ else:
685
+ o = torch.empty_like(q)
686
+
687
+ mla_prefill_fwd(
688
+ q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
689
+ K_Buffer.view(-1, 1, 1, layer.qk_head_dim),
690
+ o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
691
+ qo_indptr,
692
+ kv_indptr,
693
+ kv_indices,
694
+ self.forward_metadata.kv_last_page_len,
695
+ self.forward_metadata.max_q_len,
696
+ layer.scaling,
697
+ layer.logit_cap,
698
+ )
699
+ K_Buffer = K_Buffer.view(-1, layer.tp_k_head_num, layer.qk_head_dim)
700
+ return o
677
701
  elif forward_batch.forward_mode.is_target_verify():
678
702
  o = q.new_empty((q.shape[0], layer.tp_q_head_num, layer.v_head_dim))
679
703
  mla_decode_fwd(
@@ -859,7 +883,7 @@ class AiterIndicesUpdaterPrefill:
859
883
  seq_lens_sum: int,
860
884
  prefix_lens: torch.Tensor,
861
885
  encoder_lens: Optional[torch.Tensor],
862
- spec_info: Optional[SpecInfo],
886
+ spec_info: Optional[SpecInput],
863
887
  ):
864
888
  # Keep the signature for type checking. It will be assigned during runtime.
865
889
  raise NotImplementedError()
@@ -871,7 +895,7 @@ class AiterIndicesUpdaterPrefill:
871
895
  seq_lens_sum: int,
872
896
  prefix_lens: torch.Tensor,
873
897
  encoder_lens: Optional[torch.Tensor],
874
- spec_info: Optional[SpecInfo],
898
+ spec_info: Optional[SpecInput],
875
899
  ):
876
900
 
877
901
  kv_start_idx = None
@@ -955,7 +979,7 @@ class AiterMlaIndicesUpdaterPrefill:
955
979
  extend_lens: torch.Tensor,
956
980
  max_q_len: int,
957
981
  max_kv_len: int,
958
- spec_info: Optional[SpecInfo],
982
+ spec_info: Optional[SpecInput],
959
983
  ):
960
984
  # Keep the signature for type checking. It will be assigned during runtime.
961
985
  raise NotImplementedError()
@@ -968,7 +992,7 @@ class AiterMlaIndicesUpdaterPrefill:
968
992
  extend_lens: torch.Tensor,
969
993
  max_q_len: int,
970
994
  max_kv_len: int,
971
- spec_info: Optional[SpecInfo],
995
+ spec_info: Optional[SpecInput],
972
996
  ):
973
997
  bs = len(req_pool_indices)
974
998
 
@@ -1025,7 +1049,7 @@ class AiterMultiStepDraftBackend:
1025
1049
  topk: int,
1026
1050
  speculative_num_steps: int,
1027
1051
  ):
1028
- from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
1052
+ from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
1029
1053
 
1030
1054
  self.topk = topk
1031
1055
  self.speculative_num_steps = speculative_num_steps
@@ -5,13 +5,15 @@ from typing import TYPE_CHECKING, List, Optional
5
5
 
6
6
  import torch
7
7
  import torch_npu
8
- from torch.nn.functional import scaled_dot_product_attention
9
8
 
10
9
  from sglang.srt.configs.model_config import AttentionArch
11
10
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
11
+ from sglang.srt.layers.attention.npu_ops.mla_preprocess import is_mla_preprocess_enabled
12
12
  from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
13
+ from sglang.srt.layers.dp_attention import get_attention_tp_size
13
14
  from sglang.srt.layers.radix_attention import AttentionType
14
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
15
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
16
+ from sglang.srt.speculative.spec_info import SpecInput
15
17
  from sglang.srt.utils import get_bool_env_var
16
18
 
17
19
  if TYPE_CHECKING:
@@ -33,6 +35,9 @@ class ForwardMetadata:
33
35
  extend_seq_lens_cpu_int: Optional[torch.Tensor] = None
34
36
  seq_lens_cpu_int: Optional[torch.Tensor] = None
35
37
  seq_lens_cpu_list: Optional[List[int]] = None
38
+ seq_lens_list_cumsum: Optional[List[int]] = None
39
+ seq_lens: Optional[torch.Tensor] = None
40
+ actual_seq_lengths_q: Optional[torch.Tensor] = None
36
41
 
37
42
 
38
43
  class AscendAttnBackend(AttentionBackend):
@@ -64,6 +69,9 @@ class AscendAttnBackend(AttentionBackend):
64
69
  if self.use_mla:
65
70
  self.kv_lora_rank = model_runner.model_config.kv_lora_rank
66
71
  self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
72
+ self.q_head_dim = (
73
+ self.qk_rope_head_dim + model_runner.model_config.qk_nope_head_dim
74
+ )
67
75
  self.native_attn = TorchNativeAttnBackend(model_runner)
68
76
  self.graph_metadata = {}
69
77
  self.max_context_len = model_runner.model_config.context_len
@@ -83,6 +91,7 @@ class AscendAttnBackend(AttentionBackend):
83
91
 
84
92
  def init_forward_metadata(self, forward_batch: ForwardBatch):
85
93
  """Init the metadata for a forward pass."""
94
+ tp_size = get_attention_tp_size()
86
95
  self.forward_metadata = ForwardMetadata()
87
96
 
88
97
  self.forward_metadata.block_tables = (
@@ -96,9 +105,9 @@ class AscendAttnBackend(AttentionBackend):
96
105
  forward_batch.extend_seq_lens.cpu().int()
97
106
  )
98
107
  self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
99
- self.forward_metadata.seq_lens_list_cumsum = np.cumsum(
100
- forward_batch.extend_seq_lens_cpu
101
- )
108
+
109
+ seq_lens_list_cumsum = np.cumsum(forward_batch.extend_seq_lens_cpu)
110
+ self.forward_metadata.seq_lens_list_cumsum = seq_lens_list_cumsum
102
111
 
103
112
  self.graph_mode = False
104
113
 
@@ -119,12 +128,16 @@ class AscendAttnBackend(AttentionBackend):
119
128
  seq_lens: torch.Tensor,
120
129
  encoder_lens: Optional[torch.Tensor],
121
130
  forward_mode: ForwardMode,
122
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
131
+ spec_info: Optional[SpecInput],
123
132
  ):
124
133
  metadata = ForwardMetadata()
125
134
 
126
135
  metadata.block_tables = self.graph_metadata["block_tables"][:bs, :]
127
136
  metadata.seq_lens_cpu_list = seq_lens.cpu().int().tolist()
137
+ metadata.seq_lens = seq_lens
138
+ metadata.actual_seq_lengths_q = torch.tensor(
139
+ [1 + i * 1 for i in range(bs)], dtype=torch.int32, device=seq_lens.device
140
+ )
128
141
 
129
142
  self.graph_metadata[bs] = metadata
130
143
  self.forward_metadata = metadata
@@ -139,7 +152,7 @@ class AscendAttnBackend(AttentionBackend):
139
152
  seq_lens_sum: int,
140
153
  encoder_lens: Optional[torch.Tensor],
141
154
  forward_mode: ForwardMode,
142
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
155
+ spec_info: Optional[SpecInput],
143
156
  seq_lens_cpu: Optional[torch.Tensor],
144
157
  ):
145
158
  metadata = self.graph_metadata[bs]
@@ -153,6 +166,8 @@ class AscendAttnBackend(AttentionBackend):
153
166
  metadata.block_tables[:bs, max_seq_pages:].fill_(0)
154
167
  metadata.block_tables[bs:, :].fill_(0)
155
168
 
169
+ metadata.seq_lens[:bs].copy_(seq_lens[:bs])
170
+
156
171
  self.forward_metadata = metadata
157
172
 
158
173
  self.graph_mode = True
@@ -160,6 +175,64 @@ class AscendAttnBackend(AttentionBackend):
160
175
  def get_cuda_graph_seq_len_fill_value(self):
161
176
  return 0
162
177
 
178
+ def forward_sparse(
179
+ self,
180
+ q: torch.Tensor,
181
+ k: torch.Tensor,
182
+ v: torch.Tensor,
183
+ layer: RadixAttention,
184
+ forward_batch: ForwardBatch,
185
+ save_kv_cache: bool = True,
186
+ # For multi_head latent attention
187
+ q_rope: Optional[torch.Tensor] = None,
188
+ k_rope: Optional[torch.Tensor] = None,
189
+ topk_indices: torch.Tensor = None,
190
+ ):
191
+
192
+ is_prefill = forward_batch.forward_mode.is_extend()
193
+
194
+ if save_kv_cache:
195
+ k = k.view(-1, layer.tp_k_head_num, self.kv_lora_rank)
196
+ k_rope = k_rope.view(-1, layer.tp_k_head_num, self.qk_rope_head_dim)
197
+ forward_batch.token_to_kv_pool.set_kv_buffer(
198
+ layer, forward_batch.out_cache_loc, k, k_rope
199
+ )
200
+ q_nope, q_pe = q, q_rope
201
+ k_nope, k_pe = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
202
+ block_table = self.forward_metadata.block_tables
203
+ if is_prefill:
204
+ actual_seq_qlen = torch.cumsum(forward_batch.seq_lens, dim=0)
205
+ else:
206
+ if self.forward_metadata.actual_seq_lengths_q is None:
207
+ actual_seq_qlen = (
208
+ torch.arange(1, q.shape[0] + 1).to(q.device).to(torch.int32)
209
+ )
210
+ else:
211
+ actual_seq_qlen = self.forward_metadata.actual_seq_lengths_q
212
+ if self.forward_metadata.seq_lens_cpu_int is None:
213
+ actual_seq_lengths_kv = self.forward_metadata.seq_lens
214
+ else:
215
+ actual_seq_lengths_kv = self.forward_metadata.seq_lens_cpu_int
216
+
217
+ attn_out = torch.ops.custom.npu_sparse_flash_attention(
218
+ query=q_nope,
219
+ key=k_nope,
220
+ value=k_nope,
221
+ query_rope=q_pe,
222
+ key_rope=k_pe,
223
+ sparse_indices=topk_indices,
224
+ scale_value=layer.scaling,
225
+ actual_seq_lengths_query=actual_seq_qlen.to(torch.int32),
226
+ actual_seq_lengths_kv=actual_seq_lengths_kv.to(q.device),
227
+ block_table=block_table,
228
+ sparse_block_size=1,
229
+ layout_query="TND",
230
+ layout_kv="PA_BSND",
231
+ sparse_mode=3,
232
+ )
233
+
234
+ return attn_out
235
+
163
236
  def forward_extend(
164
237
  self,
165
238
  q,
@@ -168,7 +241,23 @@ class AscendAttnBackend(AttentionBackend):
168
241
  layer: RadixAttention,
169
242
  forward_batch: ForwardBatch,
170
243
  save_kv_cache: bool = True,
244
+ # For multi_head latent attention
245
+ q_rope: Optional[torch.Tensor] = None,
246
+ k_rope: Optional[torch.Tensor] = None,
247
+ topk_indices: Optional[torch.Tensor] = None,
171
248
  ):
249
+ if topk_indices is not None:
250
+ return self.forward_sparse(
251
+ q,
252
+ k,
253
+ v,
254
+ layer,
255
+ forward_batch,
256
+ save_kv_cache,
257
+ q_rope,
258
+ k_rope,
259
+ topk_indices,
260
+ )
172
261
  if not self.use_mla:
173
262
  if save_kv_cache:
174
263
  forward_batch.token_to_kv_pool.set_kv_buffer(
@@ -368,7 +457,7 @@ class AscendAttnBackend(AttentionBackend):
368
457
  -1, layer.tp_v_head_num, self.page_size, self.kv_lora_rank
369
458
  )
370
459
 
371
- q_nope = q.view(-1, layer.tp_q_head_num, 1, self.kv_lora_rank)
460
+ q_nope = q.view(-1, layer.tp_q_head_num, 1, self.kv_lora_rank).contiguous()
372
461
  q_rope = q_rope.view(-1, layer.tp_q_head_num, 1, self.qk_rope_head_dim)
373
462
  if self.forward_metadata.seq_lens_cpu_int is None:
374
463
  actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_list
@@ -394,7 +483,7 @@ class AscendAttnBackend(AttentionBackend):
394
483
  antiquant_scale=None,
395
484
  sparse_mode=0,
396
485
  )
397
- output = torch.zeros_like(q_nope, dtype=q.dtype, device=q.device)
486
+ output = torch.empty_like(q_nope, dtype=q.dtype, device=q.device)
398
487
  softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)
399
488
 
400
489
  torch_npu.npu_fused_infer_attention_score.out(
@@ -429,7 +518,24 @@ class AscendAttnBackend(AttentionBackend):
429
518
  # For multi-head latent attention
430
519
  q_rope: Optional[torch.Tensor] = None,
431
520
  k_rope: Optional[torch.Tensor] = None,
521
+ topk_indices: Optional[torch.Tensor] = None,
432
522
  ):
523
+ if is_mla_preprocess_enabled():
524
+ # MLAPO does saving kv_cache
525
+ save_kv_cache = False
526
+ if topk_indices is not None:
527
+ return self.forward_sparse(
528
+ q,
529
+ k,
530
+ v,
531
+ layer,
532
+ forward_batch,
533
+ save_kv_cache,
534
+ q_rope,
535
+ k_rope,
536
+ topk_indices,
537
+ )
538
+
433
539
  if self.graph_mode:
434
540
  return self.forward_decode_graph(
435
541
  q,