sglang 0.4.5.post3__py3-none-any.whl → 0.4.6.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +19 -3
- sglang/bench_serving.py +8 -9
- sglang/compile_deep_gemm.py +45 -4
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +1 -1
- sglang/srt/configs/model_config.py +9 -3
- sglang/srt/constrained/llguidance_backend.py +78 -61
- sglang/srt/conversation.py +34 -1
- sglang/srt/disaggregation/decode.py +67 -13
- sglang/srt/disaggregation/fake/__init__.py +1 -0
- sglang/srt/disaggregation/fake/conn.py +88 -0
- sglang/srt/disaggregation/mini_lb.py +45 -8
- sglang/srt/disaggregation/mooncake/conn.py +198 -31
- sglang/srt/disaggregation/prefill.py +36 -12
- sglang/srt/disaggregation/utils.py +16 -2
- sglang/srt/entrypoints/engine.py +9 -0
- sglang/srt/entrypoints/http_server.py +35 -4
- sglang/srt/function_call_parser.py +77 -5
- sglang/srt/layers/attention/base_attn_backend.py +3 -0
- sglang/srt/layers/attention/cutlass_mla_backend.py +278 -0
- sglang/srt/layers/attention/flashattention_backend.py +28 -10
- sglang/srt/layers/attention/flashmla_backend.py +8 -11
- sglang/srt/layers/attention/utils.py +1 -1
- sglang/srt/layers/attention/vision.py +2 -0
- sglang/srt/layers/layernorm.py +38 -16
- sglang/srt/layers/logits_processor.py +2 -2
- sglang/srt/layers/moe/fused_moe_native.py +2 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=96,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +41 -41
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +20 -17
- sglang/srt/layers/moe/fused_moe_triton/layer.py +15 -17
- sglang/srt/layers/pooler.py +6 -0
- sglang/srt/layers/quantization/awq.py +5 -1
- sglang/srt/layers/quantization/deep_gemm.py +17 -10
- sglang/srt/layers/quantization/fp8.py +20 -22
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/int8_kernel.py +32 -1
- sglang/srt/layers/radix_attention.py +13 -3
- sglang/srt/layers/rotary_embedding.py +170 -126
- sglang/srt/managers/data_parallel_controller.py +10 -3
- sglang/srt/managers/io_struct.py +7 -0
- sglang/srt/managers/mm_utils.py +85 -28
- sglang/srt/managers/multimodal_processors/base_processor.py +14 -1
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +9 -2
- sglang/srt/managers/multimodal_processors/gemma3.py +2 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +2 -2
- sglang/srt/managers/multimodal_processors/minicpm.py +4 -3
- sglang/srt/managers/multimodal_processors/qwen_vl.py +38 -13
- sglang/srt/managers/schedule_batch.py +38 -12
- sglang/srt/managers/scheduler.py +41 -28
- sglang/srt/managers/scheduler_output_processor_mixin.py +25 -9
- sglang/srt/managers/tokenizer_manager.py +5 -1
- sglang/srt/managers/tp_worker.py +3 -3
- sglang/srt/managers/tp_worker_overlap_thread.py +9 -4
- sglang/srt/mem_cache/memory_pool.py +87 -0
- sglang/srt/model_executor/cuda_graph_runner.py +4 -3
- sglang/srt/model_executor/forward_batch_info.py +51 -95
- sglang/srt/model_executor/model_runner.py +19 -25
- sglang/srt/models/deepseek.py +12 -2
- sglang/srt/models/deepseek_nextn.py +101 -6
- sglang/srt/models/deepseek_v2.py +144 -70
- sglang/srt/models/deepseek_vl2.py +9 -4
- sglang/srt/models/gemma3_causal.py +1 -1
- sglang/srt/models/llama4.py +0 -1
- sglang/srt/models/minicpmo.py +5 -1
- sglang/srt/models/mllama4.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +3 -6
- sglang/srt/models/qwen2_vl.py +3 -7
- sglang/srt/models/roberta.py +178 -0
- sglang/srt/openai_api/adapter.py +50 -11
- sglang/srt/openai_api/protocol.py +2 -0
- sglang/srt/reasoning_parser.py +25 -1
- sglang/srt/server_args.py +31 -24
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/torch_memory_saver_adapter.py +10 -1
- sglang/srt/utils.py +5 -1
- sglang/test/runners.py +6 -13
- sglang/test/send_one.py +84 -28
- sglang/test/test_utils.py +74 -18
- sglang/version.py +1 -1
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.post1.dist-info}/METADATA +5 -6
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.post1.dist-info}/RECORD +97 -80
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.post1.dist-info}/WHEEL +1 -1
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.post1.dist-info}/top_level.txt +0 -0
@@ -42,6 +42,7 @@ from fastapi import FastAPI, File, Form, Request, UploadFile
|
|
42
42
|
from fastapi.middleware.cors import CORSMiddleware
|
43
43
|
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
44
44
|
|
45
|
+
from sglang.srt.disaggregation.utils import FakeBootstrapHost
|
45
46
|
from sglang.srt.entrypoints.engine import _launch_subprocesses
|
46
47
|
from sglang.srt.function_call_parser import FunctionCallParser
|
47
48
|
from sglang.srt.managers.io_struct import (
|
@@ -84,6 +85,7 @@ from sglang.srt.utils import (
|
|
84
85
|
add_api_key_middleware,
|
85
86
|
add_prometheus_middleware,
|
86
87
|
delete_directory,
|
88
|
+
get_bool_env_var,
|
87
89
|
kill_process_tree,
|
88
90
|
set_uvicorn_logging_configs,
|
89
91
|
)
|
@@ -126,7 +128,10 @@ async def lifespan(fast_api_app: FastAPI):
|
|
126
128
|
|
127
129
|
|
128
130
|
# Fast API
|
129
|
-
app = FastAPI(
|
131
|
+
app = FastAPI(
|
132
|
+
lifespan=lifespan,
|
133
|
+
openapi_url=None if get_bool_env_var("DISABLE_OPENAPI_DOC") else "/openapi.json",
|
134
|
+
)
|
130
135
|
app.add_middleware(
|
131
136
|
CORSMiddleware,
|
132
137
|
allow_origins=["*"],
|
@@ -277,7 +282,9 @@ async def generate_from_file_request(file: UploadFile, request: Request):
|
|
277
282
|
)
|
278
283
|
|
279
284
|
try:
|
280
|
-
ret = await _global_state.generate_request(
|
285
|
+
ret = await _global_state.tokenizer_manager.generate_request(
|
286
|
+
obj, request
|
287
|
+
).__anext__()
|
281
288
|
return ret
|
282
289
|
except ValueError as e:
|
283
290
|
logger.error(f"Error: {e}")
|
@@ -815,8 +822,32 @@ def _wait_and_warmup(
|
|
815
822
|
)
|
816
823
|
assert res.status_code == 200, f"{res}"
|
817
824
|
else:
|
818
|
-
|
819
|
-
|
825
|
+
logger.info(f"Start of prefill warmup ...")
|
826
|
+
json_data = {
|
827
|
+
"sampling_params": {
|
828
|
+
"temperature": 0.0,
|
829
|
+
"max_new_tokens": 8,
|
830
|
+
"ignore_eos": True,
|
831
|
+
},
|
832
|
+
"bootstrap_host": [FakeBootstrapHost] * server_args.dp_size,
|
833
|
+
# This is a hack to ensure fake transfer is enabled during prefill warmup
|
834
|
+
# ensure each dp rank has a unique bootstrap_room during prefill warmup
|
835
|
+
"bootstrap_room": [
|
836
|
+
i * (2**63 // server_args.dp_size) + (i % server_args.tp_size)
|
837
|
+
for i in range(server_args.dp_size)
|
838
|
+
],
|
839
|
+
"input_ids": [[0, 1, 2, 3]] * server_args.dp_size,
|
840
|
+
}
|
841
|
+
res = requests.post(
|
842
|
+
url + request_name,
|
843
|
+
json=json_data,
|
844
|
+
headers=headers,
|
845
|
+
timeout=1800, # because of deep gemm precache is very long if not precache.
|
846
|
+
)
|
847
|
+
logger.info(
|
848
|
+
f"End of prefill warmup with status {res.status_code}, resp: {res.json()}"
|
849
|
+
)
|
850
|
+
|
820
851
|
except Exception:
|
821
852
|
last_traceback = get_exception_traceback()
|
822
853
|
if pipe_finish_writer is not None:
|
@@ -491,6 +491,7 @@ class DeepSeekV3Detector(BaseFormatDetector):
|
|
491
491
|
self.eot_token = "<|tool▁calls▁end|>"
|
492
492
|
self.func_call_regex = r"<|tool▁call▁begin|>.*?<|tool▁call▁end|>"
|
493
493
|
self.func_detail_regex = r"<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)\n```json\n(.*)\n```<|tool▁call▁end|>"
|
494
|
+
self._last_arguments = ""
|
494
495
|
|
495
496
|
def has_tool_call(self, text: str) -> bool:
|
496
497
|
"""Check if the text contains a deepseek format tool call."""
|
@@ -528,13 +529,84 @@ class DeepSeekV3Detector(BaseFormatDetector):
|
|
528
529
|
|
529
530
|
def structure_info(self) -> _GetInfoFunc:
|
530
531
|
return lambda name: StructureInfo(
|
531
|
-
begin="
|
532
|
-
|
533
|
-
+ "\n```json\n",
|
534
|
-
end="\n```<|tool▁call▁end|><|tool▁calls▁end|>",
|
535
|
-
trigger="<|tool▁calls▁begin|>",
|
532
|
+
begin=">" + name + "\n```json\n",
|
533
|
+
end="\n```<",
|
534
|
+
trigger=">" + name + "\n```json\n",
|
536
535
|
)
|
537
536
|
|
537
|
+
def parse_streaming_increment(
|
538
|
+
self, new_text: str, tools: List[Tool]
|
539
|
+
) -> StreamingParseResult:
|
540
|
+
"""
|
541
|
+
Streaming incremental parsing tool calls for DeepSeekV3 format.
|
542
|
+
"""
|
543
|
+
self._buffer += new_text
|
544
|
+
current_text = self._buffer
|
545
|
+
|
546
|
+
if self.bot_token not in current_text:
|
547
|
+
self._buffer = ""
|
548
|
+
for e_token in [self.eot_token, "```", "<|tool▁call▁end|>"]:
|
549
|
+
if e_token in new_text:
|
550
|
+
new_text = new_text.replace(e_token, "")
|
551
|
+
return StreamingParseResult(normal_text=new_text)
|
552
|
+
|
553
|
+
if not hasattr(self, "_tool_indices"):
|
554
|
+
self._tool_indices = {
|
555
|
+
tool.function.name: i
|
556
|
+
for i, tool in enumerate(tools)
|
557
|
+
if tool.function and tool.function.name
|
558
|
+
}
|
559
|
+
|
560
|
+
calls: list[ToolCallItem] = []
|
561
|
+
try:
|
562
|
+
partial_match = re.search(
|
563
|
+
pattern=r"<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)\n```json\n(.*)",
|
564
|
+
string=current_text,
|
565
|
+
flags=re.DOTALL,
|
566
|
+
)
|
567
|
+
if partial_match:
|
568
|
+
func_name = partial_match.group(2).strip()
|
569
|
+
func_args_raw = partial_match.group(3).strip()
|
570
|
+
|
571
|
+
if not self.current_tool_name_sent:
|
572
|
+
calls.append(
|
573
|
+
ToolCallItem(
|
574
|
+
tool_index=self._tool_indices.get(func_name, 0),
|
575
|
+
name=func_name,
|
576
|
+
parameters="",
|
577
|
+
)
|
578
|
+
)
|
579
|
+
self.current_tool_name_sent = True
|
580
|
+
else:
|
581
|
+
argument_diff = (
|
582
|
+
func_args_raw[len(self._last_arguments) :]
|
583
|
+
if func_args_raw.startswith(self._last_arguments)
|
584
|
+
else func_args_raw
|
585
|
+
)
|
586
|
+
|
587
|
+
if argument_diff:
|
588
|
+
calls.append(
|
589
|
+
ToolCallItem(
|
590
|
+
tool_index=self._tool_indices.get(func_name, 0),
|
591
|
+
name=None,
|
592
|
+
parameters=argument_diff,
|
593
|
+
)
|
594
|
+
)
|
595
|
+
self._last_arguments += argument_diff
|
596
|
+
|
597
|
+
if _is_complete_json(func_args_raw):
|
598
|
+
result = StreamingParseResult(normal_text="", calls=calls)
|
599
|
+
self._buffer = ""
|
600
|
+
self._last_arguments = ""
|
601
|
+
self.current_tool_name_sent = False
|
602
|
+
return result
|
603
|
+
|
604
|
+
return StreamingParseResult(normal_text="", calls=calls)
|
605
|
+
|
606
|
+
except Exception as e:
|
607
|
+
logger.error(f"Error in parse_streaming_increment: {e}")
|
608
|
+
return StreamingParseResult(normal_text=current_text)
|
609
|
+
|
538
610
|
|
539
611
|
class MultiFormatParser:
|
540
612
|
def __init__(self, detectors: List[BaseFormatDetector]):
|
@@ -62,6 +62,7 @@ class AttentionBackend(ABC):
|
|
62
62
|
layer: RadixAttention,
|
63
63
|
forward_batch: ForwardBatch,
|
64
64
|
save_kv_cache: bool = True,
|
65
|
+
**kwargs,
|
65
66
|
):
|
66
67
|
"""Run forward on an attention layer."""
|
67
68
|
if forward_batch.forward_mode.is_decode():
|
@@ -72,6 +73,7 @@ class AttentionBackend(ABC):
|
|
72
73
|
layer,
|
73
74
|
forward_batch,
|
74
75
|
save_kv_cache=save_kv_cache,
|
76
|
+
**kwargs,
|
75
77
|
)
|
76
78
|
else:
|
77
79
|
return self.forward_extend(
|
@@ -81,6 +83,7 @@ class AttentionBackend(ABC):
|
|
81
83
|
layer,
|
82
84
|
forward_batch,
|
83
85
|
save_kv_cache=save_kv_cache,
|
86
|
+
**kwargs,
|
84
87
|
)
|
85
88
|
|
86
89
|
def forward_decode(
|
@@ -0,0 +1,278 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
"""
|
4
|
+
Support attention backend for Cutlass MLA.
|
5
|
+
|
6
|
+
"""
|
7
|
+
|
8
|
+
from dataclasses import dataclass
|
9
|
+
from typing import TYPE_CHECKING, Optional, Union
|
10
|
+
|
11
|
+
import torch
|
12
|
+
import triton
|
13
|
+
|
14
|
+
from sglang.global_config import global_config
|
15
|
+
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
16
|
+
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
|
17
|
+
from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
|
18
|
+
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
19
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
20
|
+
from sglang.srt.utils import is_cuda
|
21
|
+
|
22
|
+
if TYPE_CHECKING:
|
23
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
24
|
+
from sglang.srt.model_executor.model_runner import ModelRunner
|
25
|
+
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
26
|
+
from sglang.srt.speculative.spec_info import SpecInfo
|
27
|
+
|
28
|
+
_is_cuda = is_cuda()
|
29
|
+
if _is_cuda:
|
30
|
+
from sgl_kernel import cutlass_mla_decode, cutlass_mla_get_workspace_size
|
31
|
+
|
32
|
+
|
33
|
+
# Cutlass MLA only supports pagesize=128
|
34
|
+
PAGE_SIZE = 128
|
35
|
+
|
36
|
+
|
37
|
+
@dataclass
|
38
|
+
class CutlassMLADecodeMetadata:
|
39
|
+
workspace: Optional[torch.Tensor] = None
|
40
|
+
block_kv_indices: Optional[torch.Tensor] = None
|
41
|
+
|
42
|
+
def __init__(
|
43
|
+
self,
|
44
|
+
workspace: Optional[torch.Tensor] = None,
|
45
|
+
block_kv_indices: Optional[torch.Tensor] = None,
|
46
|
+
):
|
47
|
+
self.workspace = workspace
|
48
|
+
self.block_kv_indices = block_kv_indices
|
49
|
+
|
50
|
+
|
51
|
+
class CutlassMLABackend(FlashInferMLAAttnBackend):
|
52
|
+
"""Cutlass attention kernels."""
|
53
|
+
|
54
|
+
def __init__(
|
55
|
+
self,
|
56
|
+
model_runner: ModelRunner,
|
57
|
+
skip_prefill: bool = False,
|
58
|
+
kv_indptr_buf: Optional[torch.Tensor] = None,
|
59
|
+
kv_last_page_len_buf: Optional[torch.Tensor] = None,
|
60
|
+
):
|
61
|
+
super().__init__(
|
62
|
+
model_runner, skip_prefill, kv_indptr_buf, kv_last_page_len_buf
|
63
|
+
)
|
64
|
+
|
65
|
+
self.num_q_heads = (
|
66
|
+
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
67
|
+
)
|
68
|
+
self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
|
69
|
+
get_attention_tp_size()
|
70
|
+
)
|
71
|
+
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
72
|
+
self.num_local_heads = (
|
73
|
+
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
74
|
+
)
|
75
|
+
self.forward_metadata: Union[CutlassMLADecodeMetadata] = None
|
76
|
+
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
|
77
|
+
self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
|
78
|
+
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
|
79
|
+
self.v_head_dim = model_runner.model_config.v_head_dim
|
80
|
+
self.scaling = model_runner.model_config.scaling
|
81
|
+
self.data_type = model_runner.kv_cache_dtype
|
82
|
+
self.q_data_type = model_runner.dtype
|
83
|
+
self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim
|
84
|
+
|
85
|
+
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
86
|
+
|
87
|
+
bs = forward_batch.batch_size
|
88
|
+
spec_info = forward_batch.spec_info
|
89
|
+
if forward_batch.forward_mode.is_decode_or_idle():
|
90
|
+
if spec_info is None:
|
91
|
+
max_seqlen_pad = triton.cdiv(
|
92
|
+
forward_batch.seq_lens_cpu.max().item(), PAGE_SIZE
|
93
|
+
)
|
94
|
+
block_kv_indices = torch.full(
|
95
|
+
(bs, max_seqlen_pad),
|
96
|
+
-1,
|
97
|
+
dtype=torch.int32,
|
98
|
+
device=forward_batch.seq_lens.device,
|
99
|
+
)
|
100
|
+
create_flashmla_kv_indices_triton[(bs,)](
|
101
|
+
self.req_to_token,
|
102
|
+
forward_batch.req_pool_indices,
|
103
|
+
forward_batch.seq_lens,
|
104
|
+
None,
|
105
|
+
block_kv_indices,
|
106
|
+
self.req_to_token.stride(0),
|
107
|
+
max_seqlen_pad,
|
108
|
+
PAGE_SIZE,
|
109
|
+
)
|
110
|
+
workspace_size = cutlass_mla_get_workspace_size(
|
111
|
+
max_seqlen_pad * PAGE_SIZE, bs
|
112
|
+
)
|
113
|
+
workspace = torch.empty(
|
114
|
+
workspace_size, device="cuda", dtype=torch.uint8
|
115
|
+
)
|
116
|
+
self.forward_metadata = CutlassMLADecodeMetadata(
|
117
|
+
workspace,
|
118
|
+
block_kv_indices,
|
119
|
+
)
|
120
|
+
else:
|
121
|
+
super().init_forward_metadata(forward_batch)
|
122
|
+
else:
|
123
|
+
super().init_forward_metadata(forward_batch)
|
124
|
+
|
125
|
+
def init_cuda_graph_state(
|
126
|
+
self,
|
127
|
+
max_bs: int,
|
128
|
+
block_kv_indices: Optional[torch.Tensor] = None,
|
129
|
+
):
|
130
|
+
if block_kv_indices is None:
|
131
|
+
cuda_graph_kv_indices = torch.full(
|
132
|
+
(max_bs, (self.max_context_len + PAGE_SIZE) // PAGE_SIZE),
|
133
|
+
1,
|
134
|
+
dtype=torch.int32,
|
135
|
+
device="cuda",
|
136
|
+
)
|
137
|
+
else:
|
138
|
+
cuda_graph_kv_indices = block_kv_indices
|
139
|
+
|
140
|
+
workspace_size = cutlass_mla_get_workspace_size(
|
141
|
+
cuda_graph_kv_indices.shape[1] * PAGE_SIZE, max_bs
|
142
|
+
)
|
143
|
+
self.cuda_graph_mla_workspace = torch.empty(
|
144
|
+
workspace_size, device="cuda", dtype=torch.uint8
|
145
|
+
)
|
146
|
+
self.cuda_graph_kv_indices = cuda_graph_kv_indices
|
147
|
+
|
148
|
+
def init_forward_metadata_capture_cuda_graph(
|
149
|
+
self,
|
150
|
+
bs: int,
|
151
|
+
num_tokens: int,
|
152
|
+
req_pool_indices: torch.Tensor,
|
153
|
+
seq_lens: torch.Tensor,
|
154
|
+
encoder_lens: Optional[torch.Tensor],
|
155
|
+
forward_mode: ForwardMode,
|
156
|
+
spec_info: Optional[SpecInfo],
|
157
|
+
):
|
158
|
+
if forward_mode.is_decode_or_idle():
|
159
|
+
if spec_info is None:
|
160
|
+
max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
|
161
|
+
|
162
|
+
create_flashmla_kv_indices_triton[(bs,)](
|
163
|
+
self.req_to_token,
|
164
|
+
req_pool_indices,
|
165
|
+
seq_lens,
|
166
|
+
None,
|
167
|
+
self.cuda_graph_kv_indices,
|
168
|
+
self.req_to_token.stride(0),
|
169
|
+
self.cuda_graph_kv_indices.stride(0),
|
170
|
+
PAGE_SIZE,
|
171
|
+
)
|
172
|
+
workspace_size = cutlass_mla_get_workspace_size(
|
173
|
+
max_seqlen_pad * PAGE_SIZE, bs
|
174
|
+
)
|
175
|
+
self.cuda_graph_mla_workspace = torch.empty(
|
176
|
+
workspace_size, device="cuda", dtype=torch.uint8
|
177
|
+
)
|
178
|
+
self.forward_metadata = CutlassMLADecodeMetadata(
|
179
|
+
self.cuda_graph_mla_workspace,
|
180
|
+
self.cuda_graph_kv_indices[:bs, :max_seqlen_pad],
|
181
|
+
)
|
182
|
+
else:
|
183
|
+
super().init_forward_metadata_capture_cuda_graph(
|
184
|
+
bs,
|
185
|
+
num_tokens,
|
186
|
+
req_pool_indices,
|
187
|
+
seq_lens,
|
188
|
+
encoder_lens,
|
189
|
+
forward_mode,
|
190
|
+
spec_info,
|
191
|
+
)
|
192
|
+
|
193
|
+
def init_forward_metadata_replay_cuda_graph(
|
194
|
+
self,
|
195
|
+
bs: int,
|
196
|
+
req_pool_indices: torch.Tensor,
|
197
|
+
seq_lens: torch.Tensor,
|
198
|
+
seq_lens_sum: int,
|
199
|
+
encoder_lens: Optional[torch.Tensor],
|
200
|
+
forward_mode: ForwardMode,
|
201
|
+
spec_info: Optional[SpecInfo],
|
202
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
203
|
+
):
|
204
|
+
|
205
|
+
if forward_mode.is_decode_or_idle():
|
206
|
+
assert seq_lens_cpu is not None
|
207
|
+
seq_lens = seq_lens[:bs]
|
208
|
+
seq_lens_cpu = seq_lens_cpu[:bs]
|
209
|
+
max_seqlen_pad = triton.cdiv(seq_lens_cpu.max().item(), PAGE_SIZE)
|
210
|
+
create_flashmla_kv_indices_triton[(bs,)](
|
211
|
+
self.req_to_token,
|
212
|
+
req_pool_indices[:bs],
|
213
|
+
seq_lens,
|
214
|
+
None,
|
215
|
+
self.cuda_graph_kv_indices,
|
216
|
+
self.req_to_token.stride(0),
|
217
|
+
self.cuda_graph_kv_indices.stride(0),
|
218
|
+
PAGE_SIZE,
|
219
|
+
)
|
220
|
+
workspace_size = cutlass_mla_get_workspace_size(
|
221
|
+
max_seqlen_pad * PAGE_SIZE, bs
|
222
|
+
)
|
223
|
+
self.cuda_graph_mla_workspace = torch.empty(
|
224
|
+
workspace_size, device="cuda", dtype=torch.uint8
|
225
|
+
)
|
226
|
+
self.forward_metadata.workspace = self.cuda_graph_mla_workspace
|
227
|
+
self.forward_metadata.block_kv_indices = self.cuda_graph_kv_indices[
|
228
|
+
:bs, :max_seqlen_pad
|
229
|
+
]
|
230
|
+
else:
|
231
|
+
super().init_forward_metadata_replay_cuda_graph(
|
232
|
+
bs,
|
233
|
+
req_pool_indices,
|
234
|
+
seq_lens,
|
235
|
+
seq_lens_sum,
|
236
|
+
encoder_lens,
|
237
|
+
forward_mode,
|
238
|
+
spec_info,
|
239
|
+
seq_lens_cpu,
|
240
|
+
)
|
241
|
+
|
242
|
+
def get_cuda_graph_seq_len_fill_value(self):
|
243
|
+
return 1
|
244
|
+
|
245
|
+
def forward_decode(
|
246
|
+
self,
|
247
|
+
q: torch.Tensor,
|
248
|
+
k: torch.Tensor,
|
249
|
+
v: torch.Tensor,
|
250
|
+
layer: RadixAttention,
|
251
|
+
forward_batch: ForwardBatch,
|
252
|
+
save_kv_cache: bool = True,
|
253
|
+
):
|
254
|
+
cache_loc = forward_batch.out_cache_loc
|
255
|
+
|
256
|
+
if k is not None:
|
257
|
+
assert v is not None
|
258
|
+
if save_kv_cache:
|
259
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
260
|
+
layer,
|
261
|
+
cache_loc,
|
262
|
+
k,
|
263
|
+
v,
|
264
|
+
)
|
265
|
+
bs = forward_batch.batch_size
|
266
|
+
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
267
|
+
|
268
|
+
reshape_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
269
|
+
|
270
|
+
o = cutlass_mla_decode(
|
271
|
+
q_nope_and_q_pe=reshape_q,
|
272
|
+
kv_c_and_k_pe_cache=k_cache.view(-1, PAGE_SIZE, self.kv_cache_dim),
|
273
|
+
seq_lens=forward_batch.seq_lens.to(torch.int32),
|
274
|
+
page_table=self.forward_metadata.block_kv_indices,
|
275
|
+
workspace=self.forward_metadata.workspace,
|
276
|
+
)
|
277
|
+
|
278
|
+
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
@@ -623,6 +623,9 @@ class FlashAttentionBackend(AttentionBackend):
|
|
623
623
|
layer: RadixAttention,
|
624
624
|
forward_batch: ForwardBatch,
|
625
625
|
save_kv_cache=True,
|
626
|
+
# For multi-head latent attention
|
627
|
+
q_rope: Optional[torch.Tensor] = None,
|
628
|
+
k_rope: Optional[torch.Tensor] = None,
|
626
629
|
):
|
627
630
|
if k is not None:
|
628
631
|
assert v is not None
|
@@ -637,11 +640,11 @@ class FlashAttentionBackend(AttentionBackend):
|
|
637
640
|
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
638
641
|
)
|
639
642
|
else:
|
640
|
-
forward_batch.token_to_kv_pool.
|
643
|
+
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
|
641
644
|
layer,
|
642
645
|
cache_loc,
|
643
646
|
k,
|
644
|
-
|
647
|
+
k_rope,
|
645
648
|
)
|
646
649
|
|
647
650
|
# Use precomputed metadata across all layers
|
@@ -815,9 +818,15 @@ class FlashAttentionBackend(AttentionBackend):
|
|
815
818
|
c_kv_cache = c_kv.view(
|
816
819
|
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
|
817
820
|
)
|
818
|
-
|
819
|
-
|
820
|
-
|
821
|
+
if q_rope is not None:
|
822
|
+
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
823
|
+
q_rope = q_rope.view(
|
824
|
+
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
825
|
+
)
|
826
|
+
else:
|
827
|
+
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
828
|
+
q_nope = q_all[:, :, : layer.v_head_dim]
|
829
|
+
q_rope = q_all[:, :, layer.v_head_dim :]
|
821
830
|
|
822
831
|
result = flash_attn_with_kvcache(
|
823
832
|
q=q_rope,
|
@@ -877,6 +886,9 @@ class FlashAttentionBackend(AttentionBackend):
|
|
877
886
|
layer: RadixAttention,
|
878
887
|
forward_batch: ForwardBatch,
|
879
888
|
save_kv_cache=True,
|
889
|
+
# For multi-head latent attention
|
890
|
+
q_rope: Optional[torch.Tensor] = None,
|
891
|
+
k_rope: Optional[torch.Tensor] = None,
|
880
892
|
) -> torch.Tensor:
|
881
893
|
if k is not None:
|
882
894
|
assert v is not None
|
@@ -891,11 +903,11 @@ class FlashAttentionBackend(AttentionBackend):
|
|
891
903
|
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
892
904
|
)
|
893
905
|
else:
|
894
|
-
forward_batch.token_to_kv_pool.
|
906
|
+
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
|
895
907
|
layer,
|
896
908
|
cache_loc,
|
897
909
|
k,
|
898
|
-
|
910
|
+
k_rope,
|
899
911
|
)
|
900
912
|
|
901
913
|
# Use precomputed metadata across all layers
|
@@ -1047,9 +1059,15 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1047
1059
|
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
|
1048
1060
|
)
|
1049
1061
|
|
1050
|
-
|
1051
|
-
|
1052
|
-
|
1062
|
+
if q_rope is not None:
|
1063
|
+
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
1064
|
+
q_rope = q_rope.view(
|
1065
|
+
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
1066
|
+
)
|
1067
|
+
else:
|
1068
|
+
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
1069
|
+
q_nope = q_all[:, :, : layer.v_head_dim]
|
1070
|
+
q_rope = q_all[:, :, layer.v_head_dim :]
|
1053
1071
|
max_seqlen_q = metadata.max_seq_len_q
|
1054
1072
|
|
1055
1073
|
result = flash_attn_with_kvcache(
|
@@ -68,9 +68,6 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
|
68
68
|
self.num_q_heads = (
|
69
69
|
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
70
70
|
)
|
71
|
-
self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
|
72
|
-
get_attention_tp_size()
|
73
|
-
)
|
74
71
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
75
72
|
self.num_local_heads = (
|
76
73
|
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
@@ -111,8 +108,8 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
|
111
108
|
)
|
112
109
|
mla_metadata, num_splits = get_mla_metadata(
|
113
110
|
forward_batch.seq_lens.to(torch.int32),
|
114
|
-
Q_LEN * self.num_q_heads
|
115
|
-
|
111
|
+
Q_LEN * self.num_q_heads,
|
112
|
+
1,
|
116
113
|
)
|
117
114
|
self.forward_metadata = FlashMLADecodeMetadata(
|
118
115
|
mla_metadata,
|
@@ -141,8 +138,8 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
|
141
138
|
|
142
139
|
self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = get_mla_metadata(
|
143
140
|
torch.ones(max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device),
|
144
|
-
Q_LEN * self.num_q_heads
|
145
|
-
|
141
|
+
Q_LEN * self.num_q_heads,
|
142
|
+
1,
|
146
143
|
)
|
147
144
|
self.cuda_graph_kv_indices = cuda_graph_kv_indices
|
148
145
|
|
@@ -171,8 +168,8 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
|
171
168
|
)
|
172
169
|
mla_metadata, num_splits = get_mla_metadata(
|
173
170
|
seq_lens.to(torch.int32),
|
174
|
-
Q_LEN * self.num_q_heads
|
175
|
-
|
171
|
+
Q_LEN * self.num_q_heads,
|
172
|
+
1,
|
176
173
|
)
|
177
174
|
self.cuda_graph_mla_metadata.copy_(mla_metadata)
|
178
175
|
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
|
@@ -221,8 +218,8 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
|
221
218
|
)
|
222
219
|
mla_metadata, num_splits = get_mla_metadata(
|
223
220
|
seq_lens.to(torch.int32),
|
224
|
-
Q_LEN * self.num_q_heads
|
225
|
-
|
221
|
+
Q_LEN * self.num_q_heads,
|
222
|
+
1,
|
226
223
|
)
|
227
224
|
self.cuda_graph_mla_metadata.copy_(mla_metadata)
|
228
225
|
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
|
@@ -49,8 +49,8 @@ def create_flashmla_kv_indices_triton(
|
|
49
49
|
kv_indices_ptr,
|
50
50
|
req_to_token_ptr_stride: tl.constexpr,
|
51
51
|
kv_indices_ptr_stride: tl.constexpr,
|
52
|
+
PAGED_SIZE: tl.constexpr = 64,
|
52
53
|
):
|
53
|
-
PAGED_SIZE: tl.constexpr = 64
|
54
54
|
BLOCK_SIZE: tl.constexpr = 4096
|
55
55
|
NUM_PAGE_PER_BLOCK: tl.constexpr = 64
|
56
56
|
pid = tl.program_id(axis=0)
|