sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (245) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +10 -1
  3. sglang/bench_serving.py +251 -26
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/internvl.py +6 -0
  7. sglang/srt/configs/longcat_flash.py +104 -0
  8. sglang/srt/configs/model_config.py +37 -7
  9. sglang/srt/configs/qwen3_next.py +326 -0
  10. sglang/srt/connector/__init__.py +1 -1
  11. sglang/srt/connector/base_connector.py +1 -2
  12. sglang/srt/connector/redis.py +2 -2
  13. sglang/srt/connector/serde/__init__.py +1 -1
  14. sglang/srt/connector/serde/safe_serde.py +4 -3
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/ascend/conn.py +75 -0
  21. sglang/srt/disaggregation/base/conn.py +1 -1
  22. sglang/srt/disaggregation/common/conn.py +15 -12
  23. sglang/srt/disaggregation/decode.py +6 -4
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -420
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +6 -4
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +94 -58
  31. sglang/srt/entrypoints/engine.py +34 -14
  32. sglang/srt/entrypoints/http_server.py +172 -47
  33. sglang/srt/entrypoints/openai/protocol.py +63 -3
  34. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  35. sglang/srt/entrypoints/openai/serving_chat.py +34 -19
  36. sglang/srt/entrypoints/openai/serving_completions.py +10 -4
  37. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  38. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  39. sglang/srt/eplb/eplb_manager.py +28 -4
  40. sglang/srt/eplb/expert_distribution.py +55 -15
  41. sglang/srt/eplb/expert_location.py +8 -3
  42. sglang/srt/eplb/expert_location_updater.py +1 -1
  43. sglang/srt/function_call/ebnf_composer.py +11 -9
  44. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  45. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  46. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  47. sglang/srt/hf_transformers_utils.py +12 -0
  48. sglang/srt/layers/activation.py +44 -9
  49. sglang/srt/layers/attention/aiter_backend.py +93 -68
  50. sglang/srt/layers/attention/ascend_backend.py +250 -112
  51. sglang/srt/layers/attention/fla/chunk.py +242 -0
  52. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  53. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  54. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  55. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  56. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  57. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  58. sglang/srt/layers/attention/fla/index.py +37 -0
  59. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  60. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  61. sglang/srt/layers/attention/fla/op.py +66 -0
  62. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  63. sglang/srt/layers/attention/fla/utils.py +331 -0
  64. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  65. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  66. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  67. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  68. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  69. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  70. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  71. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  72. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  73. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  74. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  75. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  76. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  77. sglang/srt/layers/communicator.py +45 -7
  78. sglang/srt/layers/layernorm.py +54 -12
  79. sglang/srt/layers/logits_processor.py +10 -3
  80. sglang/srt/layers/moe/__init__.py +2 -1
  81. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  82. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  83. sglang/srt/layers/moe/ep_moe/layer.py +110 -49
  84. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  85. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  86. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  88. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  89. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  90. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  94. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  95. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  96. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  97. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  98. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  99. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  100. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  101. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  102. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  103. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  104. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  105. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  106. sglang/srt/layers/moe/topk.py +43 -12
  107. sglang/srt/layers/moe/utils.py +6 -5
  108. sglang/srt/layers/quantization/awq.py +19 -7
  109. sglang/srt/layers/quantization/base_config.py +11 -6
  110. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  111. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  112. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  113. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
  114. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
  115. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  116. sglang/srt/layers/quantization/fp8.py +76 -47
  117. sglang/srt/layers/quantization/fp8_utils.py +43 -29
  118. sglang/srt/layers/quantization/gptq.py +25 -17
  119. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  120. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  121. sglang/srt/layers/quantization/mxfp4.py +77 -45
  122. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  123. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  124. sglang/srt/layers/quantization/quark/utils.py +97 -0
  125. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  126. sglang/srt/layers/quantization/unquant.py +135 -47
  127. sglang/srt/layers/quantization/utils.py +13 -0
  128. sglang/srt/layers/quantization/w4afp8.py +60 -42
  129. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  130. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  131. sglang/srt/layers/rocm_linear_utils.py +44 -0
  132. sglang/srt/layers/rotary_embedding.py +28 -19
  133. sglang/srt/layers/sampler.py +29 -5
  134. sglang/srt/lora/backend/base_backend.py +50 -8
  135. sglang/srt/lora/backend/triton_backend.py +90 -2
  136. sglang/srt/lora/layers.py +32 -0
  137. sglang/srt/lora/lora.py +4 -1
  138. sglang/srt/lora/lora_manager.py +35 -112
  139. sglang/srt/lora/mem_pool.py +24 -10
  140. sglang/srt/lora/utils.py +18 -9
  141. sglang/srt/managers/cache_controller.py +242 -278
  142. sglang/srt/managers/data_parallel_controller.py +30 -15
  143. sglang/srt/managers/detokenizer_manager.py +13 -2
  144. sglang/srt/managers/disagg_service.py +46 -0
  145. sglang/srt/managers/io_struct.py +160 -11
  146. sglang/srt/managers/mm_utils.py +6 -1
  147. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  148. sglang/srt/managers/schedule_batch.py +27 -44
  149. sglang/srt/managers/schedule_policy.py +4 -3
  150. sglang/srt/managers/scheduler.py +90 -115
  151. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  152. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  153. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  154. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  155. sglang/srt/managers/template_manager.py +3 -3
  156. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  157. sglang/srt/managers/tokenizer_manager.py +41 -477
  158. sglang/srt/managers/tp_worker.py +16 -4
  159. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  160. sglang/srt/mem_cache/allocator.py +1 -1
  161. sglang/srt/mem_cache/chunk_cache.py +1 -1
  162. sglang/srt/mem_cache/hicache_storage.py +24 -22
  163. sglang/srt/mem_cache/hiradix_cache.py +184 -101
  164. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  165. sglang/srt/mem_cache/memory_pool.py +324 -41
  166. sglang/srt/mem_cache/memory_pool_host.py +25 -18
  167. sglang/srt/mem_cache/radix_cache.py +5 -6
  168. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  169. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  170. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  171. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  172. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
  173. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  174. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  175. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
  176. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  177. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  178. sglang/srt/metrics/collector.py +484 -63
  179. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  180. sglang/srt/metrics/utils.py +48 -0
  181. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  182. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  183. sglang/srt/model_executor/forward_batch_info.py +72 -18
  184. sglang/srt/model_executor/model_runner.py +189 -31
  185. sglang/srt/model_loader/__init__.py +9 -3
  186. sglang/srt/model_loader/loader.py +33 -28
  187. sglang/srt/model_loader/utils.py +12 -0
  188. sglang/srt/model_loader/weight_utils.py +2 -1
  189. sglang/srt/models/deepseek_v2.py +311 -50
  190. sglang/srt/models/gemma3n_mm.py +1 -1
  191. sglang/srt/models/glm4_moe.py +10 -1
  192. sglang/srt/models/glm4v.py +4 -2
  193. sglang/srt/models/gpt_oss.py +5 -18
  194. sglang/srt/models/internvl.py +28 -0
  195. sglang/srt/models/llama4.py +9 -0
  196. sglang/srt/models/llama_eagle3.py +17 -0
  197. sglang/srt/models/longcat_flash.py +1026 -0
  198. sglang/srt/models/longcat_flash_nextn.py +699 -0
  199. sglang/srt/models/minicpmv.py +165 -3
  200. sglang/srt/models/mllama4.py +25 -0
  201. sglang/srt/models/opt.py +637 -0
  202. sglang/srt/models/qwen2.py +33 -3
  203. sglang/srt/models/qwen2_5_vl.py +90 -42
  204. sglang/srt/models/qwen2_moe.py +79 -14
  205. sglang/srt/models/qwen3.py +8 -2
  206. sglang/srt/models/qwen3_moe.py +39 -8
  207. sglang/srt/models/qwen3_next.py +1039 -0
  208. sglang/srt/models/qwen3_next_mtp.py +109 -0
  209. sglang/srt/models/torch_native_llama.py +1 -1
  210. sglang/srt/models/transformers.py +1 -1
  211. sglang/srt/multimodal/processors/base_processor.py +4 -2
  212. sglang/srt/multimodal/processors/glm4v.py +9 -9
  213. sglang/srt/multimodal/processors/internvl.py +141 -129
  214. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  215. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  216. sglang/srt/sampling/sampling_batch_info.py +18 -15
  217. sglang/srt/server_args.py +297 -79
  218. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  219. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  220. sglang/srt/speculative/eagle_worker.py +216 -120
  221. sglang/srt/speculative/spec_info.py +5 -0
  222. sglang/srt/speculative/standalone_worker.py +109 -0
  223. sglang/srt/utils.py +37 -2
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  226. sglang/test/few_shot_gsm8k.py +1 -0
  227. sglang/test/runners.py +4 -0
  228. sglang/test/test_cutlass_moe.py +24 -6
  229. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  230. sglang/test/test_disaggregation_utils.py +66 -0
  231. sglang/test/test_utils.py +25 -1
  232. sglang/utils.py +5 -0
  233. sglang/version.py +1 -1
  234. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
  235. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
  236. sglang/srt/disaggregation/launch_lb.py +0 -131
  237. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  238. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  239. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  240. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  241. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  242. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  243. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  244. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  245. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -18,7 +18,10 @@ import triton.language as tl
