sglang 0.4.5.post3__py3-none-any.whl → 0.4.6.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. sglang/bench_one_batch.py +19 -3
  2. sglang/bench_serving.py +8 -9
  3. sglang/compile_deep_gemm.py +45 -4
  4. sglang/srt/code_completion_parser.py +1 -1
  5. sglang/srt/configs/deepseekvl2.py +1 -1
  6. sglang/srt/configs/model_config.py +9 -3
  7. sglang/srt/constrained/llguidance_backend.py +78 -61
  8. sglang/srt/conversation.py +34 -1
  9. sglang/srt/disaggregation/decode.py +67 -13
  10. sglang/srt/disaggregation/fake/__init__.py +1 -0
  11. sglang/srt/disaggregation/fake/conn.py +88 -0
  12. sglang/srt/disaggregation/mini_lb.py +45 -8
  13. sglang/srt/disaggregation/mooncake/conn.py +198 -31
  14. sglang/srt/disaggregation/prefill.py +36 -12
  15. sglang/srt/disaggregation/utils.py +16 -2
  16. sglang/srt/entrypoints/engine.py +9 -0
  17. sglang/srt/entrypoints/http_server.py +35 -4
  18. sglang/srt/function_call_parser.py +77 -5
  19. sglang/srt/layers/attention/base_attn_backend.py +3 -0
  20. sglang/srt/layers/attention/cutlass_mla_backend.py +278 -0
  21. sglang/srt/layers/attention/flashattention_backend.py +28 -10
  22. sglang/srt/layers/attention/flashmla_backend.py +8 -11
  23. sglang/srt/layers/attention/utils.py +1 -1
  24. sglang/srt/layers/attention/vision.py +2 -0
  25. sglang/srt/layers/layernorm.py +38 -16
  26. sglang/srt/layers/logits_processor.py +2 -2
  27. sglang/srt/layers/moe/fused_moe_native.py +2 -4
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H20.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H200.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  38. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200.json +146 -0
  39. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=96,device_name=NVIDIA_H20.json +146 -0
  40. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +41 -41
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +20 -17
  43. sglang/srt/layers/moe/fused_moe_triton/layer.py +15 -17
  44. sglang/srt/layers/pooler.py +6 -0
  45. sglang/srt/layers/quantization/awq.py +5 -1
  46. sglang/srt/layers/quantization/deep_gemm.py +17 -10
  47. sglang/srt/layers/quantization/fp8.py +20 -22
  48. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  49. sglang/srt/layers/quantization/int8_kernel.py +32 -1
  50. sglang/srt/layers/radix_attention.py +13 -3
  51. sglang/srt/layers/rotary_embedding.py +170 -126
  52. sglang/srt/managers/data_parallel_controller.py +10 -3
  53. sglang/srt/managers/io_struct.py +7 -0
  54. sglang/srt/managers/mm_utils.py +85 -28
  55. sglang/srt/managers/multimodal_processors/base_processor.py +14 -1
  56. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +9 -2
  57. sglang/srt/managers/multimodal_processors/gemma3.py +2 -5
  58. sglang/srt/managers/multimodal_processors/janus_pro.py +2 -2
  59. sglang/srt/managers/multimodal_processors/minicpm.py +4 -3
  60. sglang/srt/managers/multimodal_processors/qwen_vl.py +38 -13
  61. sglang/srt/managers/schedule_batch.py +38 -12
  62. sglang/srt/managers/scheduler.py +41 -28
  63. sglang/srt/managers/scheduler_output_processor_mixin.py +25 -9
  64. sglang/srt/managers/tokenizer_manager.py +5 -1
  65. sglang/srt/managers/tp_worker.py +3 -3
  66. sglang/srt/managers/tp_worker_overlap_thread.py +9 -4
  67. sglang/srt/mem_cache/memory_pool.py +87 -0
  68. sglang/srt/model_executor/cuda_graph_runner.py +4 -3
  69. sglang/srt/model_executor/forward_batch_info.py +51 -95
  70. sglang/srt/model_executor/model_runner.py +19 -25
  71. sglang/srt/models/deepseek.py +12 -2
  72. sglang/srt/models/deepseek_nextn.py +101 -6
  73. sglang/srt/models/deepseek_v2.py +144 -70
  74. sglang/srt/models/deepseek_vl2.py +9 -4
  75. sglang/srt/models/gemma3_causal.py +1 -1
  76. sglang/srt/models/llama4.py +0 -1
  77. sglang/srt/models/minicpmo.py +5 -1
  78. sglang/srt/models/mllama4.py +2 -2
  79. sglang/srt/models/qwen2_5_vl.py +3 -6
  80. sglang/srt/models/qwen2_vl.py +3 -7
  81. sglang/srt/models/roberta.py +178 -0
  82. sglang/srt/openai_api/adapter.py +50 -11
  83. sglang/srt/openai_api/protocol.py +2 -0
  84. sglang/srt/reasoning_parser.py +25 -1
  85. sglang/srt/server_args.py +31 -24
  86. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  87. sglang/srt/torch_memory_saver_adapter.py +10 -1
  88. sglang/srt/utils.py +5 -1
  89. sglang/test/runners.py +6 -13
  90. sglang/test/send_one.py +84 -28
  91. sglang/test/test_utils.py +74 -18
  92. sglang/version.py +1 -1
  93. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.post1.dist-info}/METADATA +5 -6
  94. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.post1.dist-info}/RECORD +97 -80
  95. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.post1.dist-info}/WHEEL +1 -1
  96. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.post1.dist-info}/licenses/LICENSE +0 -0
  97. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.post1.dist-info}/top_level.txt +0 -0
