sglang 0.5.1.post2__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (256) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +89 -54
  3. sglang/bench_serving.py +437 -40
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/profiler.py +0 -1
  6. sglang/srt/configs/__init__.py +4 -0
  7. sglang/srt/configs/internvl.py +6 -0
  8. sglang/srt/configs/longcat_flash.py +104 -0
  9. sglang/srt/configs/model_config.py +37 -7
  10. sglang/srt/configs/qwen3_next.py +326 -0
  11. sglang/srt/connector/__init__.py +1 -1
  12. sglang/srt/connector/base_connector.py +1 -2
  13. sglang/srt/connector/redis.py +2 -2
  14. sglang/srt/connector/serde/__init__.py +1 -1
  15. sglang/srt/connector/serde/safe_serde.py +4 -3
  16. sglang/srt/custom_op.py +11 -1
  17. sglang/srt/debug_utils/dump_comparator.py +81 -44
  18. sglang/srt/debug_utils/dump_loader.py +97 -0
  19. sglang/srt/debug_utils/dumper.py +11 -3
  20. sglang/srt/debug_utils/text_comparator.py +73 -11
  21. sglang/srt/disaggregation/ascend/conn.py +75 -0
  22. sglang/srt/disaggregation/base/conn.py +1 -1
  23. sglang/srt/disaggregation/common/conn.py +15 -12
  24. sglang/srt/disaggregation/decode.py +6 -4
  25. sglang/srt/disaggregation/fake/conn.py +1 -1
  26. sglang/srt/disaggregation/mini_lb.py +6 -420
  27. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  28. sglang/srt/disaggregation/nixl/conn.py +180 -16
  29. sglang/srt/disaggregation/prefill.py +6 -4
  30. sglang/srt/disaggregation/utils.py +5 -50
  31. sglang/srt/distributed/parallel_state.py +94 -58
  32. sglang/srt/entrypoints/engine.py +34 -14
  33. sglang/srt/entrypoints/http_server.py +172 -47
  34. sglang/srt/entrypoints/openai/protocol.py +90 -27
  35. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  36. sglang/srt/entrypoints/openai/serving_chat.py +82 -26
  37. sglang/srt/entrypoints/openai/serving_completions.py +25 -4
  38. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  39. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  40. sglang/srt/eplb/eplb_manager.py +28 -4
  41. sglang/srt/eplb/expert_distribution.py +55 -15
  42. sglang/srt/eplb/expert_location.py +8 -3
  43. sglang/srt/eplb/expert_location_updater.py +1 -1
  44. sglang/srt/function_call/deepseekv31_detector.py +222 -0
  45. sglang/srt/function_call/ebnf_composer.py +11 -9
  46. sglang/srt/function_call/function_call_parser.py +2 -0
  47. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  48. sglang/srt/function_call/gpt_oss_detector.py +144 -256
  49. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  50. sglang/srt/hf_transformers_utils.py +28 -7
  51. sglang/srt/layers/activation.py +44 -9
  52. sglang/srt/layers/attention/aiter_backend.py +93 -68
  53. sglang/srt/layers/attention/ascend_backend.py +381 -136
  54. sglang/srt/layers/attention/fla/chunk.py +242 -0
  55. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  56. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  57. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  58. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  59. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  60. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  61. sglang/srt/layers/attention/fla/index.py +37 -0
  62. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  63. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  64. sglang/srt/layers/attention/fla/op.py +66 -0
  65. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  66. sglang/srt/layers/attention/fla/utils.py +331 -0
  67. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  68. sglang/srt/layers/attention/flashattention_backend.py +241 -7
  69. sglang/srt/layers/attention/flashinfer_backend.py +11 -6
  70. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -14
  71. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  72. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  73. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  74. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  75. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  76. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  77. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  78. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  79. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  80. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  81. sglang/srt/layers/communicator.py +45 -8
  82. sglang/srt/layers/layernorm.py +54 -12
  83. sglang/srt/layers/logits_processor.py +10 -3
  84. sglang/srt/layers/moe/__init__.py +2 -1
  85. sglang/srt/layers/moe/cutlass_moe.py +0 -8
  86. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  87. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  88. sglang/srt/layers/moe/ep_moe/layer.py +111 -56
  89. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  90. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  94. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  100. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  101. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  102. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  103. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  104. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  105. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  106. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  107. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  108. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  109. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  110. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  111. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  112. sglang/srt/layers/moe/topk.py +43 -12
  113. sglang/srt/layers/moe/utils.py +6 -5
  114. sglang/srt/layers/quantization/awq.py +19 -7
  115. sglang/srt/layers/quantization/base_config.py +11 -6
  116. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  119. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +141 -235
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
  121. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +31 -22
  122. sglang/srt/layers/quantization/fp8.py +78 -48
  123. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  124. sglang/srt/layers/quantization/fp8_utils.py +45 -31
  125. sglang/srt/layers/quantization/gptq.py +25 -17
  126. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  127. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  128. sglang/srt/layers/quantization/mxfp4.py +93 -68
  129. sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
  130. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  131. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  132. sglang/srt/layers/quantization/quark/utils.py +97 -0
  133. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  134. sglang/srt/layers/quantization/unquant.py +135 -47
  135. sglang/srt/layers/quantization/utils.py +13 -0
  136. sglang/srt/layers/quantization/w4afp8.py +60 -42
  137. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  138. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  139. sglang/srt/layers/rocm_linear_utils.py +44 -0
  140. sglang/srt/layers/rotary_embedding.py +28 -19
  141. sglang/srt/layers/sampler.py +29 -5
  142. sglang/srt/layers/utils.py +0 -14
  143. sglang/srt/lora/backend/base_backend.py +50 -8
  144. sglang/srt/lora/backend/triton_backend.py +90 -2
  145. sglang/srt/lora/layers.py +32 -0
  146. sglang/srt/lora/lora.py +4 -1
  147. sglang/srt/lora/lora_manager.py +35 -112
  148. sglang/srt/lora/mem_pool.py +24 -10
  149. sglang/srt/lora/utils.py +18 -9
  150. sglang/srt/managers/cache_controller.py +396 -365
  151. sglang/srt/managers/data_parallel_controller.py +30 -15
  152. sglang/srt/managers/detokenizer_manager.py +18 -2
  153. sglang/srt/managers/disagg_service.py +46 -0
  154. sglang/srt/managers/io_struct.py +190 -11
  155. sglang/srt/managers/mm_utils.py +6 -1
  156. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  157. sglang/srt/managers/schedule_batch.py +27 -44
  158. sglang/srt/managers/schedule_policy.py +4 -3
  159. sglang/srt/managers/scheduler.py +148 -122
  160. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  161. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  162. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  163. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  164. sglang/srt/managers/template_manager.py +3 -3
  165. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  166. sglang/srt/managers/tokenizer_manager.py +77 -480
  167. sglang/srt/managers/tp_worker.py +16 -4
  168. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  169. sglang/srt/mem_cache/allocator.py +1 -1
  170. sglang/srt/mem_cache/chunk_cache.py +1 -1
  171. sglang/srt/mem_cache/hicache_storage.py +53 -40
  172. sglang/srt/mem_cache/hiradix_cache.py +196 -104
  173. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  174. sglang/srt/mem_cache/memory_pool.py +395 -53
  175. sglang/srt/mem_cache/memory_pool_host.py +27 -19
  176. sglang/srt/mem_cache/radix_cache.py +6 -6
  177. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  178. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  179. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  180. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  181. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +152 -23
  182. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  183. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  184. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +154 -95
  185. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  186. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  187. sglang/srt/metrics/collector.py +484 -63
  188. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  189. sglang/srt/metrics/utils.py +48 -0
  190. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  191. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  192. sglang/srt/model_executor/forward_batch_info.py +72 -18
  193. sglang/srt/model_executor/model_runner.py +190 -32
  194. sglang/srt/model_loader/__init__.py +9 -3
  195. sglang/srt/model_loader/loader.py +33 -28
  196. sglang/srt/model_loader/utils.py +12 -0
  197. sglang/srt/model_loader/weight_utils.py +2 -1
  198. sglang/srt/models/deepseek_v2.py +323 -53
  199. sglang/srt/models/gemma3n_mm.py +1 -1
  200. sglang/srt/models/glm4_moe.py +10 -1
  201. sglang/srt/models/glm4v.py +4 -2
  202. sglang/srt/models/gpt_oss.py +7 -19
  203. sglang/srt/models/internvl.py +28 -0
  204. sglang/srt/models/llama4.py +9 -0
  205. sglang/srt/models/llama_eagle3.py +17 -0
  206. sglang/srt/models/longcat_flash.py +1026 -0
  207. sglang/srt/models/longcat_flash_nextn.py +699 -0
  208. sglang/srt/models/minicpmv.py +165 -3
  209. sglang/srt/models/mllama4.py +25 -0
  210. sglang/srt/models/opt.py +637 -0
  211. sglang/srt/models/qwen2.py +33 -3
  212. sglang/srt/models/qwen2_5_vl.py +91 -42
  213. sglang/srt/models/qwen2_moe.py +79 -14
  214. sglang/srt/models/qwen3.py +8 -2
  215. sglang/srt/models/qwen3_moe.py +39 -8
  216. sglang/srt/models/qwen3_next.py +1039 -0
  217. sglang/srt/models/qwen3_next_mtp.py +109 -0
  218. sglang/srt/models/torch_native_llama.py +1 -1
  219. sglang/srt/models/transformers.py +1 -1
  220. sglang/srt/multimodal/processors/base_processor.py +4 -2
  221. sglang/srt/multimodal/processors/glm4v.py +9 -9
  222. sglang/srt/multimodal/processors/internvl.py +141 -129
  223. sglang/srt/{conversation.py → parser/conversation.py} +38 -5
  224. sglang/srt/parser/harmony_parser.py +588 -0
  225. sglang/srt/parser/reasoning_parser.py +309 -0
  226. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  227. sglang/srt/sampling/sampling_batch_info.py +18 -15
  228. sglang/srt/server_args.py +307 -80
  229. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  230. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  231. sglang/srt/speculative/eagle_worker.py +216 -120
  232. sglang/srt/speculative/spec_info.py +5 -0
  233. sglang/srt/speculative/standalone_worker.py +109 -0
  234. sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
  235. sglang/srt/utils.py +96 -7
  236. sglang/srt/weight_sync/utils.py +1 -1
  237. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  238. sglang/test/few_shot_gsm8k.py +1 -0
  239. sglang/test/runners.py +4 -0
  240. sglang/test/test_cutlass_moe.py +24 -6
  241. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  242. sglang/test/test_disaggregation_utils.py +66 -0
  243. sglang/test/test_utils.py +25 -1
  244. sglang/utils.py +5 -0
  245. sglang/version.py +1 -1
  246. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/METADATA +13 -10
  247. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/RECORD +253 -201
  248. sglang/srt/disaggregation/launch_lb.py +0 -131
  249. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  250. sglang/srt/reasoning_parser.py +0 -553
  251. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  252. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  253. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  254. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  255. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  256. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -10,13 +10,19 @@ from torch.nn.functional import scaled_dot_product_attention