18
18
  from sglang.global_config import global_config
19
19
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
20
20
  from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
21
- from sglang.srt.layers.dp_attention import get_attention_tp_size
21
+ from sglang.srt.layers.dp_attention import (
22
+ get_attention_tp_size,
23
+ is_dp_attention_enabled,
24
+ )
22
25
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
23
26
 
24
27
  if TYPE_CHECKING:
@@ -154,6 +157,8 @@ class AiterAttnBackend(AttentionBackend):
154
157
  (max_bs + 1,), dtype=torch.int32, device=model_runner.device
155
158
  )
156
159
 
160
+ self.enable_dp_attention = is_dp_attention_enabled()
161
+
157
162
  def init_forward_metadata(self, forward_batch: ForwardBatch):
158
163
  """Init auxiliary variables for triton attention backend."""
159
164
 
@@ -302,19 +307,19 @@ class AiterAttnBackend(AttentionBackend):
302
307
  if self.use_mla:
303
308
  self.mla_indices_updater_prefill.update(
304
309
  forward_batch.req_pool_indices,
305
- forward_batch.extend_prefix_lens,
306
- sum(forward_batch.extend_prefix_lens_cpu),
310
+ forward_batch.seq_lens,
311
+ forward_batch.seq_lens_sum,
307
312
  forward_batch.extend_seq_lens,
308
- max(forward_batch.extend_seq_lens_cpu),
309
- forward_batch.seq_lens_cpu.max().item(),
313
+ forward_batch.extend_seq_lens.max().item(),
314
+ forward_batch.seq_lens.max().item(),
310
315
  spec_info=None,
311
316
  )