@@ -42,6 +42,7 @@ from fastapi import FastAPI, File, Form, Request, UploadFile
42
42
  from fastapi.middleware.cors import CORSMiddleware
43
43
  from fastapi.responses import ORJSONResponse, Response, StreamingResponse
44
44
 
45
+ from sglang.srt.disaggregation.utils import FakeBootstrapHost
45
46
  from sglang.srt.entrypoints.engine import _launch_subprocesses
46
47
  from sglang.srt.function_call_parser import FunctionCallParser
47
48
  from sglang.srt.managers.io_struct import (
@@ -84,6 +85,7 @@ from sglang.srt.utils import (
84
85
  add_api_key_middleware,
85
86
  add_prometheus_middleware,
86
87
  delete_directory,
88
+ get_bool_env_var,
87
89
  kill_process_tree,
88
90
  set_uvicorn_logging_configs,
89
91
  )
@@ -126,7 +128,10 @@ async def lifespan(fast_api_app: FastAPI):
126
128
 
127
129
 
128
130
  # Fast API
129
- app = FastAPI(lifespan=lifespan)
131
+ app = FastAPI(
132
+ lifespan=lifespan,
133
+ openapi_url=None if get_bool_env_var("DISABLE_OPENAPI_DOC") else "/openapi.json",
134
+ )
130
135
  app.add_middleware(
131
136
  CORSMiddleware,
132
137
  allow_origins=["*"],
@@ -277,7 +282,9 @@ async def generate_from_file_request(file: UploadFile, request: Request):
277
282
  )
278
283
 
279
284
  try:
280
- ret = await _global_state.generate_request(obj, request).__anext__()
285
+ ret = await _global_state.tokenizer_manager.generate_request(
286
+ obj, request
287
+ ).__anext__()
281
288
  return ret
282
289
  except ValueError as e:
283
290
  logger.error(f"Error: {e}")
@@ -815,8 +822,32 @@ def _wait_and_warmup(
815
822
  )
816
823
  assert res.status_code == 200, f"{res}"
817
824
  else:
818
- # Warmup request currently hangs in disaggregation mode, so we skip it.
819
- logger.info("Skipping warmup request in disaggregation mode")
825
+ logger.info(f"Start of prefill warmup ...")
826
+ json_data = {
827
+ "sampling_params": {
828
+ "temperature": 0.0,
829
+ "max_new_tokens": 8,
830
+ "ignore_eos": True,
831
+ },
832
+ "bootstrap_host": [FakeBootstrapHost] * server_args.dp_size,
833
+ # This is a hack to ensure fake transfer is enabled during prefill warmup
834
+ # ensure each dp rank has a unique bootstrap_room during prefill warmup
835
+ "bootstrap_room": [
836
+ i * (2**63 // server_args.dp_size) + (i % server_args.tp_size)
837
+ for i in range(server_args.dp_size)
838
+ ],
839
+ "input_ids": [[0, 1, 2, 3]] * server_args.dp_size,
840
+ }
841
+ res = requests.post(
842
+ url + request_name,
843
+ json=json_data,
844
+ headers=headers,
845
+ timeout=1800, # because of deep gemm precache is very long if not precache.
846
+ )
847
+ logger.info(
848
+ f"End of prefill warmup with status {res.status_code}, resp: {res.json()}"
849
+ )
850
+
820
851
  except Exception:
821
852
  last_traceback = get_exception_traceback()
822
853
  if pipe_finish_writer is not None:
@@ -491,6 +491,7 @@ class DeepSeekV3Detector(BaseFormatDetector):
491
491
  self.eot_token = "<|tool▁calls▁end|>"
492
492
  self.func_call_regex = r"<|tool▁call▁begin|>.*?<|tool▁call▁end|>"
493
493
  self.func_detail_regex = r"<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)\n```json\n(.*)\n```<|tool▁call▁end|>"
494
+ self._last_arguments = ""
494
495
 
495
496
  def has_tool_call(self, text: str) -> bool:
496
497
  """Check if the text contains a deepseek format tool call."""
@@ -528,13 +529,84 @@ class DeepSeekV3Detector(BaseFormatDetector):
528
529
 
529
530
  def structure_info(self) -> _GetInfoFunc:
530
531
  return lambda name: StructureInfo(
531
- begin="<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>"
532
- + name
533
- + "\n```json\n",
534
- end="\n```<|tool▁call▁end|><|tool▁calls▁end|>",
535
- trigger="<|tool▁calls▁begin|>",
532
+ begin=">" + name + "\n```json\n",
533
+ end="\n```<",
534
+ trigger=">" + name + "\n```json\n",
536
535
  )
537
536
 
537
+ def parse_streaming_increment(
538
+ self, new_text: str, tools: List[Tool]
539
+ ) -> StreamingParseResult:
540
+ """
541
+ Streaming incremental parsing tool calls for DeepSeekV3 format.
542
+ """
543
+ self._buffer += new_text
544
+ current_text = self._buffer
545
+
546
+ if self.bot_token not in current_text:
547
+ self._buffer = ""
548
+ for e_token in [self.eot_token, "```", "<|tool▁call▁end|>"]:
549
+ if e_token in new_text:
550
+ new_text = new_text.replace(e_token, "")
551
+ return StreamingParseResult(normal_text=new_text)
552
+
553
+ if not hasattr(self, "_tool_indices"):
554
+ self._tool_indices = {
555
+ tool.function.name: i
556
+ for i, tool in enumerate(tools)
557
+ if tool.function and tool.function.name
558
+ }
559
+
560
+ calls: list[ToolCallItem] = []
561
+ try:
562
+ partial_match = re.search(
563
+ pattern=r"<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)\n```json\n(.*)",
564
+ string=current_text,
565
+ flags=re.DOTALL,
566
+ )
567
+ if partial_match:
568
+ func_name = partial_match.group(2).strip()
569
+ func_args_raw = partial_match.group(3).strip()
570
+
571
+ if not self.current_tool_name_sent:
572
+ calls.append(
573
+ ToolCallItem(
574
+ tool_index=self._tool_indices.get(func_name, 0),
575
+ name=func_name,
576
+ parameters="",
577
+ )
578
+ )
579
+ self.current_tool_name_sent = True
580
+ else:
581
+ argument_diff = (
582
+ func_args_raw[len(self._last_arguments) :]
583
+ if func_args_raw.startswith(self._last_arguments)
584
+ else func_args_raw
585
+ )
586
+
587
+ if argument_diff:
588
+ calls.append(
589
+ ToolCallItem(
590
+ tool_index=self._tool_indices.get(func_name, 0),
591
+ name=None,
592
+ parameters=argument_diff,
593
+ )
594
+ )
595
+ self._last_arguments += argument_diff
596
+
597
+ if _is_complete_json(func_args_raw):
598
+ result = StreamingParseResult(normal_text="", calls=calls)
599
+ self._buffer = ""
600
+ self._last_arguments = ""
601
+ self.current_tool_name_sent = False
602
+ return result
603
+
604
+ return StreamingParseResult(normal_text="", calls=calls)
605
+
606
+ except Exception as e:
607
+ logger.error(f"Error in parse_streaming_increment: {e}")
608
+ return StreamingParseResult(normal_text=current_text)
609
+
538
610
 
539
611
  class MultiFormatParser:
540
612
  def __init__(self, detectors: List[BaseFormatDetector]):
@@ -62,6 +62,7 @@ class AttentionBackend(ABC):
62
62
  layer: RadixAttention,
63
63
  forward_batch: ForwardBatch,
64
64
  save_kv_cache: bool = True,
65
+ **kwargs,
65
66
  ):
66
67
  """Run forward on an attention layer."""
67
68
  if forward_batch.forward_mode.is_decode():
@@ -72,6 +73,7 @@ class AttentionBackend(ABC):
72
73
  layer,
73
74
  forward_batch,
74
75
  save_kv_cache=save_kv_cache,
76
+ **kwargs,
75
77
  )
