sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -35,6 +35,7 @@ from sglang.srt.utils import (
35
35
  is_cuda,
36
36
  is_hip,
37
37
  is_npu,
38
+ is_xpu,
38
39
  set_weight_attrs,
39
40
  )
40
41
  from sglang.utils import resolve_obj_by_qualname
@@ -44,8 +45,9 @@ _is_npu = is_npu()
44
45
  _is_cpu_amx_available = cpu_has_amx_support()
45
46
  _is_cpu = is_cpu()
46
47
  _is_hip = is_hip()
48
+ _is_xpu = is_xpu()
47
49
 
48
- if _is_cuda:
50
+ if _is_cuda or _is_xpu:
49
51
  from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
50
52
  elif _is_hip:
51
53
  from sgl_kernel import gelu_and_mul, gelu_quick, gelu_tanh_and_mul, silu_and_mul
@@ -70,8 +72,6 @@ class SiluAndMul(CustomOp):
70
72
 
71
73
  def forward_cpu(self, x: torch.Tensor) -> torch.Tensor:
72
74
  if _is_cpu_amx_available:
73
- d = x.shape[-1] // 2
74
- output_shape = x.shape[:-1] + (d,)
75
75
  out = torch.ops.sgl_kernel.silu_and_mul_cpu(x)
76
76
  return out
77
77
  else:
@@ -81,17 +81,20 @@ class SiluAndMul(CustomOp):
81
81
  out = torch_npu.npu_swiglu(x)
82
82
  return out
83
83
 
84
+ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
85
+ d = x.shape[-1] // 2
86
+ output_shape = x.shape[:-1] + (d,)
87
+ out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
88
+ silu_and_mul(x, out)
89
+ return out
90
+
84
91
 
85
92
  class GeluAndMul(CustomOp):
86
93
  def __init__(self, approximate="tanh"):
87
94
  super().__init__()
88
95
  self.approximate = approximate
89
96
 
90
- def forward_native(self, x: torch.Tensor) -> torch.Tensor:
91
- d = x.shape[-1] // 2
92
- return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
93
-
94
- def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
97
+ def _forward_impl(self, x: torch.Tensor) -> torch.Tensor:
95
98
  d = x.shape[-1] // 2
96
99
  output_shape = x.shape[:-1] + (d,)
97
100
  out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
@@ -103,6 +106,24 @@ class GeluAndMul(CustomOp):
103
106
  raise RuntimeError("GeluAndMul only support tanh or none")
104
107
  return out
105
108
 
109
+ def forward_native(self, x: torch.Tensor) -> torch.Tensor:
110
+ d = x.shape[-1] // 2
111
+ return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
112
+
113
+ def forward_cpu(self, x: torch.Tensor) -> torch.Tensor:
114
+ if _is_cpu_amx_available and self.approximate == "tanh":
115
+ return torch.ops.sgl_kernel.gelu_tanh_and_mul_cpu(x)
116
+ elif _is_cpu_amx_available and self.approximate == "none":
117
+ return torch.ops.sgl_kernel.gelu_and_mul_cpu(x)
118
+ else:
119
+ return self.forward_native(x)
120
+
121
+ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
122
+ return self._forward_impl(x)
123
+
124
+ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
125
+ return self._forward_impl(x)
126
+
106
127
  def forward_npu(self, x: torch.Tensor) -> torch.Tensor:
107
128
  y_npu, gelu_npu = torch_npu.npu_geglu(
108
129
  x,
@@ -150,6 +171,116 @@ class QuickGELU(CustomOp):
150
171
  return torch_npu.npu_fast_gelu(x)
151
172
 
152
173
 
174
+ class XIELU(CustomOp):
175
+ """
176
+ Applies the xIELU activation function introduced in https://arxiv.org/abs/2411.13010
177
+ If the user has installed the nickjbrowning/XIELU, we import xIELU CUDA
178
+ Otherwise, we emit a single warning and use xIELU Python
179
+ """
180
+
181
+ def __init__(
182
+ self,
183
+ alpha_p_init: float = 0.8,
184
+ alpha_n_init: float = 0.8,
185
+ beta: float = 0.5,
186
+ eps: float = -1e-6,
187
+ dtype: torch.dtype = torch.bfloat16,
188
+ with_vector_loads: bool = False,
189
+ ):
190
+ super().__init__()
191
+ self.alpha_p = nn.Parameter(
192
+ torch.log(torch.exp(torch.tensor(alpha_p_init, dtype=dtype)) - 1).unsqueeze(
193
+ 0
194
+ )
195
+ )
196
+ self.alpha_n = nn.Parameter(
197
+ torch.log(
198
+ torch.exp(torch.tensor(alpha_n_init - beta, dtype=dtype)) - 1
199
+ ).unsqueeze(0)
200
+ )
201
+ self.register_buffer("beta", torch.tensor(beta, dtype=dtype))
202
+ self.register_buffer("eps", torch.tensor(eps, dtype=dtype))
203
+ self.with_vector_loads = with_vector_loads
204
+ # Temporary until xIELU CUDA fully implemented
205
+ self._beta_scalar = float(self.beta.detach().cpu().float().item())
206
+ self._eps_scalar = float(self.eps.detach().cpu().float().item())
207
+
208
+ self._xielu_cuda_obj = None
209
+ try:
210
+ import xielu.ops # noqa: F401
211
+
212
+ self._xielu_cuda_obj = torch.classes.xielu.XIELU()
213
+ msg = "Using experimental xIELU CUDA."
214
+ try:
215
+ from torch._dynamo import allow_in_graph
216
+
217
+ self._xielu_cuda_fn = allow_in_graph(self._xielu_cuda)
218
+ msg += " Enabled torch._dynamo for xIELU CUDA."
219
+ except Exception as err:
220
+ msg += (
221
+ f" Could not enable torch._dynamo for xIELU ({err}) - "
222
+ "this may result in slower performance."
223
+ )
224
+ self._xielu_cuda_fn = self._xielu_cuda
225
+ logger.warning_once(msg)
226
+ except Exception as err:
227
+ pass
228
+ # logger.warning_once(
229
+ # "CUDA-fused xIELU not available (%s) –"
230
+ # " falling back to a Python version.\n"
231
+ # "For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`",
232
+ # str(err),
233
+ # )
234
+
235
+ def _xielu_python(self, x: torch.Tensor) -> torch.Tensor:
236
+ alpha_p = nn.functional.softplus(self.alpha_p)
237
+ alpha_n = self.beta + nn.functional.softplus(self.alpha_n)
238
+ return torch.where(
239
+ x > 0,
240
+ alpha_p * x * x + self.beta * x,
241
+ (torch.expm1(torch.min(x, self.eps)) - x) * alpha_n + self.beta * x,
242
+ )
243
+
244
+ def _xielu_cuda(self, x: torch.Tensor) -> torch.Tensor:
245
+ """Firewall function to prevent torch.compile from seeing .item()"""
246
+ assert self._xielu_cuda_obj is not None, "XIELU CUDA object must not be None"
247
+ original_shape = x.shape
248
+ # CUDA kernel expects 3D tensors, reshape if needed
249
+ while x.dim() < 3:
250
+ x = x.unsqueeze(0)
251
+ if x.dim() > 3:
252
+ x = x.view(-1, 1, x.size(-1))
253
+ if original_shape != x.shape:
254
+ logger.warning_once(
255
+ "Warning: xIELU input tensor expects 3 dimensions"
256
+ " but got (shape: %s). Reshaping to (shape: %s).\n"
257
+ "Note: For SGLang this may be expected if sending"
258
+ "[B*S,D] instead of [B,S,D].",
259
+ original_shape,
260
+ x.shape,
261
+ )
262
+ result = self._xielu_cuda_obj.forward(
263
+ x,
264
+ self.alpha_p,
265
+ self.alpha_n,
266
+ # Temporary until xIELU CUDA fully implemented -> self.{beta,eps}.item()
267
+ self._beta_scalar,
268
+ self._eps_scalar,
269
+ self.with_vector_loads,
270
+ )
271
+ return result.view(original_shape)
272
+
273
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
274
+ if self._xielu_cuda_obj is not None and input.is_cuda:
275
+ if not torch._dynamo.is_compiling():
276
+ return self._xielu_cuda_fn(input)
277
+ else:
278
+ logger.warning_once(
279
+ "torch._dynamo is compiling, using Python version of xIELU."
280
+ )
281
+ return self._xielu_python(input)
282
+
283
+
153
284
  class ScaledActivation(nn.Module):
154
285
  """An activation function with post-scale parameters.
155
286
 
@@ -197,6 +328,7 @@ _ACTIVATION_REGISTRY = {
197
328
  "gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
198
329
  "gelu_new": NewGELU(),
199
330
  "relu2": ReLU2(),
331
+ "xielu": XIELU(),
200
332
  }
201
333
 
202
334
 
@@ -242,7 +374,9 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
242
374
  return nn.Identity()
243
375
 
244
376
 
245
- if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip):
377
+ if not (
378
+ _is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip or _is_xpu
379
+ ):
246
380
  logger.info(
247
381
  "sgl-kernel is not available on Non-NV, Non-AMD platforms or Non-AMX CPUs. Fallback to other kernel libraries."
248
382
  )
@@ -4,18 +4,13 @@ from __future__ import annotations
4
4
  end to end attention solution with aiter kernels
5
5
  """
6
6
 
7
- import math
8
- import os
9
7
  from dataclasses import dataclass
10
8
  from enum import Enum, auto
11
- from functools import partial
12
- from typing import TYPE_CHECKING, List, Optional, Union
9
+ from typing import TYPE_CHECKING, Optional
13
10
 
14
11
  import torch
15
12
  import triton
16
- import triton.language as tl
17
13
 
18
- from sglang.global_config import global_config
19
14
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
20
15
  from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
21
16
  from sglang.srt.layers.dp_attention import (
@@ -27,7 +22,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo
27
22
  if TYPE_CHECKING:
28
23
  from sglang.srt.layers.radix_attention import RadixAttention
29
24
  from sglang.srt.model_executor.model_runner import ModelRunner
30
- from sglang.srt.speculative.spec_info import SpecInfo
25
+ from sglang.srt.speculative.spec_info import SpecInput
31
26
 
32
27
  try:
33
28
  from aiter import (
@@ -374,7 +369,7 @@ class AiterAttnBackend(AttentionBackend):
374
369
  seq_lens: torch.Tensor,
375
370
  encoder_lens: Optional[torch.Tensor],
376
371
  forward_mode: ForwardMode,
377
- spec_info: Optional[SpecInfo],
372
+ spec_info: Optional[SpecInput],
378
373
  ):
379
374
  if forward_mode.is_decode_or_idle():
380
375
  qo_indptr = None
@@ -509,7 +504,7 @@ class AiterAttnBackend(AttentionBackend):
509
504
  seq_lens_sum: int,
510
505
  encoder_lens: Optional[torch.Tensor],
511
506
  forward_mode: ForwardMode,
512
- spec_info: Optional[SpecInfo],
507
+ spec_info: Optional[SpecInput],
513
508
  seq_lens_cpu: Optional[torch.Tensor],
514
509
  ):
515
510
  if forward_mode.is_decode_or_idle():
@@ -619,7 +614,11 @@ class AiterAttnBackend(AttentionBackend):
619
614
  assert len(k.shape) == 3
620
615
  assert len(v.shape) == 3
621
616
 
622
- if forward_batch.forward_mode.is_extend():
617
+ if (
618
+ forward_batch.forward_mode.is_extend()
619
+ and not forward_batch.forward_mode.is_target_verify()
620
+ and not forward_batch.forward_mode.is_draft_extend()
621
+ ):
623
622
  if kv_indices.shape[0] == 0:
624
623
  o = flash_attn_varlen_func(
625
624
  q,
@@ -884,7 +883,7 @@ class AiterIndicesUpdaterPrefill:
884
883
  seq_lens_sum: int,
885
884
  prefix_lens: torch.Tensor,
886
885
  encoder_lens: Optional[torch.Tensor],
887
- spec_info: Optional[SpecInfo],
886
+ spec_info: Optional[SpecInput],
888
887
  ):
889
888
  # Keep the signature for type checking. It will be assigned during runtime.
890
889
  raise NotImplementedError()
@@ -896,7 +895,7 @@ class AiterIndicesUpdaterPrefill:
896
895
  seq_lens_sum: int,
897
896
  prefix_lens: torch.Tensor,
898
897
  encoder_lens: Optional[torch.Tensor],
899
- spec_info: Optional[SpecInfo],
898
+ spec_info: Optional[SpecInput],
900
899
  ):
