sglang 0.5.0rc2__py3-none-any.whl → 0.5.1.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_one_batch.py +0 -6
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +24 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -1
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +27 -2
  24. sglang/srt/entrypoints/http_server.py +12 -0
  25. sglang/srt/entrypoints/openai/protocol.py +2 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +22 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +9 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +11 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
  37. sglang/srt/layers/attention/triton_backend.py +85 -46
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +51 -3
  46. sglang/srt/layers/dp_attention.py +23 -4
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +5 -1
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  60. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  61. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  62. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  63. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  64. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  65. sglang/srt/layers/moe/router.py +15 -9
  66. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  67. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  68. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  69. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  70. sglang/srt/layers/moe/topk.py +167 -83
  71. sglang/srt/layers/moe/utils.py +159 -18
  72. sglang/srt/layers/quantization/__init__.py +13 -14
  73. sglang/srt/layers/quantization/awq.py +7 -7
  74. sglang/srt/layers/quantization/base_config.py +2 -6
  75. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  76. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
  77. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -0
  78. sglang/srt/layers/quantization/fp8.py +127 -119
  79. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  80. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  81. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  82. sglang/srt/layers/quantization/gptq.py +5 -4
  83. sglang/srt/layers/quantization/marlin_utils.py +11 -3
  84. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  85. sglang/srt/layers/quantization/modelopt_quant.py +165 -68
  86. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  87. sglang/srt/layers/quantization/mxfp4.py +206 -37
  88. sglang/srt/layers/quantization/quark/quark.py +390 -0
  89. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  90. sglang/srt/layers/quantization/unquant.py +34 -70
  91. sglang/srt/layers/quantization/utils.py +25 -0
  92. sglang/srt/layers/quantization/w4afp8.py +7 -8
  93. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  94. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  95. sglang/srt/layers/radix_attention.py +6 -0
  96. sglang/srt/layers/rotary_embedding.py +1 -0
  97. sglang/srt/lora/lora_manager.py +21 -22
  98. sglang/srt/lora/lora_registry.py +3 -3
  99. sglang/srt/lora/mem_pool.py +26 -24
  100. sglang/srt/lora/utils.py +10 -12
  101. sglang/srt/managers/cache_controller.py +76 -18
  102. sglang/srt/managers/detokenizer_manager.py +10 -2
  103. sglang/srt/managers/io_struct.py +9 -0
  104. sglang/srt/managers/mm_utils.py +1 -1
  105. sglang/srt/managers/schedule_batch.py +4 -9
  106. sglang/srt/managers/scheduler.py +25 -16
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/template_manager.py +7 -5
  109. sglang/srt/managers/tokenizer_manager.py +60 -21
  110. sglang/srt/managers/tp_worker.py +1 -0
  111. sglang/srt/managers/utils.py +59 -1
  112. sglang/srt/mem_cache/allocator.py +7 -5
  113. sglang/srt/mem_cache/allocator_ascend.py +0 -11
  114. sglang/srt/mem_cache/hicache_storage.py +14 -4
  115. sglang/srt/mem_cache/memory_pool.py +3 -3
  116. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  117. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  118. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  119. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  120. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  121. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  122. sglang/srt/model_executor/cuda_graph_runner.py +25 -12
  123. sglang/srt/model_executor/forward_batch_info.py +4 -1
  124. sglang/srt/model_executor/model_runner.py +43 -32
  125. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  126. sglang/srt/model_loader/loader.py +24 -6
  127. sglang/srt/models/dbrx.py +12 -6
  128. sglang/srt/models/deepseek.py +2 -1
  129. sglang/srt/models/deepseek_nextn.py +3 -1
  130. sglang/srt/models/deepseek_v2.py +224 -223
  131. sglang/srt/models/ernie4.py +2 -2
  132. sglang/srt/models/glm4_moe.py +25 -63
  133. sglang/srt/models/glm4v.py +52 -1
  134. sglang/srt/models/glm4v_moe.py +8 -11
  135. sglang/srt/models/gpt_oss.py +34 -74
  136. sglang/srt/models/granitemoe.py +0 -1
  137. sglang/srt/models/grok.py +375 -51
  138. sglang/srt/models/interns1.py +12 -47
  139. sglang/srt/models/internvl.py +6 -51
  140. sglang/srt/models/llama4.py +0 -2
  141. sglang/srt/models/minicpm3.py +0 -1
  142. sglang/srt/models/mixtral.py +0 -2
  143. sglang/srt/models/nemotron_nas.py +435 -0
  144. sglang/srt/models/olmoe.py +0 -1
  145. sglang/srt/models/phi4mm.py +3 -21
  146. sglang/srt/models/qwen2_5_vl.py +2 -0
  147. sglang/srt/models/qwen2_moe.py +3 -18
  148. sglang/srt/models/qwen3.py +2 -2
  149. sglang/srt/models/qwen3_classification.py +7 -1
  150. sglang/srt/models/qwen3_moe.py +9 -38
  151. sglang/srt/models/step3_vl.py +2 -1
  152. sglang/srt/models/xverse_moe.py +11 -5
  153. sglang/srt/multimodal/processors/base_processor.py +3 -3
  154. sglang/srt/multimodal/processors/internvl.py +7 -2
  155. sglang/srt/multimodal/processors/llava.py +11 -7
  156. sglang/srt/offloader.py +433 -0
  157. sglang/srt/operations.py +6 -1
  158. sglang/srt/reasoning_parser.py +4 -3
  159. sglang/srt/server_args.py +237 -104
  160. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
  161. sglang/srt/speculative/eagle_utils.py +36 -13
  162. sglang/srt/speculative/eagle_worker.py +56 -3
  163. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  164. sglang/srt/two_batch_overlap.py +16 -11
  165. sglang/srt/utils.py +68 -70
  166. sglang/test/runners.py +8 -5
  167. sglang/test/test_block_fp8.py +5 -6
  168. sglang/test/test_block_fp8_ep.py +13 -19
  169. sglang/test/test_cutlass_moe.py +4 -6
  170. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  171. sglang/test/test_fp4_moe.py +4 -3
  172. sglang/test/test_utils.py +7 -0
  173. sglang/utils.py +0 -1
  174. sglang/version.py +1 -1
  175. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/METADATA +7 -7
  176. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/RECORD +179 -161
  177. sglang/srt/layers/quantization/fp4.py +0 -557
  178. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/WHEEL +0 -0
  179. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/top_level.txt +0 -0