312
- self.mla_indices_updater_prefill.kv_indptr += (
313
- self.mla_indices_updater_prefill.qo_indptr
314
- )
317
+
318
+ kv_indices = self.mla_indices_updater_prefill.kv_indices
319
+
315
320
  self.forward_metadata = ForwardMetadata(
316
321
  self.mla_indices_updater_prefill.kv_indptr,
317
- self.mla_indices_updater_prefill.kv_indices,
322
+ kv_indices,
318
323
  self.mla_indices_updater_prefill.qo_indptr,
319
324
  self.kv_last_page_len[:bs],
320
325
  self.mla_indices_updater_prefill.max_q_len,
@@ -614,66 +619,86 @@ class AiterAttnBackend(AttentionBackend):
614
619
  assert len(k.shape) == 3
615
620
  assert len(v.shape) == 3
616
621
 
617
- if kv_indices.shape[0] == 0:
618
- o = flash_attn_varlen_func(
619
- q,
620
- k,
621
- v,
622
- qo_indptr,
623
- qo_indptr,
624
- max_q_len,
625
- max_q_len,
626
- softmax_scale=layer.scaling,
627
- causal=True,
628
- )
629
- return o
630
- elif layer.qk_head_dim != (kv_lora_rank + qk_rope_head_dim):
631
- K_Buffer = torch.index_select(K_Buffer, 0, kv_indices)
632
- kvc, k_pe = torch.split(
633
- K_Buffer, [kv_lora_rank, qk_rope_head_dim], dim=-1
634
- )
635
- kvprefix = layer.kv_b_proj(kvc.contiguous())[0]
622
+ if forward_batch.forward_mode.is_extend():
623
+ if kv_indices.shape[0] == 0:
624
+ o = flash_attn_varlen_func(
625
+ q,
626
+ k,
627
+ v,
628
+ qo_indptr,
629
+ qo_indptr,
630
+ max_q_len,
631
+ max_q_len,
632
+ softmax_scale=layer.scaling,
633
+ causal=True,
634
+ )
635
+ return o
636
+ elif layer.qk_head_dim != (kv_lora_rank + qk_rope_head_dim):
637
+ K_Buffer = torch.index_select(K_Buffer, 0, kv_indices)
638
+ kvc, k_pe = torch.split(
639
+ K_Buffer, [kv_lora_rank, qk_rope_head_dim], dim=-1
640
+ )
641
+ kvprefix = layer.kv_b_proj(kvc.contiguous())[0]
636
642
 
637
- kvprefix = kvprefix.view(
638
- -1, layer.tp_k_head_num, qk_nope_head_dim + layer.v_head_dim
639
- )
640
- k_prefix, v_prefix = torch.split(
641
- kvprefix, [qk_nope_head_dim, layer.v_head_dim], dim=-1
642
- )
643
- k_prefix = torch.cat(
644
- [
645
- k_prefix,
646
- torch.broadcast_to(
647
- k_pe,
648
- (k_pe.shape[0], layer.tp_k_head_num, k_pe.shape[2]),
649
- ),
650
- ],
651
- dim=-1,
652
- )
653
- assert (
654
- forward_batch.extend_prefix_lens.shape
655
- == forward_batch.extend_seq_lens.shape
656
- )
657
- k_prefix = torch.split(k_prefix, forward_batch.extend_prefix_lens_cpu)
658
- k_extend = torch.split(k, forward_batch.extend_seq_lens_cpu)
659
- assert len(k_prefix) == len(forward_batch.extend_prefix_lens_cpu)
660
- k = torch.cat([x for el in zip(k_prefix, k_extend) for x in el])
661
- v_prefix = torch.split(v_prefix, forward_batch.extend_prefix_lens_cpu)
662
- v_extend = torch.split(v, forward_batch.extend_seq_lens_cpu)
663
- v = torch.cat([x for el in zip(v_prefix, v_extend) for x in el])
664
-
665
- o = flash_attn_varlen_func(
666
- q,
667
- k,
668
- v,
669
- qo_indptr,
670
- kv_indptr,
671
- max_q_len,
672
- max_kv_len,
673
- softmax_scale=layer.scaling,
674
- causal=True,
675
- )
676
- return o
643
+ kvprefix = kvprefix.view(
644
+ -1, layer.tp_k_head_num, qk_nope_head_dim + layer.v_head_dim
645
+ )
646
+ k_prefix, v_prefix = torch.split(
647
+ kvprefix, [qk_nope_head_dim, layer.v_head_dim], dim=-1
648
+ )
649
+ k_prefix = torch.cat(
650
+ [
651
+ k_prefix,
652
+ torch.broadcast_to(
653
+ k_pe,
654
+ (k_pe.shape[0], layer.tp_k_head_num, k_pe.shape[2]),
655
+ ),
656
+ ],
657
+ dim=-1,
658
+ )
659
+ assert (
660
+ forward_batch.extend_prefix_lens.shape
661
+ == forward_batch.extend_seq_lens.shape
662
+ )
663
+
664
+ k = k_prefix
665
+ v = v_prefix
666
+
667
+ o = flash_attn_varlen_func(
668
+ q,
669
+ k,
670
+ v,
671
+ qo_indptr,
672
+ kv_indptr,
673
+ max_q_len,
674
+ max_kv_len,
675
+ softmax_scale=layer.scaling,
676
+ causal=True,
677
+ )
678
+ return o
679
+
680
+ else:
681
+ if layer.qk_head_dim != layer.v_head_dim:
682
+ o = q.new_empty(
683
+ (q.shape[0], layer.tp_q_head_num * layer.v_head_dim)
684
+ )
685
+ else:
686
+ o = torch.empty_like(q)
687
+
688
+ mla_prefill_fwd(
689
+ q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
690
+ K_Buffer.view(-1, 1, 1, layer.qk_head_dim),
691
+ o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
692
+ qo_indptr,
693
+ kv_indptr,
694
+ kv_indices,
695
+ self.forward_metadata.kv_last_page_len,
696
+ self.forward_metadata.max_q_len,
697
+ layer.scaling,
698
+ layer.logit_cap,
699
+ )
700
+ K_Buffer = K_Buffer.view(-1, layer.tp_k_head_num, layer.qk_head_dim)
701
+ return o
677
702
  elif forward_batch.forward_mode.is_target_verify():
678
703
  o = q.new_empty((q.shape[0], layer.tp_q_head_num, layer.v_head_dim))
679
704
  mla_decode_fwd(
@@ -10,6 +10,7 @@ from torch.nn.functional import scaled_dot_product_attention
10
10
  from sglang.srt.configs.model_config import AttentionArch
11
11
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
12
12
  from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
13
+ from sglang.srt.layers.dp_attention import get_attention_tp_size
13
14
  from sglang.srt.layers.radix_attention import AttentionType
14
15
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
15
16
  from sglang.srt.utils import get_bool_env_var
@@ -33,6 +34,7 @@ class ForwardMetadata:
33
34
  extend_seq_lens_cpu_int: Optional[torch.Tensor] = None
34
35
  seq_lens_cpu_int: Optional[torch.Tensor] = None
35
36
  seq_lens_cpu_list: Optional[List[int]] = None
37
+ seq_lens_list_cumsum: Optional[List[int]] = None
36
38
 
37
39
 
38
40
  class AscendAttnBackend(AttentionBackend):
@@ -64,7 +66,7 @@ class AscendAttnBackend(AttentionBackend):
64
66
  if self.use_mla:
65
67
  self.kv_lora_rank = model_runner.model_config.kv_lora_rank
66
68
  self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
67
- self.native_attn = TorchNativeAttnBackend(model_runner)
69
+ self.native_attn = TorchNativeAttnBackend(model_runner)
68
70
  self.graph_metadata = {}
69
71
  self.max_context_len = model_runner.model_config.context_len
70
72
  self.req_to_token = model_runner.req_to_token_pool.req_to_token
@@ -83,6 +85,7 @@ class AscendAttnBackend(AttentionBackend):
83
85
 
84
86
  def init_forward_metadata(self, forward_batch: ForwardBatch):
85
87
  """Init the metadata for a forward pass."""
88
+ tp_size = get_attention_tp_size()
86
89
  self.forward_metadata = ForwardMetadata()
87
90
 
88
91
  self.forward_metadata.block_tables = (
@@ -96,9 +99,13 @@ class AscendAttnBackend(AttentionBackend):
96
99
  forward_batch.extend_seq_lens.cpu().int()
97
100
  )
98
101
  self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
99
- self.forward_metadata.seq_lens_list_cumsum = np.cumsum(
100
- forward_batch.extend_seq_lens_cpu
101
- )
102
+
103
+ seq_lens_list_cumsum = np.cumsum(forward_batch.extend_seq_lens_cpu)
104
+ if forward_batch.is_extend_in_batch:
105
+ seq_lens_list_cumsum[-1] = (
106
+ (seq_lens_list_cumsum[-1] - 1) // tp_size + 1
107
+ ) * tp_size
108
+ self.forward_metadata.seq_lens_list_cumsum = seq_lens_list_cumsum
102
109
 
103
110
  self.graph_mode = False
104
111
 
@@ -158,7 +165,7 @@ class AscendAttnBackend(AttentionBackend):
158
165
  self.graph_mode = True
159
166
 
160
167
  def get_cuda_graph_seq_len_fill_value(self):
161
- return 1
168
+ return 0
162
169
 
163
170
  def forward_extend(
164
171
  self,
@@ -167,7 +174,7 @@ class AscendAttnBackend(AttentionBackend):
167
174
  v,
168
175
  layer: RadixAttention,
169
176
  forward_batch: ForwardBatch,
170
- save_kv_cache=True,
177
+ save_kv_cache: bool = True,
171
178
  ):
172
179
  if not self.use_mla:
173
180
  if save_kv_cache:
@@ -180,7 +187,7 @@ class AscendAttnBackend(AttentionBackend):
180
187
 
181
188
  if self.use_fia:
182
189
  """FIA will support multi-bs in the later version of CANN"""
183
- q = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
190
+ q = q.reshape(-1, layer.tp_q_head_num, layer.qk_head_dim)
184
191
  attn_output = torch.empty(
185
192
  (q.size(0), layer.tp_q_head_num, layer.v_head_dim),
186
193
  device=q.device,
@@ -208,26 +215,61 @@ class AscendAttnBackend(AttentionBackend):
208
215
  )
209
216
 
210
217
  else:
211
- query = q.view(-1, layer.tp_q_head_num * layer.qk_head_dim)
212
- attn_output = torch.empty(
213
- (query.shape[0], layer.tp_q_head_num * layer.v_head_dim),
214
- dtype=query.dtype,
215
- device=query.device,
216
- )
218
+ if layer.qk_head_dim <= 128:
219
+ query = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
220
+ attn_output = torch.empty(
221
+ (query.shape[0], layer.tp_q_head_num * layer.v_head_dim),
222
+ dtype=query.dtype,
223
+ device=query.device,
224
+ )
217
225
 
218
- torch_npu._npu_flash_attention_qlens(
219
- query=query,
220
- key_cache=k_cache,
221
- value_cache=v_cache,
222
- mask=self.mask,
223
- block_table=self.forward_metadata.block_tables,
224
- seq_len=self.forward_metadata.extend_seq_lens_cpu_int,
225
- context_lens=self.forward_metadata.seq_lens_cpu_int,
226
- scale_value=layer.scaling,
227
- num_heads=layer.tp_q_head_num,
228
- num_kv_heads=layer.tp_k_head_num,
229
- out=attn_output,
230
- )
226
+ torch_npu._npu_flash_attention_qlens(
227
+ query=query,
228
+ key_cache=k_cache,
229
+ value_cache=v_cache,
230
+ mask=self.mask,
231
+ block_table=self.forward_metadata.block_tables,
232
+ seq_len=self.forward_metadata.extend_seq_lens_cpu_int,
233
+ context_lens=self.forward_metadata.seq_lens_cpu_int,
234
+ scale_value=layer.scaling,
235
+ num_heads=layer.tp_q_head_num,
236
+ num_kv_heads=layer.tp_k_head_num,
237
+ out=attn_output,
238
+ )
239
+ else:
240
+ if layer.qk_head_dim != layer.v_head_dim:
241
+ attn_output = q.new_empty(
242
+ (q.shape[0], layer.tp_q_head_num * layer.v_head_dim)
243
+ )
244
+ else:
245
+ attn_output = torch.empty_like(q)
246
+
247
+ use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
248
+
249
+ q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
250
+ o_ = attn_output.view(-1, layer.tp_q_head_num, layer.v_head_dim)
251
+
252
+ causal = True
253
+ if (
254
+ layer.is_cross_attention
255
+ or layer.attn_type == AttentionType.ENCODER_ONLY
256
+ ):
257
+ causal = False
258
+
259
+ self.native_attn._run_sdpa_forward_extend(
260
+ q_,
261
+ o_,
262
+ k_cache.view(-1, layer.tp_k_head_num, layer.qk_head_dim),
263
+ v_cache.view(-1, layer.tp_v_head_num, layer.v_head_dim),
264
+ forward_batch.req_to_token_pool.req_to_token,
265
+ forward_batch.req_pool_indices,
266
+ forward_batch.seq_lens,
267
+ forward_batch.extend_prefix_lens,
268
+ forward_batch.extend_seq_lens,
269
+ scaling=layer.scaling,
270
+ enable_gqa=use_gqa,
271
+ causal=causal,
272
+ )
231
273
  else:
232
274
  assert (
233
275
  layer.qk_head_dim != layer.v_head_dim
@@ -253,6 +295,136 @@ class AscendAttnBackend(AttentionBackend):
253
295
 
254
296
  return attn_output
255
297
 
298
+ def forward_decode_graph(
299
+ self,
300
+ q: torch.Tensor,
301
+ k: torch.Tensor,
302
+ v: torch.Tensor,
303
+ layer: RadixAttention,
304
+ forward_batch: ForwardBatch,
305
+ save_kv_cache: bool = True,
306
+ q_rope: Optional[torch.Tensor] = None,
307
+ k_rope: Optional[torch.Tensor] = None,
308
+ ):
309
+ if save_kv_cache:
310
+ if self.use_mla:
311
+ k = k.view(-1, layer.tp_k_head_num, self.kv_lora_rank)
312
+ k_rope = k_rope.view(-1, layer.tp_k_head_num, self.qk_rope_head_dim)
313
+ forward_batch.token_to_kv_pool.set_kv_buffer(
314
+ layer, forward_batch.out_cache_loc, k, k_rope
315
+ )
316
+ else:
317
+ forward_batch.token_to_kv_pool.set_kv_buffer(
318
+ layer, forward_batch.out_cache_loc, k, v
319
+ )
320
+
321
+ if not self.use_mla:
322
+ k_cache = forward_batch.token_to_kv_pool.get_key_buffer(
323
+ layer.layer_id
324
+ ).view(-1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim)
325
+ v_cache = forward_batch.token_to_kv_pool.get_value_buffer(
326
+ layer.layer_id
327
+ ).view(-1, self.page_size, layer.tp_v_head_num * layer.v_head_dim)
328
+ query = q.reshape(-1, 1, layer.tp_q_head_num * layer.qk_head_dim)
329
+ if self.forward_metadata.seq_lens_cpu_int is None:
330
+ actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_list
331
+ else:
332
+ actual_seq_len_kv = (
333
+ self.forward_metadata.seq_lens_cpu_int.cpu().int().tolist()
334
+ )
335
+ num_tokens = query.shape[0]
336
+ workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
337
+ query,
338
+ k_cache,
339
+ v_cache,
340
+ block_table=self.forward_metadata.block_tables,
341
+ block_size=self.page_size,
342
+ num_heads=layer.tp_q_head_num,
343
+ num_key_value_heads=layer.tp_k_head_num,
344
+ input_layout="BSH",
345
+ scale=layer.scaling,
346
+ actual_seq_lengths_kv=actual_seq_len_kv,
347
+ )
348
+ output = torch.empty(
349
+ (num_tokens, 1, layer.tp_q_head_num * layer.v_head_dim),
350
+ dtype=q.dtype,
351
+ device=q.device,
352
+ )
353
+ softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)
354
+ torch_npu.npu_fused_infer_attention_score.out(
355
+ query,
356
+ k_cache,
357
+ v_cache,
358
+ block_table=self.forward_metadata.block_tables,
359
+ block_size=self.page_size,
360
+ num_heads=layer.tp_q_head_num,
361
+ num_key_value_heads=layer.tp_k_head_num,
362
+ input_layout="BSH",
363
+ scale=layer.scaling,
364
+ actual_seq_lengths_kv=actual_seq_len_kv,
365
+ workspace=workspace,
366
+ out=[output, softmax_lse],
367
+ )
368
+ return output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
369
+ else:
370
+ c_kv, k_rope = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
371
+ k_rope_cache = k_rope.view(
372
+ -1, layer.tp_k_head_num, self.page_size, self.qk_rope_head_dim
373
+ )
374
+ c_kv_cache = c_kv.view(
375
+ -1, layer.tp_v_head_num, self.page_size, self.kv_lora_rank
376
+ )
377
+
378
+ q_nope = q.view(-1, layer.tp_q_head_num, 1, self.kv_lora_rank).contiguous()
379
+ q_rope = q_rope.view(-1, layer.tp_q_head_num, 1, self.qk_rope_head_dim)
380
+ if self.forward_metadata.seq_lens_cpu_int is None:
381
+ actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_list
382
+ else:
383
+ actual_seq_len_kv = (
384
+ self.forward_metadata.seq_lens_cpu_int.cpu().int().tolist()
385
+ )
386
+
387
+ workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
388
+ q_nope,
389
+ c_kv_cache,
390
+ c_kv_cache,
391
+ query_rope=q_rope,
392
+ key_rope=k_rope_cache,
393
+ num_heads=layer.tp_q_head_num,
394
+ num_key_value_heads=layer.tp_k_head_num,
395
+ block_table=self.forward_metadata.block_tables,
396
+ block_size=self.page_size,
397
+ input_layout="BNSD",
398
+ scale=layer.scaling,
399
+ actual_seq_lengths_kv=actual_seq_len_kv,
400
+ antiquant_mode=0,
401
+ antiquant_scale=None,
402
+ sparse_mode=0,
403
+ )
404
+ output = torch.zeros_like(q_nope, dtype=q.dtype, device=q.device)
405
+ softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)
406
+
407
+ torch_npu.npu_fused_infer_attention_score.out(
408
+ q_nope,
409
+ c_kv_cache,
410
+ c_kv_cache,
411
+ query_rope=q_rope,
412
+ key_rope=k_rope_cache,
413
+ num_heads=layer.tp_q_head_num,
414
+ num_key_value_heads=layer.tp_k_head_num,
415
+ block_table=self.forward_metadata.block_tables,
416
+ block_size=self.page_size,
417
+ input_layout="BNSD",
418
+ scale=layer.scaling,
419
+ actual_seq_lengths_kv=actual_seq_len_kv,
420
+ antiquant_mode=0,
421
+ antiquant_scale=None,
422
+ sparse_mode=0,
423
+ workspace=workspace,
424
+ out=[output, softmax_lse],
425
+ )
426
+ return output.view(-1, layer.tp_q_head_num * self.kv_lora_rank)
427
+
256
428
  def forward_decode(
257
429
  self,
258
430
  q: torch.Tensor,
@@ -260,106 +432,74 @@ class AscendAttnBackend(AttentionBackend):
260
432
  v: torch.Tensor,
261
433
  layer: RadixAttention,
262
434
  forward_batch: ForwardBatch,
263
- save_kv_cache: bool = False,
435
+ save_kv_cache: bool = True,
264
436
  # For multi-head latent attention
265
437
  q_rope: Optional[torch.Tensor] = None,
266
438
  k_rope: Optional[torch.Tensor] = None,
267
439
  ):