76
78
  else:
77
79
  return self.forward_extend(
@@ -81,6 +83,7 @@ class AttentionBackend(ABC):
81
83
  layer,
82
84
  forward_batch,
83
85
  save_kv_cache=save_kv_cache,
86
+ **kwargs,
84
87
  )
85
88
 
86
89
  def forward_decode(
@@ -0,0 +1,278 @@
1
+ from __future__ import annotations
2
+
3
+ """
4
+ Support attention backend for Cutlass MLA.
5
+
6
+ """
7
+
8
+ from dataclasses import dataclass
9
+ from typing import TYPE_CHECKING, Optional, Union
10
+
11
+ import torch
12
+ import triton
13
+
14
+ from sglang.global_config import global_config
15
+ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
16
+ from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
17
+ from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
18
+ from sglang.srt.layers.dp_attention import get_attention_tp_size
19
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
20
+ from sglang.srt.utils import is_cuda
21
+
22
+ if TYPE_CHECKING:
23
+ from sglang.srt.layers.radix_attention import RadixAttention
24
+ from sglang.srt.model_executor.model_runner import ModelRunner
25
+ from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
26
+ from sglang.srt.speculative.spec_info import SpecInfo
27
+
28
+ _is_cuda = is_cuda()
29
+ if _is_cuda:
30
+ from sgl_kernel import cutlass_mla_decode, cutlass_mla_get_workspace_size
31
+
32
+
33
+ # Cutlass MLA only supports pagesize=128
34
+ PAGE_SIZE = 128
35
+
36
+
37
+ @dataclass
38
+ class CutlassMLADecodeMetadata:
39
+ workspace: Optional[torch.Tensor] = None
40
+ block_kv_indices: Optional[torch.Tensor] = None
41
+
42
+ def __init__(
43
+ self,
44
+ workspace: Optional[torch.Tensor] = None,
45
+ block_kv_indices: Optional[torch.Tensor] = None,
46
+ ):
47
+ self.workspace = workspace
48
+ self.block_kv_indices = block_kv_indices
49
+
50
+
51
+ class CutlassMLABackend(FlashInferMLAAttnBackend):
52
+ """Cutlass attention kernels."""
53
+
54
+ def __init__(
55
+ self,
56
+ model_runner: ModelRunner,
57
+ skip_prefill: bool = False,
58
+ kv_indptr_buf: Optional[torch.Tensor] = None,
59
+ kv_last_page_len_buf: Optional[torch.Tensor] = None,
60
+ ):
61
+ super().__init__(
62
+ model_runner, skip_prefill, kv_indptr_buf, kv_last_page_len_buf
63
+ )
64
+
65
+ self.num_q_heads = (
66
+ model_runner.model_config.num_attention_heads // get_attention_tp_size()
67
+ )
68
+ self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
69
+ get_attention_tp_size()
70
+ )
71
+ self.req_to_token = model_runner.req_to_token_pool.req_to_token
72
+ self.num_local_heads = (
73
+ model_runner.model_config.num_attention_heads // get_attention_tp_size()
74
+ )
75
+ self.forward_metadata: Union[CutlassMLADecodeMetadata] = None
76
+ self.kv_lora_rank = model_runner.model_config.kv_lora_rank
77
+ self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
78
+ self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
79
+ self.v_head_dim = model_runner.model_config.v_head_dim
80
+ self.scaling = model_runner.model_config.scaling
81
+ self.data_type = model_runner.kv_cache_dtype
82
+ self.q_data_type = model_runner.dtype
83
+ self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim
84
+
85
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
86
+
87
+ bs = forward_batch.batch_size
88
+ spec_info = forward_batch.spec_info
89
+ if forward_batch.forward_mode.is_decode_or_idle():
90
+ if spec_info is None:
91
+ max_seqlen_pad = triton.cdiv(
92
+ forward_batch.seq_lens_cpu.max().item(), PAGE_SIZE
93
+ )
94
+ block_kv_indices = torch.full(
95
+ (bs, max_seqlen_pad),
96
+ -1,
97
+ dtype=torch.int32,
98
+ device=forward_batch.seq_lens.device,
99
+ )
100
+ create_flashmla_kv_indices_triton[(bs,)](
101
+ self.req_to_token,
102
+ forward_batch.req_pool_indices,
103
+ forward_batch.seq_lens,
104
+ None,
105
+ block_kv_indices,
106
+ self.req_to_token.stride(0),
107
+ max_seqlen_pad,
108
+ PAGE_SIZE,
109
+ )
110
+ workspace_size = cutlass_mla_get_workspace_size(
111
+ max_seqlen_pad * PAGE_SIZE, bs
112
+ )
113
+ workspace = torch.empty(
114
+ workspace_size, device="cuda", dtype=torch.uint8
115
+ )
116
+ self.forward_metadata = CutlassMLADecodeMetadata(
117
+ workspace,
118
+ block_kv_indices,
119
+ )
120
+ else:
121
+ super().init_forward_metadata(forward_batch)
122
+ else:
123
+ super().init_forward_metadata(forward_batch)
124
+
125
+ def init_cuda_graph_state(
126
+ self,
127
+ max_bs: int,
128
+ block_kv_indices: Optional[torch.Tensor] = None,
129
+ ):
130
+ if block_kv_indices is None:
131
+ cuda_graph_kv_indices = torch.full(
132
+ (max_bs, (self.max_context_len + PAGE_SIZE) // PAGE_SIZE),
133
+ 1,
134
+ dtype=torch.int32,
135
+ device="cuda",
136
+ )
137
+ else:
138
+ cuda_graph_kv_indices = block_kv_indices
139
+
140
+ workspace_size = cutlass_mla_get_workspace_size(
141
+ cuda_graph_kv_indices.shape[1] * PAGE_SIZE, max_bs
142
+ )
143
+ self.cuda_graph_mla_workspace = torch.empty(
144
+ workspace_size, device="cuda", dtype=torch.uint8
145
+ )
146
+ self.cuda_graph_kv_indices = cuda_graph_kv_indices
147
+
148
+ def init_forward_metadata_capture_cuda_graph(
149
+ self,
150
+ bs: int,
151
+ num_tokens: int,
152
+ req_pool_indices: torch.Tensor,
153
+ seq_lens: torch.Tensor,
154
+ encoder_lens: Optional[torch.Tensor],
155
+ forward_mode: ForwardMode,
156
+ spec_info: Optional[SpecInfo],
157
+ ):
158
+ if forward_mode.is_decode_or_idle():
159
+ if spec_info is None:
160
+ max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
161
+
162
+ create_flashmla_kv_indices_triton[(bs,)](
163
+ self.req_to_token,
164
+ req_pool_indices,
165
+ seq_lens,
166
+ None,
167
+ self.cuda_graph_kv_indices,
168
+ self.req_to_token.stride(0),
169
+ self.cuda_graph_kv_indices.stride(0),
170
+ PAGE_SIZE,
171
+ )
172
+ workspace_size = cutlass_mla_get_workspace_size(
173
+ max_seqlen_pad * PAGE_SIZE, bs
174
+ )
175
+ self.cuda_graph_mla_workspace = torch.empty(
176
+ workspace_size, device="cuda", dtype=torch.uint8
177
+ )
178
+ self.forward_metadata = CutlassMLADecodeMetadata(
179
+ self.cuda_graph_mla_workspace,
180
+ self.cuda_graph_kv_indices[:bs, :max_seqlen_pad],
181
+ )
182
+ else:
183
+ super().init_forward_metadata_capture_cuda_graph(
184
+ bs,
185
+ num_tokens,
186
+ req_pool_indices,
187
+ seq_lens,
188
+ encoder_lens,
189
+ forward_mode,
190
+ spec_info,
191
+ )
192
+
193
+ def init_forward_metadata_replay_cuda_graph(
194
+ self,
195
+ bs: int,
196
+ req_pool_indices: torch.Tensor,
197
+ seq_lens: torch.Tensor,
198
+ seq_lens_sum: int,
199
+ encoder_lens: Optional[torch.Tensor],
200
+ forward_mode: ForwardMode,
201
+ spec_info: Optional[SpecInfo],
202
+ seq_lens_cpu: Optional[torch.Tensor],
203
+ ):
204
+
205
+ if forward_mode.is_decode_or_idle():
206
+ assert seq_lens_cpu is not None
207
+ seq_lens = seq_lens[:bs]
208
+ seq_lens_cpu = seq_lens_cpu[:bs]
209
+ max_seqlen_pad = triton.cdiv(seq_lens_cpu.max().item(), PAGE_SIZE)
210
+ create_flashmla_kv_indices_triton[(bs,)](
211
+ self.req_to_token,
212
+ req_pool_indices[:bs],
213
+ seq_lens,
214
+ None,
215
+ self.cuda_graph_kv_indices,
216
+ self.req_to_token.stride(0),
217
+ self.cuda_graph_kv_indices.stride(0),
218
+ PAGE_SIZE,
219
+ )
220
+ workspace_size = cutlass_mla_get_workspace_size(
221
+ max_seqlen_pad * PAGE_SIZE, bs
222
+ )
223
+ self.cuda_graph_mla_workspace = torch.empty(
224
+ workspace_size, device="cuda", dtype=torch.uint8
225
+ )
226
+ self.forward_metadata.workspace = self.cuda_graph_mla_workspace
227
+ self.forward_metadata.block_kv_indices = self.cuda_graph_kv_indices[
228
+ :bs, :max_seqlen_pad
229
+ ]
230
+ else:
231
+ super().init_forward_metadata_replay_cuda_graph(
232
+ bs,
233
+ req_pool_indices,
234
+ seq_lens,
235
+ seq_lens_sum,
236
+ encoder_lens,
237
+ forward_mode,
238
+ spec_info,
239
+ seq_lens_cpu,
240
+ )
241
+
242
+ def get_cuda_graph_seq_len_fill_value(self):
243
+ return 1
244
+
245
+ def forward_decode(
246
+ self,
247
+ q: torch.Tensor,
248
+ k: torch.Tensor,
249
+ v: torch.Tensor,
250
+ layer: RadixAttention,
251
+ forward_batch: ForwardBatch,
252
+ save_kv_cache: bool = True,
253
+ ):
254
+ cache_loc = forward_batch.out_cache_loc
255
+
256
+ if k is not None:
257
+ assert v is not None
258
+ if save_kv_cache:
259
+ forward_batch.token_to_kv_pool.set_kv_buffer(
260
+ layer,
261
+ cache_loc,
262
+ k,
263
+ v,
264
+ )
265
+ bs = forward_batch.batch_size
266
+ k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
267
+
268
+ reshape_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
269
+
270
+ o = cutlass_mla_decode(
271
+ q_nope_and_q_pe=reshape_q,
272
+ kv_c_and_k_pe_cache=k_cache.view(-1, PAGE_SIZE, self.kv_cache_dim),
273
+ seq_lens=forward_batch.seq_lens.to(torch.int32),
274
+ page_table=self.forward_metadata.block_kv_indices,
275
+ workspace=self.forward_metadata.workspace,
276
+ )
277
+
278
+ return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
@@ -623,6 +623,9 @@ class FlashAttentionBackend(AttentionBackend):
623
623
  layer: RadixAttention,
624
624
  forward_batch: ForwardBatch,
625
625
  save_kv_cache=True,
626
+ # For multi-head latent attention
627
+ q_rope: Optional[torch.Tensor] = None,
628
+ k_rope: Optional[torch.Tensor] = None,
626
629
  ):
