sglang 0.4.4__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +3 -1
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +667 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +63 -11
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/parallel_state.py +10 -3
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +71 -12
- sglang/srt/function_call_parser.py +164 -54
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +295 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +62 -23
- sglang/srt/layers/elementwise.py +411 -0
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +26 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/router.py +342 -0
- sglang/srt/layers/moe/topk.py +31 -18
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +184 -126
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +24 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -9
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +66 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/layers.py +68 -0
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +47 -23
- sglang/srt/lora/mem_pool.py +110 -51
- sglang/srt/lora/utils.py +12 -1
- sglang/srt/managers/cache_controller.py +4 -5
- sglang/srt/managers/data_parallel_controller.py +31 -9
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +39 -3
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +134 -31
- sglang/srt/managers/scheduler.py +325 -38
- sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -23
- sglang/srt/managers/tp_worker.py +1 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +27 -8
- sglang/srt/mem_cache/memory_pool.py +258 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +85 -28
- sglang/srt/model_executor/forward_batch_info.py +81 -15
- sglang/srt/model_executor/model_runner.py +70 -6
- sglang/srt/model_loader/loader.py +160 -2
- sglang/srt/model_loader/weight_utils.py +45 -0
- sglang/srt/models/deepseek_janus_pro.py +29 -86
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +326 -192
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +684 -0
- sglang/srt/models/gemma3_mm.py +462 -0
- sglang/srt/models/grok.py +374 -119
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +145 -47
- sglang/srt/openai_api/protocol.py +23 -2
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +104 -14
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +139 -53
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +182 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +2 -0
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +55 -4
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +171 -125
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -33,16 +33,18 @@ from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_N
|
|
33
33
|
from sglang.srt.configs import (
|
34
34
|
ChatGLMConfig,
|
35
35
|
DbrxConfig,
|
36
|
+
DeepseekVL2Config,
|
36
37
|
ExaoneConfig,
|
37
38
|
MultiModalityConfig,
|
38
|
-
Qwen2_5_VLConfig,
|
39
39
|
)
|
40
|
+
from sglang.srt.connector import create_remote_connector
|
41
|
+
from sglang.srt.utils import is_remote_url
|
40
42
|
|
41
43
|
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
42
44
|
ChatGLMConfig.model_type: ChatGLMConfig,
|
43
45
|
DbrxConfig.model_type: DbrxConfig,
|
44
46
|
ExaoneConfig.model_type: ExaoneConfig,
|
45
|
-
|
47
|
+
DeepseekVL2Config.model_type: DeepseekVL2Config,
|
46
48
|
MultiModalityConfig.model_type: MultiModalityConfig,
|
47
49
|
}
|
48
50
|
|
@@ -155,6 +157,14 @@ def get_tokenizer(
|
|
155
157
|
kwargs["gguf_file"] = tokenizer_name
|
156
158
|
tokenizer_name = Path(tokenizer_name).parent
|
157
159
|
|
160
|
+
if is_remote_url(tokenizer_name):
|
161
|
+
# BaseConnector implements __del__() to clean up the local dir.
|
162
|
+
# Since config files need to exist all the time, so we DO NOT use
|
163
|
+
# with statement to avoid closing the client.
|
164
|
+
client = create_remote_connector(tokenizer_name)
|
165
|
+
client.pull_files(ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
|
166
|
+
tokenizer_name = client.get_local_dir()
|
167
|
+
|
158
168
|
try:
|
159
169
|
tokenizer = AutoTokenizer.from_pretrained(
|
160
170
|
tokenizer_name,
|
@@ -207,11 +217,26 @@ def get_processor(
|
|
207
217
|
tokenizer_revision: Optional[str] = None,
|
208
218
|
**kwargs,
|
209
219
|
):
|
220
|
+
# pop 'revision' from kwargs if present.
|
221
|
+
revision = kwargs.pop("revision", tokenizer_revision)
|
222
|
+
|
223
|
+
config = AutoConfig.from_pretrained(
|
224
|
+
tokenizer_name,
|
225
|
+
trust_remote_code=trust_remote_code,
|
226
|
+
revision=revision,
|
227
|
+
**kwargs,
|
228
|
+
)
|
229
|
+
|
230
|
+
# fix: for Qwen2-VL model, inject default 'size' if not provided.
|
231
|
+
if config.model_type in {"qwen2_vl"}:
|
232
|
+
if "size" not in kwargs:
|
233
|
+
kwargs["size"] = {"shortest_edge": 3136, "longest_edge": 1003520}
|
234
|
+
|
210
235
|
processor = AutoProcessor.from_pretrained(
|
211
236
|
tokenizer_name,
|
212
237
|
*args,
|
213
238
|
trust_remote_code=trust_remote_code,
|
214
|
-
|
239
|
+
revision=revision,
|
215
240
|
**kwargs,
|
216
241
|
)
|
217
242
|
|
sglang/srt/layers/activation.py
CHANGED
@@ -23,7 +23,9 @@ import torch.nn.functional as F
|
|
23
23
|
|
24
24
|
from sglang.srt.utils import is_cuda_available
|
25
25
|
|
26
|
-
|
26
|
+
_is_cuda = is_cuda_available()
|
27
|
+
|
28
|
+
if _is_cuda:
|
27
29
|
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
28
30
|
|
29
31
|
from sglang.srt.custom_op import CustomOp
|
@@ -165,7 +167,7 @@ def get_act_fn(
|
|
165
167
|
return act_fn
|
166
168
|
|
167
169
|
|
168
|
-
if not
|
170
|
+
if not _is_cuda:
|
169
171
|
logger.info(
|
170
172
|
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
|
171
173
|
)
|
@@ -47,7 +47,7 @@ class AttentionBackend(ABC):
|
|
47
47
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
48
48
|
seq_lens_cpu: Optional[torch.Tensor],
|
49
49
|
):
|
50
|
-
"""Init the metadata for a forward pass for
|
50
|
+
"""Init the metadata for a forward pass for replaying a cuda graph."""
|
51
51
|
raise NotImplementedError()
|
52
52
|
|
53
53
|
def get_cuda_graph_seq_len_fill_value(self):
|
@@ -0,0 +1,295 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
4
|
+
|
5
|
+
"""
|
6
|
+
Support different attention backends.
|
7
|
+
Now there are three backends: FlashInfer, Triton and FlashAttention.
|
8
|
+
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
|
9
|
+
"""
|
10
|
+
|
11
|
+
from dataclasses import dataclass
|
12
|
+
from typing import TYPE_CHECKING, Optional, Union
|
13
|
+
|
14
|
+
import torch
|
15
|
+
|
16
|
+
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
17
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
18
|
+
|
19
|
+
if TYPE_CHECKING:
|
20
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
21
|
+
from sglang.srt.model_executor.model_runner import ModelRunner
|
22
|
+
|
23
|
+
from flash_attn_interface import flash_attn_with_kvcache
|
24
|
+
|
25
|
+
|
26
|
+
@dataclass
|
27
|
+
class FlashAttentionMetadata:
|
28
|
+
"""Metadata for decode operations to avoid redundant computations."""
|
29
|
+
|
30
|
+
cu_seqlens_q: torch.Tensor = None
|
31
|
+
cu_seqlens_k: torch.Tensor = None
|
32
|
+
max_seq_len_k: int = 0
|
33
|
+
window_size: tuple = (-1, -1)
|
34
|
+
page_table: torch.Tensor = None
|
35
|
+
cache_seqlens_int32: torch.Tensor = None
|
36
|
+
max_seq_len_q: int = 0
|
37
|
+
|
38
|
+
|
39
|
+
class FlashAttentionBackend(AttentionBackend):
|
40
|
+
"""FlashAttention backend implementation."""
|
41
|
+
|
42
|
+
def __init__(
|
43
|
+
self,
|
44
|
+
model_runner: ModelRunner,
|
45
|
+
skip_prefill: bool = False,
|
46
|
+
):
|
47
|
+
super().__init__()
|
48
|
+
|
49
|
+
assert not (
|
50
|
+
model_runner.sliding_window_size is not None
|
51
|
+
and model_runner.model_config.is_encoder_decoder
|
52
|
+
), "Sliding window and cross attention are not supported together"
|
53
|
+
|
54
|
+
# Initialize metadata
|
55
|
+
self.forward_metadata: FlashAttentionMetadata = None
|
56
|
+
self.max_context_len = model_runner.model_config.context_len
|
57
|
+
self.device = model_runner.device
|
58
|
+
self.decode_cuda_graph_metadata = {}
|
59
|
+
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
60
|
+
|
61
|
+
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
62
|
+
"""Initialize forward metadata to cache repetitive calculations."""
|
63
|
+
# Create metadata based on forward mode
|
64
|
+
metadata = FlashAttentionMetadata()
|
65
|
+
|
66
|
+
extend_seq_lens = forward_batch.extend_seq_lens
|
67
|
+
# Get sequence information
|
68
|
+
seqlens_in_batch = forward_batch.seq_lens
|
69
|
+
# Precompute int32 version of sequence lengths
|
70
|
+
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
71
|
+
batch_size = len(seqlens_in_batch)
|
72
|
+
device = seqlens_in_batch.device
|
73
|
+
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
74
|
+
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
75
|
+
)
|
76
|
+
# Precompute maximum sequence length
|
77
|
+
metadata.max_seq_len_k = seqlens_in_batch.max().item()
|
78
|
+
# Precompute page table
|
79
|
+
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
80
|
+
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
81
|
+
]
|
82
|
+
if forward_batch.forward_mode == ForwardMode.DECODE:
|
83
|
+
# Precompute cumulative sequence lengths
|
84
|
+
metadata.cu_seqlens_q = torch.arange(
|
85
|
+
0, batch_size + 1, dtype=torch.int32, device=device
|
86
|
+
)
|
87
|
+
else:
|
88
|
+
extend_no_prefix = not any(forward_batch.extend_prefix_lens)
|
89
|
+
# Precompute cumulative sequence lengths
|
90
|
+
if not extend_no_prefix:
|
91
|
+
metadata.cu_seqlens_q = torch.nn.functional.pad(
|
92
|
+
torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0)
|
93
|
+
)
|
94
|
+
else:
|
95
|
+
metadata.cu_seqlens_q = metadata.cu_seqlens_k
|
96
|
+
metadata.max_seq_len_q = seqlens_in_batch.max().item()
|
97
|
+
self.forward_metadata = metadata
|
98
|
+
|
99
|
+
def forward_extend(
|
100
|
+
self,
|
101
|
+
q: torch.Tensor,
|
102
|
+
k: torch.Tensor,
|
103
|
+
v: torch.Tensor,
|
104
|
+
layer: RadixAttention,
|
105
|
+
forward_batch: ForwardBatch,
|
106
|
+
save_kv_cache=True,
|
107
|
+
):
|
108
|
+
cache_loc = (
|
109
|
+
forward_batch.out_cache_loc
|
110
|
+
if not layer.is_cross_attention
|
111
|
+
else forward_batch.encoder_out_cache_loc
|
112
|
+
)
|
113
|
+
|
114
|
+
if k is not None:
|
115
|
+
assert v is not None
|
116
|
+
if save_kv_cache:
|
117
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
118
|
+
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
119
|
+
)
|
120
|
+
|
121
|
+
# Use precomputed metadata
|
122
|
+
metadata = self.forward_metadata
|
123
|
+
|
124
|
+
# # Use Flash Attention for prefill
|
125
|
+
# Calculate window size (can be moved to metadata if layer properties don't change)
|
126
|
+
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
|
127
|
+
# here is two side inclusive
|
128
|
+
window_size = (
|
129
|
+
(layer.sliding_window_size, 0)
|
130
|
+
if layer.sliding_window_size is not None
|
131
|
+
else (-1, -1)
|
132
|
+
)
|
133
|
+
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
134
|
+
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
135
|
+
o = flash_attn_with_kvcache(
|
136
|
+
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
137
|
+
k_cache=key_cache.unsqueeze(1),
|
138
|
+
v_cache=value_cache.unsqueeze(1),
|
139
|
+
page_table=metadata.page_table,
|
140
|
+
cache_seqlens=metadata.cache_seqlens_int32,
|
141
|
+
cu_seqlens_q=metadata.cu_seqlens_q,
|
142
|
+
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
143
|
+
max_seqlen_q=metadata.max_seq_len_q,
|
144
|
+
softmax_scale=layer.scaling,
|
145
|
+
causal=True,
|
146
|
+
window_size=window_size,
|
147
|
+
softcap=layer.logit_cap,
|
148
|
+
k_descale=layer.k_scale,
|
149
|
+
v_descale=layer.v_scale,
|
150
|
+
)
|
151
|
+
|
152
|
+
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
153
|
+
|
154
|
+
def forward_decode(
|
155
|
+
self,
|
156
|
+
q: torch.Tensor,
|
157
|
+
k: torch.Tensor,
|
158
|
+
v: torch.Tensor,
|
159
|
+
layer: RadixAttention,
|
160
|
+
forward_batch: ForwardBatch,
|
161
|
+
save_kv_cache=True,
|
162
|
+
) -> torch.Tensor:
|
163
|
+
"""Forward pass with FlashAttention using precomputed metadata."""
|
164
|
+
# Save KV cache if needed
|
165
|
+
if k is not None and v is not None and save_kv_cache:
|
166
|
+
cache_loc = (
|
167
|
+
forward_batch.out_cache_loc
|
168
|
+
if not layer.is_cross_attention
|
169
|
+
else forward_batch.encoder_out_cache_loc
|
170
|
+
)
|
171
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
172
|
+
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
173
|
+
)
|
174
|
+
|
175
|
+
# Get KV cache
|
176
|
+
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
177
|
+
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
178
|
+
|
179
|
+
# Use precomputed metadata
|
180
|
+
metadata = self.forward_metadata
|
181
|
+
|
182
|
+
# Pre-reshape query tensor
|
183
|
+
q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
184
|
+
|
185
|
+
# Calculate window size (can be moved to metadata if layer properties don't change)
|
186
|
+
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
|
187
|
+
# here is two side inclusive
|
188
|
+
window_size = (
|
189
|
+
(layer.sliding_window_size, 0)
|
190
|
+
if layer.sliding_window_size is not None
|
191
|
+
else (-1, -1)
|
192
|
+
)
|
193
|
+
# Run attention with precomputed values
|
194
|
+
o = flash_attn_with_kvcache(
|
195
|
+
q=q_reshaped,
|
196
|
+
k_cache=key_cache.unsqueeze(1),
|
197
|
+
v_cache=value_cache.unsqueeze(1),
|
198
|
+
page_table=metadata.page_table,
|
199
|
+
cache_seqlens=metadata.cache_seqlens_int32,
|
200
|
+
cu_seqlens_q=metadata.cu_seqlens_q,
|
201
|
+
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
202
|
+
max_seqlen_q=1,
|
203
|
+
softmax_scale=layer.scaling,
|
204
|
+
causal=True,
|
205
|
+
window_size=window_size,
|
206
|
+
softcap=layer.logit_cap,
|
207
|
+
k_descale=layer.k_scale,
|
208
|
+
v_descale=layer.v_scale,
|
209
|
+
)
|
210
|
+
|
211
|
+
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
212
|
+
|
213
|
+
def init_cuda_graph_state(self, max_bs: int):
|
214
|
+
"""Initialize CUDA graph state for the attention backend.
|
215
|
+
|
216
|
+
Args:
|
217
|
+
max_bs (int): Maximum batch size to support in CUDA graphs
|
218
|
+
|
219
|
+
This creates fixed-size tensors that will be reused during CUDA graph replay
|
220
|
+
to avoid memory allocations.
|
221
|
+
"""
|
222
|
+
# Initialize fixed size tensors for decode operations
|
223
|
+
self.decode_cuda_graph_metadata = {
|
224
|
+
# Page table for token mapping (batch_size, max_context_len)
|
225
|
+
"page_table": torch.zeros(
|
226
|
+
max_bs, self.max_context_len, dtype=torch.int32, device=self.device
|
227
|
+
),
|
228
|
+
}
|
229
|
+
|
230
|
+
def init_forward_metadata_capture_cuda_graph(
|
231
|
+
self,
|
232
|
+
bs: int,
|
233
|
+
num_tokens: int,
|
234
|
+
req_pool_indices: torch.Tensor,
|
235
|
+
seq_lens: torch.Tensor,
|
236
|
+
encoder_lens: Optional[torch.Tensor],
|
237
|
+
forward_mode: ForwardMode,
|
238
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
239
|
+
):
|
240
|
+
"""Initialize forward metadata for capturing CUDA graph."""
|
241
|
+
metadata = FlashAttentionMetadata()
|
242
|
+
# Get sequence information
|
243
|
+
metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
|
244
|
+
batch_size = len(seq_lens)
|
245
|
+
device = seq_lens.device
|
246
|
+
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
247
|
+
torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0)
|
248
|
+
)
|
249
|
+
# Precompute maximum sequence length
|
250
|
+
metadata.max_seq_len_k = seq_lens.max().item()
|
251
|
+
# Precompute page table
|
252
|
+
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][
|
253
|
+
req_pool_indices, :
|
254
|
+
]
|
255
|
+
if forward_mode == ForwardMode.DECODE:
|
256
|
+
# Precompute cumulative sequence lengths
|
257
|
+
metadata.cu_seqlens_q = torch.arange(
|
258
|
+
0, batch_size + 1, dtype=torch.int32, device=device
|
259
|
+
)
|
260
|
+
else:
|
261
|
+
raise ValueError("Do not support Prefill Mode cuda graph")
|
262
|
+
self.decode_cuda_graph_metadata[bs] = metadata
|
263
|
+
self.forward_metadata = metadata
|
264
|
+
|
265
|
+
def init_forward_metadata_replay_cuda_graph(
|
266
|
+
self,
|
267
|
+
bs: int,
|
268
|
+
req_pool_indices: torch.Tensor,
|
269
|
+
seq_lens: torch.Tensor,
|
270
|
+
seq_lens_sum: int,
|
271
|
+
encoder_lens: Optional[torch.Tensor],
|
272
|
+
forward_mode: ForwardMode,
|
273
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
274
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
275
|
+
):
|
276
|
+
# """Initialize forward metadata for replaying CUDA graph."""
|
277
|
+
seqlens_in_batch = seq_lens[:bs]
|
278
|
+
metadata = self.decode_cuda_graph_metadata[bs]
|
279
|
+
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
280
|
+
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
281
|
+
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
282
|
+
)
|
283
|
+
# Precompute maximum sequence length
|
284
|
+
metadata.max_seq_len_k = seqlens_in_batch.max().item()
|
285
|
+
# Only zero out the part out of max_len_k
|
286
|
+
metadata.page_table[:, metadata.max_seq_len_k :].fill_(0)
|
287
|
+
# Then do the copy
|
288
|
+
metadata.page_table[:, : metadata.max_seq_len_k].copy_(
|
289
|
+
self.req_to_token[req_pool_indices[:bs], : metadata.max_seq_len_k]
|
290
|
+
)
|
291
|
+
self.forward_decode_metadata = metadata
|
292
|
+
|
293
|
+
def get_cuda_graph_seq_len_fill_value(self):
|
294
|
+
"""Get the fill value for sequence length in CUDA graph."""
|
295
|
+
return 0
|
@@ -1008,7 +1008,7 @@ class FlashInferMultiStepDraftBackend:
|
|
1008
1008
|
global_override_indptr_cpu = None
|
1009
1009
|
|
1010
1010
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
1011
|
-
kv_indices = torch.
|
1011
|
+
kv_indices = torch.empty(
|
1012
1012
|
(
|
1013
1013
|
self.speculative_num_steps,
|
1014
1014
|
forward_batch.batch_size * self.topk * self.max_context_len,
|
@@ -0,0 +1,284 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
"""
|
4
|
+
Support attention backend for FlashMLA.
|
5
|
+
|
6
|
+
#TODO
|
7
|
+
Enable speculative sampling in FlashMLA
|
8
|
+
"""
|
9
|
+
|
10
|
+
from dataclasses import dataclass
|
11
|
+
from typing import TYPE_CHECKING, Optional, Union
|
12
|
+
|
13
|
+
import torch
|
14
|
+
import triton
|
15
|
+
from flash_mla import flash_mla_with_kvcache, get_mla_metadata
|
16
|
+
|
17
|
+
from sglang.global_config import global_config
|
18
|
+
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
19
|
+
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
|
20
|
+
from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
|
21
|
+
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
22
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
23
|
+
|
24
|
+
if TYPE_CHECKING:
|
25
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
26
|
+
from sglang.srt.model_executor.model_runner import ModelRunner
|
27
|
+
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
28
|
+
from sglang.srt.speculative.spec_info import SpecInfo
|
29
|
+
|
30
|
+
|
31
|
+
# FlashMLA only supports pagesize=64
|
32
|
+
PAGE_SIZE = 64
|
33
|
+
# TODO The current setup is hard-coded and will be changed after integrating with MTP.
|
34
|
+
Q_LEN = 1
|
35
|
+
|
36
|
+
|
37
|
+
@dataclass
|
38
|
+
class FlashMLADecodeMetadata:
|
39
|
+
flashmla_metadata: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
|
40
|
+
num_splits: Optional[torch.Tensor] = None
|
41
|
+
block_kv_indices: Optional[torch.Tensor] = None
|
42
|
+
|
43
|
+
def __init__(
|
44
|
+
self,
|
45
|
+
flashmla_metadata: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
46
|
+
num_splits: Optional[torch.Tensor] = None,
|
47
|
+
block_kv_indices: Optional[torch.Tensor] = None,
|
48
|
+
):
|
49
|
+
self.flashmla_metadata = flashmla_metadata
|
50
|
+
self.num_splits = num_splits
|
51
|
+
self.block_kv_indices = block_kv_indices
|
52
|
+
|
53
|
+
|
54
|
+
class FlashMLABackend(FlashInferMLAAttnBackend):
|
55
|
+
"""Flashinfer attention kernels."""
|
56
|
+
|
57
|
+
def __init__(
|
58
|
+
self,
|
59
|
+
model_runner: ModelRunner,
|
60
|
+
skip_prefill: bool = False,
|
61
|
+
kv_indptr_buf: Optional[torch.Tensor] = None,
|
62
|
+
kv_last_page_len_buf: Optional[torch.Tensor] = None,
|
63
|
+
):
|
64
|
+
super().__init__(
|
65
|
+
model_runner, skip_prefill, kv_indptr_buf, kv_last_page_len_buf
|
66
|
+
)
|
67
|
+
|
68
|
+
self.num_q_heads = (
|
69
|
+
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
70
|
+
)
|
71
|
+
self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
|
72
|
+
get_attention_tp_size()
|
73
|
+
)
|
74
|
+
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
75
|
+
self.num_local_heads = (
|
76
|
+
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
77
|
+
)
|
78
|
+
self.forward_metadata: Union[FlashMLADecodeMetadata] = None
|
79
|
+
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
|
80
|
+
self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
|
81
|
+
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
|
82
|
+
self.v_head_dim = model_runner.model_config.v_head_dim
|
83
|
+
self.scaling = model_runner.model_config.scaling
|
84
|
+
self.data_type = model_runner.kv_cache_dtype
|
85
|
+
self.q_data_type = model_runner.dtype
|
86
|
+
self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim
|
87
|
+
|
88
|
+
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
89
|
+
|
90
|
+
bs = forward_batch.batch_size
|
91
|
+
spec_info = forward_batch.spec_info
|
92
|
+
if forward_batch.forward_mode.is_decode_or_idle():
|
93
|
+
if spec_info is None:
|
94
|
+
max_seqlen_pad = triton.cdiv(
|
95
|
+
forward_batch.decode_seq_lens_cpu.max().item(), PAGE_SIZE
|
96
|
+
)
|
97
|
+
block_kv_indices = torch.full(
|
98
|
+
(bs, max_seqlen_pad),
|
99
|
+
-1,
|
100
|
+
dtype=torch.int32,
|
101
|
+
device=forward_batch.seq_lens.device,
|
102
|
+
)
|
103
|
+
create_flashmla_kv_indices_triton[(bs,)](
|
104
|
+
self.req_to_token,
|
105
|
+
forward_batch.req_pool_indices,
|
106
|
+
forward_batch.seq_lens,
|
107
|
+
None,
|
108
|
+
block_kv_indices,
|
109
|
+
self.req_to_token.stride(0),
|
110
|
+
max_seqlen_pad,
|
111
|
+
)
|
112
|
+
mla_metadata, num_splits = get_mla_metadata(
|
113
|
+
forward_batch.seq_lens.to(torch.int32),
|
114
|
+
Q_LEN * self.num_q_heads // self.num_kv_heads,
|
115
|
+
self.num_kv_heads,
|
116
|
+
)
|
117
|
+
self.forward_metadata = FlashMLADecodeMetadata(
|
118
|
+
mla_metadata,
|
119
|
+
num_splits,
|
120
|
+
block_kv_indices,
|
121
|
+
)
|
122
|
+
else:
|
123
|
+
super().init_forward_metadata(forward_batch)
|
124
|
+
else:
|
125
|
+
super().init_forward_metadata(forward_batch)
|
126
|
+
|
127
|
+
def init_cuda_graph_state(
|
128
|
+
self,
|
129
|
+
max_bs: int,
|
130
|
+
block_kv_indices: Optional[torch.Tensor] = None,
|
131
|
+
):
|
132
|
+
if block_kv_indices is None:
|
133
|
+
cuda_graph_kv_indices = torch.full(
|
134
|
+
(max_bs, (self.max_context_len + PAGE_SIZE) // PAGE_SIZE),
|
135
|
+
1,
|
136
|
+
dtype=torch.int32,
|
137
|
+
device="cuda",
|
138
|
+
)
|
139
|
+
else:
|
140
|
+
cuda_graph_kv_indices = block_kv_indices
|
141
|
+
|
142
|
+
self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = get_mla_metadata(
|
143
|
+
torch.ones(max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device),
|
144
|
+
Q_LEN * self.num_q_heads // self.num_kv_heads,
|
145
|
+
self.num_kv_heads,
|
146
|
+
)
|
147
|
+
self.cuda_graph_kv_indices = cuda_graph_kv_indices
|
148
|
+
|
149
|
+
def init_forward_metadata_capture_cuda_graph(
|
150
|
+
self,
|
151
|
+
bs: int,
|
152
|
+
num_tokens: int,
|
153
|
+
req_pool_indices: torch.Tensor,
|
154
|
+
seq_lens: torch.Tensor,
|
155
|
+
encoder_lens: Optional[torch.Tensor],
|
156
|
+
forward_mode: ForwardMode,
|
157
|
+
spec_info: Optional[SpecInfo],
|
158
|
+
):
|
159
|
+
if forward_mode.is_decode_or_idle():
|
160
|
+
if spec_info is None:
|
161
|
+
max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
|
162
|
+
|
163
|
+
create_flashmla_kv_indices_triton[(bs,)](
|
164
|
+
self.req_to_token,
|
165
|
+
req_pool_indices,
|
166
|
+
seq_lens,
|
167
|
+
None,
|
168
|
+
self.cuda_graph_kv_indices,
|
169
|
+
self.req_to_token.stride(0),
|
170
|
+
self.cuda_graph_kv_indices.stride(0),
|
171
|
+
)
|
172
|
+
mla_metadata, num_splits = get_mla_metadata(
|
173
|
+
seq_lens.to(torch.int32),
|
174
|
+
Q_LEN * self.num_q_heads // self.num_kv_heads,
|
175
|
+
self.num_kv_heads,
|
176
|
+
)
|
177
|
+
self.cuda_graph_mla_metadata.copy_(mla_metadata)
|
178
|
+
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
|
179
|
+
self.forward_metadata = FlashMLADecodeMetadata(
|
180
|
+
self.cuda_graph_mla_metadata,
|
181
|
+
self.cuda_graph_num_splits[: bs + 1],
|
182
|
+
self.cuda_graph_kv_indices[:bs, :max_seqlen_pad],
|
183
|
+
)
|
184
|
+
|
185
|
+
else:
|
186
|
+
super().init_forward_metadata_capture_cuda_graph(
|
187
|
+
bs,
|
188
|
+
num_tokens,
|
189
|
+
req_pool_indices,
|
190
|
+
seq_lens,
|
191
|
+
encoder_lens,
|
192
|
+
forward_mode,
|
193
|
+
spec_info,
|
194
|
+
)
|
195
|
+
|
196
|
+
def init_forward_metadata_replay_cuda_graph(
|
197
|
+
self,
|
198
|
+
bs: int,
|
199
|
+
req_pool_indices: torch.Tensor,
|
200
|
+
seq_lens: torch.Tensor,
|
201
|
+
seq_lens_sum: int,
|
202
|
+
encoder_lens: Optional[torch.Tensor],
|
203
|
+
forward_mode: ForwardMode,
|
204
|
+
spec_info: Optional[SpecInfo],
|
205
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
206
|
+
):
|
207
|
+
|
208
|
+
if forward_mode.is_decode_or_idle():
|
209
|
+
assert seq_lens_cpu is not None
|
210
|
+
seq_lens = seq_lens[:bs]
|
211
|
+
seq_lens_cpu = seq_lens_cpu[:bs]
|
212
|
+
max_seqlen_pad = triton.cdiv(seq_lens_cpu.max().item(), PAGE_SIZE)
|
213
|
+
create_flashmla_kv_indices_triton[(bs,)](
|
214
|
+
self.req_to_token,
|
215
|
+
req_pool_indices[:bs],
|
216
|
+
seq_lens,
|
217
|
+
None,
|
218
|
+
self.cuda_graph_kv_indices,
|
219
|
+
self.req_to_token.stride(0),
|
220
|
+
self.cuda_graph_kv_indices.stride(0),
|
221
|
+
)
|
222
|
+
mla_metadata, num_splits = get_mla_metadata(
|
223
|
+
seq_lens.to(torch.int32),
|
224
|
+
Q_LEN * self.num_q_heads // self.num_kv_heads,
|
225
|
+
self.num_kv_heads,
|
226
|
+
)
|
227
|
+
self.cuda_graph_mla_metadata.copy_(mla_metadata)
|
228
|
+
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
|
229
|
+
self.forward_metadata.mla_metadata = self.cuda_graph_mla_metadata
|
230
|
+
self.forward_metadata.num_splits = self.cuda_graph_num_splits[: bs + 1]
|
231
|
+
self.forward_metadata.block_kv_indices = self.cuda_graph_kv_indices[
|
232
|
+
:bs, :max_seqlen_pad
|
233
|
+
]
|
234
|
+
|
235
|
+
else:
|
236
|
+
super().init_forward_metadata_replay_cuda_graph(
|
237
|
+
bs,
|
238
|
+
req_pool_indices,
|
239
|
+
seq_lens,
|
240
|
+
seq_lens_sum,
|
241
|
+
encoder_lens,
|
242
|
+
forward_mode,
|
243
|
+
spec_info,
|
244
|
+
seq_lens_cpu,
|
245
|
+
)
|
246
|
+
|
247
|
+
def forward_decode(
|
248
|
+
self,
|
249
|
+
q: torch.Tensor,
|
250
|
+
k: torch.Tensor,
|
251
|
+
v: torch.Tensor,
|
252
|
+
layer: RadixAttention,
|
253
|
+
forward_batch: ForwardBatch,
|
254
|
+
save_kv_cache: bool = True,
|
255
|
+
):
|
256
|
+
cache_loc = forward_batch.out_cache_loc
|
257
|
+
|
258
|
+
if k is not None:
|
259
|
+
assert v is not None
|
260
|
+
if save_kv_cache:
|
261
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
262
|
+
layer,
|
263
|
+
cache_loc,
|
264
|
+
k,
|
265
|
+
v,
|
266
|
+
)
|
267
|
+
bs = forward_batch.batch_size
|
268
|
+
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
269
|
+
|
270
|
+
reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
|
271
|
+
|
272
|
+
o, _ = flash_mla_with_kvcache(
|
273
|
+
q=reshape_q,
|
274
|
+
k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim),
|
275
|
+
block_table=self.forward_metadata.block_kv_indices,
|
276
|
+
cache_seqlens=forward_batch.seq_lens.to(torch.int32),
|
277
|
+
head_dim_v=self.kv_lora_rank, # TODO Retrieve from config.
|
278
|
+
tile_scheduler_metadata=self.forward_metadata.flashmla_metadata,
|
279
|
+
num_splits=self.forward_metadata.num_splits,
|
280
|
+
softmax_scale=layer.scaling,
|
281
|
+
causal=False,
|
282
|
+
)
|
283
|
+
|
284
|
+
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|