440
+ if self.graph_mode:
441
+ return self.forward_decode_graph(
442
+ q,
443
+ k,
444
+ v,
445
+ layer,
446
+ forward_batch,
447
+ save_kv_cache,
448
+ q_rope=q_rope,
449
+ k_rope=k_rope,
450
+ )
451
+
268
452
  if not self.use_mla:
269
453
  if save_kv_cache:
270
454
  forward_batch.token_to_kv_pool.set_kv_buffer(
271
455
  layer, forward_batch.out_cache_loc, k, v
272
456
  )
273
457
  num_tokens = q.shape[0]
274
- if self.graph_mode:
275
- k_cache = forward_batch.token_to_kv_pool.get_key_buffer(
276
- layer.layer_id
277
- ).view(-1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim)
278
- v_cache = forward_batch.token_to_kv_pool.get_value_buffer(
279
- layer.layer_id
280
- ).view(-1, self.page_size, layer.tp_v_head_num * layer.v_head_dim)
281
- query = q.view(-1, 1, layer.tp_q_head_num * layer.qk_head_dim)
282
- workspace = (
283
- torch_npu._npu_fused_infer_attention_score_get_max_workspace(
284
- query,
285
- k_cache,
286
- v_cache,
287
- block_table=self.forward_metadata.block_tables,
288
- block_size=self.page_size,
289
- num_heads=layer.tp_q_head_num,
290
- num_key_value_heads=layer.tp_k_head_num,
291
- input_layout="BSH",
292
- scale=layer.scaling,
293
- actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list,
294
- )
295
- )
296
- attn_output = torch.empty(
297
- (num_tokens, 1, layer.tp_q_head_num * layer.v_head_dim),
298
- dtype=q.dtype,
299
- device=q.device,
300
- )
301
- softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)
302
- torch_npu.npu_fused_infer_attention_score.out(
303
- query,
304
- k_cache,
305
- v_cache,
306
- block_table=self.forward_metadata.block_tables,
307
- block_size=self.page_size,
458
+ k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
459
+ v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
460
+ if self.use_fia:
461
+ attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
462
+ q.view(
463
+ forward_batch.batch_size,
464
+ -1,
465
+ layer.tp_q_head_num,
466
+ layer.qk_head_dim,
467
+ ),
468
+ k_cache.view(
469
+ -1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim
470
+ ),
471
+ v_cache.view(
472
+ -1, self.page_size, layer.tp_v_head_num * layer.qk_head_dim
473
+ ),
308
474
  num_heads=layer.tp_q_head_num,
309
475
  num_key_value_heads=layer.tp_k_head_num,
310
- input_layout="BSH",
476
+ input_layout="BSND",
477
+ atten_mask=None,
478
+ block_size=self.page_size,
479
+ block_table=self.forward_metadata.block_tables,
480
+ actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_int,
311
481
  scale=layer.scaling,
312
- actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list,
313
- workspace=workspace,
314
- out=[attn_output, softmax_lse],
315
482
  )