627
630
  if k is not None:
628
631
  assert v is not None
@@ -637,11 +640,11 @@ class FlashAttentionBackend(AttentionBackend):
637
640
  layer, cache_loc, k, v, layer.k_scale, layer.v_scale
638
641
  )
639
642
  else:
640
- forward_batch.token_to_kv_pool.set_kv_buffer(
643
+ forward_batch.token_to_kv_pool.set_mla_kv_buffer(
641
644
  layer,
642
645
  cache_loc,
643
646
  k,
644
- v,
647
+ k_rope,
645
648
  )
646
649
 
647
650
  # Use precomputed metadata across all layers
@@ -815,9 +818,15 @@ class FlashAttentionBackend(AttentionBackend):
815
818
  c_kv_cache = c_kv.view(
816
819
  -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
817
820
  )
818
- q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
819
- q_nope = q_all[:, :, : layer.v_head_dim]
820
- q_rope = q_all[:, :, layer.v_head_dim :]
821
+ if q_rope is not None:
822
+ q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
823
+ q_rope = q_rope.view(
824
+ -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
825
+ )
826
+ else:
827
+ q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
828
+ q_nope = q_all[:, :, : layer.v_head_dim]
829
+ q_rope = q_all[:, :, layer.v_head_dim :]
821
830
 
822
831
  result = flash_attn_with_kvcache(
823
832
  q=q_rope,
@@ -877,6 +886,9 @@ class FlashAttentionBackend(AttentionBackend):
877
886
  layer: RadixAttention,
878
887
  forward_batch: ForwardBatch,
879
888
  save_kv_cache=True,
889
+ # For multi-head latent attention
890
+ q_rope: Optional[torch.Tensor] = None,
891
+ k_rope: Optional[torch.Tensor] = None,
880
892
  ) -> torch.Tensor:
881
893
  if k is not None:
882
894
  assert v is not None
@@ -891,11 +903,11 @@ class FlashAttentionBackend(AttentionBackend):
891
903
  layer, cache_loc, k, v, layer.k_scale, layer.v_scale
892
904
  )