901
900
 
902
901
  kv_start_idx = None
@@ -980,7 +979,7 @@ class AiterMlaIndicesUpdaterPrefill:
980
979
  extend_lens: torch.Tensor,
981
980
  max_q_len: int,
982
981
  max_kv_len: int,
983
- spec_info: Optional[SpecInfo],
982
+ spec_info: Optional[SpecInput],
984
983
  ):
985
984
  # Keep the signature for type checking. It will be assigned during runtime.
986
985
  raise NotImplementedError()
@@ -993,7 +992,7 @@ class AiterMlaIndicesUpdaterPrefill:
993
992
  extend_lens: torch.Tensor,
994
993
  max_q_len: int,
995
994
  max_kv_len: int,
996
- spec_info: Optional[SpecInfo],
995
+ spec_info: Optional[SpecInput],
997
996
  ):
998
997
  bs = len(req_pool_indices)
999
998
 
@@ -1050,7 +1049,7 @@ class AiterMultiStepDraftBackend:
1050
1049
  topk: int,
1051
1050
  speculative_num_steps: int,
1052
1051
  ):
1053
- from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
1052
+ from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
1054
1053
 
1055
1054
  self.topk = topk
1056
1055
  self.speculative_num_steps = speculative_num_steps
@@ -5,13 +5,15 @@ from typing import TYPE_CHECKING, List, Optional
5
5
 
