sglang 0.4.6__py3-none-any.whl → 0.4.6.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. sglang/srt/disaggregation/decode.py +8 -2
  2. sglang/srt/disaggregation/fake/__init__.py +1 -0
  3. sglang/srt/disaggregation/fake/conn.py +88 -0
  4. sglang/srt/disaggregation/prefill.py +12 -3
  5. sglang/srt/disaggregation/utils.py +16 -2
  6. sglang/srt/entrypoints/engine.py +9 -0
  7. sglang/srt/entrypoints/http_server.py +27 -2
  8. sglang/srt/layers/attention/cutlass_mla_backend.py +278 -0
  9. sglang/srt/layers/attention/utils.py +1 -1
  10. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  11. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H20.json +146 -0
  12. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H200.json +146 -0
  13. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  14. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20.json +146 -0
  15. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  16. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200.json +146 -0
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200.json +146 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=96,device_name=NVIDIA_H20.json +146 -0
  22. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +2 -2
  23. sglang/srt/layers/moe/fused_moe_triton/layer.py +15 -17
  24. sglang/srt/layers/quantization/fp8.py +20 -22
  25. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  26. sglang/srt/managers/schedule_batch.py +9 -0
  27. sglang/srt/managers/scheduler.py +10 -8
  28. sglang/srt/managers/scheduler_output_processor_mixin.py +25 -9
  29. sglang/srt/managers/tp_worker.py +3 -3
  30. sglang/srt/managers/tp_worker_overlap_thread.py +9 -4
  31. sglang/srt/model_executor/model_runner.py +8 -1
  32. sglang/srt/openai_api/adapter.py +32 -3
  33. sglang/srt/openai_api/protocol.py +2 -0
  34. sglang/srt/reasoning_parser.py +25 -1
  35. sglang/srt/server_args.py +16 -2
  36. sglang/srt/utils.py +3 -0
  37. sglang/test/send_one.py +84 -28
  38. sglang/test/test_utils.py +38 -0
  39. sglang/version.py +1 -1
  40. {sglang-0.4.6.dist-info → sglang-0.4.6.post1.dist-info}/METADATA +2 -2
  41. {sglang-0.4.6.dist-info → sglang-0.4.6.post1.dist-info}/RECORD +44 -29
  42. {sglang-0.4.6.dist-info → sglang-0.4.6.post1.dist-info}/WHEEL +0 -0
  43. {sglang-0.4.6.dist-info → sglang-0.4.6.post1.dist-info}/licenses/LICENSE +0 -0
  44. {sglang-0.4.6.dist-info → sglang-0.4.6.post1.dist-info}/top_level.txt +0 -0
@@ -32,6 +32,7 @@ from torch.distributed import ProcessGroup
32
32
  from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVArgs, KVPoll
