sglang 0.4.6__py3-none-any.whl → 0.4.6.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/srt/disaggregation/decode.py +8 -2
- sglang/srt/disaggregation/fake/__init__.py +1 -0
- sglang/srt/disaggregation/fake/conn.py +88 -0
- sglang/srt/disaggregation/prefill.py +12 -3
- sglang/srt/disaggregation/utils.py +16 -2
- sglang/srt/entrypoints/engine.py +9 -0
- sglang/srt/entrypoints/http_server.py +27 -2
- sglang/srt/layers/attention/cutlass_mla_backend.py +278 -0
- sglang/srt/layers/attention/utils.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=96,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +2 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +15 -17
- sglang/srt/layers/quantization/fp8.py +20 -22
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/managers/schedule_batch.py +9 -0
- sglang/srt/managers/scheduler.py +10 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +25 -9
- sglang/srt/managers/tp_worker.py +3 -3
- sglang/srt/managers/tp_worker_overlap_thread.py +9 -4
- sglang/srt/model_executor/model_runner.py +8 -1
- sglang/srt/openai_api/adapter.py +32 -3
- sglang/srt/openai_api/protocol.py +2 -0
- sglang/srt/reasoning_parser.py +25 -1
- sglang/srt/server_args.py +16 -2
- sglang/srt/utils.py +3 -0
- sglang/test/send_one.py +84 -28
- sglang/test/test_utils.py +38 -0
- sglang/version.py +1 -1
- {sglang-0.4.6.dist-info → sglang-0.4.6.post1.dist-info}/METADATA +2 -2
- {sglang-0.4.6.dist-info → sglang-0.4.6.post1.dist-info}/RECORD +44 -29
- {sglang-0.4.6.dist-info → sglang-0.4.6.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.6.dist-info → sglang-0.4.6.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.dist-info → sglang-0.4.6.post1.dist-info}/top_level.txt +0 -0
@@ -32,6 +32,7 @@ from torch.distributed import ProcessGroup
|
|
32
32
|
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVArgs, KVPoll
|
33
33
|
from sglang.srt.disaggregation.utils import (
|
34
34
|
DisaggregationMode,
|
35
|
+
FakeBootstrapHost,
|
35
36
|
KVClassType,
|
36
37
|
ReqToMetadataIdxAllocator,
|
37
38
|
TransferBackend,
|
@@ -133,8 +134,13 @@ class DecodePreallocQueue:
|
|
133
134
|
|
134
135
|
def add(self, req: Req) -> None:
|
135
136
|
"""Add a request to the pending queue."""
|
136
|
-
|
137
|
-
|
137
|
+
if req.bootstrap_host == FakeBootstrapHost:
|
138
|
+
# Fake transfer for warmup reqs
|
139
|
+
kv_receiver_class = get_kv_class(TransferBackend.FAKE, KVClassType.RECEIVER)
|
140
|
+
else:
|
141
|
+
kv_receiver_class = get_kv_class(
|
142
|
+
self.transfer_backend, KVClassType.RECEIVER
|
143
|
+
)
|
138
144
|
kv_receiver = kv_receiver_class(
|
139
145
|
mgr=self.kv_manager,
|
140
146
|
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
|
@@ -0,0 +1 @@
|
|
1
|
+
from .conn import FakeKVReceiver, FakeKVSender
|
@@ -0,0 +1,88 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import Dict, List, Optional, Tuple, Union
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
import numpy.typing as npt
|
6
|
+
|
7
|
+
from sglang.srt.disaggregation.base.conn import (
|
8
|
+
BaseKVManager,
|
9
|
+
BaseKVReceiver,
|
10
|
+
BaseKVSender,
|
11
|
+
KVArgs,
|
12
|
+
KVPoll,
|
13
|
+
)
|
14
|
+
|
15
|
+
logger = logging.getLogger(__name__)
|
16
|
+
|
17
|
+
|
18
|
+
# For warmup reqs, we don't kv transfer, we use the fake sender and receiver
|
19
|
+
class FakeKVSender(BaseKVSender):
|
20
|
+
def __init__(self, mgr: BaseKVManager, bootstrap_addr: str, bootstrap_room: int):
|
21
|
+
self.has_sent = False
|
22
|
+
|
23
|
+
def poll(self) -> KVPoll:
|
24
|
+
if self.has_sent is False:
|
25
|
+
# Assume handshake completed instantly
|
26
|
+
return KVPoll.WaitingForInput
|
27
|
+
else:
|
28
|
+
# Assume transfer completed instantly
|
29
|
+
logger.info("FakeKVSender poll success")
|
30
|
+
return KVPoll.Success
|
31
|
+
|
32
|
+
def init(
|
33
|
+
self,
|
34
|
+
kv_indices: list[int],
|
35
|
+
aux_index: Optional[int] = None,
|
36
|
+
dest_ranks: Optional[list[int]] = None,
|
37
|
+
):
|
38
|
+
logger.info(
|
39
|
+
f"FakeKVSender init with kv_indices: {kv_indices}, aux_index: {aux_index}, dest_ranks: {dest_ranks}"
|
40
|
+
)
|
41
|
+
pass
|
42
|
+
|
43
|
+
def send(
|
44
|
+
self,
|
45
|
+
kv_indices: npt.NDArray[np.int64],
|
46
|
+
index_slice: slice,
|
47
|
+
is_last: bool,
|
48
|
+
):
|
49
|
+
logger.info(
|
50
|
+
f"FakeKVSender send with kv_indices: {kv_indices}, index_slice: {index_slice}, is_last: {is_last}"
|
51
|
+
)
|
52
|
+
if is_last:
|
53
|
+
self.has_sent = True
|
54
|
+
logger.info(f"FakeKVSender send success")
|
55
|
+
else:
|
56
|
+
self.has_sent = False
|
57
|
+
logger.info(f"FakeKVSender send fake transfering")
|
58
|
+
|
59
|
+
def failure_exception(self):
|
60
|
+
raise Exception("Fake KVSender Exception")
|
61
|
+
|
62
|
+
|
63
|
+
class FakeKVReceiver(BaseKVReceiver):
|
64
|
+
def __init__(
|
65
|
+
self,
|
66
|
+
mgr: BaseKVManager,
|
67
|
+
bootstrap_addr: str,
|
68
|
+
bootstrap_room: Optional[int] = None,
|
69
|
+
):
|
70
|
+
self.has_init = False
|
71
|
+
|
72
|
+
def poll(self) -> KVPoll:
|
73
|
+
if self.has_init is False:
|
74
|
+
# Assume handshake completed instantly
|
75
|
+
return KVPoll.WaitingForInput
|
76
|
+
else:
|
77
|
+
# Assume transfer completed instantly
|
78
|
+
logger.info("FakeKVReceiver poll success")
|
79
|
+
return KVPoll.Success
|
80
|
+
|
81
|
+
def init(self, kv_indices: list[int], aux_index: Optional[int] = None):
|
82
|
+
self.has_init = True
|
83
|
+
logger.info(
|
84
|
+
f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}"
|
85
|
+
)
|
86
|
+
|
87
|
+
def failure_exception(self):
|
88
|
+
raise Exception("Fake KVReceiver Exception")
|
@@ -20,6 +20,7 @@ Life cycle of a request in the prefill server
|
|
20
20
|
from __future__ import annotations
|
21
21
|
|
22
22
|
import logging
|
23
|
+
import threading
|
23
24
|
from collections import deque
|
24
25
|
from typing import TYPE_CHECKING, List, Optional
|
25
26
|
|
@@ -28,6 +29,7 @@ import torch
|
|
28
29
|
from sglang.srt.disaggregation.base import BaseKVManager, KVArgs, KVPoll
|
29
30
|
from sglang.srt.disaggregation.utils import (
|
30
31
|
DisaggregationMode,
|
32
|
+
FakeBootstrapHost,
|
31
33
|
KVClassType,
|
32
34
|
ReqToMetadataIdxAllocator,
|
33
35
|
TransferBackend,
|
@@ -115,7 +117,11 @@ class PrefillBootstrapQueue:
|
|
115
117
|
return kv_manager
|
116
118
|
|
117
119
|
def add(self, req: Req) -> None:
|
118
|
-
|
120
|
+
if req.bootstrap_host == FakeBootstrapHost:
|
121
|
+
# Fake transfer for warmup reqs
|
122
|
+
kv_sender_class = get_kv_class(TransferBackend.FAKE, KVClassType.SENDER)
|
123
|
+
else:
|
124
|
+
kv_sender_class = get_kv_class(self.transfer_backend, KVClassType.SENDER)
|
119
125
|
req.disagg_kv_sender = kv_sender_class(
|
120
126
|
mgr=self.kv_manager,
|
121
127
|
bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
|
@@ -256,7 +262,10 @@ class SchedulerDisaggregationPrefillMixin:
|
|
256
262
|
self.running_batch.batch_is_full = False
|
257
263
|
|
258
264
|
def process_batch_result_disagg_prefill(
|
259
|
-
self: Scheduler,
|
265
|
+
self: Scheduler,
|
266
|
+
batch: ScheduleBatch,
|
267
|
+
result: GenerationBatchResult,
|
268
|
+
launch_done: Optional[threading.Event] = None,
|
260
269
|
) -> None:
|
261
270
|
"""
|
262
271
|
Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
|
@@ -280,7 +289,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|
280
289
|
# Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
|
281
290
|
if self.enable_overlap:
|
282
291
|
# wait
|
283
|
-
_, next_token_ids = self.tp_worker.
|
292
|
+
_, next_token_ids = self.tp_worker.resolve_last_batch_result(launch_done)
|
284
293
|
else:
|
285
294
|
next_token_ids = result.next_token_ids.tolist()
|
286
295
|
|
@@ -15,6 +15,9 @@ class DisaggregationMode(Enum):
|
|
15
15
|
DECODE = "decode"
|
16
16
|
|
17
17
|
|
18
|
+
FakeBootstrapHost = "2.2.2.2"
|
19
|
+
|
20
|
+
|
18
21
|
def poll_and_all_reduce(pollers, gloo_group):
|
19
22
|
polls = [int(poller.poll()) for poller in pollers]
|
20
23
|
tensor_to_reduce = torch.tensor(polls, dtype=torch.uint8, device="cpu")
|
@@ -59,6 +62,8 @@ class KVClassType(Enum):
|
|
59
62
|
|
60
63
|
|
61
64
|
def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
|
65
|
+
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
|
66
|
+
|
62
67
|
if transfer_backend == TransferBackend.MOONCAKE:
|
63
68
|
from sglang.srt.disaggregation.mooncake import (
|
64
69
|
MooncakeKVBootstrapServer,
|
@@ -70,7 +75,7 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
|
|
70
75
|
class_mapping = {
|
71
76
|
KVClassType.MANAGER: MooncakeKVManager,
|
72
77
|
KVClassType.SENDER: MooncakeKVSender,
|
73
|
-
KVClassType.RECEIVER: MooncakeKVReceiver,
|
78
|
+
KVClassType.RECEIVER: (MooncakeKVReceiver),
|
74
79
|
KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer,
|
75
80
|
}
|
76
81
|
return class_mapping.get(class_type)
|
@@ -85,10 +90,19 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
|
|
85
90
|
class_mapping = {
|
86
91
|
KVClassType.MANAGER: NixlKVManager,
|
87
92
|
KVClassType.SENDER: NixlKVSender,
|
88
|
-
KVClassType.RECEIVER: NixlKVReceiver,
|
93
|
+
KVClassType.RECEIVER: (NixlKVReceiver),
|
89
94
|
KVClassType.BOOTSTRAP_SERVER: NixlKVBootstrapServer,
|
90
95
|
}
|
91
96
|
return class_mapping.get(class_type)
|
97
|
+
if transfer_backend == TransferBackend.FAKE:
|
98
|
+
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
|
99
|
+
|
100
|
+
class_mapping = {
|
101
|
+
KVClassType.SENDER: FakeKVSender,
|
102
|
+
KVClassType.RECEIVER: (FakeKVReceiver),
|
103
|
+
}
|
104
|
+
return class_mapping.get(class_type)
|
105
|
+
|
92
106
|
raise ValueError(f"Unsupported transfer backend: {transfer_backend}")
|
93
107
|
|
94
108
|
|
sglang/srt/entrypoints/engine.py
CHANGED
@@ -66,6 +66,7 @@ from sglang.srt.utils import (
|
|
66
66
|
assert_pkg_version,
|
67
67
|
configure_logger,
|
68
68
|
get_zmq_socket,
|
69
|
+
is_cuda,
|
69
70
|
kill_process_tree,
|
70
71
|
launch_dummy_health_check_server,
|
71
72
|
maybe_set_triton_cache_manager,
|
@@ -78,6 +79,8 @@ from sglang.version import __version__
|
|
78
79
|
logger = logging.getLogger(__name__)
|
79
80
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
80
81
|
|
82
|
+
_is_cuda = is_cuda()
|
83
|
+
|
81
84
|
|
82
85
|
class Engine(EngineBase):
|
83
86
|
"""
|
@@ -452,6 +455,12 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
452
455
|
"reinstall the latest version by following the instructions "
|
453
456
|
"at https://docs.flashinfer.ai/installation.html.",
|
454
457
|
)
|
458
|
+
if _is_cuda:
|
459
|
+
assert_pkg_version(
|
460
|
+
"sgl-kernel",
|
461
|
+
"0.1.0",
|
462
|
+
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
|
463
|
+
)
|
455
464
|
|
456
465
|
def sigchld_handler(signum, frame):
|
457
466
|
pid, exitcode = os.waitpid(0, os.WNOHANG)
|
@@ -42,6 +42,7 @@ from fastapi import FastAPI, File, Form, Request, UploadFile
|
|
42
42
|
from fastapi.middleware.cors import CORSMiddleware
|
43
43
|
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
44
44
|
|
45
|
+
from sglang.srt.disaggregation.utils import FakeBootstrapHost
|
45
46
|
from sglang.srt.entrypoints.engine import _launch_subprocesses
|
46
47
|
from sglang.srt.function_call_parser import FunctionCallParser
|
47
48
|
from sglang.srt.managers.io_struct import (
|
@@ -821,8 +822,32 @@ def _wait_and_warmup(
|
|
821
822
|
)
|
822
823
|
assert res.status_code == 200, f"{res}"
|
823
824
|
else:
|
824
|
-
|
825
|
-
|
825
|
+
logger.info(f"Start of prefill warmup ...")
|
826
|
+
json_data = {
|
827
|
+
"sampling_params": {
|
828
|
+
"temperature": 0.0,
|
829
|
+
"max_new_tokens": 8,
|
830
|
+
"ignore_eos": True,
|
831
|
+
},
|
832
|
+
"bootstrap_host": [FakeBootstrapHost] * server_args.dp_size,
|
833
|
+
# This is a hack to ensure fake transfer is enabled during prefill warmup
|
834
|
+
# ensure each dp rank has a unique bootstrap_room during prefill warmup
|
835
|
+
"bootstrap_room": [
|
836
|
+
i * (2**63 // server_args.dp_size) + (i % server_args.tp_size)
|
837
|
+
for i in range(server_args.dp_size)
|
838
|
+
],
|
839
|
+
"input_ids": [[0, 1, 2, 3]] * server_args.dp_size,
|
840
|
+
}
|
841
|
+
res = requests.post(
|
842
|
+
url + request_name,
|
843
|
+
json=json_data,
|
844
|
+
headers=headers,
|
845
|
+
timeout=1800, # because of deep gemm precache is very long if not precache.
|
846
|
+
)
|
847
|
+
logger.info(
|
848
|
+
f"End of prefill warmup with status {res.status_code}, resp: {res.json()}"
|
849
|
+
)
|
850
|
+
|
826
851
|
except Exception:
|
827
852
|
last_traceback = get_exception_traceback()
|
828
853
|
if pipe_finish_writer is not None:
|
@@ -0,0 +1,278 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
"""
|
4
|
+
Support attention backend for Cutlass MLA.
|
5
|
+
|
6
|
+
"""
|
7
|
+
|
8
|
+
from dataclasses import dataclass
|
9
|
+
from typing import TYPE_CHECKING, Optional, Union
|
10
|
+
|
11
|
+
import torch
|
12
|
+
import triton
|
13
|
+
|
14
|
+
from sglang.global_config import global_config
|
15
|
+
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
16
|
+
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
|
17
|
+
from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
|
18
|
+
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
19
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
20
|
+
from sglang.srt.utils import is_cuda
|
21
|
+
|
22
|
+
if TYPE_CHECKING:
|
23
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
24
|
+
from sglang.srt.model_executor.model_runner import ModelRunner
|
25
|
+
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
26
|
+
from sglang.srt.speculative.spec_info import SpecInfo
|
27
|
+
|
28
|
+
_is_cuda = is_cuda()
|
29
|
+
if _is_cuda:
|
30
|
+
from sgl_kernel import cutlass_mla_decode, cutlass_mla_get_workspace_size
|
31
|
+
|
32
|
+
|
33
|
+
# Cutlass MLA only supports pagesize=128
|
34
|
+
PAGE_SIZE = 128
|
35
|
+
|
36
|
+
|
37
|
+
@dataclass
|
38
|
+
class CutlassMLADecodeMetadata:
|
39
|
+
workspace: Optional[torch.Tensor] = None
|
40
|
+
block_kv_indices: Optional[torch.Tensor] = None
|
41
|
+
|
42
|
+
def __init__(
|
43
|
+
self,
|
44
|
+
workspace: Optional[torch.Tensor] = None,
|
45
|
+
block_kv_indices: Optional[torch.Tensor] = None,
|
46
|
+
):
|
47
|
+
self.workspace = workspace
|
48
|
+
self.block_kv_indices = block_kv_indices
|
49
|
+
|
50
|
+
|
51
|
+
class CutlassMLABackend(FlashInferMLAAttnBackend):
|
52
|
+
"""Cutlass attention kernels."""
|
53
|
+
|
54
|
+
def __init__(
|
55
|
+
self,
|
56
|
+
model_runner: ModelRunner,
|
57
|
+
skip_prefill: bool = False,
|
58
|
+
kv_indptr_buf: Optional[torch.Tensor] = None,
|
59
|
+
kv_last_page_len_buf: Optional[torch.Tensor] = None,
|
60
|
+
):
|
61
|
+
super().__init__(
|
62
|
+
model_runner, skip_prefill, kv_indptr_buf, kv_last_page_len_buf
|
63
|
+
)
|
64
|
+
|
65
|
+
self.num_q_heads = (
|
66
|
+
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
67
|
+
)
|
68
|
+
self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
|
69
|
+
get_attention_tp_size()
|
70
|
+
)
|
71
|
+
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
72
|
+
self.num_local_heads = (
|
73
|
+
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
74
|
+
)
|
75
|
+
self.forward_metadata: Union[CutlassMLADecodeMetadata] = None
|
76
|
+
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
|
77
|
+
self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
|
78
|
+
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
|
79
|
+
self.v_head_dim = model_runner.model_config.v_head_dim
|
80
|
+
self.scaling = model_runner.model_config.scaling
|
81
|
+
self.data_type = model_runner.kv_cache_dtype
|
82
|
+
self.q_data_type = model_runner.dtype
|
83
|
+
self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim
|
84
|
+
|
85
|
+
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
86
|
+
|
87
|
+
bs = forward_batch.batch_size
|
88
|
+
spec_info = forward_batch.spec_info
|
89
|
+
if forward_batch.forward_mode.is_decode_or_idle():
|
90
|
+
if spec_info is None:
|
91
|
+
max_seqlen_pad = triton.cdiv(
|
92
|
+
forward_batch.seq_lens_cpu.max().item(), PAGE_SIZE
|
93
|
+
)
|
94
|
+
block_kv_indices = torch.full(
|
95
|
+
(bs, max_seqlen_pad),
|
96
|
+
-1,
|
97
|
+
dtype=torch.int32,
|
98
|
+
device=forward_batch.seq_lens.device,
|
99
|
+
)
|
100
|
+
create_flashmla_kv_indices_triton[(bs,)](
|
101
|
+
self.req_to_token,
|
102
|
+
forward_batch.req_pool_indices,
|
103
|
+
forward_batch.seq_lens,
|
104
|
+
None,
|
105
|
+
block_kv_indices,
|
106
|
+
self.req_to_token.stride(0),
|
107
|
+
max_seqlen_pad,
|
108
|
+
PAGE_SIZE,
|
109
|
+
)
|
110
|
+
workspace_size = cutlass_mla_get_workspace_size(
|
111
|
+
max_seqlen_pad * PAGE_SIZE, bs
|
112
|
+
)
|
113
|
+
workspace = torch.empty(
|
114
|
+
workspace_size, device="cuda", dtype=torch.uint8
|
115
|
+
)
|
116
|
+
self.forward_metadata = CutlassMLADecodeMetadata(
|
117
|
+
workspace,
|
118
|
+
block_kv_indices,
|
119
|
+
)
|
120
|
+
else:
|
121
|
+
super().init_forward_metadata(forward_batch)
|
122
|
+
else:
|
123
|
+
super().init_forward_metadata(forward_batch)
|
124
|
+
|
125
|
+
def init_cuda_graph_state(
|
126
|
+
self,
|
127
|
+
max_bs: int,
|
128
|
+
block_kv_indices: Optional[torch.Tensor] = None,
|
129
|
+
):
|
130
|
+
if block_kv_indices is None:
|
131
|
+
cuda_graph_kv_indices = torch.full(
|
132
|
+
(max_bs, (self.max_context_len + PAGE_SIZE) // PAGE_SIZE),
|
133
|
+
1,
|
134
|
+
dtype=torch.int32,
|
135
|
+
device="cuda",
|
136
|
+
)
|
137
|
+
else:
|
138
|
+
cuda_graph_kv_indices = block_kv_indices
|
139
|
+
|
140
|
+
workspace_size = cutlass_mla_get_workspace_size(
|
141
|
+
cuda_graph_kv_indices.shape[1] * PAGE_SIZE, max_bs
|
142
|
+
)
|
143
|
+
self.cuda_graph_mla_workspace = torch.empty(
|
144
|
+
workspace_size, device="cuda", dtype=torch.uint8
|
145
|
+
)
|
146
|
+
self.cuda_graph_kv_indices = cuda_graph_kv_indices
|
147
|
+
|
148
|
+
def init_forward_metadata_capture_cuda_graph(
|
149
|
+
self,
|
150
|
+
bs: int,
|
151
|
+
num_tokens: int,
|
152
|
+
req_pool_indices: torch.Tensor,
|
153
|
+
seq_lens: torch.Tensor,
|
154
|
+
encoder_lens: Optional[torch.Tensor],
|
155
|
+
forward_mode: ForwardMode,
|
156
|
+
spec_info: Optional[SpecInfo],
|
157
|
+
):
|
158
|
+
if forward_mode.is_decode_or_idle():
|
159
|
+
if spec_info is None:
|
160
|
+
max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
|
161
|
+
|
162
|
+
create_flashmla_kv_indices_triton[(bs,)](
|
163
|
+
self.req_to_token,
|
164
|
+
req_pool_indices,
|
165
|
+
seq_lens,
|
166
|
+
None,
|
167
|
+
self.cuda_graph_kv_indices,
|
168
|
+
self.req_to_token.stride(0),
|
169
|
+
self.cuda_graph_kv_indices.stride(0),
|
170
|
+
PAGE_SIZE,
|
171
|
+
)
|
172
|
+
workspace_size = cutlass_mla_get_workspace_size(
|
173
|
+
max_seqlen_pad * PAGE_SIZE, bs
|
174
|
+
)
|
175
|
+
self.cuda_graph_mla_workspace = torch.empty(
|
176
|
+
workspace_size, device="cuda", dtype=torch.uint8
|
177
|
+
)
|
178
|
+
self.forward_metadata = CutlassMLADecodeMetadata(
|
179
|
+
self.cuda_graph_mla_workspace,
|
180
|
+
self.cuda_graph_kv_indices[:bs, :max_seqlen_pad],
|
181
|
+
)
|
182
|
+
else:
|
183
|
+
super().init_forward_metadata_capture_cuda_graph(
|
184
|
+
bs,
|
185
|
+
num_tokens,
|
186
|
+
req_pool_indices,
|
187
|
+
seq_lens,
|
188
|
+
encoder_lens,
|
189
|
+
forward_mode,
|
190
|
+
spec_info,
|
191
|
+
)
|
192
|
+
|
193
|
+
def init_forward_metadata_replay_cuda_graph(
|
194
|
+
self,
|
195
|
+
bs: int,
|
196
|
+
req_pool_indices: torch.Tensor,
|
197
|
+
seq_lens: torch.Tensor,
|
198
|
+
seq_lens_sum: int,
|
199
|
+
encoder_lens: Optional[torch.Tensor],
|
200
|
+
forward_mode: ForwardMode,
|
201
|
+
spec_info: Optional[SpecInfo],
|
202
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
203
|
+
):
|
204
|
+
|
205
|
+
if forward_mode.is_decode_or_idle():
|
206
|
+
assert seq_lens_cpu is not None
|
207
|
+
seq_lens = seq_lens[:bs]
|
208
|
+
seq_lens_cpu = seq_lens_cpu[:bs]
|
209
|
+
max_seqlen_pad = triton.cdiv(seq_lens_cpu.max().item(), PAGE_SIZE)
|
210
|
+
create_flashmla_kv_indices_triton[(bs,)](
|
211
|
+
self.req_to_token,
|
212
|
+
req_pool_indices[:bs],
|
213
|
+
seq_lens,
|
214
|
+
None,
|
215
|
+
self.cuda_graph_kv_indices,
|
216
|
+
self.req_to_token.stride(0),
|
217
|
+
self.cuda_graph_kv_indices.stride(0),
|
218
|
+
PAGE_SIZE,
|
219
|
+
)
|
220
|
+
workspace_size = cutlass_mla_get_workspace_size(
|
221
|
+
max_seqlen_pad * PAGE_SIZE, bs
|
222
|
+
)
|
223
|
+
self.cuda_graph_mla_workspace = torch.empty(
|
224
|
+
workspace_size, device="cuda", dtype=torch.uint8
|
225
|
+
)
|
226
|
+
self.forward_metadata.workspace = self.cuda_graph_mla_workspace
|
227
|
+
self.forward_metadata.block_kv_indices = self.cuda_graph_kv_indices[
|
228
|
+
:bs, :max_seqlen_pad
|
229
|
+
]
|
230
|
+
else:
|
231
|
+
super().init_forward_metadata_replay_cuda_graph(
|
232
|
+
bs,
|
233
|
+
req_pool_indices,
|
234
|
+
seq_lens,
|
235
|
+
seq_lens_sum,
|
236
|
+
encoder_lens,
|
237
|
+
forward_mode,
|
238
|
+
spec_info,
|
239
|
+
seq_lens_cpu,
|
240
|
+
)
|
241
|
+
|
242
|
+
def get_cuda_graph_seq_len_fill_value(self):
|
243
|
+
return 1
|
244
|
+
|
245
|
+
def forward_decode(
|
246
|
+
self,
|
247
|
+
q: torch.Tensor,
|
248
|
+
k: torch.Tensor,
|
249
|
+
v: torch.Tensor,
|
250
|
+
layer: RadixAttention,
|
251
|
+
forward_batch: ForwardBatch,
|
252
|
+
save_kv_cache: bool = True,
|
253
|
+
):
|
254
|
+
cache_loc = forward_batch.out_cache_loc
|
255
|
+
|
256
|
+
if k is not None:
|
257
|
+
assert v is not None
|
258
|
+
if save_kv_cache:
|
259
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
260
|
+
layer,
|
261
|
+
cache_loc,
|
262
|
+
k,
|
263
|
+
v,
|
264
|
+
)
|
265
|
+
bs = forward_batch.batch_size
|
266
|
+
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
267
|
+
|
268
|
+
reshape_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
269
|
+
|
270
|
+
o = cutlass_mla_decode(
|
271
|
+
q_nope_and_q_pe=reshape_q,
|
272
|
+
kv_c_and_k_pe_cache=k_cache.view(-1, PAGE_SIZE, self.kv_cache_dim),
|
273
|
+
seq_lens=forward_batch.seq_lens.to(torch.int32),
|
274
|
+
page_table=self.forward_metadata.block_kv_indices,
|
275
|
+
workspace=self.forward_metadata.workspace,
|
276
|
+
)
|
277
|
+
|
278
|
+
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
@@ -49,8 +49,8 @@ def create_flashmla_kv_indices_triton(
|
|
49
49
|
kv_indices_ptr,
|
50
50
|
req_to_token_ptr_stride: tl.constexpr,
|
51
51
|
kv_indices_ptr_stride: tl.constexpr,
|
52
|
+
PAGED_SIZE: tl.constexpr = 64,
|
52
53
|
):
|
53
|
-
PAGED_SIZE: tl.constexpr = 64
|
54
54
|
BLOCK_SIZE: tl.constexpr = 4096
|
55
55
|
NUM_PAGE_PER_BLOCK: tl.constexpr = 64
|
56
56
|
pid = tl.program_id(axis=0)
|