6
6
  import torch
7
7
  import torch_npu
8
- from torch.nn.functional import scaled_dot_product_attention
9
8
 
10
9
  from sglang.srt.configs.model_config import AttentionArch
11
10
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
11
+ from sglang.srt.layers.attention.npu_ops.mla_preprocess import is_mla_preprocess_enabled
12
12
  from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
13
+ from sglang.srt.layers.dp_attention import get_attention_tp_size
13
14
  from sglang.srt.layers.radix_attention import AttentionType
14
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
15
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
16
+ from sglang.srt.speculative.spec_info import SpecInput
15
17
  from sglang.srt.utils import get_bool_env_var
16
18
 
17
19
  if TYPE_CHECKING:
@@ -33,6 +35,9 @@ class ForwardMetadata:
33
35
  extend_seq_lens_cpu_int: Optional[torch.Tensor] = None
34
36
  seq_lens_cpu_int: Optional[torch.Tensor] = None
35
37
  seq_lens_cpu_list: Optional[List[int]] = None
38
+ seq_lens_list_cumsum: Optional[List[int]] = None
39
+ seq_lens: Optional[torch.Tensor] = None
40
+ actual_seq_lengths_q: Optional[torch.Tensor] = None
36
41
 
37
42
 
38
43
  class AscendAttnBackend(AttentionBackend):
@@ -64,6 +69,9 @@ class AscendAttnBackend(AttentionBackend):
64
69
  if self.use_mla:
65
70
  self.kv_lora_rank = model_runner.model_config.kv_lora_rank
66
71
  self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
72
+ self.q_head_dim = (
73
+ self.qk_rope_head_dim + model_runner.model_config.qk_nope_head_dim
74
+ )
67
75
  self.native_attn = TorchNativeAttnBackend(model_runner)
68
76
  self.graph_metadata = {}
69
77
  self.max_context_len = model_runner.model_config.context_len
@@ -83,6 +91,7 @@ class AscendAttnBackend(AttentionBackend):
83
91
 
84
92
  def init_forward_metadata(self, forward_batch: ForwardBatch):
85
93
  """Init the metadata for a forward pass."""
94
+ tp_size = get_attention_tp_size()
86
95
  self.forward_metadata = ForwardMetadata()
87
96
 
88
97
  self.forward_metadata.block_tables = (
@@ -96,9 +105,9 @@ class AscendAttnBackend(AttentionBackend):
96
105
  forward_batch.extend_seq_lens.cpu().int()
97
106
  )
