sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +26 -4
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +676 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +49 -8
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/parallel_state.py +42 -8
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +78 -13
- sglang/srt/entrypoints/verl_engine.py +2 -0
- sglang/srt/function_call_parser.py +133 -55
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +434 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +41 -19
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +25 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/topk.py +60 -20
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +80 -53
- sglang/srt/layers/quantization/awq.py +200 -0
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +25 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -19
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +78 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/backend/base_backend.py +4 -4
- sglang/srt/lora/backend/flashinfer_backend.py +12 -9
- sglang/srt/lora/backend/triton_backend.py +5 -8
- sglang/srt/lora/layers.py +87 -33
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +67 -30
- sglang/srt/lora/mem_pool.py +117 -52
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
- sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
- sglang/srt/lora/utils.py +18 -1
- sglang/srt/managers/cache_controller.py +2 -5
- sglang/srt/managers/data_parallel_controller.py +30 -8
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +43 -5
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/clip.py +63 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +134 -30
- sglang/srt/managers/scheduler.py +290 -31
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -24
- sglang/srt/managers/tp_worker.py +4 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +18 -7
- sglang/srt/mem_cache/memory_pool.py +255 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +36 -21
- sglang/srt/model_executor/forward_batch_info.py +68 -11
- sglang/srt/model_executor/model_runner.py +75 -8
- sglang/srt/model_loader/loader.py +171 -3
- sglang/srt/model_loader/weight_utils.py +51 -3
- sglang/srt/models/clip.py +563 -0
- sglang/srt/models/deepseek_janus_pro.py +31 -88
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +329 -73
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +694 -0
- sglang/srt/models/gemma3_mm.py +468 -0
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +201 -104
- sglang/srt/openai_api/protocol.py +33 -7
- sglang/srt/patch_torch.py +71 -0
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +114 -14
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +140 -54
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +215 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +29 -2
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +56 -5
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +16 -8
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +180 -132
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
@@ -33,16 +33,18 @@ from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_N
|
|
33
33
|
from sglang.srt.configs import (
|
34
34
|
ChatGLMConfig,
|
35
35
|
DbrxConfig,
|
36
|
+
DeepseekVL2Config,
|
36
37
|
ExaoneConfig,
|
37
38
|
MultiModalityConfig,
|
38
|
-
Qwen2_5_VLConfig,
|
39
39
|
)
|
40
|
+
from sglang.srt.connector import create_remote_connector
|
41
|
+
from sglang.srt.utils import is_remote_url
|
40
42
|
|
41
43
|
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
42
44
|
ChatGLMConfig.model_type: ChatGLMConfig,
|
43
45
|
DbrxConfig.model_type: DbrxConfig,
|
44
46
|
ExaoneConfig.model_type: ExaoneConfig,
|
45
|
-
|
47
|
+
DeepseekVL2Config.model_type: DeepseekVL2Config,
|
46
48
|
MultiModalityConfig.model_type: MultiModalityConfig,
|
47
49
|
}
|
48
50
|
|
@@ -155,6 +157,14 @@ def get_tokenizer(
|
|
155
157
|
kwargs["gguf_file"] = tokenizer_name
|
156
158
|
tokenizer_name = Path(tokenizer_name).parent
|
157
159
|
|
160
|
+
if is_remote_url(tokenizer_name):
|
161
|
+
# BaseConnector implements __del__() to clean up the local dir.
|
162
|
+
# Since config files need to exist all the time, so we DO NOT use
|
163
|
+
# with statement to avoid closing the client.
|
164
|
+
client = create_remote_connector(tokenizer_name)
|
165
|
+
client.pull_files(ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
|
166
|
+
tokenizer_name = client.get_local_dir()
|
167
|
+
|
158
168
|
try:
|
159
169
|
tokenizer = AutoTokenizer.from_pretrained(
|
160
170
|
tokenizer_name,
|
@@ -207,11 +217,26 @@ def get_processor(
|
|
207
217
|
tokenizer_revision: Optional[str] = None,
|
208
218
|
**kwargs,
|
209
219
|
):
|
220
|
+
# pop 'revision' from kwargs if present.
|
221
|
+
revision = kwargs.pop("revision", tokenizer_revision)
|
222
|
+
|
223
|
+
config = AutoConfig.from_pretrained(
|
224
|
+
tokenizer_name,
|
225
|
+
trust_remote_code=trust_remote_code,
|
226
|
+
revision=revision,
|
227
|
+
**kwargs,
|
228
|
+
)
|
229
|
+
|
230
|
+
# fix: for Qwen2-VL model, inject default 'size' if not provided.
|
231
|
+
if config.model_type in {"qwen2_vl"}:
|
232
|
+
if "size" not in kwargs:
|
233
|
+
kwargs["size"] = {"shortest_edge": 3136, "longest_edge": 1003520}
|
234
|
+
|
210
235
|
processor = AutoProcessor.from_pretrained(
|
211
236
|
tokenizer_name,
|
212
237
|
*args,
|
213
238
|
trust_remote_code=trust_remote_code,
|
214
|
-
|
239
|
+
revision=revision,
|
215
240
|
**kwargs,
|
216
241
|
)
|
217
242
|
|
sglang/srt/layers/activation.py
CHANGED
@@ -23,7 +23,9 @@ import torch.nn.functional as F
|
|
23
23
|
|
24
24
|
from sglang.srt.utils import is_cuda_available
|
25
25
|
|
26
|
-
|
26
|
+
_is_cuda = is_cuda_available()
|
27
|
+
|
28
|
+
if _is_cuda:
|
27
29
|
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
28
30
|
|
29
31
|
from sglang.srt.custom_op import CustomOp
|
@@ -165,7 +167,7 @@ def get_act_fn(
|
|
165
167
|
return act_fn
|
166
168
|
|
167
169
|
|
168
|
-
if not
|
170
|
+
if not _is_cuda:
|
169
171
|
logger.info(
|
170
172
|
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
|
171
173
|
)
|
@@ -47,7 +47,7 @@ class AttentionBackend(ABC):
|
|
47
47
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
48
48
|
seq_lens_cpu: Optional[torch.Tensor],
|
49
49
|
):
|
50
|
-
"""Init the metadata for a forward pass for
|
50
|
+
"""Init the metadata for a forward pass for replaying a cuda graph."""
|
51
51
|
raise NotImplementedError()
|
52
52
|
|
53
53
|
def get_cuda_graph_seq_len_fill_value(self):
|
@@ -0,0 +1,434 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
4
|
+
|
5
|
+
"""
|
6
|
+
Support different attention backends.
|
7
|
+
Now there are three backends: FlashInfer, Triton and FlashAttention.
|
8
|
+
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
|
9
|
+
"""
|
10
|
+
|
11
|
+
from dataclasses import dataclass
|
12
|
+
from typing import TYPE_CHECKING, Optional, Union
|
13
|
+
|
14
|
+
import torch
|
15
|
+
|
16
|
+
from sglang.srt.configs.model_config import AttentionArch
|
17
|
+
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
18
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
19
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
20
|
+
|
21
|
+
if TYPE_CHECKING:
|
22
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
23
|
+
from sglang.srt.model_executor.model_runner import ModelRunner
|
24
|
+
|
25
|
+
from flash_attn_interface import flash_attn_with_kvcache
|
26
|
+
|
27
|
+
|
28
|
+
@dataclass
|
29
|
+
class FlashAttentionMetadata:
|
30
|
+
"""Metadata for decode operations to avoid redundant computations."""
|
31
|
+
|
32
|
+
cu_seqlens_q: torch.Tensor = None
|
33
|
+
cu_seqlens_k: torch.Tensor = None
|
34
|
+
max_seq_len_q: int = 0
|
35
|
+
max_seq_len_k: int = 0
|
36
|
+
window_size: tuple = (-1, -1)
|
37
|
+
page_table: torch.Tensor = None
|
38
|
+
cache_seqlens_int32: torch.Tensor = None
|
39
|
+
|
40
|
+
|
41
|
+
class FlashAttentionBackend(AttentionBackend):
|
42
|
+
"""FlashAttention backend implementation."""
|
43
|
+
|
44
|
+
def __init__(
|
45
|
+
self,
|
46
|
+
model_runner: ModelRunner,
|
47
|
+
skip_prefill: bool = False,
|
48
|
+
):
|
49
|
+
super().__init__()
|
50
|
+
|
51
|
+
assert not (
|
52
|
+
model_runner.sliding_window_size is not None
|
53
|
+
and model_runner.model_config.is_encoder_decoder
|
54
|
+
), "Sliding window and cross attention are not supported together"
|
55
|
+
|
56
|
+
# Initialize metadata
|
57
|
+
self.forward_metadata: FlashAttentionMetadata = None
|
58
|
+
self.max_context_len = model_runner.model_config.context_len
|
59
|
+
self.device = model_runner.device
|
60
|
+
self.decode_cuda_graph_metadata = {}
|
61
|
+
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
62
|
+
self.page_size = model_runner.page_size
|
63
|
+
self.use_mla = (
|
64
|
+
model_runner.model_config.attention_arch == AttentionArch.MLA
|
65
|
+
) and (not global_server_args_dict["disable_mla"])
|
66
|
+
|
67
|
+
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
68
|
+
"""Initialize forward metadata to cache repetitive calculations."""
|
69
|
+
# Create metadata based on forward mode
|
70
|
+
metadata = FlashAttentionMetadata()
|
71
|
+
|
72
|
+
# Get sequence information
|
73
|
+
seqlens_in_batch = forward_batch.seq_lens
|
74
|
+
# Precompute int32 version of sequence lengths
|
75
|
+
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
76
|
+
batch_size = len(seqlens_in_batch)
|
77
|
+
device = seqlens_in_batch.device
|
78
|
+
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
79
|
+
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
80
|
+
)
|
81
|
+
# Precompute maximum sequence length
|
82
|
+
metadata.max_seq_len_k = seqlens_in_batch.max().item()
|
83
|
+
# Precompute page table
|
84
|
+
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
85
|
+
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
86
|
+
]
|
87
|
+
|
88
|
+
# Precompute strided indices
|
89
|
+
# [0, page_size, 2 * page_size, ...]
|
90
|
+
if self.page_size > 1:
|
91
|
+
self.strided_indices = torch.arange(
|
92
|
+
0, metadata.page_table.shape[1], self.page_size, device=self.device
|
93
|
+
)
|
94
|
+
metadata.page_table = (
|
95
|
+
metadata.page_table[:, self.strided_indices] // self.page_size
|
96
|
+
)
|
97
|
+
|
98
|
+
if forward_batch.forward_mode == ForwardMode.DECODE:
|
99
|
+
# Precompute cumulative sequence lengths
|
100
|
+
metadata.cu_seqlens_q = torch.arange(
|
101
|
+
0, batch_size + 1, dtype=torch.int32, device=device
|
102
|
+
)
|
103
|
+
else:
|
104
|
+
# Precompute cumulative sequence lengths
|
105
|
+
if any(forward_batch.extend_prefix_lens_cpu):
|
106
|
+
extend_seq_lens = forward_batch.extend_seq_lens
|
107
|
+
metadata.cu_seqlens_q = torch.nn.functional.pad(
|
108
|
+
torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0)
|
109
|
+
)
|
110
|
+
metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
|
111
|
+
else:
|
112
|
+
metadata.cu_seqlens_q = metadata.cu_seqlens_k
|
113
|
+
metadata.max_seq_len_q = metadata.max_seq_len_k
|
114
|
+
self.forward_metadata = metadata
|
115
|
+
|
116
|
+
def forward_extend(
|
117
|
+
self,
|
118
|
+
q: torch.Tensor,
|
119
|
+
k: torch.Tensor,
|
120
|
+
v: torch.Tensor,
|
121
|
+
layer: RadixAttention,
|
122
|
+
forward_batch: ForwardBatch,
|
123
|
+
save_kv_cache=True,
|
124
|
+
):
|
125
|
+
|
126
|
+
if k is not None:
|
127
|
+
assert v is not None
|
128
|
+
if save_kv_cache:
|
129
|
+
cache_loc = (
|
130
|
+
forward_batch.out_cache_loc
|
131
|
+
if not layer.is_cross_attention
|
132
|
+
else forward_batch.encoder_out_cache_loc
|
133
|
+
)
|
134
|
+
if not self.use_mla:
|
135
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
136
|
+
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
137
|
+
)
|
138
|
+
else:
|
139
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
140
|
+
layer,
|
141
|
+
cache_loc,
|
142
|
+
k,
|
143
|
+
v,
|
144
|
+
)
|
145
|
+
|
146
|
+
# Use precomputed metadata
|
147
|
+
metadata = self.forward_metadata
|
148
|
+
|
149
|
+
# Calculate window size (can be moved to metadata if layer properties don't change)
|
150
|
+
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
|
151
|
+
# here is two side inclusive
|
152
|
+
window_size = (
|
153
|
+
(layer.sliding_window_size, 0)
|
154
|
+
if layer.sliding_window_size is not None
|
155
|
+
else (-1, -1)
|
156
|
+
)
|
157
|
+
|
158
|
+
page_table = metadata.page_table
|
159
|
+
|
160
|
+
# # Use Flash Attention for prefill
|
161
|
+
if not self.use_mla:
|
162
|
+
# Do multi-head attention
|
163
|
+
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
164
|
+
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
165
|
+
key_cache = key_cache.view(
|
166
|
+
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
|
167
|
+
)
|
168
|
+
value_cache = value_cache.view(
|
169
|
+
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
|
170
|
+
)
|
171
|
+
o = flash_attn_with_kvcache(
|
172
|
+
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
173
|
+
k_cache=key_cache,
|
174
|
+
v_cache=value_cache,
|
175
|
+
page_table=page_table,
|
176
|
+
cache_seqlens=metadata.cache_seqlens_int32,
|
177
|
+
cu_seqlens_q=metadata.cu_seqlens_q,
|
178
|
+
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
179
|
+
max_seqlen_q=metadata.max_seq_len_q,
|
180
|
+
softmax_scale=layer.scaling,
|
181
|
+
causal=True,
|
182
|
+
window_size=window_size,
|
183
|
+
softcap=layer.logit_cap,
|
184
|
+
k_descale=layer.k_scale,
|
185
|
+
v_descale=layer.v_scale,
|
186
|
+
)
|
187
|
+
else:
|
188
|
+
# Do absorbed multi-latent attention
|
189
|
+
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
190
|
+
k_rope = kv_cache[:, :, layer.v_head_dim :]
|
191
|
+
c_kv = kv_cache[:, :, : layer.v_head_dim]
|
192
|
+
k_rope_cache = k_rope.view(
|
193
|
+
-1,
|
194
|
+
self.page_size,
|
195
|
+
layer.tp_k_head_num,
|
196
|
+
layer.head_dim - layer.v_head_dim,
|
197
|
+
)
|
198
|
+
c_kv_cache = c_kv.view(
|
199
|
+
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
|
200
|
+
)
|
201
|
+
|
202
|
+
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
203
|
+
q_nope = q_all[:, :, : layer.v_head_dim]
|
204
|
+
q_rope = q_all[:, :, layer.v_head_dim :]
|
205
|
+
o = flash_attn_with_kvcache(
|
206
|
+
q=q_rope,
|
207
|
+
k_cache=k_rope_cache,
|
208
|
+
v_cache=c_kv_cache,
|
209
|
+
qv=q_nope,
|
210
|
+
page_table=page_table,
|
211
|
+
cache_seqlens=metadata.cache_seqlens_int32,
|
212
|
+
cu_seqlens_q=metadata.cu_seqlens_q,
|
213
|
+
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
214
|
+
max_seqlen_q=metadata.max_seq_len_q,
|
215
|
+
softmax_scale=layer.scaling,
|
216
|
+
causal=True,
|
217
|
+
softcap=layer.logit_cap,
|
218
|
+
k_descale=layer.k_scale,
|
219
|
+
v_descale=layer.v_scale,
|
220
|
+
)
|
221
|
+
|
222
|
+
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
223
|
+
|
224
|
+
def forward_decode(
|
225
|
+
self,
|
226
|
+
q: torch.Tensor,
|
227
|
+
k: torch.Tensor,
|
228
|
+
v: torch.Tensor,
|
229
|
+
layer: RadixAttention,
|
230
|
+
forward_batch: ForwardBatch,
|
231
|
+
save_kv_cache=True,
|
232
|
+
) -> torch.Tensor:
|
233
|
+
"""Forward pass with FlashAttention using precomputed metadata."""
|
234
|
+
# Save KV cache if needed
|
235
|
+
if k is not None:
|
236
|
+
assert v is not None
|
237
|
+
if save_kv_cache:
|
238
|
+
cache_loc = (
|
239
|
+
forward_batch.out_cache_loc
|
240
|
+
if not layer.is_cross_attention
|
241
|
+
else forward_batch.encoder_out_cache_loc
|
242
|
+
)
|
243
|
+
if not self.use_mla:
|
244
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
245
|
+
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
246
|
+
)
|
247
|
+
else:
|
248
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
249
|
+
layer,
|
250
|
+
cache_loc,
|
251
|
+
k,
|
252
|
+
v,
|
253
|
+
)
|
254
|
+
|
255
|
+
# Use precomputed metadata
|
256
|
+
metadata = self.forward_metadata
|
257
|
+
|
258
|
+
# Calculate window size (can be moved to metadata if layer properties don't change)
|
259
|
+
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
|
260
|
+
# here is two side inclusive
|
261
|
+
window_size = (
|
262
|
+
(layer.sliding_window_size, 0)
|
263
|
+
if layer.sliding_window_size is not None
|
264
|
+
else (-1, -1)
|
265
|
+
)
|
266
|
+
|
267
|
+
page_table = metadata.page_table
|
268
|
+
|
269
|
+
if not self.use_mla:
|
270
|
+
# Do multi-head attention
|
271
|
+
|
272
|
+
# Get KV cache
|
273
|
+
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
274
|
+
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
275
|
+
key_cache = key_cache.view(
|
276
|
+
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
|
277
|
+
)
|
278
|
+
value_cache = value_cache.view(
|
279
|
+
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
|
280
|
+
)
|
281
|
+
|
282
|
+
# Pre-reshape query tensor
|
283
|
+
q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
284
|
+
|
285
|
+
# Run attention with precomputed values
|
286
|
+
o = flash_attn_with_kvcache(
|
287
|
+
q=q_reshaped,
|
288
|
+
k_cache=key_cache,
|
289
|
+
v_cache=value_cache,
|
290
|
+
page_table=page_table,
|
291
|
+
cache_seqlens=metadata.cache_seqlens_int32,
|
292
|
+
cu_seqlens_q=metadata.cu_seqlens_q,
|
293
|
+
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
294
|
+
max_seqlen_q=1,
|
295
|
+
softmax_scale=layer.scaling,
|
296
|
+
causal=True,
|
297
|
+
window_size=window_size,
|
298
|
+
softcap=layer.logit_cap,
|
299
|
+
k_descale=layer.k_scale,
|
300
|
+
v_descale=layer.v_scale,
|
301
|
+
)
|
302
|
+
else:
|
303
|
+
# Do absorbed multi-latent attention
|
304
|
+
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
305
|
+
k_rope = kv_cache[:, :, layer.v_head_dim :]
|
306
|
+
c_kv = kv_cache[:, :, : layer.v_head_dim]
|
307
|
+
k_rope_cache = k_rope.view(
|
308
|
+
-1,
|
309
|
+
self.page_size,
|
310
|
+
layer.tp_k_head_num,
|
311
|
+
layer.head_dim - layer.v_head_dim,
|
312
|
+
)
|
313
|
+
c_kv_cache = c_kv.view(
|
314
|
+
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
|
315
|
+
)
|
316
|
+
|
317
|
+
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
318
|
+
q_nope = q_all[:, :, : layer.v_head_dim]
|
319
|
+
q_rope = q_all[:, :, layer.v_head_dim :]
|
320
|
+
|
321
|
+
o = flash_attn_with_kvcache(
|
322
|
+
q=q_rope,
|
323
|
+
k_cache=k_rope_cache,
|
324
|
+
v_cache=c_kv_cache,
|
325
|
+
qv=q_nope,
|
326
|
+
page_table=page_table,
|
327
|
+
cache_seqlens=metadata.cache_seqlens_int32,
|
328
|
+
cu_seqlens_q=metadata.cu_seqlens_q,
|
329
|
+
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
330
|
+
max_seqlen_q=1,
|
331
|
+
softmax_scale=layer.scaling,
|
332
|
+
causal=True,
|
333
|
+
softcap=layer.logit_cap,
|
334
|
+
k_descale=layer.k_scale,
|
335
|
+
v_descale=layer.v_scale,
|
336
|
+
)
|
337
|
+
|
338
|
+
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
339
|
+
|
340
|
+
def init_cuda_graph_state(self, max_bs: int):
|
341
|
+
"""Initialize CUDA graph state for the attention backend.
|
342
|
+
|
343
|
+
Args:
|
344
|
+
max_bs (int): Maximum batch size to support in CUDA graphs
|
345
|
+
|
346
|
+
This creates fixed-size tensors that will be reused during CUDA graph replay
|
347
|
+
to avoid memory allocations.
|
348
|
+
"""
|
349
|
+
# Initialize fixed size tensors for decode operations
|
350
|
+
self.decode_cuda_graph_metadata = {
|
351
|
+
# Page table for token mapping (batch_size, max_context_len)
|
352
|
+
"page_table": torch.zeros(
|
353
|
+
max_bs,
|
354
|
+
(self.max_context_len + self.page_size - 1) // self.page_size,
|
355
|
+
dtype=torch.int32,
|
356
|
+
device=self.device,
|
357
|
+
),
|
358
|
+
"strided_indices": torch.arange(
|
359
|
+
0, self.max_context_len, self.page_size, device=self.device
|
360
|
+
),
|
361
|
+
}
|
362
|
+
|
363
|
+
def init_forward_metadata_capture_cuda_graph(
|
364
|
+
self,
|
365
|
+
bs: int,
|
366
|
+
num_tokens: int,
|
367
|
+
req_pool_indices: torch.Tensor,
|
368
|
+
seq_lens: torch.Tensor,
|
369
|
+
encoder_lens: Optional[torch.Tensor],
|
370
|
+
forward_mode: ForwardMode,
|
371
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
372
|
+
):
|
373
|
+
"""Initialize forward metadata for capturing CUDA graph."""
|
374
|
+
metadata = FlashAttentionMetadata()
|
375
|
+
# Get sequence information
|
376
|
+
metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
|
377
|
+
batch_size = len(seq_lens)
|
378
|
+
device = seq_lens.device
|
379
|
+
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
380
|
+
torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0)
|
381
|
+
)
|
382
|
+
# Precompute maximum sequence length
|
383
|
+
metadata.max_seq_len_k = seq_lens.max().item()
|
384
|
+
# Precompute page table
|
385
|
+
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][
|
386
|
+
req_pool_indices, :
|
387
|
+
]
|
388
|
+
if forward_mode == ForwardMode.DECODE:
|
389
|
+
# Precompute cumulative sequence lengths
|
390
|
+
metadata.cu_seqlens_q = torch.arange(
|
391
|
+
0, batch_size + 1, dtype=torch.int32, device=device
|
392
|
+
)
|
393
|
+
else:
|
394
|
+
raise ValueError("Do not support Prefill Mode cuda graph")
|
395
|
+
self.decode_cuda_graph_metadata[bs] = metadata
|
396
|
+
self.forward_metadata = metadata
|
397
|
+
|
398
|
+
def init_forward_metadata_replay_cuda_graph(
|
399
|
+
self,
|
400
|
+
bs: int,
|
401
|
+
req_pool_indices: torch.Tensor,
|
402
|
+
seq_lens: torch.Tensor,
|
403
|
+
seq_lens_sum: int,
|
404
|
+
encoder_lens: Optional[torch.Tensor],
|
405
|
+
forward_mode: ForwardMode,
|
406
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
407
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
408
|
+
):
|
409
|
+
# """Initialize forward metadata for replaying CUDA graph."""
|
410
|
+
metadata = self.decode_cuda_graph_metadata[bs]
|
411
|
+
|
412
|
+
# For CPU operations
|
413
|
+
max_len = seq_lens_cpu[:bs].max().item()
|
414
|
+
metadata.max_seq_len_k = max_len
|
415
|
+
|
416
|
+
# For GPU operations
|
417
|
+
seq_lens_in_batch = seq_lens[:bs]
|
418
|
+
metadata.cache_seqlens_int32 = seq_lens_in_batch.to(torch.int32)
|
419
|
+
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
420
|
+
torch.cumsum(seq_lens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
421
|
+
)
|
422
|
+
|
423
|
+
max_seq_pages = (metadata.max_seq_len_k + self.page_size - 1) // self.page_size
|
424
|
+
page_indices = self.req_to_token[
|
425
|
+
:, self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages]
|
426
|
+
]
|
427
|
+
page_indices = page_indices[req_pool_indices[:bs]] // self.page_size
|
428
|
+
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
429
|
+
metadata.page_table[:, max_seq_pages:].fill_(0)
|
430
|
+
self.forward_metadata = metadata
|
431
|
+
|
432
|
+
def get_cuda_graph_seq_len_fill_value(self):
|
433
|
+
"""Get the fill value for sequence length in CUDA graph."""
|
434
|
+
return 0
|
@@ -1008,7 +1008,7 @@ class FlashInferMultiStepDraftBackend:
|
|
1008
1008
|
global_override_indptr_cpu = None
|
1009
1009
|
|
1010
1010
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
1011
|
-
kv_indices = torch.
|
1011
|
+
kv_indices = torch.empty(
|
1012
1012
|
(
|
1013
1013
|
self.speculative_num_steps,
|
1014
1014
|
forward_batch.batch_size * self.topk * self.max_context_len,
|