sglang 0.5.0rc2__py3-none-any.whl → 0.5.1.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -6
- sglang/bench_one_batch_server.py +7 -2
- sglang/bench_serving.py +3 -3
- sglang/eval/llama3_eval.py +0 -1
- sglang/srt/configs/model_config.py +24 -9
- sglang/srt/configs/update_config.py +40 -5
- sglang/srt/constrained/xgrammar_backend.py +23 -11
- sglang/srt/conversation.py +2 -15
- sglang/srt/disaggregation/ascend/conn.py +1 -3
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +1 -1
- sglang/srt/disaggregation/launch_lb.py +7 -1
- sglang/srt/disaggregation/mini_lb.py +11 -5
- sglang/srt/disaggregation/mooncake/conn.py +141 -47
- sglang/srt/disaggregation/prefill.py +261 -5
- sglang/srt/disaggregation/utils.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/device_communicators/pynccl.py +68 -18
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
- sglang/srt/distributed/naive_distributed.py +112 -0
- sglang/srt/distributed/parallel_state.py +90 -4
- sglang/srt/entrypoints/context.py +20 -1
- sglang/srt/entrypoints/engine.py +27 -2
- sglang/srt/entrypoints/http_server.py +12 -0
- sglang/srt/entrypoints/openai/protocol.py +2 -2
- sglang/srt/entrypoints/openai/serving_chat.py +22 -6
- sglang/srt/entrypoints/openai/serving_completions.py +9 -1
- sglang/srt/entrypoints/openai/serving_responses.py +2 -2
- sglang/srt/eplb/expert_distribution.py +2 -3
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +24 -0
- sglang/srt/host_shared_memory.py +83 -0
- sglang/srt/layers/attention/ascend_backend.py +132 -22
- sglang/srt/layers/attention/flashattention_backend.py +24 -17
- sglang/srt/layers/attention/flashinfer_backend.py +11 -3
- sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
- sglang/srt/layers/attention/triton_backend.py +85 -46
- sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
- sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
- sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
- sglang/srt/layers/attention/utils.py +94 -15
- sglang/srt/layers/attention/vision.py +40 -13
- sglang/srt/layers/attention/vision_utils.py +65 -0
- sglang/srt/layers/communicator.py +51 -3
- sglang/srt/layers/dp_attention.py +23 -4
- sglang/srt/layers/elementwise.py +94 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
- sglang/srt/layers/layernorm.py +8 -1
- sglang/srt/layers/linear.py +24 -0
- sglang/srt/layers/logits_processor.py +5 -1
- sglang/srt/layers/moe/__init__.py +31 -0
- sglang/srt/layers/moe/ep_moe/layer.py +37 -33
- sglang/srt/layers/moe/fused_moe_native.py +14 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
- sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
- sglang/srt/layers/moe/moe_runner/base.py +13 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
- sglang/srt/layers/moe/router.py +15 -9
- sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
- sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +167 -83
- sglang/srt/layers/moe/utils.py +159 -18
- sglang/srt/layers/quantization/__init__.py +13 -14
- sglang/srt/layers/quantization/awq.py +7 -7
- sglang/srt/layers/quantization/base_config.py +2 -6
- sglang/srt/layers/quantization/blockwise_int8.py +4 -12
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -0
- sglang/srt/layers/quantization/fp8.py +127 -119
- sglang/srt/layers/quantization/fp8_kernel.py +195 -24
- sglang/srt/layers/quantization/fp8_utils.py +34 -9
- sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
- sglang/srt/layers/quantization/gptq.py +5 -4
- sglang/srt/layers/quantization/marlin_utils.py +11 -3
- sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
- sglang/srt/layers/quantization/modelopt_quant.py +165 -68
- sglang/srt/layers/quantization/moe_wna16.py +10 -15
- sglang/srt/layers/quantization/mxfp4.py +206 -37
- sglang/srt/layers/quantization/quark/quark.py +390 -0
- sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
- sglang/srt/layers/quantization/unquant.py +34 -70
- sglang/srt/layers/quantization/utils.py +25 -0
- sglang/srt/layers/quantization/w4afp8.py +7 -8
- sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
- sglang/srt/layers/quantization/w8a8_int8.py +5 -13
- sglang/srt/layers/radix_attention.py +6 -0
- sglang/srt/layers/rotary_embedding.py +1 -0
- sglang/srt/lora/lora_manager.py +21 -22
- sglang/srt/lora/lora_registry.py +3 -3
- sglang/srt/lora/mem_pool.py +26 -24
- sglang/srt/lora/utils.py +10 -12
- sglang/srt/managers/cache_controller.py +76 -18
- sglang/srt/managers/detokenizer_manager.py +10 -2
- sglang/srt/managers/io_struct.py +9 -0
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/schedule_batch.py +4 -9
- sglang/srt/managers/scheduler.py +25 -16
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/template_manager.py +7 -5
- sglang/srt/managers/tokenizer_manager.py +60 -21
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/managers/utils.py +59 -1
- sglang/srt/mem_cache/allocator.py +7 -5
- sglang/srt/mem_cache/allocator_ascend.py +0 -11
- sglang/srt/mem_cache/hicache_storage.py +14 -4
- sglang/srt/mem_cache/memory_pool.py +3 -3
- sglang/srt/mem_cache/memory_pool_host.py +35 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
- sglang/srt/model_executor/cuda_graph_runner.py +25 -12
- sglang/srt/model_executor/forward_batch_info.py +4 -1
- sglang/srt/model_executor/model_runner.py +43 -32
- sglang/srt/model_executor/npu_graph_runner.py +94 -0
- sglang/srt/model_loader/loader.py +24 -6
- sglang/srt/models/dbrx.py +12 -6
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +3 -1
- sglang/srt/models/deepseek_v2.py +224 -223
- sglang/srt/models/ernie4.py +2 -2
- sglang/srt/models/glm4_moe.py +25 -63
- sglang/srt/models/glm4v.py +52 -1
- sglang/srt/models/glm4v_moe.py +8 -11
- sglang/srt/models/gpt_oss.py +34 -74
- sglang/srt/models/granitemoe.py +0 -1
- sglang/srt/models/grok.py +375 -51
- sglang/srt/models/interns1.py +12 -47
- sglang/srt/models/internvl.py +6 -51
- sglang/srt/models/llama4.py +0 -2
- sglang/srt/models/minicpm3.py +0 -1
- sglang/srt/models/mixtral.py +0 -2
- sglang/srt/models/nemotron_nas.py +435 -0
- sglang/srt/models/olmoe.py +0 -1
- sglang/srt/models/phi4mm.py +3 -21
- sglang/srt/models/qwen2_5_vl.py +2 -0
- sglang/srt/models/qwen2_moe.py +3 -18
- sglang/srt/models/qwen3.py +2 -2
- sglang/srt/models/qwen3_classification.py +7 -1
- sglang/srt/models/qwen3_moe.py +9 -38
- sglang/srt/models/step3_vl.py +2 -1
- sglang/srt/models/xverse_moe.py +11 -5
- sglang/srt/multimodal/processors/base_processor.py +3 -3
- sglang/srt/multimodal/processors/internvl.py +7 -2
- sglang/srt/multimodal/processors/llava.py +11 -7
- sglang/srt/offloader.py +433 -0
- sglang/srt/operations.py +6 -1
- sglang/srt/reasoning_parser.py +4 -3
- sglang/srt/server_args.py +237 -104
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
- sglang/srt/speculative/eagle_utils.py +36 -13
- sglang/srt/speculative/eagle_worker.py +56 -3
- sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
- sglang/srt/two_batch_overlap.py +16 -11
- sglang/srt/utils.py +68 -70
- sglang/test/runners.py +8 -5
- sglang/test/test_block_fp8.py +5 -6
- sglang/test/test_block_fp8_ep.py +13 -19
- sglang/test/test_cutlass_moe.py +4 -6
- sglang/test/test_cutlass_w4a8_moe.py +4 -3
- sglang/test/test_fp4_moe.py +4 -3
- sglang/test/test_utils.py +7 -0
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/METADATA +7 -7
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/RECORD +179 -161
- sglang/srt/layers/quantization/fp4.py +0 -557
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/top_level.txt +0 -0
@@ -25,7 +25,6 @@ import torch
|
|
25
25
|
import torch.distributed
|
26
26
|
|
27
27
|
from sglang.srt.eplb.expert_location import ExpertLocationMetadata
|
28
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
29
28
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
30
29
|
from sglang.srt.server_args import ServerArgs
|
31
30
|
from sglang.srt.utils import Withable, get_bool_env_var
|
@@ -288,14 +287,14 @@ class _SinglePassGatherer(ABC):
|
|
288
287
|
)
|
289
288
|
|
290
289
|
if server_args.expert_distribution_recorder_mode == "stat_approx":
|
291
|
-
if server_args.moe_a2a_backend
|
290
|
+
if server_args.moe_a2a_backend != "none" and (
|
292
291
|
server_args.deepep_mode == "normal"
|
293
292
|
):
|
294
293
|
return _DeepepNormalSinglePassGatherer(expert_location_metadata, rank)
|
295
294
|
else:
|
296
295
|
raise NotImplementedError
|
297
296
|
|
298
|
-
if server_args.moe_a2a_backend
|
297
|
+
if server_args.moe_a2a_backend != "none":
|
299
298
|
if server_args.deepep_mode == "normal":
|
300
299
|
return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank)
|
301
300
|
elif server_args.deepep_mode == "low_latency":
|
@@ -215,6 +215,6 @@ class DeepSeekV3Detector(BaseFormatDetector):
|
|
215
215
|
sequence_start_token=self.bot_token,
|
216
216
|
sequence_end_token=self.eot_token,
|
217
217
|
tool_call_separator="",
|
218
|
-
call_rule_fmt='"<|tool▁call▁begin|>function<|tool▁sep|>{name}\\n```json\\n"
|
218
|
+
call_rule_fmt='"<|tool▁call▁begin|>function<|tool▁sep|>{name}\\n```json\\n"{arguments_rule}"\\n```<|tool▁call▁end|>"',
|
219
219
|
function_format="json",
|
220
220
|
)
|
@@ -129,6 +129,25 @@ def get_config(
|
|
129
129
|
config = AutoConfig.from_pretrained(
|
130
130
|
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
|
131
131
|
)
|
132
|
+
if (
|
133
|
+
config.architectures is not None
|
134
|
+
and config.architectures[0] == "Phi4MMForCausalLM"
|
135
|
+
):
|
136
|
+
# Phi4MMForCausalLM uses a hard-coded vision_config. See:
|
137
|
+
# https://github.com/vllm-project/vllm/blob/6071e989df1531b59ef35568f83f7351afb0b51e/vllm/model_executor/models/phi4mm.py#L71
|
138
|
+
# We set it here to support cases where num_attention_heads is not divisible by the TP size.
|
139
|
+
from transformers import SiglipVisionConfig
|
140
|
+
|
141
|
+
vision_config = {
|
142
|
+
"hidden_size": 1152,
|
143
|
+
"image_size": 448,
|
144
|
+
"intermediate_size": 4304,
|
145
|
+
"model_type": "siglip_vision_model",
|
146
|
+
"num_attention_heads": 16,
|
147
|
+
"num_hidden_layers": 26, # Model is originally 27-layer, we only need the first 26 layers for feature extraction.
|
148
|
+
"patch_size": 14,
|
149
|
+
}
|
150
|
+
config.vision_config = SiglipVisionConfig(**vision_config)
|
132
151
|
text_config = get_hf_text_config(config=config)
|
133
152
|
|
134
153
|
if isinstance(model, str) and text_config is not None:
|
@@ -244,6 +263,11 @@ def get_tokenizer(
|
|
244
263
|
**kwargs,
|
245
264
|
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
246
265
|
"""Gets a tokenizer for the given model name via Huggingface."""
|
266
|
+
if tokenizer_name.endswith(".json"):
|
267
|
+
from sglang.srt.tokenizer.tiktoken_tokenizer import TiktokenTokenizer
|
268
|
+
|
269
|
+
return TiktokenTokenizer(tokenizer_name)
|
270
|
+
|
247
271
|
if tokenizer_mode == "slow":
|
248
272
|
if kwargs.get("use_fast", False):
|
249
273
|
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
|
@@ -0,0 +1,83 @@
|
|
1
|
+
import logging
|
2
|
+
import os
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from multiprocessing import shared_memory
|
5
|
+
from pathlib import Path
|
6
|
+
from typing import List, Optional
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
import torch
|
10
|
+
|
11
|
+
from sglang.srt.distributed.naive_distributed import get_naive_distributed
|
12
|
+
from sglang.srt.utils import check_cuda_result
|
13
|
+
|
14
|
+
logger = logging.getLogger(__name__)
|
15
|
+
|
16
|
+
|
17
|
+
class HostSharedMemoryManager:
|
18
|
+
def __init__(self, base_name: str):
|
19
|
+
self._base_name = Path(base_name)
|
20
|
+
self._operation_index = 0
|
21
|
+
self._records: List[_Record] = []
|
22
|
+
|
23
|
+
def malloc(self, *, shape, dtype):
|
24
|
+
meta_tensor = torch.empty(size=shape, dtype=dtype, device="meta")
|
25
|
+
raw = self._malloc_raw(num_bytes=meta_tensor.nbytes)
|
26
|
+
return raw.view(dtype).view(*shape)
|
27
|
+
|
28
|
+
def _malloc_raw(self, *, num_bytes: int) -> torch.Tensor:
|
29
|
+
import cuda.bindings.runtime as cuda_rt
|
30
|
+
|
31
|
+
self._operation_index += 1
|
32
|
+
shm_name = f"{self._base_name}_op{self._operation_index}"
|
33
|
+
|
34
|
+
# TODO handle dispose
|
35
|
+
if get_naive_distributed().get_rank() == 0:
|
36
|
+
shm = shared_memory.SharedMemory(name=shm_name, create=True, size=num_bytes)
|
37
|
+
|
38
|
+
get_naive_distributed().barrier()
|
39
|
+
|
40
|
+
if get_naive_distributed().get_rank() != 0:
|
41
|
+
shm = shared_memory.SharedMemory(name=shm_name)
|
42
|
+
|
43
|
+
np_array = np.ndarray((num_bytes,), dtype=np.uint8, buffer=shm.buf)
|
44
|
+
tensor = torch.from_numpy(np_array)
|
45
|
+
|
46
|
+
check_cuda_result(
|
47
|
+
cuda_rt.cudaHostRegister(
|
48
|
+
tensor.data_ptr(), num_bytes, cuda_rt.cudaHostRegisterPortable
|
49
|
+
)
|
50
|
+
)
|
51
|
+
|
52
|
+
get_naive_distributed().barrier()
|
53
|
+
|
54
|
+
self._records.append(
|
55
|
+
_Record(
|
56
|
+
shm=shm,
|
57
|
+
np_array=np_array,
|
58
|
+
tensor=tensor,
|
59
|
+
)
|
60
|
+
)
|
61
|
+
return tensor
|
62
|
+
|
63
|
+
|
64
|
+
@dataclass
|
65
|
+
class _Record:
|
66
|
+
shm: shared_memory.SharedMemory
|
67
|
+
np_array: np.ndarray
|
68
|
+
tensor: torch.Tensor
|
69
|
+
|
70
|
+
|
71
|
+
# Can have multi instances if needed
|
72
|
+
_instance: Optional[HostSharedMemoryManager] = None
|
73
|
+
|
74
|
+
|
75
|
+
def get_host_shared_memory_manager():
|
76
|
+
assert _instance is not None
|
77
|
+
return _instance
|
78
|
+
|
79
|
+
|
80
|
+
def set_host_shared_memory_manager(instance: HostSharedMemoryManager):
|
81
|
+
global _instance
|
82
|
+
assert _instance is None
|
83
|
+
_instance = instance
|
@@ -1,7 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
from dataclasses import dataclass
|
4
|
-
from typing import TYPE_CHECKING, Optional
|
4
|
+
from typing import TYPE_CHECKING, List, Optional
|
5
5
|
|
6
6
|
import torch
|
7
7
|
import torch_npu
|
@@ -27,6 +27,7 @@ class ForwardMetadata:
|
|
27
27
|
# seq len inputs
|
28
28
|
extend_seq_lens_cpu_int: Optional[torch.Tensor] = None
|
29
29
|
seq_lens_cpu_int: Optional[torch.Tensor] = None
|
30
|
+
seq_lens_cpu_list: Optional[List[int]] = None
|
30
31
|
|
31
32
|
|
32
33
|
class AscendAttnBackend(AttentionBackend):
|
@@ -51,7 +52,7 @@ class AscendAttnBackend(AttentionBackend):
|
|
51
52
|
|
52
53
|
def __init__(self, model_runner: ModelRunner):
|
53
54
|
super().__init__()
|
54
|
-
self.forward_metadata =
|
55
|
+
self.forward_metadata = None
|
55
56
|
self.device = model_runner.device
|
56
57
|
self.gen_attention_mask(128, model_runner.dtype)
|
57
58
|
self.page_size = model_runner.page_size
|
@@ -60,9 +61,15 @@ class AscendAttnBackend(AttentionBackend):
|
|
60
61
|
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
|
61
62
|
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
|
62
63
|
self.native_attn = TorchNativeAttnBackend(model_runner)
|
64
|
+
self.graph_metadata = {}
|
65
|
+
self.max_context_len = model_runner.model_config.context_len
|
66
|
+
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
67
|
+
self.graph_mode = False
|
63
68
|
|
64
69
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
65
70
|
"""Init the metadata for a forward pass."""
|
71
|
+
self.forward_metadata = ForwardMetadata()
|
72
|
+
|
66
73
|
self.forward_metadata.block_tables = (
|
67
74
|
forward_batch.req_to_token_pool.req_to_token[
|
68
75
|
forward_batch.req_pool_indices, : forward_batch.seq_lens.max()
|
@@ -75,6 +82,63 @@ class AscendAttnBackend(AttentionBackend):
|
|
75
82
|
)
|
76
83
|
self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
|
77
84
|
|
85
|
+
self.graph_mode = False
|
86
|
+
|
87
|
+
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
88
|
+
self.graph_metadata = {
|
89
|
+
"block_tables": torch.empty(
|
90
|
+
(max_bs, self.max_context_len // self.page_size),
|
91
|
+
dtype=torch.int32,
|
92
|
+
device=self.device,
|
93
|
+
),
|
94
|
+
}
|
95
|
+
|
96
|
+
def init_forward_metadata_capture_cuda_graph(
|
97
|
+
self,
|
98
|
+
bs: int,
|
99
|
+
num_tokens: int,
|
100
|
+
req_pool_indices: torch.Tensor,
|
101
|
+
seq_lens: torch.Tensor,
|
102
|
+
encoder_lens: Optional[torch.Tensor],
|
103
|
+
forward_mode: ForwardMode,
|
104
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
105
|
+
):
|
106
|
+
metadata = ForwardMetadata()
|
107
|
+
|
108
|
+
metadata.block_tables = self.graph_metadata["block_tables"][:bs, :]
|
109
|
+
metadata.seq_lens_cpu_list = seq_lens.cpu().int().tolist()
|
110
|
+
|
111
|
+
self.graph_metadata[bs] = metadata
|
112
|
+
self.forward_metadata = metadata
|
113
|
+
|
114
|
+
self.graph_mode = True
|
115
|
+
|
116
|
+
def init_forward_metadata_replay_cuda_graph(
|
117
|
+
self,
|
118
|
+
bs: int,
|
119
|
+
req_pool_indices: torch.Tensor,
|
120
|
+
seq_lens: torch.Tensor,
|
121
|
+
seq_lens_sum: int,
|
122
|
+
encoder_lens: Optional[torch.Tensor],
|
123
|
+
forward_mode: ForwardMode,
|
124
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
125
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
126
|
+
):
|
127
|
+
metadata = self.graph_metadata[bs]
|
128
|
+
max_len = seq_lens_cpu[:bs].max().item()
|
129
|
+
max_seq_pages = (max_len + self.page_size - 1) // self.page_size
|
130
|
+
|
131
|
+
metadata.block_tables[:bs, :max_seq_pages].copy_(
|
132
|
+
self.req_to_token[req_pool_indices[:bs], :max_len][:, :: self.page_size]
|
133
|
+
// self.page_size
|
134
|
+
)
|
135
|
+
metadata.block_tables[:bs, max_seq_pages:].fill_(0)
|
136
|
+
metadata.block_tables[bs:, :].fill_(0)
|
137
|
+
|
138
|
+
self.forward_metadata = metadata
|
139
|
+
|
140
|
+
self.graph_mode = True
|
141
|
+
|
78
142
|
def get_cuda_graph_seq_len_fill_value(self):
|
79
143
|
return 1
|
80
144
|
|
@@ -167,28 +231,74 @@ class AscendAttnBackend(AttentionBackend):
|
|
167
231
|
layer, forward_batch.out_cache_loc, k, v
|
168
232
|
)
|
169
233
|
if not self.use_mla:
|
170
|
-
|
171
|
-
|
234
|
+
if self.graph_mode:
|
235
|
+
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(
|
236
|
+
layer.layer_id
|
237
|
+
).view(-1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim)
|
238
|
+
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(
|
239
|
+
layer.layer_id
|
240
|
+
).view(-1, self.page_size, layer.tp_v_head_num * layer.v_head_dim)
|
241
|
+
query = q.view(-1, 1, layer.tp_q_head_num * layer.qk_head_dim)
|
242
|
+
num_tokens = query.shape[0]
|
243
|
+
workspace = (
|
244
|
+
torch_npu._npu_fused_infer_attention_score_get_max_workspace(
|
245
|
+
query,
|
246
|
+
k_cache,
|
247
|
+
v_cache,
|
248
|
+
block_table=self.forward_metadata.block_tables,
|
249
|
+
block_size=self.page_size,
|
250
|
+
num_heads=layer.tp_q_head_num,
|
251
|
+
num_key_value_heads=layer.tp_k_head_num,
|
252
|
+
input_layout="BSH",
|
253
|
+
scale=layer.scaling,
|
254
|
+
actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list,
|
255
|
+
)
|
256
|
+
)
|
257
|
+
output = torch.empty(
|
258
|
+
(num_tokens, 1, layer.tp_q_head_num * layer.v_head_dim),
|
259
|
+
dtype=q.dtype,
|
260
|
+
device=q.device,
|
261
|
+
)
|
262
|
+
softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)
|
263
|
+
torch_npu.npu_fused_infer_attention_score.out(
|
264
|
+
query,
|
265
|
+
k_cache,
|
266
|
+
v_cache,
|
267
|
+
block_table=self.forward_metadata.block_tables,
|
268
|
+
block_size=self.page_size,
|
269
|
+
num_heads=layer.tp_q_head_num,
|
270
|
+
num_key_value_heads=layer.tp_k_head_num,
|
271
|
+
input_layout="BSH",
|
272
|
+
scale=layer.scaling,
|
273
|
+
actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list,
|
274
|
+
workspace=workspace,
|
275
|
+
out=[output, softmax_lse],
|
276
|
+
)
|
277
|
+
else:
|
278
|
+
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
279
|
+
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(
|
280
|
+
layer.layer_id
|
281
|
+
)
|
172
282
|
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
283
|
+
query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
284
|
+
num_tokens = query.shape[0]
|
285
|
+
output = torch.empty(
|
286
|
+
(num_tokens, layer.tp_q_head_num, layer.v_head_dim),
|
287
|
+
dtype=query.dtype,
|
288
|
+
device=query.device,
|
289
|
+
)
|
180
290
|
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
291
|
+
torch_npu._npu_paged_attention(
|
292
|
+
query=query,
|
293
|
+
key_cache=k_cache,
|
294
|
+
value_cache=v_cache,
|
295
|
+
num_heads=layer.tp_q_head_num,
|
296
|
+
num_kv_heads=layer.tp_k_head_num,
|
297
|
+
scale_value=layer.scaling,
|
298
|
+
block_table=self.forward_metadata.block_tables,
|
299
|
+
context_lens=self.forward_metadata.seq_lens_cpu_int,
|
300
|
+
out=output,
|
301
|
+
)
|
192
302
|
return output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
|
193
303
|
else:
|
194
304
|
query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
@@ -776,14 +776,13 @@ class FlashAttentionBackend(AttentionBackend):
|
|
776
776
|
o = result
|
777
777
|
else:
|
778
778
|
if (
|
779
|
-
not
|
780
|
-
and forward_batch.attn_attend_prefix_cache is not None
|
779
|
+
forward_batch.attn_attend_prefix_cache is not None
|
781
780
|
and not forward_batch.forward_mode.is_target_verify()
|
782
781
|
and not forward_batch.forward_mode.is_draft_extend()
|
783
782
|
):
|
784
783
|
# Do multi-head attention with chunked prefix cache
|
785
|
-
|
786
784
|
if forward_batch.attn_attend_prefix_cache:
|
785
|
+
assert not global_server_args_dict["disable_chunked_prefix_cache"]
|
787
786
|
# MHA for chunked prefix kv cache when running model with MLA
|
788
787
|
assert forward_batch.prefix_chunk_idx is not None
|
789
788
|
assert forward_batch.prefix_chunk_cu_seq_lens is not None
|
@@ -792,7 +791,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|
792
791
|
chunk_idx = forward_batch.prefix_chunk_idx
|
793
792
|
assert chunk_idx >= 0
|
794
793
|
|
795
|
-
|
794
|
+
assert forward_batch.mha_return_lse
|
795
|
+
output = flash_attn_varlen_func(
|
796
796
|
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
797
797
|
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
|
798
798
|
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
|
@@ -806,7 +806,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
806
806
|
)
|
807
807
|
else:
|
808
808
|
# MHA for extend part of sequence without attending prefix kv cache
|
809
|
-
output
|
809
|
+
output = flash_attn_varlen_func(
|
810
810
|
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
811
811
|
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
|
812
812
|
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
|
@@ -816,9 +816,13 @@ class FlashAttentionBackend(AttentionBackend):
|
|
816
816
|
max_seqlen_k=metadata.max_seq_len_q,
|
817
817
|
softmax_scale=layer.scaling,
|
818
818
|
causal=True,
|
819
|
-
return_softmax_lse=
|
819
|
+
return_softmax_lse=forward_batch.mha_return_lse,
|
820
820
|
)
|
821
|
-
|
821
|
+
if forward_batch.mha_return_lse:
|
822
|
+
output, lse, *rest = output
|
823
|
+
lse = torch.transpose(lse, 0, 1).contiguous()
|
824
|
+
return output, lse
|
825
|
+
return output
|
822
826
|
else:
|
823
827
|
# Do absorbed multi-latent attention
|
824
828
|
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(
|
@@ -1163,6 +1167,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1163
1167
|
This creates fixed-size tensors that will be reused during CUDA graph replay
|
1164
1168
|
to avoid memory allocations.
|
1165
1169
|
"""
|
1170
|
+
max_num_pages = (self.max_context_len + self.page_size - 1) // self.page_size
|
1171
|
+
|
1166
1172
|
# This is being used by normal decode and draft decode when topk == 1
|
1167
1173
|
self.decode_cuda_graph_metadata = {
|
1168
1174
|
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
|
@@ -1174,13 +1180,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1174
1180
|
),
|
1175
1181
|
"page_table": torch.zeros(
|
1176
1182
|
max_bs,
|
1177
|
-
|
1178
|
-
dtype=torch.int32,
|
1179
|
-
device=self.device,
|
1180
|
-
),
|
1181
|
-
"page_table_draft_decode": torch.zeros(
|
1182
|
-
max_bs,
|
1183
|
-
(self.max_context_len + self.page_size - 1) // self.page_size,
|
1183
|
+
max_num_pages,
|
1184
1184
|
dtype=torch.int32,
|
1185
1185
|
device=self.device,
|
1186
1186
|
),
|
@@ -1188,7 +1188,6 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1188
1188
|
0, self.max_context_len, self.page_size, device=self.device
|
1189
1189
|
),
|
1190
1190
|
}
|
1191
|
-
|
1192
1191
|
# Only allocate local attention buffers if local attention is enabled
|
1193
1192
|
# This prevents OOM errors when local attention is not being used
|
1194
1193
|
if self.attention_chunk_size is not None:
|
@@ -1274,6 +1273,14 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1274
1273
|
self.speculative_num_draft_tokens is not None
|
1275
1274
|
and self.speculative_num_draft_tokens > 0
|
1276
1275
|
):
|
1276
|
+
# "page_table_draft_decode" will be set only when spec decoding enabled to save memory
|
1277
|
+
self.decode_cuda_graph_metadata["page_table_draft_decode"] = torch.zeros(
|
1278
|
+
max_bs,
|
1279
|
+
max_num_pages,
|
1280
|
+
dtype=torch.int32,
|
1281
|
+
device=self.device,
|
1282
|
+
)
|
1283
|
+
|
1277
1284
|
self.target_verify_metadata = {
|
1278
1285
|
"cache_seqlens": torch.zeros(
|
1279
1286
|
max_bs, dtype=torch.int32, device=self.device
|
@@ -1290,7 +1297,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1290
1297
|
),
|
1291
1298
|
"page_table": torch.zeros(
|
1292
1299
|
max_bs,
|
1293
|
-
|
1300
|
+
max_num_pages,
|
1294
1301
|
dtype=torch.int32,
|
1295
1302
|
device=self.device,
|
1296
1303
|
),
|
@@ -1313,7 +1320,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1313
1320
|
),
|
1314
1321
|
"page_table": torch.zeros(
|
1315
1322
|
max_bs,
|
1316
|
-
|
1323
|
+
max_num_pages,
|
1317
1324
|
dtype=torch.int32,
|
1318
1325
|
device=self.device,
|
1319
1326
|
),
|
@@ -1263,11 +1263,12 @@ def should_use_tensor_core(
|
|
1263
1263
|
# Calculate GQA group size
|
1264
1264
|
gqa_group_size = num_attention_heads // num_kv_heads
|
1265
1265
|
|
1266
|
-
#
|
1266
|
+
# For Flashinfer, a GQA group size of at least 4 is needed to efficiently
|
1267
|
+
# use Tensor Cores, as it fuses the head group with the token dimension in MMA.
|
1267
1268
|
if kv_cache_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
|
1268
1269
|
return True
|
1269
1270
|
elif kv_cache_dtype in (torch.float16, torch.half, torch.bfloat16):
|
1270
|
-
return gqa_group_size
|
1271
|
+
return gqa_group_size >= 4
|
1271
1272
|
else:
|
1272
1273
|
return False
|
1273
1274
|
|
@@ -1372,7 +1373,14 @@ def fast_decode_plan(
|
|
1372
1373
|
|
1373
1374
|
if self.use_tensor_cores:
|
1374
1375
|
# ALSO convert last_page_len to CPU
|
1375
|
-
|
1376
|
+
if page_size == 1:
|
1377
|
+
# When page size is 1, last_page_len is always 1.
|
1378
|
+
# Directly construct the host tensor rather than executing a device-to-host copy.
|
1379
|
+
last_page_len_host = torch.ones(
|
1380
|
+
(batch_size,), dtype=torch.int32, device="cpu"
|
1381
|
+
)
|
1382
|
+
else:
|
1383
|
+
last_page_len_host = last_page_len.cpu()
|
1376
1384
|
|
1377
1385
|
kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size)
|
1378
1386
|
|