98
107
  self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
99
- self.forward_metadata.seq_lens_list_cumsum = np.cumsum(
100
- forward_batch.extend_seq_lens_cpu
101
- )
108
+
109
+ seq_lens_list_cumsum = np.cumsum(forward_batch.extend_seq_lens_cpu)
110
+ self.forward_metadata.seq_lens_list_cumsum = seq_lens_list_cumsum
102
111
 
103
112
  self.graph_mode = False
104
113
 
@@ -119,12 +128,16 @@ class AscendAttnBackend(AttentionBackend):
119
128
  seq_lens: torch.Tensor,
120
129
  encoder_lens: Optional[torch.Tensor],
121
130
  forward_mode: ForwardMode,
122
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
131
+ spec_info: Optional[SpecInput],
123
132
  ):
124
133
  metadata = ForwardMetadata()
125
134
 
126
135
  metadata.block_tables = self.graph_metadata["block_tables"][:bs, :]
127
136
  metadata.seq_lens_cpu_list = seq_lens.cpu().int().tolist()
137
+ metadata.seq_lens = seq_lens
138
+ metadata.actual_seq_lengths_q = torch.tensor(
139
+ [1 + i * 1 for i in range(bs)], dtype=torch.int32, device=seq_lens.device
140
+ )
128
141
 
129
142
  self.graph_metadata[bs] = metadata
130
143
  self.forward_metadata = metadata
@@ -139,7 +152,7 @@ class AscendAttnBackend(AttentionBackend):
139
152
  seq_lens_sum: int,
140
153
  encoder_lens: Optional[torch.Tensor],
141
154
  forward_mode: ForwardMode,
142
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
155
+ spec_info: Optional[SpecInput],
143
156
  seq_lens_cpu: Optional[torch.Tensor],
144
157
  ):
145
158
  metadata = self.graph_metadata[bs]
@@ -153,6 +166,8 @@ class AscendAttnBackend(AttentionBackend):
153
166
  metadata.block_tables[:bs, max_seq_pages:].fill_(0)
154
167
  metadata.block_tables[bs:, :].fill_(0)
155
168
 
169
+ metadata.seq_lens[:bs].copy_(seq_lens[:bs])
170
+
156
171
  self.forward_metadata = metadata
157
172
 
158
173
  self.graph_mode = True
@@ -160,6 +175,64 @@ class AscendAttnBackend(AttentionBackend):
160
175
  def get_cuda_graph_seq_len_fill_value(self):
161
176
  return 0
162
177
 