316
483
  else:
317
- k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
318
- v_cache = forward_batch.token_to_kv_pool.get_value_buffer(
319
- layer.layer_id
484
+ query = q.reshape(-1, layer.tp_q_head_num, layer.qk_head_dim)
485
+ num_tokens = query.shape[0]
486
+ attn_output = torch.empty(
487
+ (num_tokens, layer.tp_q_head_num, layer.v_head_dim),
488
+ dtype=query.dtype,
489
+ device=query.device,
320
490
  )
321
- if self.use_fia:
322
- attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
323
- q.view(
324
- forward_batch.batch_size,
325
- -1,
326
- layer.tp_q_head_num,
327
- layer.qk_head_dim,
328
- ),
329
- k_cache.view(
330
- -1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim
331
- ),
332
- v_cache.view(
333
- -1, self.page_size, layer.tp_v_head_num * layer.qk_head_dim
334
- ),
335
- num_heads=layer.tp_q_head_num,
336
- num_key_value_heads=layer.tp_k_head_num,
337
- input_layout="BSND",
338
- atten_mask=None,
339
- block_size=self.page_size,
340
- block_table=self.forward_metadata.block_tables,
341
- actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_int,
342
- scale=layer.scaling,
343
- )
344
- else:
345
- query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
346
- attn_output = torch.empty(
347
- (num_tokens, layer.tp_q_head_num, layer.v_head_dim),
348
- dtype=query.dtype,
349
- device=query.device,
350
- )
351
491
 