10
10
  from sglang.srt.configs.model_config import AttentionArch
11
11
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
12
12
  from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
13
+ from sglang.srt.layers.dp_attention import get_attention_tp_size
13
14
  from sglang.srt.layers.radix_attention import AttentionType
14
15
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
16
+ from sglang.srt.utils import get_bool_env_var
15
17
 
16
18
  if TYPE_CHECKING:
17
19
  from sglang.srt.layers.radix_attention import RadixAttention
18
20
  from sglang.srt.model_executor.model_runner import ModelRunner
19
21
 
22
+ import os
23
+
24
+ import numpy as np
25
+
20
26
 
21
27
  @dataclass
22
28
  class ForwardMetadata:
@@ -28,6 +34,7 @@ class ForwardMetadata:
28
34
  extend_seq_lens_cpu_int: Optional[torch.Tensor] = None
29
35
  seq_lens_cpu_int: Optional[torch.Tensor] = None
30
36
  seq_lens_cpu_list: Optional[List[int]] = None
37
+ seq_lens_list_cumsum: Optional[List[int]] = None
31
38
 
32
39
 
33
40
  class AscendAttnBackend(AttentionBackend):
@@ -54,20 +61,31 @@ class AscendAttnBackend(AttentionBackend):
54
61
  super().__init__()
