sglang 0.5.1.post2__py3-none-any.whl → 0.5.2rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +79 -53
  3. sglang/bench_serving.py +186 -14
  4. sglang/profiler.py +0 -1
  5. sglang/srt/configs/__init__.py +2 -0
  6. sglang/srt/configs/longcat_flash.py +104 -0
  7. sglang/srt/configs/model_config.py +12 -0
  8. sglang/srt/connector/__init__.py +1 -1
  9. sglang/srt/connector/base_connector.py +1 -2
  10. sglang/srt/connector/redis.py +2 -2
  11. sglang/srt/connector/serde/__init__.py +1 -1
  12. sglang/srt/connector/serde/safe_serde.py +4 -3
  13. sglang/srt/conversation.py +38 -5
  14. sglang/srt/disaggregation/ascend/conn.py +75 -0
  15. sglang/srt/disaggregation/launch_lb.py +0 -13
  16. sglang/srt/disaggregation/mini_lb.py +33 -8
  17. sglang/srt/disaggregation/prefill.py +1 -1
  18. sglang/srt/distributed/parallel_state.py +24 -14
  19. sglang/srt/entrypoints/engine.py +19 -12
  20. sglang/srt/entrypoints/http_server.py +174 -34
  21. sglang/srt/entrypoints/openai/protocol.py +87 -24
  22. sglang/srt/entrypoints/openai/serving_chat.py +50 -9
  23. sglang/srt/entrypoints/openai/serving_completions.py +15 -0
  24. sglang/srt/eplb/eplb_manager.py +26 -2
  25. sglang/srt/eplb/expert_distribution.py +29 -2
  26. sglang/srt/function_call/deepseekv31_detector.py +222 -0
  27. sglang/srt/function_call/function_call_parser.py +2 -0
  28. sglang/srt/function_call/gpt_oss_detector.py +144 -256
  29. sglang/srt/harmony_parser.py +588 -0
  30. sglang/srt/hf_transformers_utils.py +26 -7
  31. sglang/srt/layers/activation.py +12 -0
  32. sglang/srt/layers/attention/ascend_backend.py +374 -136
  33. sglang/srt/layers/attention/flashattention_backend.py +241 -7
  34. sglang/srt/layers/attention/flashinfer_backend.py +5 -2
  35. sglang/srt/layers/attention/flashinfer_mla_backend.py +5 -2
  36. sglang/srt/layers/attention/hybrid_attn_backend.py +53 -21
  37. sglang/srt/layers/attention/trtllm_mla_backend.py +25 -10
  38. sglang/srt/layers/communicator.py +1 -2
  39. sglang/srt/layers/layernorm.py +28 -3
  40. sglang/srt/layers/linear.py +3 -2
  41. sglang/srt/layers/logits_processor.py +1 -1
  42. sglang/srt/layers/moe/cutlass_moe.py +0 -8
  43. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  44. sglang/srt/layers/moe/ep_moe/layer.py +13 -13
  45. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  46. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  47. sglang/srt/layers/moe/topk.py +35 -12
  48. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +133 -235
  49. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
  50. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +5 -23
  51. sglang/srt/layers/quantization/fp8.py +2 -1
  52. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  53. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  54. sglang/srt/layers/quantization/modelopt_quant.py +7 -0
  55. sglang/srt/layers/quantization/mxfp4.py +25 -27
  56. sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
  57. sglang/srt/layers/quantization/utils.py +13 -0
  58. sglang/srt/layers/quantization/w8a8_int8.py +7 -3
  59. sglang/srt/layers/rotary_embedding.py +28 -1
  60. sglang/srt/layers/sampler.py +29 -5
  61. sglang/srt/layers/utils.py +0 -14
  62. sglang/srt/managers/cache_controller.py +237 -204
  63. sglang/srt/managers/detokenizer_manager.py +48 -2
  64. sglang/srt/managers/io_struct.py +57 -0
  65. sglang/srt/managers/mm_utils.py +5 -1
  66. sglang/srt/managers/multi_tokenizer_mixin.py +591 -0
  67. sglang/srt/managers/scheduler.py +94 -9
  68. sglang/srt/managers/scheduler_output_processor_mixin.py +20 -18
  69. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  70. sglang/srt/managers/tokenizer_manager.py +122 -42
  71. sglang/srt/mem_cache/chunk_cache.py +1 -1
  72. sglang/srt/mem_cache/hicache_storage.py +51 -23
  73. sglang/srt/mem_cache/hiradix_cache.py +87 -71
  74. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  75. sglang/srt/mem_cache/memory_pool.py +77 -14
  76. sglang/srt/mem_cache/memory_pool_host.py +4 -5
  77. sglang/srt/mem_cache/radix_cache.py +6 -4
  78. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  79. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +38 -20
  80. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +87 -82
  81. sglang/srt/mem_cache/swa_radix_cache.py +1 -1
  82. sglang/srt/model_executor/model_runner.py +6 -5
  83. sglang/srt/model_loader/loader.py +15 -24
  84. sglang/srt/model_loader/utils.py +12 -0
  85. sglang/srt/models/deepseek_v2.py +38 -13
  86. sglang/srt/models/gpt_oss.py +2 -15
  87. sglang/srt/models/llama_eagle3.py +4 -0
  88. sglang/srt/models/longcat_flash.py +1015 -0
  89. sglang/srt/models/longcat_flash_nextn.py +691 -0
  90. sglang/srt/models/qwen2.py +26 -3
  91. sglang/srt/models/qwen2_5_vl.py +66 -41
  92. sglang/srt/models/qwen2_moe.py +22 -2
  93. sglang/srt/models/transformers.py +1 -1
  94. sglang/srt/multimodal/processors/base_processor.py +4 -2
  95. sglang/srt/reasoning_parser.py +56 -300
  96. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  97. sglang/srt/server_args.py +122 -56
  98. sglang/srt/speculative/eagle_worker.py +28 -8
  99. sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
  100. sglang/srt/utils.py +73 -5
  101. sglang/test/attention/test_trtllm_mla_backend.py +12 -3
  102. sglang/version.py +1 -1
  103. {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/METADATA +7 -6
  104. {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/RECORD +107 -99
  105. {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/WHEEL +0 -0
  106. {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/licenses/LICENSE +0 -0
  107. {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/top_level.txt +0 -0
@@ -12,11 +12,16 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
12
12
  from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
13
13
  from sglang.srt.layers.radix_attention import AttentionType
14
14
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
15
+ from sglang.srt.utils import get_bool_env_var
15
16
 
16
17
  if TYPE_CHECKING:
17
18
  from sglang.srt.layers.radix_attention import RadixAttention
18
19
  from sglang.srt.model_executor.model_runner import ModelRunner
19
20
 
21
+ import os
22
+
23
+ import numpy as np
24
+
20
25
 
21
26
  @dataclass
22
27
  class ForwardMetadata:
@@ -54,17 +59,27 @@ class AscendAttnBackend(AttentionBackend):
54
59
  super().__init__()
55
60
  self.forward_metadata = None
56
61
  self.device = model_runner.device
57
- self.gen_attention_mask(128, model_runner.dtype)
58
62
  self.page_size = model_runner.page_size
59
63
  self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
60
64
  if self.use_mla:
61
65
  self.kv_lora_rank = model_runner.model_config.kv_lora_rank
62
66
  self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
63
- self.native_attn = TorchNativeAttnBackend(model_runner)
67
+ self.native_attn = TorchNativeAttnBackend(model_runner)
64
68
  self.graph_metadata = {}
65
69
  self.max_context_len = model_runner.model_config.context_len
66
70
  self.req_to_token = model_runner.req_to_token_pool.req_to_token
67
71
  self.graph_mode = False
72
+ self.use_fia = get_bool_env_var("ASCEND_USE_FIA", "False")
73
+ if not self.use_fia:
74
+ self.gen_attention_mask(128, model_runner.dtype)
75
+ mask_length = 2048
76
+ self.fia_mask = ~torch.tril(
77
+ torch.ones(
78
+ (mask_length, mask_length),
79
+ dtype=torch.bool,
80
+ device=model_runner.device,
81
+ )
82
+ )
68
83
 
69
84
  def init_forward_metadata(self, forward_batch: ForwardBatch):
70
85
  """Init the metadata for a forward pass."""
@@ -81,6 +96,9 @@ class AscendAttnBackend(AttentionBackend):
81
96
  forward_batch.extend_seq_lens.cpu().int()
82
97
  )
83
98
  self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
99
+ self.forward_metadata.seq_lens_list_cumsum = np.cumsum(
100
+ forward_batch.extend_seq_lens_cpu
101
+ )
84
102
 
85
103
  self.graph_mode = False
86
104
 
@@ -140,7 +158,7 @@ class AscendAttnBackend(AttentionBackend):
140
158
  self.graph_mode = True
141
159
 
142
160
  def get_cuda_graph_seq_len_fill_value(self):
143
- return 1
161
+ return 0
144
162
 
145
163
  def forward_extend(
146
164
  self,
@@ -149,73 +167,256 @@ class AscendAttnBackend(AttentionBackend):
149
167
  v,
150
168
  layer: RadixAttention,
151
169
  forward_batch: ForwardBatch,
152
- save_kv_cache=True,
170
+ save_kv_cache: bool = True,
153
171
  ):
154
- if save_kv_cache:
155
- forward_batch.token_to_kv_pool.set_kv_buffer(
156
- layer, forward_batch.out_cache_loc, k, v
172
+ if not self.use_mla:
173
+ if save_kv_cache:
174
+ forward_batch.token_to_kv_pool.set_kv_buffer(
175
+ layer, forward_batch.out_cache_loc, k, v
176
+ )
177
+
178
+ k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
179
+ v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
180
+
181
+ if self.use_fia:
182
+ """FIA will support multi-bs in the later version of CANN"""
183
+ q = q.reshape(-1, layer.tp_q_head_num, layer.qk_head_dim)
184
+ attn_output = torch.empty(
185
+ (q.size(0), layer.tp_q_head_num, layer.v_head_dim),
186
+ device=q.device,
187
+ dtype=q.dtype,
188
+ )
189
+ q_len_offset = 0
190
+ for q_len in forward_batch.extend_seq_lens_cpu:
191
+ attn_output[q_len_offset : q_len_offset + q_len] = (
192
+ torch.ops.npu.npu_fused_infer_attention_score(
193
+ q[None, q_len_offset : q_len_offset + q_len],
194
+ k[None, q_len_offset : q_len_offset + q_len],
195
+ v[None, q_len_offset : q_len_offset + q_len],
196
+ num_heads=layer.tp_q_head_num,
197
+ num_key_value_heads=layer.tp_k_head_num,
198
+ input_layout="BSND", # todo, TND not supports q_heads!=k_heads
199
+ atten_mask=self.fia_mask.unsqueeze(0),
200
+ sparse_mode=3,
201
+ scale=layer.scaling,
202
+ next_tokens=0,
203
+ )[0]
204
+ )
205
+ q_len_offset += q_len
206
+ attn_output = attn_output.view(
207
+ -1, layer.tp_q_head_num * layer.v_head_dim
208
+ )
209
+
210
+ else:
211
+ if layer.qk_head_dim <= 128:
212
+ query = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
213
+ attn_output = torch.empty(
214
+ (query.shape[0], layer.tp_q_head_num * layer.v_head_dim),
215
+ dtype=query.dtype,
216
+ device=query.device,
217
+ )
218
+
219
+ torch_npu._npu_flash_attention_qlens(
220
+ query=query,
221
+ key_cache=k_cache,
222
+ value_cache=v_cache,
223
+ mask=self.mask,
224
+ block_table=self.forward_metadata.block_tables,
225
+ seq_len=self.forward_metadata.extend_seq_lens_cpu_int,
226
+ context_lens=self.forward_metadata.seq_lens_cpu_int,
227
+ scale_value=layer.scaling,
228
+ num_heads=layer.tp_q_head_num,
229
+ num_kv_heads=layer.tp_k_head_num,
230
+ out=attn_output,
231
+ )
232
+ else:
233
+ if layer.qk_head_dim != layer.v_head_dim:
234
+ attn_output = q.new_empty(
235
+ (q.shape[0], layer.tp_q_head_num * layer.v_head_dim)
236
+ )
237
+ else:
238
+ attn_output = torch.empty_like(q)
239
+
240
+ use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
241
+
242
+ q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
243
+ o_ = attn_output.view(-1, layer.tp_q_head_num, layer.v_head_dim)
244
+
245
+ causal = True
246
+ if (
247
+ layer.is_cross_attention
248
+ or layer.attn_type == AttentionType.ENCODER_ONLY
249
+ ):
250
+ causal = False
251
+
252
+ self.native_attn._run_sdpa_forward_extend(
253
+ q_,
254
+ o_,
255
+ k_cache.view(-1, layer.tp_k_head_num, layer.qk_head_dim),
256
+ v_cache.view(-1, layer.tp_v_head_num, layer.v_head_dim),
257
+ forward_batch.req_to_token_pool.req_to_token,
258
+ forward_batch.req_pool_indices,
259
+ forward_batch.seq_lens,
260
+ forward_batch.extend_prefix_lens,
261
+ forward_batch.extend_seq_lens,
262
+ scaling=layer.scaling,
263
+ enable_gqa=use_gqa,
264
+ causal=causal,
265
+ )
266
+ else:
267
+ assert (
268
+ layer.qk_head_dim != layer.v_head_dim
269
+ ), "FIA only supports qk_head_dim != v_head_dim"
270
+ q_nope, q_rope = q.split([layer.v_head_dim, self.qk_rope_head_dim], dim=-1)
271
+ k_nope, k_rope = k.split([layer.v_head_dim, self.qk_rope_head_dim], dim=-1)
272
+
273
+ attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
274
+ q_nope,
275
+ k_nope,
276
+ v,
277
+ query_rope=q_rope,
278
+ key_rope=k_rope,
279
+ num_heads=layer.tp_q_head_num,
280
+ input_layout="TND",
281
+ atten_mask=self.fia_mask,
282
+ sparse_mode=3,
283
+ actual_seq_lengths=self.forward_metadata.seq_lens_list_cumsum,
284
+ actual_seq_lengths_kv=self.forward_metadata.seq_lens_list_cumsum,
285
+ scale=layer.scaling,
286
+ next_tokens=0,
157
287
  )
158
288
 
159
- k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
160
- v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
289
+ return attn_output
290
+
291
+ def forward_decode_graph(
292
+ self,
293
+ q: torch.Tensor,
294
+ k: torch.Tensor,
295
+ v: torch.Tensor,
296
+ layer: RadixAttention,
297
+ forward_batch: ForwardBatch,
298
+ save_kv_cache: bool = True,
299
+ q_rope: Optional[torch.Tensor] = None,
300
+ k_rope: Optional[torch.Tensor] = None,
301
+ ):
302
+ if save_kv_cache:
303
+ if self.use_mla:
304
+ k = k.view(-1, layer.tp_k_head_num, self.kv_lora_rank)
305
+ k_rope = k_rope.view(-1, layer.tp_k_head_num, self.qk_rope_head_dim)
306
+ forward_batch.token_to_kv_pool.set_kv_buffer(
307
+ layer, forward_batch.out_cache_loc, k, k_rope
308
+ )
309
+ else:
310
+ forward_batch.token_to_kv_pool.set_kv_buffer(
311
+ layer, forward_batch.out_cache_loc, k, v
312
+ )
161
313
 
162
314
  if not self.use_mla:
163
- query = q.view(-1, layer.tp_q_head_num * layer.qk_head_dim)
315
+ k_cache = forward_batch.token_to_kv_pool.get_key_buffer(
316
+ layer.layer_id
317
+ ).view(-1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim)
318
+ v_cache = forward_batch.token_to_kv_pool.get_value_buffer(
319
+ layer.layer_id
320
+ ).view(-1, self.page_size, layer.tp_v_head_num * layer.v_head_dim)
321
+ query = q.reshape(-1, 1, layer.tp_q_head_num * layer.qk_head_dim)
322
+ if self.forward_metadata.seq_lens_cpu_int is None:
323
+ actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_list
324
+ else:
325
+ actual_seq_len_kv = (
326
+ self.forward_metadata.seq_lens_cpu_int.cpu().int().tolist()
327
+ )
328
+ num_tokens = query.shape[0]
329
+ workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
330
+ query,
331
+ k_cache,
332
+ v_cache,
333
+ block_table=self.forward_metadata.block_tables,
334
+ block_size=self.page_size,
335
+ num_heads=layer.tp_q_head_num,
336
+ num_key_value_heads=layer.tp_k_head_num,
337
+ input_layout="BSH",
338
+ scale=layer.scaling,
339
+ actual_seq_lengths_kv=actual_seq_len_kv,
340
+ )
164
341
  output = torch.empty(
165
- (query.shape[0], layer.tp_q_head_num * layer.v_head_dim),
166
- dtype=query.dtype,
167
- device=query.device,
342
+ (num_tokens, 1, layer.tp_q_head_num * layer.v_head_dim),
343
+ dtype=q.dtype,
344
+ device=q.device,
168
345
  )
169
-
170
- torch_npu._npu_flash_attention_qlens(
171
- query=query,
172
- key_cache=k_cache,
173
- value_cache=v_cache,
174
- mask=self.mask,
346
+ softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)
347
+ torch_npu.npu_fused_infer_attention_score.out(
348
+ query,
349
+ k_cache,
350
+ v_cache,
175
351
  block_table=self.forward_metadata.block_tables,
176
- seq_len=self.forward_metadata.extend_seq_lens_cpu_int,
177
- context_lens=self.forward_metadata.seq_lens_cpu_int,
178
- scale_value=layer.scaling,
352
+ block_size=self.page_size,
179
353
  num_heads=layer.tp_q_head_num,
180
- num_kv_heads=layer.tp_k_head_num,
181
- out=output,
354
+ num_key_value_heads=layer.tp_k_head_num,
355
+ input_layout="BSH",
356
+ scale=layer.scaling,
357
+ actual_seq_lengths_kv=actual_seq_len_kv,
358
+ workspace=workspace,
359
+ out=[output, softmax_lse],
182
360
  )
183
- return output
361
+ return output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
184
362
  else:
185
- if layer.qk_head_dim != layer.v_head_dim:
186
- o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
363
+ c_kv, k_rope = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
364
+ k_rope_cache = k_rope.view(
365
+ -1, layer.tp_k_head_num, self.page_size, self.qk_rope_head_dim
366
+ )
367
+ c_kv_cache = c_kv.view(
368
+ -1, layer.tp_v_head_num, self.page_size, self.kv_lora_rank
369
+ )
370
+
371
+ q_nope = q.view(-1, layer.tp_q_head_num, 1, self.kv_lora_rank)
372
+ q_rope = q_rope.view(-1, layer.tp_q_head_num, 1, self.qk_rope_head_dim)
373
+ if self.forward_metadata.seq_lens_cpu_int is None:
374
+ actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_list
187
375
  else:
188
- o = torch.empty_like(q)
189
-
190
- use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
191
-
192
- q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
193
- o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)
194
-
195
- causal = True
196
- if (
197
- layer.is_cross_attention
198
- or layer.attn_type == AttentionType.ENCODER_ONLY
199
- ):
200
- causal = False
201
-
202
- self.native_attn._run_sdpa_forward_extend(
203
- q_,
204
- o_,
205
- k_cache.view(
206
- -1, layer.tp_k_head_num, (self.kv_lora_rank + self.qk_rope_head_dim)
207
- ),
208
- v_cache.view(-1, layer.tp_v_head_num, self.kv_lora_rank),
209
- forward_batch.req_to_token_pool.req_to_token,
210
- forward_batch.req_pool_indices,
211
- forward_batch.seq_lens,
212
- forward_batch.extend_prefix_lens,
213
- forward_batch.extend_seq_lens,
214
- scaling=layer.scaling,
215
- enable_gqa=use_gqa,
216
- causal=causal,
376
+ actual_seq_len_kv = (
377
+ self.forward_metadata.seq_lens_cpu_int.cpu().int().tolist()
378
+ )
379
+
380
+ workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
381
+ q_nope,
382
+ c_kv_cache,
383
+ c_kv_cache,
384
+ query_rope=q_rope,
385
+ key_rope=k_rope_cache,
386
+ num_heads=layer.tp_q_head_num,
387
+ num_key_value_heads=layer.tp_k_head_num,
388
+ block_table=self.forward_metadata.block_tables,
389
+ block_size=self.page_size,
390
+ input_layout="BNSD",
391
+ scale=layer.scaling,
392
+ actual_seq_lengths_kv=actual_seq_len_kv,
393
+ antiquant_mode=0,
394
+ antiquant_scale=None,
395
+ sparse_mode=0,
396
+ )
397
+ output = torch.zeros_like(q_nope, dtype=q.dtype, device=q.device)
398
+ softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)
399
+
400
+ torch_npu.npu_fused_infer_attention_score.out(
401
+ q_nope,
402
+ c_kv_cache,
403
+ c_kv_cache,
404
+ query_rope=q_rope,
405
+ key_rope=k_rope_cache,
406
+ num_heads=layer.tp_q_head_num,
407
+ num_key_value_heads=layer.tp_k_head_num,
408
+ block_table=self.forward_metadata.block_tables,
409
+ block_size=self.page_size,
410
+ input_layout="BNSD",
411
+ scale=layer.scaling,
412
+ actual_seq_lengths_kv=actual_seq_len_kv,
413
+ antiquant_mode=0,
414
+ antiquant_scale=None,
415
+ sparse_mode=0,
416
+ workspace=workspace,
417
+ out=[output, softmax_lse],
217
418
  )
218
- return o
419
+ return output.view(-1, layer.tp_q_head_num * self.kv_lora_rank)
219
420
 
220
421
  def forward_decode(
221
422
  self,
@@ -224,65 +425,58 @@ class AscendAttnBackend(AttentionBackend):
224
425
  v: torch.Tensor,
225
426
  layer: RadixAttention,
226
427
  forward_batch: ForwardBatch,
227
- save_kv_cache=True,
428
+ save_kv_cache: bool = True,
429
+ # For multi-head latent attention
430
+ q_rope: Optional[torch.Tensor] = None,
431
+ k_rope: Optional[torch.Tensor] = None,
228
432
  ):
229
- if save_kv_cache:
230
- forward_batch.token_to_kv_pool.set_kv_buffer(
231
- layer, forward_batch.out_cache_loc, k, v
433
+ if self.graph_mode:
434
+ return self.forward_decode_graph(
435
+ q,
436
+ k,
437
+ v,
438
+ layer,
439
+ forward_batch,
440
+ save_kv_cache,
441
+ q_rope=q_rope,
442
+ k_rope=k_rope,
232
443
  )
444
+
233
445
  if not self.use_mla:
234
- if self.graph_mode:
235
- k_cache = forward_batch.token_to_kv_pool.get_key_buffer(
236
- layer.layer_id
237
- ).view(-1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim)
238
- v_cache = forward_batch.token_to_kv_pool.get_value_buffer(
239
- layer.layer_id
240
- ).view(-1, self.page_size, layer.tp_v_head_num * layer.v_head_dim)
241
- query = q.view(-1, 1, layer.tp_q_head_num * layer.qk_head_dim)
242
- num_tokens = query.shape[0]
243
- workspace = (
244
- torch_npu._npu_fused_infer_attention_score_get_max_workspace(
245
- query,
246
- k_cache,
247
- v_cache,
248
- block_table=self.forward_metadata.block_tables,
249
- block_size=self.page_size,
250
- num_heads=layer.tp_q_head_num,
251
- num_key_value_heads=layer.tp_k_head_num,
252
- input_layout="BSH",
253
- scale=layer.scaling,
254
- actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list,
255
- )
446
+ if save_kv_cache:
447
+ forward_batch.token_to_kv_pool.set_kv_buffer(
448
+ layer, forward_batch.out_cache_loc, k, v
256
449
  )
257
- output = torch.empty(
258
- (num_tokens, 1, layer.tp_q_head_num * layer.v_head_dim),
259
- dtype=q.dtype,
260
- device=q.device,
261
- )
262
- softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)
263
- torch_npu.npu_fused_infer_attention_score.out(
264
- query,
265
- k_cache,
266
- v_cache,
267
- block_table=self.forward_metadata.block_tables,
268
- block_size=self.page_size,
450
+ num_tokens = q.shape[0]
451
+ k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
452
+ v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
453
+ if self.use_fia:
454
+ attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
455
+ q.view(
456
+ forward_batch.batch_size,
457
+ -1,
458
+ layer.tp_q_head_num,
459
+ layer.qk_head_dim,
460
+ ),
461
+ k_cache.view(
462
+ -1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim
463
+ ),
464
+ v_cache.view(
465
+ -1, self.page_size, layer.tp_v_head_num * layer.qk_head_dim
466
+ ),
269
467
  num_heads=layer.tp_q_head_num,
270
468
  num_key_value_heads=layer.tp_k_head_num,
271
- input_layout="BSH",
469
+ input_layout="BSND",
470
+ atten_mask=None,
471
+ block_size=self.page_size,
472
+ block_table=self.forward_metadata.block_tables,
473
+ actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_int,
272
474
  scale=layer.scaling,
273
- actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list,
274
- workspace=workspace,
275
- out=[output, softmax_lse],
276
475
  )
277
476
  else:
278
- k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
279
- v_cache = forward_batch.token_to_kv_pool.get_value_buffer(
280
- layer.layer_id
281
- )
282
-
283
- query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
477
+ query = q.reshape(-1, layer.tp_q_head_num, layer.qk_head_dim)
284
478
  num_tokens = query.shape[0]
285
- output = torch.empty(
479
+ attn_output = torch.empty(
286
480
  (num_tokens, layer.tp_q_head_num, layer.v_head_dim),
287
481
  dtype=query.dtype,
288
482
  device=query.device,
@@ -297,36 +491,80 @@ class AscendAttnBackend(AttentionBackend):
297
491
  scale_value=layer.scaling,
298
492
  block_table=self.forward_metadata.block_tables,
299
493
  context_lens=self.forward_metadata.seq_lens_cpu_int,
300
- out=output,
494
+ out=attn_output,
301
495
  )
302
- return output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
496
+ return attn_output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
303
497
  else:
304
- query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
305
- num_tokens = query.shape[0]
306
- kv_c_and_k_pe_cache = forward_batch.token_to_kv_pool.get_key_buffer(
307
- layer.layer_id
308
- )
309
- kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
310
- -1,
311
- self.page_size,
312
- layer.tp_k_head_num,
313
- self.kv_lora_rank + self.qk_rope_head_dim,
314
- )
315
-
316
- attn_output = torch.empty(
317
- [num_tokens, layer.tp_q_head_num, self.kv_lora_rank],
318
- dtype=q.dtype,
319
- device=q.device,
320
- )
321
- torch_npu._npu_paged_attention_mla(
322
- query=query,
323
- key_cache=kv_c_and_k_pe_cache,
324
- num_kv_heads=layer.tp_k_head_num,
325
- num_heads=layer.tp_q_head_num,
326
- scale_value=layer.scaling,
327
- block_table=self.forward_metadata.block_tables,
328
- context_lens=self.forward_metadata.seq_lens_cpu_int,
329
- mla_vheadsize=self.kv_lora_rank,
330
- out=attn_output,
331
- )
498
+ if save_kv_cache:
499
+ forward_batch.token_to_kv_pool.set_kv_buffer(
500
+ layer, forward_batch.out_cache_loc, k, k_rope
501
+ )
502
+ num_tokens = q.shape[0]
503
+ kv_c = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
504
+ k_pe = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
505
+
506
+ if self.use_fia and (layer.tp_q_head_num // layer.tp_k_head_num) >= 8:
507
+ """layer.tp_q_head_num // layer.tp_k_head_num < 8 will support in the later version of CANN"""
508
+ kv_c = kv_c.view(
509
+ -1, self.page_size, layer.tp_k_head_num * self.kv_lora_rank
510
+ )
511
+ k_pe = k_pe.view(
512
+ -1, self.page_size, layer.tp_k_head_num * self.qk_rope_head_dim
513
+ )
514
+ q = q.view(
515
+ forward_batch.batch_size, -1, layer.tp_q_head_num, self.kv_lora_rank
516
+ )
517
+ q_rope = q_rope.view(
518
+ forward_batch.batch_size,
519
+ -1,
520
+ layer.tp_q_head_num,
521
+ self.qk_rope_head_dim,
522
+ )
523
+ attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
524
+ q,
525
+ kv_c,
526
+ kv_c,
527
+ query_rope=q_rope,
528
+ key_rope=k_pe,
529
+ num_heads=layer.tp_q_head_num,
530
+ num_key_value_heads=layer.tp_k_head_num,
531
+ input_layout="BSND",
532
+ atten_mask=None,
533
+ sparse_mode=0,
534
+ scale=layer.scaling,
535
+ antiquant_mode=0,
536
+ antiquant_scale=None,
537
+ block_table=self.forward_metadata.block_tables,
538
+ block_size=self.page_size,
539
+ actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_int,
540
+ )
541
+ else:
542
+ assert (
543
+ self.graph_mode == False
544
+ ) # _npu_paged_attention_mla not support graph mode
545
+ q = torch.cat([q, q_rope], dim=-1)
546
+ query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
547
+ kv_c_and_k_pe_cache = torch.cat([kv_c, k_pe], dim=-1)
548
+ kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
549
+ -1,
550
+ self.page_size,
551
+ layer.tp_k_head_num,
552
+ self.kv_lora_rank + self.qk_rope_head_dim,
553
+ )
554
+ attn_output = torch.empty(
555
+ [num_tokens, layer.tp_q_head_num, self.kv_lora_rank],
556
+ dtype=q.dtype,
557
+ device=q.device,
558
+ )
559
+ torch_npu._npu_paged_attention_mla(
560
+ query=query,
561
+ key_cache=kv_c_and_k_pe_cache,
562
+ num_kv_heads=layer.tp_k_head_num,
563
+ num_heads=layer.tp_q_head_num,
564
+ scale_value=layer.scaling,
565
+ block_table=self.forward_metadata.block_tables,
566
+ context_lens=self.forward_metadata.seq_lens_cpu_int,
567
+ mla_vheadsize=self.kv_lora_rank,
568
+ out=attn_output,
569
+ )
332
570
  return attn_output.view(num_tokens, layer.tp_q_head_num * self.kv_lora_rank)