893
905
  else:
894
- forward_batch.token_to_kv_pool.set_kv_buffer(
906
+ forward_batch.token_to_kv_pool.set_mla_kv_buffer(
895
907
  layer,
896
908
  cache_loc,
897
909
  k,
898
- v,
910
+ k_rope,
899
911
  )
900
912
 
901
913
  # Use precomputed metadata across all layers
@@ -1047,9 +1059,15 @@ class FlashAttentionBackend(AttentionBackend):
1047
1059
  -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
1048
1060
  )
1049
1061
 
1050
- q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
1051
- q_nope = q_all[:, :, : layer.v_head_dim]
1052
- q_rope = q_all[:, :, layer.v_head_dim :]
1062
+ if q_rope is not None:
1063
+ q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
1064
+ q_rope = q_rope.view(
1065
+ -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
1066
+ )
1067
+ else:
1068
+ q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
1069
+ q_nope = q_all[:, :, : layer.v_head_dim]
1070
+ q_rope = q_all[:, :, layer.v_head_dim :]
1053
1071
  max_seqlen_q = metadata.max_seq_len_q
1054
1072
 
1055
1073
  result = flash_attn_with_kvcache(
@@ -68,9 +68,6 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
68
68
  self.num_q_heads = (
69
69
  model_runner.model_config.num_attention_heads // get_attention_tp_size()
70
70
  )
71
- self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
72
- get_attention_tp_size()
73
- )
74
71
  self.req_to_token = model_runner.req_to_token_pool.req_to_token