55
62
  self.forward_metadata = None
56
63
  self.device = model_runner.device
57
- self.gen_attention_mask(128, model_runner.dtype)
58
64
  self.page_size = model_runner.page_size
59
65
  self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
60
66
  if self.use_mla:
61
67
  self.kv_lora_rank = model_runner.model_config.kv_lora_rank
62
68
  self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
63
- self.native_attn = TorchNativeAttnBackend(model_runner)
69
+ self.native_attn = TorchNativeAttnBackend(model_runner)
64
70
  self.graph_metadata = {}
65
71
  self.max_context_len = model_runner.model_config.context_len
66
72
  self.req_to_token = model_runner.req_to_token_pool.req_to_token
67
73
  self.graph_mode = False
74
+ self.use_fia = get_bool_env_var("ASCEND_USE_FIA", "False")
75
+ if not self.use_fia:
76
+ self.gen_attention_mask(128, model_runner.dtype)
77
+ mask_length = 2048
78
+ self.fia_mask = ~torch.tril(
79
+ torch.ones(
80
+ (mask_length, mask_length),
81
+ dtype=torch.bool,
82
+ device=model_runner.device,
83
+ )
84
+ )
68
85
 
69
86
  def init_forward_metadata(self, forward_batch: ForwardBatch):
70
87
  """Init the metadata for a forward pass."""
