sglang 0.4.6__py3-none-any.whl → 0.4.6.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. sglang/bench_one_batch.py +2 -0
  2. sglang/check_env.py +3 -3
  3. sglang/srt/configs/__init__.py +4 -0
  4. sglang/srt/configs/kimi_vl.py +38 -0
  5. sglang/srt/configs/kimi_vl_moonvit.py +32 -0
  6. sglang/srt/configs/model_config.py +15 -0
  7. sglang/srt/conversation.py +122 -1
  8. sglang/srt/disaggregation/decode.py +8 -2
  9. sglang/srt/disaggregation/fake/__init__.py +1 -0
  10. sglang/srt/disaggregation/fake/conn.py +88 -0
  11. sglang/srt/disaggregation/prefill.py +12 -3
  12. sglang/srt/disaggregation/utils.py +16 -2
  13. sglang/srt/entrypoints/engine.py +52 -21
  14. sglang/srt/entrypoints/http_server.py +27 -2
  15. sglang/srt/function_call_parser.py +97 -0
  16. sglang/srt/hf_transformers_utils.py +2 -0
  17. sglang/srt/layers/attention/cutlass_mla_backend.py +278 -0
  18. sglang/srt/layers/attention/flashinfer_backend.py +107 -82
  19. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -16
  20. sglang/srt/layers/attention/flashmla_backend.py +3 -0
  21. sglang/srt/layers/attention/utils.py +1 -1
  22. sglang/srt/layers/dp_attention.py +5 -2
  23. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +1 -3
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H20.json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H200.json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  38. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200.json +146 -0
  39. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=96,device_name=NVIDIA_H20.json +146 -0
  40. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +10 -8
  41. sglang/srt/layers/moe/fused_moe_triton/layer.py +15 -17
  42. sglang/srt/layers/quantization/__init__.py +2 -2
  43. sglang/srt/layers/quantization/deep_gemm.py +1 -1
  44. sglang/srt/layers/quantization/fp8.py +20 -22
  45. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  46. sglang/srt/layers/utils.py +35 -0
  47. sglang/srt/lora/layers.py +35 -9
  48. sglang/srt/lora/lora_manager.py +84 -35
  49. sglang/srt/managers/data_parallel_controller.py +52 -34
  50. sglang/srt/managers/multimodal_processors/kimi_vl.py +73 -0
  51. sglang/srt/managers/schedule_batch.py +34 -15
  52. sglang/srt/managers/scheduler.py +273 -67
  53. sglang/srt/managers/scheduler_output_processor_mixin.py +26 -10
  54. sglang/srt/managers/tp_worker.py +52 -17
  55. sglang/srt/managers/tp_worker_overlap_thread.py +18 -7
  56. sglang/srt/mem_cache/memory_pool.py +70 -36
  57. sglang/srt/model_executor/cuda_graph_runner.py +82 -19
  58. sglang/srt/model_executor/forward_batch_info.py +31 -1
  59. sglang/srt/model_executor/model_runner.py +123 -58
  60. sglang/srt/models/deepseek_nextn.py +1 -257
  61. sglang/srt/models/deepseek_v2.py +78 -18
  62. sglang/srt/models/kimi_vl.py +308 -0
  63. sglang/srt/models/kimi_vl_moonvit.py +639 -0
  64. sglang/srt/models/llama.py +92 -30
  65. sglang/srt/models/llama4.py +2 -1
  66. sglang/srt/models/llama_eagle.py +4 -1
  67. sglang/srt/models/llama_eagle3.py +4 -1
  68. sglang/srt/models/qwen2_moe.py +8 -3
  69. sglang/srt/models/qwen2_vl.py +0 -12
  70. sglang/srt/models/qwen3_moe.py +8 -3
  71. sglang/srt/openai_api/adapter.py +49 -8
  72. sglang/srt/openai_api/protocol.py +13 -1
  73. sglang/srt/reasoning_parser.py +25 -1
  74. sglang/srt/server_args.py +83 -24
  75. sglang/srt/speculative/eagle_worker.py +3 -2
  76. sglang/srt/utils.py +91 -9
  77. sglang/test/runners.py +4 -0
  78. sglang/test/send_one.py +84 -28
  79. sglang/test/test_utils.py +67 -0
  80. sglang/version.py +1 -1
  81. {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/METADATA +5 -4
  82. {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/RECORD +85 -60
  83. {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/WHEEL +1 -1
  84. {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/licenses/LICENSE +0 -0
  85. {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/top_level.txt +0 -0
@@ -42,6 +42,7 @@ from fastapi import FastAPI, File, Form, Request, UploadFile
42
42
  from fastapi.middleware.cors import CORSMiddleware
43
43
  from fastapi.responses import ORJSONResponse, Response, StreamingResponse
44
44
 
45
+ from sglang.srt.disaggregation.utils import FakeBootstrapHost
45
46
  from sglang.srt.entrypoints.engine import _launch_subprocesses
46
47
  from sglang.srt.function_call_parser import FunctionCallParser
47
48
  from sglang.srt.managers.io_struct import (
@@ -821,8 +822,32 @@ def _wait_and_warmup(
821
822
  )
822
823
  assert res.status_code == 200, f"{res}"
823
824
  else:
824
- # Warmup request currently hangs in disaggregation mode, so we skip it.
825
- logger.info("Skipping warmup request in disaggregation mode")
825
+ logger.info(f"Start of prefill warmup ...")
826
+ json_data = {
827
+ "sampling_params": {
828
+ "temperature": 0.0,
829
+ "max_new_tokens": 8,
830
+ "ignore_eos": True,
831
+ },
832
+ "bootstrap_host": [FakeBootstrapHost] * server_args.dp_size,
833
+ # This is a hack to ensure fake transfer is enabled during prefill warmup
834
+ # ensure each dp rank has a unique bootstrap_room during prefill warmup
835
+ "bootstrap_room": [
836
+ i * (2**63 // server_args.dp_size) + (i % server_args.tp_size)
837
+ for i in range(server_args.dp_size)
838
+ ],
839
+ "input_ids": [[0, 1, 2, 3]] * server_args.dp_size,
840
+ }
841
+ res = requests.post(
842
+ url + request_name,
843
+ json=json_data,
844
+ headers=headers,
845
+ timeout=1800, # because of deep gemm precache is very long if not precache.
846
+ )
847
+ logger.info(
848
+ f"End of prefill warmup with status {res.status_code}, resp: {res.json()}"
849
+ )
850
+
826
851
  except Exception:
827
852
  last_traceback = get_exception_traceback()
828
853
  if pipe_finish_writer is not None:
@@ -1,3 +1,4 @@
1
+ import ast
1
2
  import json
2
3
  import logging
3
4
  import re
@@ -664,6 +665,101 @@ class MultiFormatParser:
664
665
  return final_normal_text, final_calls
665
666
 
666
667
 
668
+ class PythonicDetector(BaseFormatDetector):
669
+ """
670
+ Detector for Llama-3.2 and Llama-4 models with pythonic tool call format.
671
+ Assumes function call format:
672
+ [tool1(arg1=val1, arg2=val2), tool2(arg1=val3)]
673
+ Arguments are Python literals (not JSON).
674
+ """
675
+
676
+ def __init__(self):
677
+ super().__init__()
678
+ self.tool_call_regex = re.compile(
679
+ r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]",
680
+ re.DOTALL,
681
+ )
682
+
683
+ def has_tool_call(self, text: str) -> bool:
684
+ return bool(self.tool_call_regex.match(text.strip()))
685
+
686
+ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
687
+ # Try parsing the text as a Python list of function calls
688
+ text = text.strip()
689
+ if not (text.startswith("[") and text.endswith("]")):
690
+ # Not a pythonic tool call format
691
+ return StreamingParseResult(normal_text=text, calls=[])
692
+ try:
693
+ module = ast.parse(text)
694
+ parsed = getattr(module.body[0], "value", None)
695
+ if not (
696
+ isinstance(parsed, ast.List)
697
+ and all(isinstance(e, ast.Call) for e in parsed.elts)
698
+ ):
699
+ return StreamingParseResult(normal_text=text, calls=[])
700
+ calls = []
701
+ tool_indices = {
702
+ tool.function.name: i
703
+ for i, tool in enumerate(tools)
704
+ if tool.function.name
705
+ }
706
+ for call in parsed.elts:
707
+ if not isinstance(call.func, ast.Name):
708
+ continue
709
+ function_name = call.func.id
710
+ arguments = {}
711
+ for keyword in call.keywords:
712
+ arguments[keyword.arg] = self._get_parameter_value(keyword.value)
713
+ calls.append(
714
+ ToolCallItem(
715
+ tool_index=tool_indices.get(function_name, -1),
716
+ name=function_name,
717
+ parameters=json.dumps(arguments, ensure_ascii=False),
718
+ )
719
+ )
720
+ return StreamingParseResult(normal_text="", calls=calls)
721
+ except Exception:
722
+ logger.exception("Error in pythonic tool call parsing.")
723
+ return StreamingParseResult(normal_text=text, calls=[])
724
+
725
+ def parse_streaming_increment(
726
+ self, new_text: str, tools: List[Tool]
727
+ ) -> StreamingParseResult:
728
+ """
729
+ Streaming incremental parsing for pythonic tool calls.
730
+ Buffers input until a complete pythonic tool call (from [ to ]) is found,
731
+ then parses and emits any detected calls.
732
+ """
733
+ self._buffer += new_text
734
+ start = self._buffer.find("[")
735
+ end = self._buffer.find("]", start)
736
+ if start != -1 and end != -1:
737
+ call_text = self._buffer[start : end + 1]
738
+ result = self.detect_and_parse(call_text, tools)
739
+ self._buffer = self._buffer[end + 1 :]
740
+ return result
741
+ return StreamingParseResult(normal_text="")
742
+
743
+ def _get_parameter_value(self, val):
744
+ if isinstance(val, ast.Constant):
745
+ return val.value
746
+ elif isinstance(val, ast.Dict):
747
+ return {
748
+ k.value: self._get_parameter_value(v)
749
+ for k, v in zip(val.keys, val.values)
750
+ }
751
+ elif isinstance(val, ast.List):
752
+ return [self._get_parameter_value(v) for v in val.elts]
753
+ else:
754
+ raise ValueError("Tool call arguments must be literals")
755
+
756
+ def structure_info(self) -> _GetInfoFunc:
757
+ def info(name: str):
758
+ return StructureInfo(begin="[", end="]", trigger="")
759
+
760
+ return info
761
+
762
+
667
763
  class FunctionCallParser:
668
764
  """
669
765
  In streaming scenarios, each time new_text is received, it calls multi_format_parser.parse_streaming_increment
@@ -675,6 +771,7 @@ class FunctionCallParser:
675
771
  "qwen25": Qwen25Detector,
676
772
  "mistral": MistralDetector,
677
773
  "deepseekv3": DeepSeekV3Detector,
774
+ "pythonic": PythonicDetector,
678
775
  }
679
776
 
680
777
  def __init__(self, tools: List[Tool], tool_call_parser: str):
@@ -35,6 +35,7 @@ from sglang.srt.configs import (
35
35
  DbrxConfig,
36
36
  DeepseekVL2Config,
37
37
  ExaoneConfig,
38
+ KimiVLConfig,
38
39
  MultiModalityConfig,
39
40
  )
40
41
  from sglang.srt.connector import create_remote_connector
@@ -46,6 +47,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
46
47
  ExaoneConfig.model_type: ExaoneConfig,
47
48
  DeepseekVL2Config.model_type: DeepseekVL2Config,
48
49
  MultiModalityConfig.model_type: MultiModalityConfig,
50
+ KimiVLConfig.model_type: KimiVLConfig,
49
51
  }
50
52
 
51
53
  for name, cls in _CONFIG_REGISTRY.items():
@@ -0,0 +1,278 @@
1
+ from __future__ import annotations
2
+
3
+ """
4
+ Support attention backend for Cutlass MLA.
5
+
6
+ """
7
+
8
+ from dataclasses import dataclass
9
+ from typing import TYPE_CHECKING, Optional, Union
10
+
11
+ import torch
12
+ import triton
13
+
14
+ from sglang.global_config import global_config
15
+ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
16
+ from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
17
+ from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
18
+ from sglang.srt.layers.dp_attention import get_attention_tp_size
19
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
20
+ from sglang.srt.utils import is_cuda
21
+
22
+ if TYPE_CHECKING:
23
+ from sglang.srt.layers.radix_attention import RadixAttention
24
+ from sglang.srt.model_executor.model_runner import ModelRunner
25
+ from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
26
+ from sglang.srt.speculative.spec_info import SpecInfo
27
+
28
+ _is_cuda = is_cuda()
29
+ if _is_cuda:
30
+ from sgl_kernel import cutlass_mla_decode, cutlass_mla_get_workspace_size
31
+
32
+
33
+ # Cutlass MLA only supports pagesize=128
34
+ PAGE_SIZE = 128
35
+
36
+
37
+ @dataclass
38
+ class CutlassMLADecodeMetadata:
39
+ workspace: Optional[torch.Tensor] = None
40
+ block_kv_indices: Optional[torch.Tensor] = None
41
+
42
+ def __init__(
43
+ self,
44
+ workspace: Optional[torch.Tensor] = None,
45
+ block_kv_indices: Optional[torch.Tensor] = None,
46
+ ):
47
+ self.workspace = workspace
48
+ self.block_kv_indices = block_kv_indices
49
+
50
+
51
+ class CutlassMLABackend(FlashInferMLAAttnBackend):
52
+ """Cutlass attention kernels."""
53
+
54
+ def __init__(
55
+ self,
56
+ model_runner: ModelRunner,
57
+ skip_prefill: bool = False,
58
+ kv_indptr_buf: Optional[torch.Tensor] = None,
59
+ kv_last_page_len_buf: Optional[torch.Tensor] = None,
60
+ ):
61
+ super().__init__(
62
+ model_runner, skip_prefill, kv_indptr_buf, kv_last_page_len_buf
63
+ )
64
+
65
+ self.num_q_heads = (
66
+ model_runner.model_config.num_attention_heads // get_attention_tp_size()
67
+ )
68
+ self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
69
+ get_attention_tp_size()
70
+ )
71
+ self.req_to_token = model_runner.req_to_token_pool.req_to_token
72
+ self.num_local_heads = (
73
+ model_runner.model_config.num_attention_heads // get_attention_tp_size()
74
+ )
75
+ self.forward_metadata: Union[CutlassMLADecodeMetadata] = None
76
+ self.kv_lora_rank = model_runner.model_config.kv_lora_rank
77
+ self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
78
+ self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
79
+ self.v_head_dim = model_runner.model_config.v_head_dim
80
+ self.scaling = model_runner.model_config.scaling
81
+ self.data_type = model_runner.kv_cache_dtype
82
+ self.q_data_type = model_runner.dtype
83
+ self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim
84
+
85
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
86
+
87
+ bs = forward_batch.batch_size
88
+ spec_info = forward_batch.spec_info
89
+ if forward_batch.forward_mode.is_decode_or_idle():
90
+ if spec_info is None:
91
+ max_seqlen_pad = triton.cdiv(
92
+ forward_batch.seq_lens_cpu.max().item(), PAGE_SIZE
93
+ )
94
+ block_kv_indices = torch.full(
95
+ (bs, max_seqlen_pad),
96
+ -1,
97
+ dtype=torch.int32,
98
+ device=forward_batch.seq_lens.device,
99
+ )
100
+ create_flashmla_kv_indices_triton[(bs,)](
101
+ self.req_to_token,
102
+ forward_batch.req_pool_indices,
103
+ forward_batch.seq_lens,
104
+ None,
105
+ block_kv_indices,
106
+ self.req_to_token.stride(0),
107
+ max_seqlen_pad,
108
+ PAGE_SIZE,
109
+ )
110
+ workspace_size = cutlass_mla_get_workspace_size(
111
+ max_seqlen_pad * PAGE_SIZE, bs
112
+ )
113
+ workspace = torch.empty(
114
+ workspace_size, device="cuda", dtype=torch.uint8
115
+ )
116
+ self.forward_metadata = CutlassMLADecodeMetadata(
117
+ workspace,
118
+ block_kv_indices,
119
+ )
120
+ else:
121
+ super().init_forward_metadata(forward_batch)
122
+ else:
123
+ super().init_forward_metadata(forward_batch)
124
+
125
+ def init_cuda_graph_state(
126
+ self,
127
+ max_bs: int,
128
+ block_kv_indices: Optional[torch.Tensor] = None,
129
+ ):
130
+ if block_kv_indices is None:
131
+ cuda_graph_kv_indices = torch.full(
132
+ (max_bs, (self.max_context_len + PAGE_SIZE) // PAGE_SIZE),
133
+ 1,
134
+ dtype=torch.int32,
135
+ device="cuda",
136
+ )
137
+ else:
138
+ cuda_graph_kv_indices = block_kv_indices
139
+
140
+ workspace_size = cutlass_mla_get_workspace_size(
141
+ cuda_graph_kv_indices.shape[1] * PAGE_SIZE, max_bs
142
+ )
143
+ self.cuda_graph_mla_workspace = torch.empty(
144
+ workspace_size, device="cuda", dtype=torch.uint8
145
+ )
146
+ self.cuda_graph_kv_indices = cuda_graph_kv_indices
147
+
148
+ def init_forward_metadata_capture_cuda_graph(
149
+ self,
150
+ bs: int,
151
+ num_tokens: int,
152
+ req_pool_indices: torch.Tensor,
153
+ seq_lens: torch.Tensor,
154
+ encoder_lens: Optional[torch.Tensor],
155
+ forward_mode: ForwardMode,
156
+ spec_info: Optional[SpecInfo],
157
+ ):
158
+ if forward_mode.is_decode_or_idle():
159
+ if spec_info is None:
160
+ max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
161
+
162
+ create_flashmla_kv_indices_triton[(bs,)](
163
+ self.req_to_token,
164
+ req_pool_indices,
165
+ seq_lens,
166
+ None,
167
+ self.cuda_graph_kv_indices,
168
+ self.req_to_token.stride(0),
169
+ self.cuda_graph_kv_indices.stride(0),
170
+ PAGE_SIZE,
171
+ )
172
+ workspace_size = cutlass_mla_get_workspace_size(
173
+ max_seqlen_pad * PAGE_SIZE, bs
174
+ )
175
+ self.cuda_graph_mla_workspace = torch.empty(
176
+ workspace_size, device="cuda", dtype=torch.uint8
177
+ )
178
+ self.forward_metadata = CutlassMLADecodeMetadata(
179
+ self.cuda_graph_mla_workspace,
180
+ self.cuda_graph_kv_indices[:bs, :max_seqlen_pad],
181
+ )
182
+ else:
183
+ super().init_forward_metadata_capture_cuda_graph(
184
+ bs,
185
+ num_tokens,
186
+ req_pool_indices,
187
+ seq_lens,
188
+ encoder_lens,
189
+ forward_mode,
190
+ spec_info,
191
+ )
192
+
193
+ def init_forward_metadata_replay_cuda_graph(
194
+ self,
195
+ bs: int,
196
+ req_pool_indices: torch.Tensor,
197
+ seq_lens: torch.Tensor,
198
+ seq_lens_sum: int,
199
+ encoder_lens: Optional[torch.Tensor],
200
+ forward_mode: ForwardMode,
201
+ spec_info: Optional[SpecInfo],
202
+ seq_lens_cpu: Optional[torch.Tensor],
203
+ ):
204
+
205
+ if forward_mode.is_decode_or_idle():
206
+ assert seq_lens_cpu is not None
207
+ seq_lens = seq_lens[:bs]
208
+ seq_lens_cpu = seq_lens_cpu[:bs]
209
+ max_seqlen_pad = triton.cdiv(seq_lens_cpu.max().item(), PAGE_SIZE)
210
+ create_flashmla_kv_indices_triton[(bs,)](
211
+ self.req_to_token,
212
+ req_pool_indices[:bs],
213
+ seq_lens,
214
+ None,
215
+ self.cuda_graph_kv_indices,
216
+ self.req_to_token.stride(0),
217
+ self.cuda_graph_kv_indices.stride(0),
218
+ PAGE_SIZE,
219
+ )
220
+ workspace_size = cutlass_mla_get_workspace_size(
221
+ max_seqlen_pad * PAGE_SIZE, bs
222
+ )
223
+ self.cuda_graph_mla_workspace = torch.empty(
224
+ workspace_size, device="cuda", dtype=torch.uint8
225
+ )
226
+ self.forward_metadata.workspace = self.cuda_graph_mla_workspace
227
+ self.forward_metadata.block_kv_indices = self.cuda_graph_kv_indices[
228
+ :bs, :max_seqlen_pad
229
+ ]
230
+ else:
231
+ super().init_forward_metadata_replay_cuda_graph(
232
+ bs,
233
+ req_pool_indices,
234
+ seq_lens,
235
+ seq_lens_sum,
236
+ encoder_lens,
237
+ forward_mode,
238
+ spec_info,
239
+ seq_lens_cpu,
240
+ )
241
+
242
+ def get_cuda_graph_seq_len_fill_value(self):
243
+ return 1
244
+
245
+ def forward_decode(
246
+ self,
247
+ q: torch.Tensor,
248
+ k: torch.Tensor,
249
+ v: torch.Tensor,
250
+ layer: RadixAttention,
251
+ forward_batch: ForwardBatch,
252
+ save_kv_cache: bool = True,
253
+ ):
254
+ cache_loc = forward_batch.out_cache_loc
255
+
256
+ if k is not None:
257
+ assert v is not None
258
+ if save_kv_cache:
259
+ forward_batch.token_to_kv_pool.set_kv_buffer(
260
+ layer,
261
+ cache_loc,
262
+ k,
263
+ v,
264
+ )
265
+ bs = forward_batch.batch_size
266
+ k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
267
+
268
+ reshape_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
269
+
270
+ o = cutlass_mla_decode(
271
+ q_nope_and_q_pe=reshape_q.to(self.q_data_type),
272
+ kv_c_and_k_pe_cache=k_cache.view(-1, PAGE_SIZE, self.kv_cache_dim),
273
+ seq_lens=forward_batch.seq_lens.to(torch.int32),
274
+ page_table=self.forward_metadata.block_kv_indices,
275
+ workspace=self.forward_metadata.workspace,
276
+ )
277
+
278
+ return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)