sglang 0.4.6__py3-none-any.whl → 0.4.6.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -0
- sglang/check_env.py +3 -3
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/kimi_vl.py +38 -0
- sglang/srt/configs/kimi_vl_moonvit.py +32 -0
- sglang/srt/configs/model_config.py +15 -0
- sglang/srt/conversation.py +122 -1
- sglang/srt/disaggregation/decode.py +8 -2
- sglang/srt/disaggregation/fake/__init__.py +1 -0
- sglang/srt/disaggregation/fake/conn.py +88 -0
- sglang/srt/disaggregation/prefill.py +12 -3
- sglang/srt/disaggregation/utils.py +16 -2
- sglang/srt/entrypoints/engine.py +52 -21
- sglang/srt/entrypoints/http_server.py +27 -2
- sglang/srt/function_call_parser.py +97 -0
- sglang/srt/hf_transformers_utils.py +2 -0
- sglang/srt/layers/attention/cutlass_mla_backend.py +278 -0
- sglang/srt/layers/attention/flashinfer_backend.py +107 -82
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -16
- sglang/srt/layers/attention/flashmla_backend.py +3 -0
- sglang/srt/layers/attention/utils.py +1 -1
- sglang/srt/layers/dp_attention.py +5 -2
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +1 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=96,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +10 -8
- sglang/srt/layers/moe/fused_moe_triton/layer.py +15 -17
- sglang/srt/layers/quantization/__init__.py +2 -2
- sglang/srt/layers/quantization/deep_gemm.py +1 -1
- sglang/srt/layers/quantization/fp8.py +20 -22
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/utils.py +35 -0
- sglang/srt/lora/layers.py +35 -9
- sglang/srt/lora/lora_manager.py +84 -35
- sglang/srt/managers/data_parallel_controller.py +52 -34
- sglang/srt/managers/multimodal_processors/kimi_vl.py +73 -0
- sglang/srt/managers/schedule_batch.py +34 -15
- sglang/srt/managers/scheduler.py +273 -67
- sglang/srt/managers/scheduler_output_processor_mixin.py +26 -10
- sglang/srt/managers/tp_worker.py +52 -17
- sglang/srt/managers/tp_worker_overlap_thread.py +18 -7
- sglang/srt/mem_cache/memory_pool.py +70 -36
- sglang/srt/model_executor/cuda_graph_runner.py +82 -19
- sglang/srt/model_executor/forward_batch_info.py +31 -1
- sglang/srt/model_executor/model_runner.py +123 -58
- sglang/srt/models/deepseek_nextn.py +1 -257
- sglang/srt/models/deepseek_v2.py +78 -18
- sglang/srt/models/kimi_vl.py +308 -0
- sglang/srt/models/kimi_vl_moonvit.py +639 -0
- sglang/srt/models/llama.py +92 -30
- sglang/srt/models/llama4.py +2 -1
- sglang/srt/models/llama_eagle.py +4 -1
- sglang/srt/models/llama_eagle3.py +4 -1
- sglang/srt/models/qwen2_moe.py +8 -3
- sglang/srt/models/qwen2_vl.py +0 -12
- sglang/srt/models/qwen3_moe.py +8 -3
- sglang/srt/openai_api/adapter.py +49 -8
- sglang/srt/openai_api/protocol.py +13 -1
- sglang/srt/reasoning_parser.py +25 -1
- sglang/srt/server_args.py +83 -24
- sglang/srt/speculative/eagle_worker.py +3 -2
- sglang/srt/utils.py +91 -9
- sglang/test/runners.py +4 -0
- sglang/test/send_one.py +84 -28
- sglang/test/test_utils.py +67 -0
- sglang/version.py +1 -1
- {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/METADATA +5 -4
- {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/RECORD +85 -60
- {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/WHEEL +1 -1
- {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/top_level.txt +0 -0
@@ -42,6 +42,7 @@ from fastapi import FastAPI, File, Form, Request, UploadFile
|
|
42
42
|
from fastapi.middleware.cors import CORSMiddleware
|
43
43
|
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
44
44
|
|
45
|
+
from sglang.srt.disaggregation.utils import FakeBootstrapHost
|
45
46
|
from sglang.srt.entrypoints.engine import _launch_subprocesses
|
46
47
|
from sglang.srt.function_call_parser import FunctionCallParser
|
47
48
|
from sglang.srt.managers.io_struct import (
|
@@ -821,8 +822,32 @@ def _wait_and_warmup(
|
|
821
822
|
)
|
822
823
|
assert res.status_code == 200, f"{res}"
|
823
824
|
else:
|
824
|
-
|
825
|
-
|
825
|
+
logger.info(f"Start of prefill warmup ...")
|
826
|
+
json_data = {
|
827
|
+
"sampling_params": {
|
828
|
+
"temperature": 0.0,
|
829
|
+
"max_new_tokens": 8,
|
830
|
+
"ignore_eos": True,
|
831
|
+
},
|
832
|
+
"bootstrap_host": [FakeBootstrapHost] * server_args.dp_size,
|
833
|
+
# This is a hack to ensure fake transfer is enabled during prefill warmup
|
834
|
+
# ensure each dp rank has a unique bootstrap_room during prefill warmup
|
835
|
+
"bootstrap_room": [
|
836
|
+
i * (2**63 // server_args.dp_size) + (i % server_args.tp_size)
|
837
|
+
for i in range(server_args.dp_size)
|
838
|
+
],
|
839
|
+
"input_ids": [[0, 1, 2, 3]] * server_args.dp_size,
|
840
|
+
}
|
841
|
+
res = requests.post(
|
842
|
+
url + request_name,
|
843
|
+
json=json_data,
|
844
|
+
headers=headers,
|
845
|
+
timeout=1800, # because of deep gemm precache is very long if not precache.
|
846
|
+
)
|
847
|
+
logger.info(
|
848
|
+
f"End of prefill warmup with status {res.status_code}, resp: {res.json()}"
|
849
|
+
)
|
850
|
+
|
826
851
|
except Exception:
|
827
852
|
last_traceback = get_exception_traceback()
|
828
853
|
if pipe_finish_writer is not None:
|
@@ -1,3 +1,4 @@
|
|
1
|
+
import ast
|
1
2
|
import json
|
2
3
|
import logging
|
3
4
|
import re
|
@@ -664,6 +665,101 @@ class MultiFormatParser:
|
|
664
665
|
return final_normal_text, final_calls
|
665
666
|
|
666
667
|
|
668
|
+
class PythonicDetector(BaseFormatDetector):
|
669
|
+
"""
|
670
|
+
Detector for Llama-3.2 and Llama-4 models with pythonic tool call format.
|
671
|
+
Assumes function call format:
|
672
|
+
[tool1(arg1=val1, arg2=val2), tool2(arg1=val3)]
|
673
|
+
Arguments are Python literals (not JSON).
|
674
|
+
"""
|
675
|
+
|
676
|
+
def __init__(self):
|
677
|
+
super().__init__()
|
678
|
+
self.tool_call_regex = re.compile(
|
679
|
+
r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]",
|
680
|
+
re.DOTALL,
|
681
|
+
)
|
682
|
+
|
683
|
+
def has_tool_call(self, text: str) -> bool:
|
684
|
+
return bool(self.tool_call_regex.match(text.strip()))
|
685
|
+
|
686
|
+
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
687
|
+
# Try parsing the text as a Python list of function calls
|
688
|
+
text = text.strip()
|
689
|
+
if not (text.startswith("[") and text.endswith("]")):
|
690
|
+
# Not a pythonic tool call format
|
691
|
+
return StreamingParseResult(normal_text=text, calls=[])
|
692
|
+
try:
|
693
|
+
module = ast.parse(text)
|
694
|
+
parsed = getattr(module.body[0], "value", None)
|
695
|
+
if not (
|
696
|
+
isinstance(parsed, ast.List)
|
697
|
+
and all(isinstance(e, ast.Call) for e in parsed.elts)
|
698
|
+
):
|
699
|
+
return StreamingParseResult(normal_text=text, calls=[])
|
700
|
+
calls = []
|
701
|
+
tool_indices = {
|
702
|
+
tool.function.name: i
|
703
|
+
for i, tool in enumerate(tools)
|
704
|
+
if tool.function.name
|
705
|
+
}
|
706
|
+
for call in parsed.elts:
|
707
|
+
if not isinstance(call.func, ast.Name):
|
708
|
+
continue
|
709
|
+
function_name = call.func.id
|
710
|
+
arguments = {}
|
711
|
+
for keyword in call.keywords:
|
712
|
+
arguments[keyword.arg] = self._get_parameter_value(keyword.value)
|
713
|
+
calls.append(
|
714
|
+
ToolCallItem(
|
715
|
+
tool_index=tool_indices.get(function_name, -1),
|
716
|
+
name=function_name,
|
717
|
+
parameters=json.dumps(arguments, ensure_ascii=False),
|
718
|
+
)
|
719
|
+
)
|
720
|
+
return StreamingParseResult(normal_text="", calls=calls)
|
721
|
+
except Exception:
|
722
|
+
logger.exception("Error in pythonic tool call parsing.")
|
723
|
+
return StreamingParseResult(normal_text=text, calls=[])
|
724
|
+
|
725
|
+
def parse_streaming_increment(
|
726
|
+
self, new_text: str, tools: List[Tool]
|
727
|
+
) -> StreamingParseResult:
|
728
|
+
"""
|
729
|
+
Streaming incremental parsing for pythonic tool calls.
|
730
|
+
Buffers input until a complete pythonic tool call (from [ to ]) is found,
|
731
|
+
then parses and emits any detected calls.
|
732
|
+
"""
|
733
|
+
self._buffer += new_text
|
734
|
+
start = self._buffer.find("[")
|
735
|
+
end = self._buffer.find("]", start)
|
736
|
+
if start != -1 and end != -1:
|
737
|
+
call_text = self._buffer[start : end + 1]
|
738
|
+
result = self.detect_and_parse(call_text, tools)
|
739
|
+
self._buffer = self._buffer[end + 1 :]
|
740
|
+
return result
|
741
|
+
return StreamingParseResult(normal_text="")
|
742
|
+
|
743
|
+
def _get_parameter_value(self, val):
|
744
|
+
if isinstance(val, ast.Constant):
|
745
|
+
return val.value
|
746
|
+
elif isinstance(val, ast.Dict):
|
747
|
+
return {
|
748
|
+
k.value: self._get_parameter_value(v)
|
749
|
+
for k, v in zip(val.keys, val.values)
|
750
|
+
}
|
751
|
+
elif isinstance(val, ast.List):
|
752
|
+
return [self._get_parameter_value(v) for v in val.elts]
|
753
|
+
else:
|
754
|
+
raise ValueError("Tool call arguments must be literals")
|
755
|
+
|
756
|
+
def structure_info(self) -> _GetInfoFunc:
|
757
|
+
def info(name: str):
|
758
|
+
return StructureInfo(begin="[", end="]", trigger="")
|
759
|
+
|
760
|
+
return info
|
761
|
+
|
762
|
+
|
667
763
|
class FunctionCallParser:
|
668
764
|
"""
|
669
765
|
In streaming scenarios, each time new_text is received, it calls multi_format_parser.parse_streaming_increment
|
@@ -675,6 +771,7 @@ class FunctionCallParser:
|
|
675
771
|
"qwen25": Qwen25Detector,
|
676
772
|
"mistral": MistralDetector,
|
677
773
|
"deepseekv3": DeepSeekV3Detector,
|
774
|
+
"pythonic": PythonicDetector,
|
678
775
|
}
|
679
776
|
|
680
777
|
def __init__(self, tools: List[Tool], tool_call_parser: str):
|
@@ -35,6 +35,7 @@ from sglang.srt.configs import (
|
|
35
35
|
DbrxConfig,
|
36
36
|
DeepseekVL2Config,
|
37
37
|
ExaoneConfig,
|
38
|
+
KimiVLConfig,
|
38
39
|
MultiModalityConfig,
|
39
40
|
)
|
40
41
|
from sglang.srt.connector import create_remote_connector
|
@@ -46,6 +47,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
|
46
47
|
ExaoneConfig.model_type: ExaoneConfig,
|
47
48
|
DeepseekVL2Config.model_type: DeepseekVL2Config,
|
48
49
|
MultiModalityConfig.model_type: MultiModalityConfig,
|
50
|
+
KimiVLConfig.model_type: KimiVLConfig,
|
49
51
|
}
|
50
52
|
|
51
53
|
for name, cls in _CONFIG_REGISTRY.items():
|
@@ -0,0 +1,278 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
"""
|
4
|
+
Support attention backend for Cutlass MLA.
|
5
|
+
|
6
|
+
"""
|
7
|
+
|
8
|
+
from dataclasses import dataclass
|
9
|
+
from typing import TYPE_CHECKING, Optional, Union
|
10
|
+
|
11
|
+
import torch
|
12
|
+
import triton
|
13
|
+
|
14
|
+
from sglang.global_config import global_config
|
15
|
+
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
16
|
+
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
|
17
|
+
from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
|
18
|
+
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
19
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
20
|
+
from sglang.srt.utils import is_cuda
|
21
|
+
|
22
|
+
if TYPE_CHECKING:
|
23
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
24
|
+
from sglang.srt.model_executor.model_runner import ModelRunner
|
25
|
+
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
26
|
+
from sglang.srt.speculative.spec_info import SpecInfo
|
27
|
+
|
28
|
+
_is_cuda = is_cuda()
|
29
|
+
if _is_cuda:
|
30
|
+
from sgl_kernel import cutlass_mla_decode, cutlass_mla_get_workspace_size
|
31
|
+
|
32
|
+
|
33
|
+
# Cutlass MLA only supports pagesize=128
|
34
|
+
PAGE_SIZE = 128
|
35
|
+
|
36
|
+
|
37
|
+
@dataclass
|
38
|
+
class CutlassMLADecodeMetadata:
|
39
|
+
workspace: Optional[torch.Tensor] = None
|
40
|
+
block_kv_indices: Optional[torch.Tensor] = None
|
41
|
+
|
42
|
+
def __init__(
|
43
|
+
self,
|
44
|
+
workspace: Optional[torch.Tensor] = None,
|
45
|
+
block_kv_indices: Optional[torch.Tensor] = None,
|
46
|
+
):
|
47
|
+
self.workspace = workspace
|
48
|
+
self.block_kv_indices = block_kv_indices
|
49
|
+
|
50
|
+
|
51
|
+
class CutlassMLABackend(FlashInferMLAAttnBackend):
|
52
|
+
"""Cutlass attention kernels."""
|
53
|
+
|
54
|
+
def __init__(
|
55
|
+
self,
|
56
|
+
model_runner: ModelRunner,
|
57
|
+
skip_prefill: bool = False,
|
58
|
+
kv_indptr_buf: Optional[torch.Tensor] = None,
|
59
|
+
kv_last_page_len_buf: Optional[torch.Tensor] = None,
|
60
|
+
):
|
61
|
+
super().__init__(
|
62
|
+
model_runner, skip_prefill, kv_indptr_buf, kv_last_page_len_buf
|
63
|
+
)
|
64
|
+
|
65
|
+
self.num_q_heads = (
|
66
|
+
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
67
|
+
)
|
68
|
+
self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
|
69
|
+
get_attention_tp_size()
|
70
|
+
)
|
71
|
+
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
72
|
+
self.num_local_heads = (
|
73
|
+
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
74
|
+
)
|
75
|
+
self.forward_metadata: Union[CutlassMLADecodeMetadata] = None
|
76
|
+
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
|
77
|
+
self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
|
78
|
+
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
|
79
|
+
self.v_head_dim = model_runner.model_config.v_head_dim
|
80
|
+
self.scaling = model_runner.model_config.scaling
|
81
|
+
self.data_type = model_runner.kv_cache_dtype
|
82
|
+
self.q_data_type = model_runner.dtype
|
83
|
+
self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim
|
84
|
+
|
85
|
+
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
86
|
+
|
87
|
+
bs = forward_batch.batch_size
|
88
|
+
spec_info = forward_batch.spec_info
|
89
|
+
if forward_batch.forward_mode.is_decode_or_idle():
|
90
|
+
if spec_info is None:
|
91
|
+
max_seqlen_pad = triton.cdiv(
|
92
|
+
forward_batch.seq_lens_cpu.max().item(), PAGE_SIZE
|
93
|
+
)
|
94
|
+
block_kv_indices = torch.full(
|
95
|
+
(bs, max_seqlen_pad),
|
96
|
+
-1,
|
97
|
+
dtype=torch.int32,
|
98
|
+
device=forward_batch.seq_lens.device,
|
99
|
+
)
|
100
|
+
create_flashmla_kv_indices_triton[(bs,)](
|
101
|
+
self.req_to_token,
|
102
|
+
forward_batch.req_pool_indices,
|
103
|
+
forward_batch.seq_lens,
|
104
|
+
None,
|
105
|
+
block_kv_indices,
|
106
|
+
self.req_to_token.stride(0),
|
107
|
+
max_seqlen_pad,
|
108
|
+
PAGE_SIZE,
|
109
|
+
)
|
110
|
+
workspace_size = cutlass_mla_get_workspace_size(
|
111
|
+
max_seqlen_pad * PAGE_SIZE, bs
|
112
|
+
)
|
113
|
+
workspace = torch.empty(
|
114
|
+
workspace_size, device="cuda", dtype=torch.uint8
|
115
|
+
)
|
116
|
+
self.forward_metadata = CutlassMLADecodeMetadata(
|
117
|
+
workspace,
|
118
|
+
block_kv_indices,
|
119
|
+
)
|
120
|
+
else:
|
121
|
+
super().init_forward_metadata(forward_batch)
|
122
|
+
else:
|
123
|
+
super().init_forward_metadata(forward_batch)
|
124
|
+
|
125
|
+
def init_cuda_graph_state(
|
126
|
+
self,
|
127
|
+
max_bs: int,
|
128
|
+
block_kv_indices: Optional[torch.Tensor] = None,
|
129
|
+
):
|
130
|
+
if block_kv_indices is None:
|
131
|
+
cuda_graph_kv_indices = torch.full(
|
132
|
+
(max_bs, (self.max_context_len + PAGE_SIZE) // PAGE_SIZE),
|
133
|
+
1,
|
134
|
+
dtype=torch.int32,
|
135
|
+
device="cuda",
|
136
|
+
)
|
137
|
+
else:
|
138
|
+
cuda_graph_kv_indices = block_kv_indices
|
139
|
+
|
140
|
+
workspace_size = cutlass_mla_get_workspace_size(
|
141
|
+
cuda_graph_kv_indices.shape[1] * PAGE_SIZE, max_bs
|
142
|
+
)
|
143
|
+
self.cuda_graph_mla_workspace = torch.empty(
|
144
|
+
workspace_size, device="cuda", dtype=torch.uint8
|
145
|
+
)
|
146
|
+
self.cuda_graph_kv_indices = cuda_graph_kv_indices
|
147
|
+
|
148
|
+
def init_forward_metadata_capture_cuda_graph(
|
149
|
+
self,
|
150
|
+
bs: int,
|
151
|
+
num_tokens: int,
|
152
|
+
req_pool_indices: torch.Tensor,
|
153
|
+
seq_lens: torch.Tensor,
|
154
|
+
encoder_lens: Optional[torch.Tensor],
|
155
|
+
forward_mode: ForwardMode,
|
156
|
+
spec_info: Optional[SpecInfo],
|
157
|
+
):
|
158
|
+
if forward_mode.is_decode_or_idle():
|
159
|
+
if spec_info is None:
|
160
|
+
max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
|
161
|
+
|
162
|
+
create_flashmla_kv_indices_triton[(bs,)](
|
163
|
+
self.req_to_token,
|
164
|
+
req_pool_indices,
|
165
|
+
seq_lens,
|
166
|
+
None,
|
167
|
+
self.cuda_graph_kv_indices,
|
168
|
+
self.req_to_token.stride(0),
|
169
|
+
self.cuda_graph_kv_indices.stride(0),
|
170
|
+
PAGE_SIZE,
|
171
|
+
)
|
172
|
+
workspace_size = cutlass_mla_get_workspace_size(
|
173
|
+
max_seqlen_pad * PAGE_SIZE, bs
|
174
|
+
)
|
175
|
+
self.cuda_graph_mla_workspace = torch.empty(
|
176
|
+
workspace_size, device="cuda", dtype=torch.uint8
|
177
|
+
)
|
178
|
+
self.forward_metadata = CutlassMLADecodeMetadata(
|
179
|
+
self.cuda_graph_mla_workspace,
|
180
|
+
self.cuda_graph_kv_indices[:bs, :max_seqlen_pad],
|
181
|
+
)
|
182
|
+
else:
|
183
|
+
super().init_forward_metadata_capture_cuda_graph(
|
184
|
+
bs,
|
185
|
+
num_tokens,
|
186
|
+
req_pool_indices,
|
187
|
+
seq_lens,
|
188
|
+
encoder_lens,
|
189
|
+
forward_mode,
|
190
|
+
spec_info,
|
191
|
+
)
|
192
|
+
|
193
|
+
def init_forward_metadata_replay_cuda_graph(
|
194
|
+
self,
|
195
|
+
bs: int,
|
196
|
+
req_pool_indices: torch.Tensor,
|
197
|
+
seq_lens: torch.Tensor,
|
198
|
+
seq_lens_sum: int,
|
199
|
+
encoder_lens: Optional[torch.Tensor],
|
200
|
+
forward_mode: ForwardMode,
|
201
|
+
spec_info: Optional[SpecInfo],
|
202
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
203
|
+
):
|
204
|
+
|
205
|
+
if forward_mode.is_decode_or_idle():
|
206
|
+
assert seq_lens_cpu is not None
|
207
|
+
seq_lens = seq_lens[:bs]
|
208
|
+
seq_lens_cpu = seq_lens_cpu[:bs]
|
209
|
+
max_seqlen_pad = triton.cdiv(seq_lens_cpu.max().item(), PAGE_SIZE)
|
210
|
+
create_flashmla_kv_indices_triton[(bs,)](
|
211
|
+
self.req_to_token,
|
212
|
+
req_pool_indices[:bs],
|
213
|
+
seq_lens,
|
214
|
+
None,
|
215
|
+
self.cuda_graph_kv_indices,
|
216
|
+
self.req_to_token.stride(0),
|
217
|
+
self.cuda_graph_kv_indices.stride(0),
|
218
|
+
PAGE_SIZE,
|
219
|
+
)
|
220
|
+
workspace_size = cutlass_mla_get_workspace_size(
|
221
|
+
max_seqlen_pad * PAGE_SIZE, bs
|
222
|
+
)
|
223
|
+
self.cuda_graph_mla_workspace = torch.empty(
|
224
|
+
workspace_size, device="cuda", dtype=torch.uint8
|
225
|
+
)
|
226
|
+
self.forward_metadata.workspace = self.cuda_graph_mla_workspace
|
227
|
+
self.forward_metadata.block_kv_indices = self.cuda_graph_kv_indices[
|
228
|
+
:bs, :max_seqlen_pad
|
229
|
+
]
|
230
|
+
else:
|
231
|
+
super().init_forward_metadata_replay_cuda_graph(
|
232
|
+
bs,
|
233
|
+
req_pool_indices,
|
234
|
+
seq_lens,
|
235
|
+
seq_lens_sum,
|
236
|
+
encoder_lens,
|
237
|
+
forward_mode,
|
238
|
+
spec_info,
|
239
|
+
seq_lens_cpu,
|
240
|
+
)
|
241
|
+
|
242
|
+
def get_cuda_graph_seq_len_fill_value(self):
|
243
|
+
return 1
|
244
|
+
|
245
|
+
def forward_decode(
|
246
|
+
self,
|
247
|
+
q: torch.Tensor,
|
248
|
+
k: torch.Tensor,
|
249
|
+
v: torch.Tensor,
|
250
|
+
layer: RadixAttention,
|
251
|
+
forward_batch: ForwardBatch,
|
252
|
+
save_kv_cache: bool = True,
|
253
|
+
):
|
254
|
+
cache_loc = forward_batch.out_cache_loc
|
255
|
+
|
256
|
+
if k is not None:
|
257
|
+
assert v is not None
|
258
|
+
if save_kv_cache:
|
259
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
260
|
+
layer,
|
261
|
+
cache_loc,
|
262
|
+
k,
|
263
|
+
v,
|
264
|
+
)
|
265
|
+
bs = forward_batch.batch_size
|
266
|
+
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
267
|
+
|
268
|
+
reshape_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
269
|
+
|
270
|
+
o = cutlass_mla_decode(
|
271
|
+
q_nope_and_q_pe=reshape_q.to(self.q_data_type),
|
272
|
+
kv_c_and_k_pe_cache=k_cache.view(-1, PAGE_SIZE, self.kv_cache_dim),
|
273
|
+
seq_lens=forward_batch.seq_lens.to(torch.int32),
|
274
|
+
page_table=self.forward_metadata.block_kv_indices,
|
275
|
+
workspace=self.forward_metadata.workspace,
|
276
|
+
)
|
277
|
+
|
278
|
+
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|