88
+ tp_size = get_attention_tp_size()
71
89
  self.forward_metadata = ForwardMetadata()
72
90
 
73
91
  self.forward_metadata.block_tables = (
@@ -82,6 +100,13 @@ class AscendAttnBackend(AttentionBackend):
82
100
  )
83
101
  self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
84
102
 
103
+ seq_lens_list_cumsum = np.cumsum(forward_batch.extend_seq_lens_cpu)
104
+ if forward_batch.is_extend_in_batch:
105
+ seq_lens_list_cumsum[-1] = (
106
+ (seq_lens_list_cumsum[-1] - 1) // tp_size + 1
107
+ ) * tp_size
108
+ self.forward_metadata.seq_lens_list_cumsum = seq_lens_list_cumsum
109
+
85
110
  self.graph_mode = False
86
111
 
87
112
  def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
@@ -140,7 +165,7 @@ class AscendAttnBackend(AttentionBackend):
140
165
  self.graph_mode = True
141
166
 
142
167
  def get_cuda_graph_seq_len_fill_value(self):
143
- return 1
168
+ return 0
144
169
 
145
170
  def forward_extend(
146
171
  self,
@@ -149,73 +174,256 @@ class AscendAttnBackend(AttentionBackend):
149
174
  v,
150
175
  layer: RadixAttention,
151
176
  forward_batch: ForwardBatch,
152
- save_kv_cache=True,
177
+ save_kv_cache: bool = True,
153
178
  ):
154
- if save_kv_cache:
155
- forward_batch.token_to_kv_pool.set_kv_buffer(
156
- layer, forward_batch.out_cache_loc, k, v
179
+ if not self.use_mla:
180
+ if save_kv_cache:
181
+ forward_batch.token_to_kv_pool.set_kv_buffer(
182
+ layer, forward_batch.out_cache_loc, k, v
183
+ )
184
+
185
+ k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
186
+ v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
187
+
188
+ if self.use_fia:
189
+ """FIA will support multi-bs in the later version of CANN"""
190
+ q = q.reshape(-1, layer.tp_q_head_num, layer.qk_head_dim)
191
+ attn_output = torch.empty(
192
+ (q.size(0), layer.tp_q_head_num, layer.v_head_dim),
193
+ device=q.device,
194
+ dtype=q.dtype,
195
+ )
196
+ q_len_offset = 0
197
+ for q_len in forward_batch.extend_seq_lens_cpu:
198
+ attn_output[q_len_offset : q_len_offset + q_len] = (
199
+ torch.ops.npu.npu_fused_infer_attention_score(
200
+ q[None, q_len_offset : q_len_offset + q_len],
201
+ k[None, q_len_offset : q_len_offset + q_len],
202
+ v[None, q_len_offset : q_len_offset + q_len],
203
+ num_heads=layer.tp_q_head_num,
204
+ num_key_value_heads=layer.tp_k_head_num,
205
+ input_layout="BSND", # todo, TND not supports q_heads!=k_heads
206
+ atten_mask=self.fia_mask.unsqueeze(0),
207
+ sparse_mode=3,
208
+ scale=layer.scaling,
209
+ next_tokens=0,
210
+ )[0]
211
+ )
212
+ q_len_offset += q_len
213
+ attn_output = attn_output.view(
214
+ -1, layer.tp_q_head_num * layer.v_head_dim
215
+ )
216
+
217
+ else:
218
+ if layer.qk_head_dim <= 128:
219
+ query = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
220
+ attn_output = torch.empty(
221
+ (query.shape[0], layer.tp_q_head_num * layer.v_head_dim),
222
+ dtype=query.dtype,
223
+ device=query.device,
224
+ )
225
+
226
+ torch_npu._npu_flash_attention_qlens(
227
+ query=query,
228
+ key_cache=k_cache,
229
+ value_cache=v_cache,
230
+ mask=self.mask,
231
+ block_table=self.forward_metadata.block_tables,
232
+ seq_len=self.forward_metadata.extend_seq_lens_cpu_int,
233
+ context_lens=self.forward_metadata.seq_lens_cpu_int,
234
+ scale_value=layer.scaling,
235
+ num_heads=layer.tp_q_head_num,
236
+ num_kv_heads=layer.tp_k_head_num,
237
+ out=attn_output,
238
+ )
239
+ else:
240
+ if layer.qk_head_dim != layer.v_head_dim:
241
+ attn_output = q.new_empty(
242
+ (q.shape[0], layer.tp_q_head_num * layer.v_head_dim)
243
+ )
244
+ else:
245
+ attn_output = torch.empty_like(q)
246
+
247
+ use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
248
+
249
+ q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
250
+ o_ = attn_output.view(-1, layer.tp_q_head_num, layer.v_head_dim)
251
+
252
+ causal = True
253
+ if (
254
+ layer.is_cross_attention
255
+ or layer.attn_type == AttentionType.ENCODER_ONLY
256
+ ):
257
+ causal = False
258
+
259
+ self.native_attn._run_sdpa_forward_extend(
260
+ q_,
261
+ o_,
262
+ k_cache.view(-1, layer.tp_k_head_num, layer.qk_head_dim),
263
+ v_cache.view(-1, layer.tp_v_head_num, layer.v_head_dim),
264
+ forward_batch.req_to_token_pool.req_to_token,
265
+ forward_batch.req_pool_indices,
266
+ forward_batch.seq_lens,
267
+ forward_batch.extend_prefix_lens,
268
+ forward_batch.extend_seq_lens,
269
+ scaling=layer.scaling,
270
+ enable_gqa=use_gqa,
271
+ causal=causal,
272
+ )
273
+ else:
274
+ assert (
275
+ layer.qk_head_dim != layer.v_head_dim
276
+ ), "FIA only supports qk_head_dim != v_head_dim"
277
+ q_nope, q_rope = q.split([layer.v_head_dim, self.qk_rope_head_dim], dim=-1)
278
+ k_nope, k_rope = k.split([layer.v_head_dim, self.qk_rope_head_dim], dim=-1)
279
+
280
+ attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
281
+ q_nope,
282
+ k_nope,
283
+ v,
284
+ query_rope=q_rope,
285
+ key_rope=k_rope,
286
+ num_heads=layer.tp_q_head_num,
287
+ input_layout="TND",
288
+ atten_mask=self.fia_mask,
289
+ sparse_mode=3,
290
+ actual_seq_lengths=self.forward_metadata.seq_lens_list_cumsum,
291
+ actual_seq_lengths_kv=self.forward_metadata.seq_lens_list_cumsum,
292
+ scale=layer.scaling,
293
+ next_tokens=0,
157
294
  )