178
+ def forward_sparse(
179
+ self,
180
+ q: torch.Tensor,
181
+ k: torch.Tensor,
182
+ v: torch.Tensor,
183
+ layer: RadixAttention,
184
+ forward_batch: ForwardBatch,
185
+ save_kv_cache: bool = True,
186
+ # For multi_head latent attention
187
+ q_rope: Optional[torch.Tensor] = None,
188
+ k_rope: Optional[torch.Tensor] = None,
189
+ topk_indices: torch.Tensor = None,
190
+ ):
191
+
192
+ is_prefill = forward_batch.forward_mode.is_extend()
193
+
194
+ if save_kv_cache:
195
+ k = k.view(-1, layer.tp_k_head_num, self.kv_lora_rank)
196
+ k_rope = k_rope.view(-1, layer.tp_k_head_num, self.qk_rope_head_dim)
197
+ forward_batch.token_to_kv_pool.set_kv_buffer(
198
+ layer, forward_batch.out_cache_loc, k, k_rope
199
+ )
200
+ q_nope, q_pe = q, q_rope
201
+ k_nope, k_pe = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
202
+ block_table = self.forward_metadata.block_tables
203
+ if is_prefill:
204
+ actual_seq_qlen = torch.cumsum(forward_batch.seq_lens, dim=0)
205
+ else:
206
+ if self.forward_metadata.actual_seq_lengths_q is None:
207
+ actual_seq_qlen = (
208
+ torch.arange(1, q.shape[0] + 1).to(q.device).to(torch.int32)
209
+ )
210
+ else:
211
+ actual_seq_qlen = self.forward_metadata.actual_seq_lengths_q
212
+ if self.forward_metadata.seq_lens_cpu_int is None:
213
+ actual_seq_lengths_kv = self.forward_metadata.seq_lens
214
+ else:
215
+ actual_seq_lengths_kv = self.forward_metadata.seq_lens_cpu_int
216
+
217
+ attn_out = torch.ops.custom.npu_sparse_flash_attention(
218
+ query=q_nope,
219
+ key=k_nope,
220
+ value=k_nope,
221
+ query_rope=q_pe,
222
+ key_rope=k_pe,
223
+ sparse_indices=topk_indices,
224
+ scale_value=layer.scaling,
225
+ actual_seq_lengths_query=actual_seq_qlen.to(torch.int32),
226
+ actual_seq_lengths_kv=actual_seq_lengths_kv.to(q.device),
227
+ block_table=block_table,
228
+ sparse_block_size=1,
229
+ layout_query="TND",
230
+ layout_kv="PA_BSND",
231
+ sparse_mode=3,
232
+ )
233
+
234
+ return attn_out
235
+
163
236
  def forward_extend(
164
237
  self,
165
238
  q,
@@ -168,7 +241,23 @@ class AscendAttnBackend(AttentionBackend):
168
241
  layer: RadixAttention,
169
242
  forward_batch: ForwardBatch,
170
243
  save_kv_cache: bool = True,
244
+ # For multi_head latent attention
245
+ q_rope: Optional[torch.Tensor] = None,
246
+ k_rope: Optional[torch.Tensor] = None,
247
+ topk_indices: Optional[torch.Tensor] = None,
171
248
  ):
249
+ if topk_indices is not None:
250
+ return self.forward_sparse(
251
+ q,
252
+ k,
253
+ v,
254
+ layer,
255
+ forward_batch,
256
+ save_kv_cache,
257
+ q_rope,
258
+ k_rope,
259
+ topk_indices,
260
+ )
172
261
  if not self.use_mla:
173
262
  if save_kv_cache:
174
263
  forward_batch.token_to_kv_pool.set_kv_buffer(
@@ -368,7 +457,7 @@ class AscendAttnBackend(AttentionBackend):
368
457
  -1, layer.tp_v_head_num, self.page_size, self.kv_lora_rank
369
458
  )
370
459
 
371
- q_nope = q.view(-1, layer.tp_q_head_num, 1, self.kv_lora_rank)
460
+ q_nope = q.view(-1, layer.tp_q_head_num, 1, self.kv_lora_rank).contiguous()
372
461
  q_rope = q_rope.view(-1, layer.tp_q_head_num, 1, self.qk_rope_head_dim)
373
462
  if self.forward_metadata.seq_lens_cpu_int is None:
374
463
  actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_list
@@ -394,7 +483,7 @@ class AscendAttnBackend(AttentionBackend):
394
483
  antiquant_scale=None,
395
484
  sparse_mode=0,
396
485
  )
397
- output = torch.zeros_like(q_nope, dtype=q.dtype, device=q.device)
486
+ output = torch.empty_like(q_nope, dtype=q.dtype, device=q.device)
398
487
  softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)
399
488
 
400
489
  torch_npu.npu_fused_infer_attention_score.out(
@@ -429,7 +518,24 @@ class AscendAttnBackend(AttentionBackend):
429
518
  # For multi-head latent attention
430
519
  q_rope: Optional[torch.Tensor] = None,
431
520
  k_rope: Optional[torch.Tensor] = None,
521
+ topk_indices: Optional[torch.Tensor] = None,
432
522
  ):
523
+ if is_mla_preprocess_enabled():
524
+ # MLAPO does saving kv_cache
525
+ save_kv_cache = False
526
+ if topk_indices is not None:
527
+ return self.forward_sparse(
528
+ q,
529
+ k,
530
+ v,
531
+ layer,
532
+ forward_batch,
533
+ save_kv_cache,
534
+ q_rope,
535
+ k_rope,
536
+ topk_indices,
537
+ )
538
+
433
539
  if self.graph_mode:
434
540
  return self.forward_decode_graph(
435
541
  q,