75
72
  self.num_local_heads = (
76
73
  model_runner.model_config.num_attention_heads // get_attention_tp_size()
@@ -111,8 +108,8 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
111
108
  )
112
109
  mla_metadata, num_splits = get_mla_metadata(
113
110
  forward_batch.seq_lens.to(torch.int32),
114
- Q_LEN * self.num_q_heads // self.num_kv_heads,
115
- self.num_kv_heads,
111
+ Q_LEN * self.num_q_heads,
112
+ 1,
116
113
  )
117
114
  self.forward_metadata = FlashMLADecodeMetadata(
118
115
  mla_metadata,
@@ -141,8 +138,8 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
141
138
 
142
139
  self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = get_mla_metadata(
143
140
  torch.ones(max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device),
144
- Q_LEN * self.num_q_heads // self.num_kv_heads,
145
- self.num_kv_heads,
141
+ Q_LEN * self.num_q_heads,
142
+ 1,
146
143
  )
147
144
  self.cuda_graph_kv_indices = cuda_graph_kv_indices
148
145
 
@@ -171,8 +168,8 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
171
168
  )
172
169
  mla_metadata, num_splits = get_mla_metadata(
173
170
  seq_lens.to(torch.int32),
174
- Q_LEN * self.num_q_heads // self.num_kv_heads,
175
- self.num_kv_heads,
171
+ Q_LEN * self.num_q_heads,
172
+ 1,
176
173
  )