158
295
 
159
- k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
160
- v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
296
+ return attn_output
297
+
298
+ def forward_decode_graph(
299
+ self,
300
+ q: torch.Tensor,
301
+ k: torch.Tensor,
302
+ v: torch.Tensor,
303
+ layer: RadixAttention,
304
+ forward_batch: ForwardBatch,
305
+ save_kv_cache: bool = True,
306
+ q_rope: Optional[torch.Tensor] = None,
307
+ k_rope: Optional[torch.Tensor] = None,
308
+ ):
309
+ if save_kv_cache:
310
+ if self.use_mla:
311
+ k = k.view(-1, layer.tp_k_head_num, self.kv_lora_rank)
312
+ k_rope = k_rope.view(-1, layer.tp_k_head_num, self.qk_rope_head_dim)
313
+ forward_batch.token_to_kv_pool.set_kv_buffer(
314
+ layer, forward_batch.out_cache_loc, k, k_rope
315
+ )
316
+ else:
317
+ forward_batch.token_to_kv_pool.set_kv_buffer(
318
+ layer, forward_batch.out_cache_loc, k, v
319
+ )
161
320
 
162
321
  if not self.use_mla:
163
- query = q.view(-1, layer.tp_q_head_num * layer.qk_head_dim)
322
+ k_cache = forward_batch.token_to_kv_pool.get_key_buffer(
323
+ layer.layer_id
324
+ ).view(-1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim)
325
+ v_cache = forward_batch.token_to_kv_pool.get_value_buffer(
326
+ layer.layer_id
327
+ ).view(-1, self.page_size, layer.tp_v_head_num * layer.v_head_dim)
328
+ query = q.reshape(-1, 1, layer.tp_q_head_num * layer.qk_head_dim)
329
+ if self.forward_metadata.seq_lens_cpu_int is None:
330
+ actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_list
331
+ else:
332
+ actual_seq_len_kv = (
333
+ self.forward_metadata.seq_lens_cpu_int.cpu().int().tolist()
334
+ )
335
+ num_tokens = query.shape[0]
336
+ workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
337
+ query,
338
+ k_cache,
339
+ v_cache,
340
+ block_table=self.forward_metadata.block_tables,
341
+ block_size=self.page_size,
342
+ num_heads=layer.tp_q_head_num,
343
+ num_key_value_heads=layer.tp_k_head_num,
344
+ input_layout="BSH",
345
+ scale=layer.scaling,
346
+ actual_seq_lengths_kv=actual_seq_len_kv,
347
+ )
164
348
  output = torch.empty(
165
- (query.shape[0], layer.tp_q_head_num * layer.v_head_dim),
166
- dtype=query.dtype,
167
- device=query.device,
349
+ (num_tokens, 1, layer.tp_q_head_num * layer.v_head_dim),
350
+ dtype=q.dtype,
351
+ device=q.device,
168
352
  )