@@ -25,7 +25,6 @@ import torch
25
25
  import torch.distributed
26
26
 
27
27
  from sglang.srt.eplb.expert_location import ExpertLocationMetadata
28
- from sglang.srt.managers.schedule_batch import global_server_args_dict
29
28
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
30
29
  from sglang.srt.server_args import ServerArgs
31
30
  from sglang.srt.utils import Withable, get_bool_env_var
@@ -288,14 +287,14 @@ class _SinglePassGatherer(ABC):
288
287
  )
289
288
 
290
289
  if server_args.expert_distribution_recorder_mode == "stat_approx":
291
- if server_args.moe_a2a_backend is not None and (
290
+ if server_args.moe_a2a_backend != "none" and (
292
291
  server_args.deepep_mode == "normal"
293
292
  ):
294
293
  return _DeepepNormalSinglePassGatherer(expert_location_metadata, rank)
295
294
  else:
296
295
  raise NotImplementedError
297
296
 
298
- if server_args.moe_a2a_backend is not None:
297
+ if server_args.moe_a2a_backend != "none":
299
298
  if server_args.deepep_mode == "normal":
300
299
  return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank)
301
300
  elif server_args.deepep_mode == "low_latency":
@@ -215,6 +215,6 @@ class DeepSeekV3Detector(BaseFormatDetector):
215
215
  sequence_start_token=self.bot_token,
216
216
  sequence_end_token=self.eot_token,
217
217
  tool_call_separator="",