33
33
  from sglang.srt.disaggregation.utils import (
34
34
  DisaggregationMode,
35
+ FakeBootstrapHost,
35
36
  KVClassType,
36
37
  ReqToMetadataIdxAllocator,
37
38
  TransferBackend,
@@ -133,8 +134,13 @@ class DecodePreallocQueue:
133
134
 
134
135
  def add(self, req: Req) -> None:
135
136
  """Add a request to the pending queue."""
136
-
137
- kv_receiver_class = get_kv_class(self.transfer_backend, KVClassType.RECEIVER)
137
+ if req.bootstrap_host == FakeBootstrapHost:
138
+ # Fake transfer for warmup reqs
139
+ kv_receiver_class = get_kv_class(TransferBackend.FAKE, KVClassType.RECEIVER)
140
+ else:
141
+ kv_receiver_class = get_kv_class(
142
+ self.transfer_backend, KVClassType.RECEIVER
143
+ )
138
144
  kv_receiver = kv_receiver_class(
139
145
  mgr=self.kv_manager,
140
146
  bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
@@ -0,0 +1 @@
1
+ from .conn import FakeKVReceiver, FakeKVSender
@@ -0,0 +1,88 @@
1
+ import logging
2
+ from typing import Dict, List, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import numpy.typing as npt
6
+
7
+ from sglang.srt.disaggregation.base.conn import (
8
+ BaseKVManager,
9
+ BaseKVReceiver,
10
+ BaseKVSender,
11
+ KVArgs,
12
+ KVPoll,
13
+ )
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ # For warmup reqs, we don't kv transfer, we use the fake sender and receiver
19
+ class FakeKVSender(BaseKVSender):
20
+ def __init__(self, mgr: BaseKVManager, bootstrap_addr: str, bootstrap_room: int):
21
+ self.has_sent = False
22
+
23
+ def poll(self) -> KVPoll:
24
+ if self.has_sent is False:
25
+ # Assume handshake completed instantly
26
+ return KVPoll.WaitingForInput
27
+ else:
28
+ # Assume transfer completed instantly
29
+ logger.info("FakeKVSender poll success")
30
+ return KVPoll.Success
31
+
32
+ def init(
33
+ self,
34
+ kv_indices: list[int],
35
+ aux_index: Optional[int] = None,
36
+ dest_ranks: Optional[list[int]] = None,
37
+ ):
38
+ logger.info(
39
+ f"FakeKVSender init with kv_indices: {kv_indices}, aux_index: {aux_index}, dest_ranks: {dest_ranks}"
40
+ )
41
+ pass
42
+
43
+ def send(
44
+ self,
45
+ kv_indices: npt.NDArray[np.int64],
46
+ index_slice: slice,
47
+ is_last: bool,
48
+ ):
49
+ logger.info(
50
+ f"FakeKVSender send with kv_indices: {kv_indices}, index_slice: {index_slice}, is_last: {is_last}"
51
+ )
52
+ if is_last:
53
+ self.has_sent = True
54
+ logger.info(f"FakeKVSender send success")
55
+ else:
56
+ self.has_sent = False
57
+ logger.info(f"FakeKVSender send fake transfering")
58
+
59
+ def failure_exception(self):
60
+ raise Exception("Fake KVSender Exception")
61
+
62
+
63
+ class FakeKVReceiver(BaseKVReceiver):
64
+ def __init__(
65
+ self,
66
+ mgr: BaseKVManager,
67
+ bootstrap_addr: str,
68
+ bootstrap_room: Optional[int] = None,
69
+ ):
70
+ self.has_init = False
71
+
72
+ def poll(self) -> KVPoll:
73
+ if self.has_init is False:
74
+ # Assume handshake completed instantly
75
+ return KVPoll.WaitingForInput
76
+ else:
77
+ # Assume transfer completed instantly
78
+ logger.info("FakeKVReceiver poll success")
79
+ return KVPoll.Success
80
+
81
+ def init(self, kv_indices: list[int], aux_index: Optional[int] = None):
82
+ self.has_init = True
83
+ logger.info(
84
+ f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}"
85
+ )
86
+
87
+ def failure_exception(self):
88
+ raise Exception("Fake KVReceiver Exception")
@@ -20,6 +20,7 @@ Life cycle of a request in the prefill server
20
20
  from __future__ import annotations
21
21
 
22
22
  import logging
23
+ import threading
23
24
  from collections import deque
24
25
  from typing import TYPE_CHECKING, List, Optional
25
26
 
@@ -28,6 +29,7 @@ import torch
28
29
  from sglang.srt.disaggregation.base import BaseKVManager, KVArgs, KVPoll
29
30
  from sglang.srt.disaggregation.utils import (
30
31
  DisaggregationMode,
32
+ FakeBootstrapHost,
31
33
  KVClassType,
32
34
  ReqToMetadataIdxAllocator,
33
35
  TransferBackend,
@@ -115,7 +117,11 @@ class PrefillBootstrapQueue:
115
117
  return kv_manager
116
118
 
117
119
  def add(self, req: Req) -> None:
118
- kv_sender_class = get_kv_class(self.transfer_backend, KVClassType.SENDER)
120
+ if req.bootstrap_host == FakeBootstrapHost:
121
+ # Fake transfer for warmup reqs
122
+ kv_sender_class = get_kv_class(TransferBackend.FAKE, KVClassType.SENDER)
123
+ else:
124
+ kv_sender_class = get_kv_class(self.transfer_backend, KVClassType.SENDER)
119
125
  req.disagg_kv_sender = kv_sender_class(
120
126
  mgr=self.kv_manager,
121
127
  bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
@@ -256,7 +262,10 @@ class SchedulerDisaggregationPrefillMixin:
256
262
  self.running_batch.batch_is_full = False
257
263
 
258
264
  def process_batch_result_disagg_prefill(
259
- self: Scheduler, batch: ScheduleBatch, result: GenerationBatchResult
265
+ self: Scheduler,
266
+ batch: ScheduleBatch,
267
+ result: GenerationBatchResult,
268
+ launch_done: Optional[threading.Event] = None,
260
269
  ) -> None:
261
270
  """
262
271
  Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
@@ -280,7 +289,7 @@ class SchedulerDisaggregationPrefillMixin:
280
289
  # Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
281
290
  if self.enable_overlap:
282
291
  # wait
283
- _, next_token_ids = self.tp_worker.resolve_batch_result(bid)
292
+ _, next_token_ids = self.tp_worker.resolve_last_batch_result(launch_done)
284
293
  else:
285
294
  next_token_ids = result.next_token_ids.tolist()
286
295
 
@@ -15,6 +15,9 @@ class DisaggregationMode(Enum):
15
15
  DECODE = "decode"
16
16
 
17
17
 
18
+ FakeBootstrapHost = "2.2.2.2"
19
+
20
+
18
21
  def poll_and_all_reduce(pollers, gloo_group):
19
22
  polls = [int(poller.poll()) for poller in pollers]
20
23
  tensor_to_reduce = torch.tensor(polls, dtype=torch.uint8, device="cpu")
@@ -59,6 +62,8 @@ class KVClassType(Enum):
59
62
 
60
63
 
61
64
  def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
65
+ from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
66
+
62
67
  if transfer_backend == TransferBackend.MOONCAKE:
63
68
  from sglang.srt.disaggregation.mooncake import (
64
69
  MooncakeKVBootstrapServer,
@@ -70,7 +75,7 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
70
75
  class_mapping = {
71
76
  KVClassType.MANAGER: MooncakeKVManager,
72
77
  KVClassType.SENDER: MooncakeKVSender,
73
- KVClassType.RECEIVER: MooncakeKVReceiver,
78
+ KVClassType.RECEIVER: (MooncakeKVReceiver),
74
79
  KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer,
75
80
  }
76
81
  return class_mapping.get(class_type)
@@ -85,10 +90,19 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
85
90
  class_mapping = {
86
91
  KVClassType.MANAGER: NixlKVManager,
87
92
  KVClassType.SENDER: NixlKVSender,
88
- KVClassType.RECEIVER: NixlKVReceiver,
93
+ KVClassType.RECEIVER: (NixlKVReceiver),
89
94
  KVClassType.BOOTSTRAP_SERVER: NixlKVBootstrapServer,
90
95
  }
91
96
  return class_mapping.get(class_type)
97
+ if transfer_backend == TransferBackend.FAKE:
98
+ from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
99
+
100
+ class_mapping = {
101
+ KVClassType.SENDER: FakeKVSender,
102
+ KVClassType.RECEIVER: (FakeKVReceiver),
103
+ }
104
+ return class_mapping.get(class_type)
105
+
92
106
  raise ValueError(f"Unsupported transfer backend: {transfer_backend}")
93
107
 
94
108
 
@@ -66,6 +66,7 @@ from sglang.srt.utils import (
66
66
  assert_pkg_version,
67
67
  configure_logger,
68
68
  get_zmq_socket,
69
+ is_cuda,
69
70
  kill_process_tree,
70
71
  launch_dummy_health_check_server,
71
72
  maybe_set_triton_cache_manager,
@@ -78,6 +79,8 @@ from sglang.version import __version__
78
79
  logger = logging.getLogger(__name__)
79
80
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
80
81
 
82
+ _is_cuda = is_cuda()
83
+
81
84
 
82
85
  class Engine(EngineBase):
83
86
  """
@@ -452,6 +455,12 @@ def _set_envs_and_config(server_args: ServerArgs):
452
455
  "reinstall the latest version by following the instructions "
453
456
  "at https://docs.flashinfer.ai/installation.html.",
454
457
  )
458
+ if _is_cuda:
459
+ assert_pkg_version(
460
+ "sgl-kernel",
461
+ "0.1.0",
462
+ "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
463
+ )
455
464
 
456
465
  def sigchld_handler(signum, frame):
457
466
  pid, exitcode = os.waitpid(0, os.WNOHANG)
@@ -42,6 +42,7 @@ from fastapi import FastAPI, File, Form, Request, UploadFile
42
42
  from fastapi.middleware.cors import CORSMiddleware
43
43
  from fastapi.responses import ORJSONResponse, Response, StreamingResponse
44
44
 
45
+ from sglang.srt.disaggregation.utils import FakeBootstrapHost
45
46
  from sglang.srt.entrypoints.engine import _launch_subprocesses
46
47
  from sglang.srt.function_call_parser import FunctionCallParser
47
48
  from sglang.srt.managers.io_struct import (
@@ -821,8 +822,32 @@ def _wait_and_warmup(
821
822
  )
822
823
  assert res.status_code == 200, f"{res}"
823
824
  else:
824
- # Warmup request currently hangs in disaggregation mode, so we skip it.
825
- logger.info("Skipping warmup request in disaggregation mode")
825
+ logger.info(f"Start of prefill warmup ...")
826
+ json_data = {
827
+ "sampling_params": {
828
+ "temperature": 0.0,
829
+ "max_new_tokens": 8,
830
+ "ignore_eos": True,
831
+ },
832
+ "bootstrap_host": [FakeBootstrapHost] * server_args.dp_size,
833
+ # This is a hack to ensure fake transfer is enabled during prefill warmup
834
+ # ensure each dp rank has a unique bootstrap_room during prefill warmup
835
+ "bootstrap_room": [
836
+ i * (2**63 // server_args.dp_size) + (i % server_args.tp_size)
837
+ for i in range(server_args.dp_size)
838
+ ],
839
+ "input_ids": [[0, 1, 2, 3]] * server_args.dp_size,
840
+ }
841
+ res = requests.post(
842
+ url + request_name,
843
+ json=json_data,
844
+ headers=headers,
845
+ timeout=1800, # because of deep gemm precache is very long if not precache.
846
+ )
847
+ logger.info(
848
+ f"End of prefill warmup with status {res.status_code}, resp: {res.json()}"
849
+ )
850
+
826
851
  except Exception:
827
852
  last_traceback = get_exception_traceback()
828
853
  if pipe_finish_writer is not None:
@@ -0,0 +1,278 @@
1
+ from __future__ import annotations
2
+
3
+ """
4
+ Support attention backend for Cutlass MLA.
5
+
6
+ """
7
+
8
+ from dataclasses import dataclass
9
+ from typing import TYPE_CHECKING, Optional, Union
10
+
11
+ import torch
12
+ import triton
13
+
14
+ from sglang.global_config import global_config
15
+ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
16
+ from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
17
+ from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
18
+ from sglang.srt.layers.dp_attention import get_attention_tp_size
19
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
20
+ from sglang.srt.utils import is_cuda
21
+
22
+ if TYPE_CHECKING:
23
+ from sglang.srt.layers.radix_attention import RadixAttention
24
+ from sglang.srt.model_executor.model_runner import ModelRunner
25
+ from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
26
+ from sglang.srt.speculative.spec_info import SpecInfo
27
+
28
+ _is_cuda = is_cuda()
29
+ if _is_cuda:
30
+ from sgl_kernel import cutlass_mla_decode, cutlass_mla_get_workspace_size
31
+
32
+
33
+ # Cutlass MLA only supports pagesize=128
34
+ PAGE_SIZE = 128
35
+
36
+
37
+ @dataclass
38
+ class CutlassMLADecodeMetadata:
39
+ workspace: Optional[torch.Tensor] = None
40
+ block_kv_indices: Optional[torch.Tensor] = None
41
+
42
+ def __init__(
43
+ self,
44
+ workspace: Optional[torch.Tensor] = None,
45
+ block_kv_indices: Optional[torch.Tensor] = None,
46
+ ):
47
+ self.workspace = workspace
48
+ self.block_kv_indices = block_kv_indices
49
+
50
+
51
+ class CutlassMLABackend(FlashInferMLAAttnBackend):
52
+ """Cutlass attention kernels."""
53
+
54
+ def __init__(
55
+ self,
56
+ model_runner: ModelRunner,
57
+ skip_prefill: bool = False,
58
+ kv_indptr_buf: Optional[torch.Tensor] = None,
59
+ kv_last_page_len_buf: Optional[torch.Tensor] = None,
60
+ ):
61
+ super().__init__(
62
+ model_runner, skip_prefill, kv_indptr_buf, kv_last_page_len_buf
63
+ )
64
+
65
+ self.num_q_heads = (
66
+ model_runner.model_config.num_attention_heads // get_attention_tp_size()
67
+ )
68
+ self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
69
+ get_attention_tp_size()
70
+ )
71
+ self.req_to_token = model_runner.req_to_token_pool.req_to_token
72
+ self.num_local_heads = (
73
+ model_runner.model_config.num_attention_heads // get_attention_tp_size()
74
+ )
75
+ self.forward_metadata: Union[CutlassMLADecodeMetadata] = None
76
+ self.kv_lora_rank = model_runner.model_config.kv_lora_rank
77
+ self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
78
+ self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
79
+ self.v_head_dim = model_runner.model_config.v_head_dim
80
+ self.scaling = model_runner.model_config.scaling
81
+ self.data_type = model_runner.kv_cache_dtype
82
+ self.q_data_type = model_runner.dtype
83
+ self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim
84
+
85
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
86
+
87
+ bs = forward_batch.batch_size
88
+ spec_info = forward_batch.spec_info
89
+ if forward_batch.forward_mode.is_decode_or_idle():
90
+ if spec_info is None:
91
+ max_seqlen_pad = triton.cdiv(
92
+ forward_batch.seq_lens_cpu.max().item(), PAGE_SIZE
93
+ )
94
+ block_kv_indices = torch.full(
95
+ (bs, max_seqlen_pad),
96
+ -1,
97
+ dtype=torch.int32,
98
+ device=forward_batch.seq_lens.device,
99
+ )
100
+ create_flashmla_kv_indices_triton[(bs,)](
101
+ self.req_to_token,
102
+ forward_batch.req_pool_indices,
103
+ forward_batch.seq_lens,
104
+ None,
105
+ block_kv_indices,
106
+ self.req_to_token.stride(0),
107
+ max_seqlen_pad,
108
+ PAGE_SIZE,
109
+ )
110
+ workspace_size = cutlass_mla_get_workspace_size(
111
+ max_seqlen_pad * PAGE_SIZE, bs
112
+ )
113
+ workspace = torch.empty(
114
+ workspace_size, device="cuda", dtype=torch.uint8
115
+ )
116
+ self.forward_metadata = CutlassMLADecodeMetadata(
117
+ workspace,
118
+ block_kv_indices,
119
+ )
120
+ else:
121
+ super().init_forward_metadata(forward_batch)
122
+ else:
123
+ super().init_forward_metadata(forward_batch)
124
+
125
+ def init_cuda_graph_state(
126
+ self,
127
+ max_bs: int,
128
+ block_kv_indices: Optional[torch.Tensor] = None,
129
+ ):
130
+ if block_kv_indices is None:
131
+ cuda_graph_kv_indices = torch.full(
132
+ (max_bs, (self.max_context_len + PAGE_SIZE) // PAGE_SIZE),
133
+ 1,
134
+ dtype=torch.int32,
135
+ device="cuda",
136
+ )
137
+ else:
138
+ cuda_graph_kv_indices = block_kv_indices
139
+
140
+ workspace_size = cutlass_mla_get_workspace_size(
141
+ cuda_graph_kv_indices.shape[1] * PAGE_SIZE, max_bs
142
+ )
143
+ self.cuda_graph_mla_workspace = torch.empty(
144
+ workspace_size, device="cuda", dtype=torch.uint8
145
+ )
146
+ self.cuda_graph_kv_indices = cuda_graph_kv_indices
147
+
148
+ def init_forward_metadata_capture_cuda_graph(
149
+ self,
150
+ bs: int,
151
+ num_tokens: int,
152
+ req_pool_indices: torch.Tensor,
153
+ seq_lens: torch.Tensor,
154
+ encoder_lens: Optional[torch.Tensor],
155
+ forward_mode: ForwardMode,
156
+ spec_info: Optional[SpecInfo],
157
+ ):
158
+ if forward_mode.is_decode_or_idle():
159
+ if spec_info is None:
160
+ max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
161
+
162
+ create_flashmla_kv_indices_triton[(bs,)](
163
+ self.req_to_token,
164
+ req_pool_indices,
165
+ seq_lens,
166
+ None,
167
+ self.cuda_graph_kv_indices,
168
+ self.req_to_token.stride(0),
169
+ self.cuda_graph_kv_indices.stride(0),
170
+ PAGE_SIZE,
171
+ )
172
+ workspace_size = cutlass_mla_get_workspace_size(
173
+ max_seqlen_pad * PAGE_SIZE, bs
174
+ )
175
+ self.cuda_graph_mla_workspace = torch.empty(
176
+ workspace_size, device="cuda", dtype=torch.uint8
177
+ )
178
+ self.forward_metadata = CutlassMLADecodeMetadata(
179
+ self.cuda_graph_mla_workspace,
180
+ self.cuda_graph_kv_indices[:bs, :max_seqlen_pad],
181
+ )
182
+ else:
183
+ super().init_forward_metadata_capture_cuda_graph(
184
+ bs,
185
+ num_tokens,
186
+ req_pool_indices,
187
+ seq_lens,
188
+ encoder_lens,
189
+ forward_mode,
190
+ spec_info,
191
+ )
192
+
193
+ def init_forward_metadata_replay_cuda_graph(
194
+ self,
195
+ bs: int,
196
+ req_pool_indices: torch.Tensor,
197
+ seq_lens: torch.Tensor,
198
+ seq_lens_sum: int,
199
+ encoder_lens: Optional[torch.Tensor],
200
+ forward_mode: ForwardMode,
201
+ spec_info: Optional[SpecInfo],
202
+ seq_lens_cpu: Optional[torch.Tensor],
203
+ ):
204
+
205
+ if forward_mode.is_decode_or_idle():
206
+ assert seq_lens_cpu is not None
207
+ seq_lens = seq_lens[:bs]
208
+ seq_lens_cpu = seq_lens_cpu[:bs]
209
+ max_seqlen_pad = triton.cdiv(seq_lens_cpu.max().item(), PAGE_SIZE)
210
+ create_flashmla_kv_indices_triton[(bs,)](
211
+ self.req_to_token,
212
+ req_pool_indices[:bs],
213
+ seq_lens,
214
+ None,
215
+ self.cuda_graph_kv_indices,
216
+ self.req_to_token.stride(0),
217
+ self.cuda_graph_kv_indices.stride(0),
218
+ PAGE_SIZE,
219
+ )
220
+ workspace_size = cutlass_mla_get_workspace_size(
221
+ max_seqlen_pad * PAGE_SIZE, bs
222
+ )
223
+ self.cuda_graph_mla_workspace = torch.empty(
224
+ workspace_size, device="cuda", dtype=torch.uint8
225
+ )
226
+ self.forward_metadata.workspace = self.cuda_graph_mla_workspace
227
+ self.forward_metadata.block_kv_indices = self.cuda_graph_kv_indices[
228
+ :bs, :max_seqlen_pad
229
+ ]
230
+ else:
231
+ super().init_forward_metadata_replay_cuda_graph(
232
+ bs,
233
+ req_pool_indices,
234
+ seq_lens,
235
+ seq_lens_sum,
236
+ encoder_lens,
237
+ forward_mode,
238
+ spec_info,
239
+ seq_lens_cpu,
240
+ )
241
+
242
+ def get_cuda_graph_seq_len_fill_value(self):
243
+ return 1
244
+
245
+ def forward_decode(
246
+ self,
247
+ q: torch.Tensor,
248
+ k: torch.Tensor,
249
+ v: torch.Tensor,
250
+ layer: RadixAttention,
251
+ forward_batch: ForwardBatch,
252
+ save_kv_cache: bool = True,
253
+ ):
254
+ cache_loc = forward_batch.out_cache_loc
255
+
256
+ if k is not None:
257
+ assert v is not None
258
+ if save_kv_cache:
259
+ forward_batch.token_to_kv_pool.set_kv_buffer(
260
+ layer,
261
+ cache_loc,
262
+ k,
263
+ v,
264
+ )
265
+ bs = forward_batch.batch_size
266
+ k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
267
+
268
+ reshape_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
269
+
270
+ o = cutlass_mla_decode(
271
+ q_nope_and_q_pe=reshape_q,
272
+ kv_c_and_k_pe_cache=k_cache.view(-1, PAGE_SIZE, self.kv_cache_dim),
273
+ seq_lens=forward_batch.seq_lens.to(torch.int32),
274
+ page_table=self.forward_metadata.block_kv_indices,
275
+ workspace=self.forward_metadata.workspace,
276
+ )
277
+
278
+ return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
@@ -49,8 +49,8 @@ def create_flashmla_kv_indices_triton(
49
49
  kv_indices_ptr,
50
50
  req_to_token_ptr_stride: tl.constexpr,
51
51
  kv_indices_ptr_stride: tl.constexpr,
52
+ PAGED_SIZE: tl.constexpr = 64,
52
53
  ):
53
- PAGED_SIZE: tl.constexpr = 64
54
54
  BLOCK_SIZE: tl.constexpr = 4096
55
55
  NUM_PAGE_PER_BLOCK: tl.constexpr = 64
56
56
  pid = tl.program_id(axis=0)