169
-
170
- torch_npu._npu_flash_attention_qlens(
171
- query=query,
172
- key_cache=k_cache,
173
- value_cache=v_cache,
174
- mask=self.mask,
353
+ softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)
354
+ torch_npu.npu_fused_infer_attention_score.out(
355
+ query,
356
+ k_cache,
357
+ v_cache,
175
358
  block_table=self.forward_metadata.block_tables,
176
- seq_len=self.forward_metadata.extend_seq_lens_cpu_int,
177
- context_lens=self.forward_metadata.seq_lens_cpu_int,
178
- scale_value=layer.scaling,
359
+ block_size=self.page_size,
179
360
  num_heads=layer.tp_q_head_num,
180
- num_kv_heads=layer.tp_k_head_num,
181
- out=output,
361
+ num_key_value_heads=layer.tp_k_head_num,
362
+ input_layout="BSH",
363
+ scale=layer.scaling,
364
+ actual_seq_lengths_kv=actual_seq_len_kv,
365
+ workspace=workspace,
366
+ out=[output, softmax_lse],
182
367
  )
183
- return output
368
+ return output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
184
369
  else:
185
- if layer.qk_head_dim != layer.v_head_dim:
186
- o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
370
+ c_kv, k_rope = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
371
+ k_rope_cache = k_rope.view(
372
+ -1, layer.tp_k_head_num, self.page_size, self.qk_rope_head_dim
373
+ )
374
+ c_kv_cache = c_kv.view(
375
+ -1, layer.tp_v_head_num, self.page_size, self.kv_lora_rank
376
+ )
377
+
378
+ q_nope = q.view(-1, layer.tp_q_head_num, 1, self.kv_lora_rank).contiguous()
379
+ q_rope = q_rope.view(-1, layer.tp_q_head_num, 1, self.qk_rope_head_dim)
380
+ if self.forward_metadata.seq_lens_cpu_int is None:
381
+ actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_list
187
382
  else:
188
- o = torch.empty_like(q)
189
-
190
- use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
191
-
192
- q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
193
- o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)
194
-
195
- causal = True
196
- if (
197
- layer.is_cross_attention
198
- or layer.attn_type == AttentionType.ENCODER_ONLY
199
- ):
200
- causal = False
201
-
202
- self.native_attn._run_sdpa_forward_extend(
203
- q_,
204
- o_,
205
- k_cache.view(
206
- -1, layer.tp_k_head_num, (self.kv_lora_rank + self.qk_rope_head_dim)
207
- ),
208
- v_cache.view(-1, layer.tp_v_head_num, self.kv_lora_rank),
209
- forward_batch.req_to_token_pool.req_to_token,
210
- forward_batch.req_pool_indices,
211
- forward_batch.seq_lens,
212
- forward_batch.extend_prefix_lens,
213
- forward_batch.extend_seq_lens,
214
- scaling=layer.scaling,
215
- enable_gqa=use_gqa,
216
- causal=causal,
383
+ actual_seq_len_kv = (
384
+ self.forward_metadata.seq_lens_cpu_int.cpu().int().tolist()
385
+ )
386
+
387
+ workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
388
+ q_nope,
389
+ c_kv_cache,
390
+ c_kv_cache,
391
+ query_rope=q_rope,
392
+ key_rope=k_rope_cache,
393
+ num_heads=layer.tp_q_head_num,
394
+ num_key_value_heads=layer.tp_k_head_num,
395
+ block_table=self.forward_metadata.block_tables,
396
+ block_size=self.page_size,
397
+ input_layout="BNSD",
398
+ scale=layer.scaling,
399
+ actual_seq_lengths_kv=actual_seq_len_kv,
400
+ antiquant_mode=0,
401
+ antiquant_scale=None,
402
+ sparse_mode=0,
403
+ )
404
+ output = torch.zeros_like(q_nope, dtype=q.dtype, device=q.device)
405
+ softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)
406
+
407
+ torch_npu.npu_fused_infer_attention_score.out(
408
+ q_nope,
409
+ c_kv_cache,
410
+ c_kv_cache,
411
+ query_rope=q_rope,
412
+ key_rope=k_rope_cache,
413
+ num_heads=layer.tp_q_head_num,
414
+ num_key_value_heads=layer.tp_k_head_num,
415
+ block_table=self.forward_metadata.block_tables,
416
+ block_size=self.page_size,
417
+ input_layout="BNSD",
418
+ scale=layer.scaling,
419
+ actual_seq_lengths_kv=actual_seq_len_kv,
420
+ antiquant_mode=0,
421
+ antiquant_scale=None,
422
+ sparse_mode=0,
423
+ workspace=workspace,
424
+ out=[output, softmax_lse],
217
425
  )