218
- call_rule_fmt='"<|tool▁call▁begin|>function<|tool▁sep|>{name}\\n```json\\n" {arguments_rule} "\\n```<|tool▁call▁end|>"',
218
+ call_rule_fmt='"<|tool▁call▁begin|>function<|tool▁sep|>{name}\\n```json\\n"{arguments_rule}"\\n```<|tool▁call▁end|>"',
219
219
  function_format="json",
220
220
  )
@@ -129,6 +129,25 @@ def get_config(
129
129
  config = AutoConfig.from_pretrained(
130
130
  model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
131
131
  )
132
+ if (
133
+ config.architectures is not None
134
+ and config.architectures[0] == "Phi4MMForCausalLM"
135
+ ):
136
+ # Phi4MMForCausalLM uses a hard-coded vision_config. See:
137
+ # https://github.com/vllm-project/vllm/blob/6071e989df1531b59ef35568f83f7351afb0b51e/vllm/model_executor/models/phi4mm.py#L71
138
+ # We set it here to support cases where num_attention_heads is not divisible by the TP size.
139
+ from transformers import SiglipVisionConfig
140
+
141
+ vision_config = {
142
+ "hidden_size": 1152,
143
+ "image_size": 448,
144
+ "intermediate_size": 4304,
145
+ "model_type": "siglip_vision_model",
146
+ "num_attention_heads": 16,
147
+ "num_hidden_layers": 26, # Model is originally 27-layer, we only need the first 26 layers for feature extraction.
148
+ "patch_size": 14,
149
+ }
150
+ config.vision_config = SiglipVisionConfig(**vision_config)
132
151
  text_config = get_hf_text_config(config=config)
133
152
 
134
153
  if isinstance(model, str) and text_config is not None:
@@ -244,6 +263,11 @@ def get_tokenizer(
244
263
  **kwargs,
245
264
  ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
246
265
  """Gets a tokenizer for the given model name via Huggingface."""
266
+ if tokenizer_name.endswith(".json"):
267
+ from sglang.srt.tokenizer.tiktoken_tokenizer import TiktokenTokenizer
268
+
269
+ return TiktokenTokenizer(tokenizer_name)
270
+
247
271
  if tokenizer_mode == "slow":
248
272
  if kwargs.get("use_fast", False):
249
273
  raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
@@ -0,0 +1,83 @@
1
+ import logging
2
+ import os
3
+ from dataclasses import dataclass
4
+ from multiprocessing import shared_memory
5
+ from pathlib import Path
6
+ from typing import List, Optional
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+ from sglang.srt.distributed.naive_distributed import get_naive_distributed
12
+ from sglang.srt.utils import check_cuda_result
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class HostSharedMemoryManager:
18
+ def __init__(self, base_name: str):
19
+ self._base_name = Path(base_name)
20
+ self._operation_index = 0
21
+ self._records: List[_Record] = []
22
+
23
+ def malloc(self, *, shape, dtype):
24
+ meta_tensor = torch.empty(size=shape, dtype=dtype, device="meta")
25
+ raw = self._malloc_raw(num_bytes=meta_tensor.nbytes)
26
+ return raw.view(dtype).view(*shape)
27
+
28
+ def _malloc_raw(self, *, num_bytes: int) -> torch.Tensor:
29
+ import cuda.bindings.runtime as cuda_rt
30
+
31
+ self._operation_index += 1
32
+ shm_name = f"{self._base_name}_op{self._operation_index}"
33
+
34
+ # TODO handle dispose
35
+ if get_naive_distributed().get_rank() == 0:
36
+ shm = shared_memory.SharedMemory(name=shm_name, create=True, size=num_bytes)
37
+
38
+ get_naive_distributed().barrier()
39
+
40
+ if get_naive_distributed().get_rank() != 0:
41
+ shm = shared_memory.SharedMemory(name=shm_name)
42
+
43
+ np_array = np.ndarray((num_bytes,), dtype=np.uint8, buffer=shm.buf)
44
+ tensor = torch.from_numpy(np_array)
45
+
46
+ check_cuda_result(
47
+ cuda_rt.cudaHostRegister(
48
+ tensor.data_ptr(), num_bytes, cuda_rt.cudaHostRegisterPortable
49
+ )
50
+ )
51
+
52
+ get_naive_distributed().barrier()
53
+
54
+ self._records.append(
55
+ _Record(
56
+ shm=shm,
57
+ np_array=np_array,
58
+ tensor=tensor,
59
+ )
60
+ )
61
+ return tensor
62
+
63
+
64
+ @dataclass
65
+ class _Record:
66
+ shm: shared_memory.SharedMemory
67
+ np_array: np.ndarray
68
+ tensor: torch.Tensor
69
+
70
+
71
+ # Can have multi instances if needed
72
+ _instance: Optional[HostSharedMemoryManager] = None
73
+
74
+
75
+ def get_host_shared_memory_manager():
76
+ assert _instance is not None
77
+ return _instance
78
+
79
+
80
+ def set_host_shared_memory_manager(instance: HostSharedMemoryManager):
81
+ global _instance
82
+ assert _instance is None
83
+ _instance = instance
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
- from typing import TYPE_CHECKING, Optional
4
+ from typing import TYPE_CHECKING, List, Optional
5
5
 
6
6
  import torch
7
7
  import torch_npu
@@ -27,6 +27,7 @@ class ForwardMetadata:
27
27
  # seq len inputs
28
28
  extend_seq_lens_cpu_int: Optional[torch.Tensor] = None
29
29
  seq_lens_cpu_int: Optional[torch.Tensor] = None
30
+ seq_lens_cpu_list: Optional[List[int]] = None
30
31
 
31
32
 
32
33
  class AscendAttnBackend(AttentionBackend):
@@ -51,7 +52,7 @@ class AscendAttnBackend(AttentionBackend):
51
52
 
52
53
  def __init__(self, model_runner: ModelRunner):
53
54
  super().__init__()
54
- self.forward_metadata = ForwardMetadata()
55
+ self.forward_metadata = None
55
56
  self.device = model_runner.device
56
57
  self.gen_attention_mask(128, model_runner.dtype)
57
58
  self.page_size = model_runner.page_size
@@ -60,9 +61,15 @@ class AscendAttnBackend(AttentionBackend):
60
61
  self.kv_lora_rank = model_runner.model_config.kv_lora_rank
61
62
  self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
62
63
  self.native_attn = TorchNativeAttnBackend(model_runner)
64
+ self.graph_metadata = {}
65
+ self.max_context_len = model_runner.model_config.context_len
66
+ self.req_to_token = model_runner.req_to_token_pool.req_to_token
67
+ self.graph_mode = False
63
68
 
64
69
  def init_forward_metadata(self, forward_batch: ForwardBatch):
65
70
  """Init the metadata for a forward pass."""
71
+ self.forward_metadata = ForwardMetadata()
72
+
66
73
  self.forward_metadata.block_tables = (
67
74
  forward_batch.req_to_token_pool.req_to_token[
68
75
  forward_batch.req_pool_indices, : forward_batch.seq_lens.max()
@@ -75,6 +82,63 @@ class AscendAttnBackend(AttentionBackend):
75
82
  )
76
83
  self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
77
84
 
85
+ self.graph_mode = False
86
+
87
+ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
88
+ self.graph_metadata = {
89
+ "block_tables": torch.empty(
90
+ (max_bs, self.max_context_len // self.page_size),
91
+ dtype=torch.int32,
92
+ device=self.device,
93
+ ),
94
+ }
95
+
96
+ def init_forward_metadata_capture_cuda_graph(
97
+ self,
98
+ bs: int,
99
+ num_tokens: int,
100
+ req_pool_indices: torch.Tensor,
101
+ seq_lens: torch.Tensor,
102
+ encoder_lens: Optional[torch.Tensor],
103
+ forward_mode: ForwardMode,
104
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
105
+ ):
106
+ metadata = ForwardMetadata()
107
+
108
+ metadata.block_tables = self.graph_metadata["block_tables"][:bs, :]
109
+ metadata.seq_lens_cpu_list = seq_lens.cpu().int().tolist()
110
+
111
+ self.graph_metadata[bs] = metadata
112
+ self.forward_metadata = metadata
113
+
114
+ self.graph_mode = True
115
+
116
+ def init_forward_metadata_replay_cuda_graph(
117
+ self,
118
+ bs: int,
119
+ req_pool_indices: torch.Tensor,
120
+ seq_lens: torch.Tensor,
121
+ seq_lens_sum: int,
122
+ encoder_lens: Optional[torch.Tensor],
123
+ forward_mode: ForwardMode,
124
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
125
+ seq_lens_cpu: Optional[torch.Tensor],
126
+ ):
127
+ metadata = self.graph_metadata[bs]
128
+ max_len = seq_lens_cpu[:bs].max().item()
129
+ max_seq_pages = (max_len + self.page_size - 1) // self.page_size
130
+
131
+ metadata.block_tables[:bs, :max_seq_pages].copy_(
132
+ self.req_to_token[req_pool_indices[:bs], :max_len][:, :: self.page_size]
133
+ // self.page_size
134
+ )
135
+ metadata.block_tables[:bs, max_seq_pages:].fill_(0)
136
+ metadata.block_tables[bs:, :].fill_(0)
137
+
138
+ self.forward_metadata = metadata
139
+
140
+ self.graph_mode = True
141
+
78
142
  def get_cuda_graph_seq_len_fill_value(self):
79
143
  return 1
80
144
 
@@ -167,28 +231,74 @@ class AscendAttnBackend(AttentionBackend):
167
231
  layer, forward_batch.out_cache_loc, k, v
168
232
  )
169
233
  if not self.use_mla:
170
- k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
171
- v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
234
+ if self.graph_mode:
235
+ k_cache = forward_batch.token_to_kv_pool.get_key_buffer(
236
+ layer.layer_id
237
+ ).view(-1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim)
238
+ v_cache = forward_batch.token_to_kv_pool.get_value_buffer(
239
+ layer.layer_id
240
+ ).view(-1, self.page_size, layer.tp_v_head_num * layer.v_head_dim)
241
+ query = q.view(-1, 1, layer.tp_q_head_num * layer.qk_head_dim)
242
+ num_tokens = query.shape[0]
243
+ workspace = (
244
+ torch_npu._npu_fused_infer_attention_score_get_max_workspace(
245
+ query,
246
+ k_cache,
247
+ v_cache,
248
+ block_table=self.forward_metadata.block_tables,
249
+ block_size=self.page_size,
250
+ num_heads=layer.tp_q_head_num,
251
+ num_key_value_heads=layer.tp_k_head_num,
252
+ input_layout="BSH",
253
+ scale=layer.scaling,
254
+ actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list,
255
+ )
256
+ )
257
+ output = torch.empty(
258
+ (num_tokens, 1, layer.tp_q_head_num * layer.v_head_dim),
259
+ dtype=q.dtype,
260
+ device=q.device,
261
+ )
262
+ softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)
263
+ torch_npu.npu_fused_infer_attention_score.out(
264
+ query,
265
+ k_cache,
266
+ v_cache,
267
+ block_table=self.forward_metadata.block_tables,
268
+ block_size=self.page_size,
269
+ num_heads=layer.tp_q_head_num,
270
+ num_key_value_heads=layer.tp_k_head_num,
271
+ input_layout="BSH",
272
+ scale=layer.scaling,
273
+ actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list,
274
+ workspace=workspace,
275
+ out=[output, softmax_lse],
276
+ )
277
+ else:
278
+ k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
279
+ v_cache = forward_batch.token_to_kv_pool.get_value_buffer(
280
+ layer.layer_id
281
+ )
172
282
 
173
- query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
174
- num_tokens = query.shape[0]
175
- output = torch.empty(
176
- (num_tokens, layer.tp_q_head_num, layer.v_head_dim),
177
- dtype=query.dtype,
178
- device=query.device,
179
- )
283
+ query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
284
+ num_tokens = query.shape[0]
285
+ output = torch.empty(
286
+ (num_tokens, layer.tp_q_head_num, layer.v_head_dim),
287
+ dtype=query.dtype,
288
+ device=query.device,
289
+ )
180
290
 
181
- torch_npu._npu_paged_attention(
182
- query=query,
183
- key_cache=k_cache,
184
- value_cache=v_cache,
185
- num_heads=layer.tp_q_head_num,
186
- num_kv_heads=layer.tp_k_head_num,
187
- scale_value=layer.scaling,
188
- block_table=self.forward_metadata.block_tables,
189
- context_lens=self.forward_metadata.seq_lens_cpu_int,
190
- out=output,
191
- )
291
+ torch_npu._npu_paged_attention(
292
+ query=query,
293
+ key_cache=k_cache,
294
+ value_cache=v_cache,
295
+ num_heads=layer.tp_q_head_num,
296
+ num_kv_heads=layer.tp_k_head_num,
297
+ scale_value=layer.scaling,
298
+ block_table=self.forward_metadata.block_tables,
299
+ context_lens=self.forward_metadata.seq_lens_cpu_int,
300
+ out=output,
301
+ )
192
302
  return output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
193
303
  else:
194
304
  query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
@@ -776,14 +776,13 @@ class FlashAttentionBackend(AttentionBackend):
776
776
  o = result
777
777
  else:
778
778
  if (
779
- not global_server_args_dict["disable_chunked_prefix_cache"]
780
- and forward_batch.attn_attend_prefix_cache is not None
779
+ forward_batch.attn_attend_prefix_cache is not None
781
780
  and not forward_batch.forward_mode.is_target_verify()
782
781
  and not forward_batch.forward_mode.is_draft_extend()
783
782
  ):
784
783
  # Do multi-head attention with chunked prefix cache
785
-
786
784
  if forward_batch.attn_attend_prefix_cache:
785
+ assert not global_server_args_dict["disable_chunked_prefix_cache"]
787
786
  # MHA for chunked prefix kv cache when running model with MLA
788
787
  assert forward_batch.prefix_chunk_idx is not None
789
788
  assert forward_batch.prefix_chunk_cu_seq_lens is not None
@@ -792,7 +791,8 @@ class FlashAttentionBackend(AttentionBackend):
792
791
  chunk_idx = forward_batch.prefix_chunk_idx
793
792
  assert chunk_idx >= 0
794
793
 
795
- output, lse, *rest = flash_attn_varlen_func(
794
+ assert forward_batch.mha_return_lse
795
+ output = flash_attn_varlen_func(
796
796
  q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
797
797
  k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
798
798
  v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
@@ -806,7 +806,7 @@ class FlashAttentionBackend(AttentionBackend):
806
806
  )
807
807
  else:
808
808
  # MHA for extend part of sequence without attending prefix kv cache
809
- output, lse, *rest = flash_attn_varlen_func(
809
+ output = flash_attn_varlen_func(
810
810
  q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
811
811
  k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
812
812
  v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
@@ -816,9 +816,13 @@ class FlashAttentionBackend(AttentionBackend):
816
816
  max_seqlen_k=metadata.max_seq_len_q,
817
817
  softmax_scale=layer.scaling,
818
818
  causal=True,
819
- return_softmax_lse=True,
819
+ return_softmax_lse=forward_batch.mha_return_lse,
820
820
  )
821
- return output, lse
821
+ if forward_batch.mha_return_lse:
822
+ output, lse, *rest = output
823
+ lse = torch.transpose(lse, 0, 1).contiguous()
824
+ return output, lse
825
+ return output
822
826
  else:
823
827
  # Do absorbed multi-latent attention
824
828
  kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(
@@ -1163,6 +1167,8 @@ class FlashAttentionBackend(AttentionBackend):
1163
1167
  This creates fixed-size tensors that will be reused during CUDA graph replay
1164
1168
  to avoid memory allocations.
1165
1169
  """
1170
+ max_num_pages = (self.max_context_len + self.page_size - 1) // self.page_size
1171
+
1166
1172
  # This is being used by normal decode and draft decode when topk == 1
1167
1173
  self.decode_cuda_graph_metadata = {
1168
1174
  "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
@@ -1174,13 +1180,7 @@ class FlashAttentionBackend(AttentionBackend):
1174
1180
  ),
1175
1181
  "page_table": torch.zeros(
1176
1182
  max_bs,
1177
- (self.max_context_len + self.page_size - 1) // self.page_size,
1178
- dtype=torch.int32,
1179
- device=self.device,
1180
- ),
1181
- "page_table_draft_decode": torch.zeros(
1182
- max_bs,
1183
- (self.max_context_len + self.page_size - 1) // self.page_size,
1183
+ max_num_pages,
1184
1184
  dtype=torch.int32,
1185
1185
  device=self.device,
1186
1186
  ),
@@ -1188,7 +1188,6 @@ class FlashAttentionBackend(AttentionBackend):
1188
1188
  0, self.max_context_len, self.page_size, device=self.device
1189
1189
  ),
1190
1190
  }
1191
-
1192
1191
  # Only allocate local attention buffers if local attention is enabled
1193
1192
  # This prevents OOM errors when local attention is not being used
1194
1193
  if self.attention_chunk_size is not None:
@@ -1274,6 +1273,14 @@ class FlashAttentionBackend(AttentionBackend):
1274
1273
  self.speculative_num_draft_tokens is not None
1275
1274
  and self.speculative_num_draft_tokens > 0
1276
1275
  ):
1276
+ # "page_table_draft_decode" will be set only when spec decoding enabled to save memory
1277
+ self.decode_cuda_graph_metadata["page_table_draft_decode"] = torch.zeros(
1278
+ max_bs,
1279
+ max_num_pages,
1280
+ dtype=torch.int32,
1281
+ device=self.device,
1282
+ )
1283
+
1277
1284
  self.target_verify_metadata = {
1278
1285
  "cache_seqlens": torch.zeros(
1279
1286
  max_bs, dtype=torch.int32, device=self.device
@@ -1290,7 +1297,7 @@ class FlashAttentionBackend(AttentionBackend):
1290
1297
  ),
1291
1298
  "page_table": torch.zeros(
1292
1299
  max_bs,
1293
- (self.max_context_len + self.page_size - 1) // self.page_size,
1300
+ max_num_pages,
1294
1301
  dtype=torch.int32,
1295
1302
  device=self.device,
1296
1303
  ),
@@ -1313,7 +1320,7 @@ class FlashAttentionBackend(AttentionBackend):
1313
1320
  ),
1314
1321
  "page_table": torch.zeros(
1315
1322
  max_bs,
1316
- (self.max_context_len + self.page_size - 1) // self.page_size,
1323
+ max_num_pages,
1317
1324
  dtype=torch.int32,
1318
1325
  device=self.device,
1319
1326
  ),
@@ -1263,11 +1263,12 @@ def should_use_tensor_core(
1263
1263
  # Calculate GQA group size
1264
1264
  gqa_group_size = num_attention_heads // num_kv_heads
1265
1265
 
1266
- # Determine based on dtype and GQA group size
1266
+ # For Flashinfer, a GQA group size of at least 4 is needed to efficiently
1267
+ # use Tensor Cores, as it fuses the head group with the token dimension in MMA.
1267
1268
  if kv_cache_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
1268
1269
  return True
1269
1270
  elif kv_cache_dtype in (torch.float16, torch.half, torch.bfloat16):
1270
- return gqa_group_size > 4
1271
+ return gqa_group_size >= 4
1271
1272
  else:
1272
1273
  return False
1273
1274
 
@@ -1372,7 +1373,14 @@ def fast_decode_plan(
1372
1373
 
1373
1374
  if self.use_tensor_cores:
1374
1375
  # ALSO convert last_page_len to CPU
1375
- last_page_len_host = last_page_len.cpu()
1376
+ if page_size == 1:
1377
+ # When page size is 1, last_page_len is always 1.
1378
+ # Directly construct the host tensor rather than executing a device-to-host copy.
1379
+ last_page_len_host = torch.ones(
1380
+ (batch_size,), dtype=torch.int32, device="cpu"
1381
+ )
1382
+ else:
1383
+ last_page_len_host = last_page_len.cpu()
1376
1384
 
1377
1385
  kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size)
1378
1386