352
- torch_npu._npu_paged_attention(
353
- query=query,
354
- key_cache=k_cache,
355
- value_cache=v_cache,
356
- num_heads=layer.tp_q_head_num,
357
- num_kv_heads=layer.tp_k_head_num,
358
- scale_value=layer.scaling,
359
- block_table=self.forward_metadata.block_tables,
360
- context_lens=self.forward_metadata.seq_lens_cpu_int,
361
- out=attn_output,
362
- )
492
+ torch_npu._npu_paged_attention(
493
+ query=query,
494
+ key_cache=k_cache,
495
+ value_cache=v_cache,
496
+ num_heads=layer.tp_q_head_num,
497
+ num_kv_heads=layer.tp_k_head_num,
498
+ scale_value=layer.scaling,
499
+ block_table=self.forward_metadata.block_tables,
500
+ context_lens=self.forward_metadata.seq_lens_cpu_int,
501
+ out=attn_output,
502
+ )
363
503
  return attn_output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
364
504
  else:
365
505
  if save_kv_cache:
@@ -370,9 +510,7 @@ class AscendAttnBackend(AttentionBackend):
370
510
  kv_c = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
371
511
  k_pe = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
372
512
 
373
- if (self.graph_mode or self.use_fia) and (
374
- layer.tp_q_head_num // layer.tp_k_head_num
375
- ) >= 8:
513
+ if self.use_fia and (layer.tp_q_head_num // layer.tp_k_head_num) >= 8:
376
514
  """layer.tp_q_head_num // layer.tp_k_head_num < 8 will support in the later version of CANN"""
377
515
  kv_c = kv_c.view(
378
516
  -1, self.page_size, layer.tp_k_head_num * self.kv_lora_rank