218
- return o
426
+ return output.view(-1, layer.tp_q_head_num * self.kv_lora_rank)
219
427
 
220
428
  def forward_decode(
221
429
  self,
@@ -224,65 +432,58 @@ class AscendAttnBackend(AttentionBackend):
224
432
  v: torch.Tensor,
225
433
  layer: RadixAttention,
226
434
  forward_batch: ForwardBatch,
227
- save_kv_cache=True,
435
+ save_kv_cache: bool = True,
436
+ # For multi-head latent attention
437
+ q_rope: Optional[torch.Tensor] = None,
438
+ k_rope: Optional[torch.Tensor] = None,
228
439
  ):
229
- if save_kv_cache:
230
- forward_batch.token_to_kv_pool.set_kv_buffer(
231
- layer, forward_batch.out_cache_loc, k, v
440
+ if self.graph_mode:
441
+ return self.forward_decode_graph(
442
+ q,
443
+ k,
444
+ v,
445
+ layer,
446
+ forward_batch,
447
+ save_kv_cache,
448
+ q_rope=q_rope,
449
+ k_rope=k_rope,
232
450
  )
451
+
233
452
  if not self.use_mla:
234
- if self.graph_mode:
235
- k_cache = forward_batch.token_to_kv_pool.get_key_buffer(
236
- layer.layer_id
237
- ).view(-1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim)
238
- v_cache = forward_batch.token_to_kv_pool.get_value_buffer(
239
- layer.layer_id
240
- ).view(-1, self.page_size, layer.tp_v_head_num * layer.v_head_dim)
241
- query = q.view(-1, 1, layer.tp_q_head_num * layer.qk_head_dim)
242
- num_tokens = query.shape[0]
243
- workspace = (
244
- torch_npu._npu_fused_infer_attention_score_get_max_workspace(
245
- query,
246
- k_cache,
247
- v_cache,
248
- block_table=self.forward_metadata.block_tables,
249
- block_size=self.page_size,
250
- num_heads=layer.tp_q_head_num,
251
- num_key_value_heads=layer.tp_k_head_num,
252
- input_layout="BSH",
253
- scale=layer.scaling,
254
- actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list,
255
- )
453
+ if save_kv_cache:
454
+ forward_batch.token_to_kv_pool.set_kv_buffer(
455
+ layer, forward_batch.out_cache_loc, k, v
256
456
  )
257
- output = torch.empty(
258
- (num_tokens, 1, layer.tp_q_head_num * layer.v_head_dim),
259
- dtype=q.dtype,
260
- device=q.device,
261
- )
262
- softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)
263
- torch_npu.npu_fused_infer_attention_score.out(
264
- query,
265
- k_cache,
266
- v_cache,
267
- block_table=self.forward_metadata.block_tables,
268
- block_size=self.page_size,
457
+ num_tokens = q.shape[0]
458
+ k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
459
+ v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
460
+ if self.use_fia:
461
+ attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
462
+ q.view(
463
+ forward_batch.batch_size,
464
+ -1,
465
+ layer.tp_q_head_num,
466
+ layer.qk_head_dim,
467
+ ),
468
+ k_cache.view(
469
+ -1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim
470
+ ),
471
+ v_cache.view(
472
+ -1, self.page_size, layer.tp_v_head_num * layer.qk_head_dim
473
+ ),
269
474
  num_heads=layer.tp_q_head_num,
270
475
  num_key_value_heads=layer.tp_k_head_num,
271
- input_layout="BSH",
476
+ input_layout="BSND",
477
+ atten_mask=None,
478
+ block_size=self.page_size,
479
+ block_table=self.forward_metadata.block_tables,
480
+ actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_int,
272
481
  scale=layer.scaling,
273
- actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list,
274
- workspace=workspace,
275
- out=[output, softmax_lse],
276
482
  )
277
483
  else:
278
- k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
279
- v_cache = forward_batch.token_to_kv_pool.get_value_buffer(
280
- layer.layer_id
281
- )
282
-
283
- query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
484
+ query = q.reshape(-1, layer.tp_q_head_num, layer.qk_head_dim)
284
485
  num_tokens = query.shape[0]
285
- output = torch.empty(
486
+ attn_output = torch.empty(
286
487
  (num_tokens, layer.tp_q_head_num, layer.v_head_dim),
287
488
  dtype=query.dtype,
288
489
  device=query.device,
@@ -297,36 +498,80 @@ class AscendAttnBackend(AttentionBackend):
297
498
  scale_value=layer.scaling,
298
499
  block_table=self.forward_metadata.block_tables,
299
500
  context_lens=self.forward_metadata.seq_lens_cpu_int,
300
- out=output,
501
+ out=attn_output,
301
502
  )
302
- return output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
503
+ return attn_output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
303
504
  else:
304
- query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
305
- num_tokens = query.shape[0]
306
- kv_c_and_k_pe_cache = forward_batch.token_to_kv_pool.get_key_buffer(
307
- layer.layer_id
308
- )
309
- kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
310
- -1,
311
- self.page_size,
312
- layer.tp_k_head_num,
313
- self.kv_lora_rank + self.qk_rope_head_dim,
314
- )
315
-
316
- attn_output = torch.empty(
317
- [num_tokens, layer.tp_q_head_num, self.kv_lora_rank],
318
- dtype=q.dtype,
319
- device=q.device,
320
- )
321
- torch_npu._npu_paged_attention_mla(
322
- query=query,
323
- key_cache=kv_c_and_k_pe_cache,
324
- num_kv_heads=layer.tp_k_head_num,
325
- num_heads=layer.tp_q_head_num,
326
- scale_value=layer.scaling,
327
- block_table=self.forward_metadata.block_tables,
328
- context_lens=self.forward_metadata.seq_lens_cpu_int,
329
- mla_vheadsize=self.kv_lora_rank,
330
- out=attn_output,
331
- )
505
+ if save_kv_cache:
506
+ forward_batch.token_to_kv_pool.set_kv_buffer(
507
+ layer, forward_batch.out_cache_loc, k, k_rope
508
+ )
509
+ num_tokens = q.shape[0]
510
+ kv_c = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
511
+ k_pe = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
512
+
513
+ if self.use_fia and (layer.tp_q_head_num // layer.tp_k_head_num) >= 8:
514
+ """layer.tp_q_head_num // layer.tp_k_head_num < 8 will support in the later version of CANN"""
515
+ kv_c = kv_c.view(
516
+ -1, self.page_size, layer.tp_k_head_num * self.kv_lora_rank
517
+ )
518
+ k_pe = k_pe.view(
519
+ -1, self.page_size, layer.tp_k_head_num * self.qk_rope_head_dim
520
+ )
521
+ q = q.view(
522
+ forward_batch.batch_size, -1, layer.tp_q_head_num, self.kv_lora_rank
523
+ )
524
+ q_rope = q_rope.view(
525
+ forward_batch.batch_size,
526
+ -1,
527
+ layer.tp_q_head_num,
528
+ self.qk_rope_head_dim,
529
+ )
530
+ attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
531
+ q,
532
+ kv_c,
533
+ kv_c,
534
+ query_rope=q_rope,
535
+ key_rope=k_pe,
536
+ num_heads=layer.tp_q_head_num,
537
+ num_key_value_heads=layer.tp_k_head_num,
538
+ input_layout="BSND",
539
+ atten_mask=None,
540
+ sparse_mode=0,
541
+ scale=layer.scaling,
542
+ antiquant_mode=0,
543
+ antiquant_scale=None,
544
+ block_table=self.forward_metadata.block_tables,
545
+ block_size=self.page_size,
546
+ actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_int,
547
+ )
548
+ else:
549
+ assert (
550
+ self.graph_mode == False
551
+ ) # _npu_paged_attention_mla not support graph mode
552
+ q = torch.cat([q, q_rope], dim=-1)
553
+ query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
554
+ kv_c_and_k_pe_cache = torch.cat([kv_c, k_pe], dim=-1)
555
+ kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
556
+ -1,
557
+ self.page_size,
558
+ layer.tp_k_head_num,
559
+ self.kv_lora_rank + self.qk_rope_head_dim,
560
+ )
561
+ attn_output = torch.empty(
562
+ [num_tokens, layer.tp_q_head_num, self.kv_lora_rank],
563
+ dtype=q.dtype,
564
+ device=q.device,
565
+ )
566
+ torch_npu._npu_paged_attention_mla(
567
+ query=query,
568
+ key_cache=kv_c_and_k_pe_cache,
569
+ num_kv_heads=layer.tp_k_head_num,
570
+ num_heads=layer.tp_q_head_num,
571
+ scale_value=layer.scaling,
572
+ block_table=self.forward_metadata.block_tables,
573
+ context_lens=self.forward_metadata.seq_lens_cpu_int,
574
+ mla_vheadsize=self.kv_lora_rank,
575
+ out=attn_output,
576
+ )
332
577
  return attn_output.view(num_tokens, layer.tp_q_head_num * self.kv_lora_rank)