177
174
  self.cuda_graph_mla_metadata.copy_(mla_metadata)
178
175
  self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
@@ -221,8 +218,8 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
221
218
  )
222
219
  mla_metadata, num_splits = get_mla_metadata(
223
220
  seq_lens.to(torch.int32),
224
- Q_LEN * self.num_q_heads // self.num_kv_heads,
225
- self.num_kv_heads,
221
+ Q_LEN * self.num_q_heads,
222
+ 1,
226
223
  )
227
224
  self.cuda_graph_mla_metadata.copy_(mla_metadata)
228
225
  self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
@@ -49,8 +49,8 @@ def create_flashmla_kv_indices_triton(
49
49
  kv_indices_ptr,
50
50
  req_to_token_ptr_stride: tl.constexpr,
51
51
  kv_indices_ptr_stride: tl.constexpr,
52
+ PAGED_SIZE: tl.constexpr = 64,
52
53
  ):
53
- PAGED_SIZE: tl.constexpr = 64
54
54
  BLOCK_SIZE: tl.constexpr = 4096
55
55
  NUM_PAGE_PER_BLOCK: tl.constexpr = 64
56
56
  pid = tl.program_id(axis=0)
@@ -271,6 +271,8 @@ class VisionSdpaAttention(nn.Module):
271
271
  Returns:
272
272
  [b * s, h, head_size]
273
273
  """
274
+ if self.flatten_batch:
275
+ assert bsz == 1, "flatten_batch is True, bsz must be 1"
274
276
 
275
277
  s = q.shape[0] // bsz
276
278