sglang 0.5.4__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. sglang/bench_serving.py +56 -12
  2. sglang/launch_server.py +2 -0
  3. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +101 -4
  4. sglang/srt/compilation/backend.py +1 -1
  5. sglang/srt/configs/model_config.py +5 -5
  6. sglang/srt/distributed/parallel_state.py +0 -7
  7. sglang/srt/entrypoints/engine.py +18 -15
  8. sglang/srt/entrypoints/grpc_server.py +0 -1
  9. sglang/srt/entrypoints/http_server.py +75 -94
  10. sglang/srt/environ.py +16 -2
  11. sglang/srt/eplb/expert_distribution.py +30 -0
  12. sglang/srt/function_call/function_call_parser.py +2 -0
  13. sglang/srt/function_call/minimax_m2.py +367 -0
  14. sglang/srt/layers/activation.py +6 -0
  15. sglang/srt/layers/attention/flashattention_backend.py +12 -2
  16. sglang/srt/layers/attention/flashinfer_backend.py +10 -1
  17. sglang/srt/layers/attention/flashinfer_mla_backend.py +18 -10
  18. sglang/srt/layers/attention/trtllm_mla_backend.py +1 -13
  19. sglang/srt/layers/attention/utils.py +78 -0
  20. sglang/srt/layers/communicator.py +1 -0
  21. sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
  22. sglang/srt/layers/layernorm.py +19 -4
  23. sglang/srt/layers/logits_processor.py +5 -0
  24. sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
  25. sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
  26. sglang/srt/layers/moe/ep_moe/layer.py +79 -272
  27. sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
  28. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
  29. sglang/srt/layers/moe/moe_runner/deep_gemm.py +287 -22
  30. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  31. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  32. sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
  33. sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
  34. sglang/srt/layers/moe/token_dispatcher/deepep.py +18 -14
  35. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  36. sglang/srt/layers/moe/topk.py +4 -4
  37. sglang/srt/layers/moe/utils.py +3 -4
  38. sglang/srt/layers/quantization/__init__.py +3 -5
  39. sglang/srt/layers/quantization/awq.py +0 -3
  40. sglang/srt/layers/quantization/base_config.py +7 -0
  41. sglang/srt/layers/quantization/fp8.py +68 -63
  42. sglang/srt/layers/quantization/gguf.py +566 -0
  43. sglang/srt/layers/quantization/mxfp4.py +30 -38
  44. sglang/srt/layers/quantization/unquant.py +23 -45
  45. sglang/srt/layers/quantization/w4afp8.py +38 -2
  46. sglang/srt/layers/radix_attention.py +5 -2
  47. sglang/srt/layers/rotary_embedding.py +13 -1
  48. sglang/srt/layers/sampler.py +12 -1
  49. sglang/srt/managers/io_struct.py +3 -0
  50. sglang/srt/managers/multi_tokenizer_mixin.py +17 -1
  51. sglang/srt/managers/scheduler.py +21 -15
  52. sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
  53. sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
  54. sglang/srt/managers/tokenizer_manager.py +11 -19
  55. sglang/srt/mem_cache/hicache_storage.py +7 -1
  56. sglang/srt/mem_cache/memory_pool.py +82 -0
  57. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  58. sglang/srt/model_executor/forward_batch_info.py +44 -3
  59. sglang/srt/model_executor/model_runner.py +1 -149
  60. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
  61. sglang/srt/models/deepseek_v2.py +147 -44
  62. sglang/srt/models/glm4_moe.py +322 -354
  63. sglang/srt/models/glm4_moe_nextn.py +4 -14
  64. sglang/srt/models/glm4v_moe.py +29 -196
  65. sglang/srt/models/minimax_m2.py +922 -0
  66. sglang/srt/models/nvila.py +355 -0
  67. sglang/srt/models/nvila_lite.py +184 -0
  68. sglang/srt/models/qwen2.py +22 -1
  69. sglang/srt/models/qwen3.py +34 -4
  70. sglang/srt/models/qwen3_moe.py +2 -4
  71. sglang/srt/multimodal/processors/base_processor.py +1 -0
  72. sglang/srt/multimodal/processors/glm4v.py +1 -1
  73. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  74. sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
  75. sglang/srt/parser/reasoning_parser.py +28 -1
  76. sglang/srt/server_args.py +365 -186
  77. sglang/srt/single_batch_overlap.py +2 -7
  78. sglang/srt/utils/common.py +87 -42
  79. sglang/srt/utils/hf_transformers_utils.py +7 -3
  80. sglang/test/test_deterministic.py +235 -12
  81. sglang/test/test_deterministic_utils.py +2 -1
  82. sglang/version.py +1 -1
  83. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +7 -6
  84. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +87 -82
  85. sglang/srt/models/vila.py +0 -306
  86. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  87. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  88. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,367 @@
1
+ import ast
2
+ import html
3
+ import json
4
+ import logging
5
+ import re
6
+ from typing import Any, Dict, List, Tuple
7
+
8
+ from sglang.srt.entrypoints.openai.protocol import Tool
9
+ from sglang.srt.function_call.base_format_detector import BaseFormatDetector
10
+ from sglang.srt.function_call.core_types import (
11
+ StreamingParseResult,
12
+ ToolCallItem,
13
+ _GetInfoFunc,
14
+ )
15
+ from sglang.srt.function_call.ebnf_composer import EBNFComposer
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def _safe_val(raw: str) -> Any:
21
+ raw = html.unescape(raw.strip())
22
+ try:
23
+ return json.loads(raw)
24
+ except Exception:
25
+ try:
26
+ return ast.literal_eval(raw)
27
+ except Exception:
28
+ return raw
29
+
30
+
31
+ class MinimaxM2Detector(BaseFormatDetector):
32
+ """
33
+ Detector for MiniMax M2 models.
34
+ Assumes function call format:
35
+ <minimax:tool_call>
36
+ <invoke name="func1">
37
+ <parameter name="param1">value1</parameter>
38
+ <parameter name="param2">value2</parameter>
39
+ </invoke>
40
+ </minimax:tool_call>
41
+ """
42
+
43
+ def __init__(self):
44
+ super().__init__()
45
+ self.tool_call_start_token: str = "<minimax:tool_call>"
46
+ self.tool_call_end_token: str = "</minimax:tool_call>"
47
+ self.tool_call_prefix: str = '<invoke name="'
48
+ self.tool_call_function_end_token: str = "</invoke>"
49
+ self.tool_call_regex = re.compile(
50
+ r"<minimax:tool_call>(.*?)</minimax:tool_call>|<minimax:tool_call>(.*?)$",
51
+ re.DOTALL,
52
+ )
53
+ self.tool_call_function_regex = re.compile(
54
+ r"<invoke name=\"(.*?)</invoke>|<invoke name=\"(.*)$", re.DOTALL
55
+ )
56
+ self.tool_call_parameter_regex = re.compile(
57
+ r"<parameter name=\"(.*?)</parameter>|<parameter name=\"(.*?)$", re.DOTALL
58
+ )
59
+ self._buf: str = ""
60
+
61
+ # Streaming state variables
62
+ self._current_function_name: str = ""
63
+ self._current_parameters: Dict[str, Any] = {}
64
+ self._streamed_parameters: Dict[str, str] = (
65
+ {}
66
+ ) # Track what parameter content we've streamed
67
+ self._in_tool_call: bool = False
68
+ self._function_name_sent: bool = False
69
+
70
+ def has_tool_call(self, text: str) -> bool:
71
+ return self.tool_call_start_token in text
72
+
73
+ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
74
+ normal, calls = self._extract(text, tools)
75
+ return StreamingParseResult(normal_text=normal, calls=calls)
76
+
77
+ def parse_streaming_increment(
78
+ self, new_text: str, tools: List[Tool]
79
+ ) -> StreamingParseResult:
80
+ self._buf += new_text
81
+ normal = ""
82
+ calls: List[ToolCallItem] = []
83
+
84
+ # Build tool indices for validation
85
+ if not hasattr(self, "_tool_indices"):
86
+ self._tool_indices = self._get_tool_indices(tools)
87
+
88
+ while True:
89
+ # If we're not in a tool call and don't see a start token, return normal text
90
+ if not self._in_tool_call and self.tool_call_start_token not in self._buf:
91
+ normal += self._buf
92
+ self._buf = ""
93
+ break
94
+
95
+ # Look for tool call start
96
+ if not self._in_tool_call:
97
+ s = self._buf.find(self.tool_call_start_token)
98
+ if s == -1:
99
+ normal += self._buf
100
+ self._buf = ""
101
+ break
102
+
103
+ normal += self._buf[:s]
104
+ self._buf = self._buf[s:]
105
+
106
+ self._in_tool_call = True
107
+ self._function_name_sent = False
108
+ self._current_function_name = ""
109
+ self._current_parameters = {}
110
+ self._streamed_parameters = {}
111
+
112
+ # Remove the start token
113
+ self._buf = self._buf[len(self.tool_call_start_token) :]
114
+ continue
115
+
116
+ # We're in a tool call, try to parse function name if not sent yet
117
+ if not self._function_name_sent:
118
+ # Look for function name pattern: <invoke name=name>
119
+ function_match = re.search(r"<invoke name=\"([^>]+)\">", self._buf)
120
+ if function_match:
121
+ function_name = function_match.group(1).strip()
122
+
123
+ # Validate function name
124
+ if function_name in self._tool_indices:
125
+ self._current_function_name = function_name
126
+ self._function_name_sent = True
127
+
128
+ # Initialize tool call tracking
129
+ if self.current_tool_id == -1:
130
+ self.current_tool_id = 0
131
+
132
+ # Ensure tracking arrays are large enough
133
+ while len(self.prev_tool_call_arr) <= self.current_tool_id:
134
+ self.prev_tool_call_arr.append({})
135
+ while len(self.streamed_args_for_tool) <= self.current_tool_id:
136
+ self.streamed_args_for_tool.append("")
137
+
138
+ # Store tool call info
139
+ self.prev_tool_call_arr[self.current_tool_id] = {
140
+ "name": function_name,
141
+ "arguments": {},
142
+ }
143
+
144
+ # Send tool name with empty parameters
145
+ calls.append(
146
+ ToolCallItem(
147
+ tool_index=self.current_tool_id,
148
+ name=function_name,
149
+ parameters="",
150
+ )
151
+ )
152
+
153
+ # Remove the processed function declaration
154
+ self._buf = self._buf[function_match.end() :]
155
+ continue
156
+ else:
157
+ # Invalid function name, reset state
158
+ logger.warning(f"Invalid function name: {function_name}")
159
+ self._reset_streaming_state()
160
+ normal += self._buf
161
+ self._buf = ""
162
+ break
163
+ else:
164
+ # Function name not complete yet, wait for more text
165
+ break
166
+
167
+ # Parse parameters incrementally
168
+ if self._function_name_sent:
169
+ # Process parameters and get any calls to emit
170
+ parameter_calls = self._parse_and_stream_parameters(self._buf)
171
+ calls.extend(parameter_calls)
172
+
173
+ # Check if tool call is complete
174
+ if self.tool_call_function_end_token in self._buf:
175
+ end_pos = self._buf.find(self.tool_call_function_end_token)
176
+
177
+ # Add closing brace to complete the JSON object
178
+ current_streamed = self.streamed_args_for_tool[self.current_tool_id]
179
+ if current_streamed:
180
+ # Count opening and closing braces to check if JSON is complete
181
+ open_braces = current_streamed.count("{")
182
+ close_braces = current_streamed.count("}")
183
+ if open_braces > close_braces:
184
+ calls.append(
185
+ ToolCallItem(
186
+ tool_index=self.current_tool_id,
187
+ name=None,
188
+ parameters="}",
189
+ )
190
+ )
191
+ self.streamed_args_for_tool[self.current_tool_id] = (
192
+ current_streamed + "}"
193
+ )
194
+
195
+ # Complete the tool call
196
+ self._buf = self._buf[
197
+ end_pos + len(self.tool_call_function_end_token) :
198
+ ]
199
+ self._reset_streaming_state(True)
200
+ self.current_tool_id += 1
201
+ continue
202
+ else:
203
+ # Tool call not complete yet, wait for more text
204
+ break
205
+
206
+ return StreamingParseResult(normal_text=normal, calls=calls)
207
+
208
+ def _parse_and_stream_parameters(self, text_to_parse: str) -> List[ToolCallItem]:
209
+ """
210
+ Parse complete parameter blocks from text and return any tool call items to emit.
211
+
212
+ This method:
213
+ 1. Finds all complete <parameter> blocks
214
+ 2. Parses them into a dictionary
215
+ 3. Compares with current parameters and generates diff if needed
216
+ 4. Updates internal state
217
+
218
+ Args:
219
+ text_to_parse: The text to search for parameter blocks
220
+
221
+ Returns:
222
+ List of ToolCallItem objects to emit (may be empty)
223
+ """
224
+ calls: List[ToolCallItem] = []
225
+
226
+ # Find all complete parameter patterns
227
+ param_matches = list(
228
+ re.finditer(
229
+ r"<parameter name=\"([^>]+)\">(.*?)</parameter>",
230
+ text_to_parse,
231
+ re.DOTALL,
232
+ )
233
+ )
234
+
235
+ # Build new parameters dictionary
236
+ new_params = {}
237
+ for match in param_matches:
238
+ param_name = match.group(1).strip()
239
+ param_value = match.group(2)
240
+ new_params[param_name] = _safe_val(param_value)
241
+
242
+ # Calculate parameter diff to stream with proper incremental JSON building
243
+ if new_params != self._current_parameters:
244
+ previous_args_json = self.streamed_args_for_tool[self.current_tool_id]
245
+
246
+ # Build incremental JSON properly
247
+ if not self._current_parameters:
248
+ # First parameter(s) - start JSON object but don't close it yet
249
+ items = []
250
+ for key, value in new_params.items():
251
+ items.append(
252
+ f"{json.dumps(key, ensure_ascii=False)}: {json.dumps(value, ensure_ascii=False)}"
253
+ )
254
+ json_fragment = "{" + ", ".join(items)
255
+
256
+ calls.append(
257
+ ToolCallItem(
258
+ tool_index=self.current_tool_id,
259
+ name=None,
260
+ parameters=json_fragment,
261
+ )
262
+ )
263
+ self.streamed_args_for_tool[self.current_tool_id] = json_fragment
264
+
265
+ else:
266
+ # Additional parameters - add them incrementally
267
+ new_keys = set(new_params.keys()) - set(self._current_parameters.keys())
268
+ if new_keys:
269
+ # Build the continuation part (no closing brace yet)
270
+ continuation_parts = []
271
+ for key in new_keys:
272
+ value = new_params[key]
273
+ continuation_parts.append(
274
+ f"{json.dumps(key, ensure_ascii=False)}: {json.dumps(value, ensure_ascii=False)}"
275
+ )
276
+
277
+ json_fragment = ", " + ", ".join(continuation_parts)
278
+
279
+ calls.append(
280
+ ToolCallItem(
281
+ tool_index=self.current_tool_id,
282
+ name=None,
283
+ parameters=json_fragment,
284
+ )
285
+ )
286
+ self.streamed_args_for_tool[self.current_tool_id] = (
287
+ previous_args_json + json_fragment
288
+ )
289
+
290
+ # Update current state
291
+ self._current_parameters = new_params
292
+ self.prev_tool_call_arr[self.current_tool_id]["arguments"] = new_params
293
+
294
+ return calls
295
+
296
+ def _reset_streaming_state(self, still_in_tool_call: bool = False):
297
+ """Reset streaming state for the next tool call"""
298
+ self._in_tool_call = still_in_tool_call
299
+ self._function_name_sent = False
300
+ self._current_function_name = ""
301
+ self._current_parameters = {}
302
+ self._streamed_parameters = {}
303
+ self.current_tool_name_sent = False
304
+
305
+ def _extract(self, text: str, tools: List[Tool]) -> Tuple[str, List[ToolCallItem]]:
306
+ normal_parts: List[str] = []
307
+ calls: List[ToolCallItem] = []
308
+ cursor = 0
309
+ while True:
310
+ s = text.find(self.tool_call_start_token, cursor)
311
+ if s == -1:
312
+ normal_parts.append(text[cursor:])
313
+ break
314
+ normal_parts.append(text[cursor:s])
315
+ e = text.find(self.tool_call_end_token, s)
316
+ if e == -1:
317
+ normal_parts.append(text[s:])
318
+ break
319
+ block = text[s : e + len(self.tool_call_end_token)]
320
+ cursor = e + len(self.tool_call_end_token)
321
+ calls.extend(self._parse_block(block, tools))
322
+ return "".join(normal_parts), calls
323
+
324
+ def _parse_block(self, block: str, tools: List[Tool]) -> List[ToolCallItem]:
325
+ res: List[ToolCallItem] = []
326
+ for m in self.tool_call_function_regex.findall(block):
327
+ txt = m[0] if m[0] else m[1]
328
+ if '">' not in txt:
329
+ continue
330
+ idx = txt.index('">')
331
+ fname = txt[:idx].strip()
332
+ body = txt[idx + 2 :]
333
+ params: Dict[str, Any] = {}
334
+ for pm in self.tool_call_parameter_regex.findall(body):
335
+ ptxt = pm[0] if pm[0] else pm[1]
336
+ if '">' not in ptxt:
337
+ continue
338
+ pidx = ptxt.index('">')
339
+ pname = ptxt[:pidx].strip()
340
+ pval = ptxt[pidx + 2 :].lstrip("\n").rstrip("\n")
341
+ params[pname] = _safe_val(pval)
342
+ raw = {"name": fname, "arguments": params}
343
+ try:
344
+ # TODO: fix idx in function call, the index for a function
345
+ # call will always be -1 in parse_base_json
346
+ res.extend(self.parse_base_json(raw, tools))
347
+ except Exception:
348
+ logger.warning("invalid tool call for %s dropped", fname)
349
+ return res
350
+
351
+ def supports_structural_tag(self) -> bool:
352
+ return False
353
+
354
+ def structure_info(self) -> _GetInfoFunc:
355
+ raise NotImplementedError
356
+
357
+ def build_ebnf(self, tools: List[Tool]):
358
+ return EBNFComposer.build_ebnf(
359
+ tools,
360
+ individual_call_start_token=self.tool_call_start_token.replace("\n", "\\n"),
361
+ individual_call_end_token=self.tool_call_end_token.replace("\n", "\\n"),
362
+ tool_call_separator="\\n",
363
+ function_format="xml",
364
+ call_rule_fmt='"<invoke name=\\"{name}\\">\\n" {arguments_rule} "\\n</invoke>"',
365
+ key_value_rule_fmt='"<parameter name=\\"{key}\\">\\n" {valrule} "\\n</parameter>"',
366
+ key_value_separator='"\\n"',
367
+ )
@@ -29,6 +29,7 @@ from sglang.srt.distributed import (
29
29
  get_tensor_model_parallel_world_size,
30
30
  )
31
31
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
32
+ from sglang.srt.server_args import get_global_server_args
32
33
  from sglang.srt.utils import (
33
34
  cpu_has_amx_support,
34
35
  is_cpu,
@@ -59,6 +60,11 @@ logger = logging.getLogger(__name__)
59
60
 
60
61
 
61
62
  class SiluAndMul(CustomOp):
63
+ def __init__(self, *args, **kwargs):
64
+ super().__init__(*args, **kwargs)
65
+ if get_global_server_args().rl_on_policy_target == "fsdp":
66
+ self._forward_method = self.forward_native
67
+
62
68
  def forward_native(self, x: torch.Tensor) -> torch.Tensor:
63
69
  d = x.shape[-1] // 2
64
70
  return F.silu(x[..., :d]) * x[..., d:]
@@ -855,14 +855,24 @@ class FlashAttentionBackend(AttentionBackend):
855
855
  )
856
856
  else:
857
857
  # MHA for extend part of sequence without attending prefix kv cache
858
+ cu_seqlens_k = (
859
+ metadata.cu_seqlens_q
860
+ if not forward_batch.mha_one_shot
861
+ else metadata.cu_seqlens_k
862
+ )
863
+ max_seqlen_k = (
864
+ metadata.max_seq_len_q
865
+ if not forward_batch.mha_one_shot
866
+ else metadata.max_seq_len_k
867
+ )
858
868
  output = flash_attn_varlen_func(
859
869
  q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
860
870
  k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
861
871
  v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
862
872
  cu_seqlens_q=metadata.cu_seqlens_q,
863
- cu_seqlens_k=metadata.cu_seqlens_q,
873
+ cu_seqlens_k=cu_seqlens_k,
864
874
  max_seqlen_q=metadata.max_seq_len_q,
865
- max_seqlen_k=metadata.max_seq_len_q,
875
+ max_seqlen_k=max_seqlen_k,
866
876
  softmax_scale=layer.scaling,
867
877
  causal=True,
868
878
  return_softmax_lse=forward_batch.mha_return_lse,
@@ -230,7 +230,16 @@ class FlashInferAttnBackend(AttentionBackend):
230
230
 
231
231
  fmha_backend = "auto"
232
232
  if is_sm100_supported():
233
- fmha_backend = "cutlass"
233
+ # Disable CUTLASS backend when piecewise cuda graph is enabled
234
+ # due to TMA descriptor initialization issues on B200
235
+ if model_runner.server_args.enable_piecewise_cuda_graph:
236
+ logger.warning(
237
+ "CUTLASS backend is disabled when piecewise cuda graph is enabled "
238
+ "due to TMA descriptor initialization issues on B200. "
239
+ "Using auto backend instead for stability."
240
+ )
241
+ else:
242
+ fmha_backend = "cutlass"
234
243
  self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
235
244
  self.workspace_buffer, "NHD", backend=fmha_backend
236
245
  )
@@ -82,6 +82,7 @@ class FlashInferMhaChunkKVRunner:
82
82
 
83
83
  # Buffers and wrappers
84
84
  self.qo_indptr = attn_backend.qo_indptr
85
+ self.kv_indptr = attn_backend.kv_indptr
85
86
  self.workspace_buffer = attn_backend.workspace_buffer
86
87
  self.fmha_backend = attn_backend.fmha_backend
87
88
 
@@ -132,9 +133,14 @@ class FlashInferMhaChunkKVRunner:
132
133
  )
133
134
  # ragged prefill
134
135
  if not disable_flashinfer_ragged:
136
+ kv_indptr = (
137
+ qo_indptr
138
+ if not forward_batch.mha_one_shot
139
+ else self.kv_indptr[: bs + 1]
140
+ )
135
141
  self.ragged_wrapper.begin_forward(
136
142
  qo_indptr=qo_indptr,
137
- kv_indptr=qo_indptr,
143
+ kv_indptr=kv_indptr,
138
144
  num_qo_heads=self.num_local_heads,
139
145
  num_kv_heads=self.num_local_heads,
140
146
  head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
@@ -156,7 +162,7 @@ class FlashInferMhaChunkKVRunner:
156
162
  chunk_idx = forward_batch.prefix_chunk_idx
157
163
  assert chunk_idx >= 0
158
164
  wrapper = self.chunk_ragged_wrappers[chunk_idx]
159
- o1, s1 = wrapper.forward_return_lse(
165
+ o = wrapper.forward_return_lse(
160
166
  q.view(-1, layer.tp_q_head_num, layer.head_dim),
161
167
  k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
162
168
  v.view(-1, layer.tp_v_head_num, layer.v_head_dim).to(q.dtype),
@@ -165,7 +171,12 @@ class FlashInferMhaChunkKVRunner:
165
171
  logits_soft_cap=logits_soft_cap,
166
172
  )
167
173
  else:
168
- o1, s1 = self.ragged_wrapper.forward_return_lse(
174
+ forward = (
175
+ self.ragged_wrapper.forward_return_lse
176
+ if forward_batch.mha_return_lse
177
+ else self.ragged_wrapper.forward
178
+ )
179
+ o = forward(
169
180
  q.view(-1, layer.tp_q_head_num, layer.head_dim),
170
181
  k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
171
182
  v.view(-1, layer.tp_v_head_num, layer.v_head_dim).to(q.dtype),
@@ -173,8 +184,7 @@ class FlashInferMhaChunkKVRunner:
173
184
  sm_scale=layer.scaling,
174
185
  logits_soft_cap=logits_soft_cap,
175
186
  )
176
-
177
- return o1, s1
187
+ return o
178
188
 
179
189
 
180
190
  class FlashInferMLAAttnBackend(AttentionBackend):
@@ -512,15 +522,13 @@ class FlashInferMLAAttnBackend(AttentionBackend):
512
522
  q_rope: Optional[torch.Tensor] = None,
513
523
  k_rope: Optional[torch.Tensor] = None,
514
524
  ):
515
- if (
516
- forward_batch.attn_attend_prefix_cache is not None
517
- and forward_batch.mha_return_lse
525
+ if forward_batch.attn_attend_prefix_cache is not None and any(
526
+ forward_batch.extend_prefix_lens_cpu
518
527
  ): # MHA Chunk
519
528
  assert self.enable_chunk_kv
520
529
  assert q_rope is None
521
530
  assert k_rope is None
522
- o1, s1 = self.mha_chunk_kv_cache.forward(q, k, v, layer, forward_batch)
523
- return o1, s1
531
+ return self.mha_chunk_kv_cache.forward(q, k, v, layer, forward_batch)
524
532
 
525
533
  cache_loc = forward_batch.out_cache_loc
526
534
  logits_soft_cap = layer.logit_cap
@@ -423,14 +423,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
423
423
  PAGED_SIZE=self.page_size,
424
424
  )
425
425
 
426
- # Record the true maximum sequence length for this capture batch so that
427
- # the kernel launch path (which requires an int not a tensor) can reuse
428
- # it safely during both capture and replay.
429
- max_seq_len_val = int(seq_lens.max().item())
430
-
431
426
  metadata = TRTLLMMLADecodeMetadata(
432
427
  block_kv_indices,
433
- max_seq_len_val,
428
+ self.max_context_len,
434
429
  )
435
430
  if forward_mode.is_draft_extend(include_v2=True):
436
431
  num_tokens_per_bs = num_tokens // bs
@@ -509,13 +504,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
509
504
  PAGED_SIZE=self.page_size,
510
505
  )
511
506
 
512
- # Update stored max_seq_len so subsequent kernel calls use the correct value
513
- # Prefer CPU tensor to avoid GPU synchronization when available.
514
- if seq_lens_cpu is not None:
515
- metadata.max_seq_len = int(seq_lens_cpu.max().item())
516
- else:
517
- metadata.max_seq_len = int(seq_lens.max().item())
518
-
519
507
  def get_cuda_graph_seq_len_fill_value(self) -> int:
520
508
  """Get the fill value for sequence lengths in CUDA graph."""
521
509
  return 1
@@ -1,3 +1,4 @@
1
+ import torch
1
2
  import triton
2
3
  import triton.language as tl
3
4
 
@@ -101,3 +102,80 @@ def create_flashmla_kv_indices_triton(
101
102
  data // PAGED_SIZE,
102
103
  mask=mask_out,
103
104
  )
105
+
106
+
107
+ @triton.jit
108
+ def concat_and_cast_mha_k_kernel(
109
+ k_ptr,
110
+ k_nope_ptr,
111
+ k_rope_ptr,
112
+ head_cnt: tl.constexpr,
113
+ k_stride0: tl.constexpr,
114
+ k_stride1: tl.constexpr,
115
+ nope_stride0: tl.constexpr,
116
+ nope_stride1: tl.constexpr,
117
+ rope_stride0: tl.constexpr,
118
+ nope_dim: tl.constexpr,
119
+ rope_dim: tl.constexpr,
120
+ ):
121
+ pid_loc = tl.program_id(0)
122
+ head_range = tl.arange(0, head_cnt)
123
+
124
+ k_head_ptr = k_ptr + pid_loc * k_stride0 + head_range[:, None] * k_stride1
125
+
126
+ nope_offs = tl.arange(0, nope_dim)
127
+
128
+ src_nope_ptr = (
129
+ k_nope_ptr
130
+ + pid_loc * nope_stride0
131
+ + head_range[:, None] * nope_stride1
132
+ + nope_offs[None, :]
133
+ )
134
+ dst_nope_ptr = k_head_ptr + nope_offs[None, :]
135
+
136
+ src_nope = tl.load(src_nope_ptr)
137
+ tl.store(dst_nope_ptr, src_nope)
138
+
139
+ rope_offs = tl.arange(0, rope_dim)
140
+ src_rope_ptr = k_rope_ptr + pid_loc * rope_stride0 + rope_offs[None, :]
141
+ dst_rope_ptr = k_head_ptr + nope_dim + rope_offs[None, :]
142
+ src_rope = tl.load(src_rope_ptr)
143
+ tl.store(dst_rope_ptr, src_rope)
144
+
145
+
146
+ def concat_and_cast_mha_k_triton(
147
+ k: torch.Tensor,
148
+ k_nope: torch.Tensor,
149
+ k_rope: torch.Tensor,
150
+ ):
151
+ # The source data type will be implicitly converted to the target data type.
152
+ assert (
153
+ len(k.shape) == 3 and len(k_nope.shape) == 3 and len(k_rope.shape) == 3
154
+ ), f"shape should be 3d, but got {k.shape=}, {k_nope.shape=}, {k_rope.shape=}"
155
+ assert (
156
+ k.shape[0] == k_nope.shape[0] and k.shape[0] == k_rope.shape[0]
157
+ ), f"invalid shape, got {k.shape=}, {k_nope.shape=}, {k_rope.shape=}"
158
+ assert (
159
+ k.shape[1] == k_nope.shape[1] and 1 == k_rope.shape[1]
160
+ ), f"invalid shape, got {k.shape=}, {k_nope.shape=}, {k_rope.shape=}"
161
+ assert (
162
+ k.shape[-1] == k_nope.shape[-1] + k_rope.shape[-1]
163
+ ), f"invalid shape, got {k.shape=}, {k_nope.shape=}, {k_rope.shape=}"
164
+
165
+ nope_dim = k_nope.shape[-1]
166
+ rope_dim = k_rope.shape[-1]
167
+ grid = (k.shape[0],)
168
+
169
+ concat_and_cast_mha_k_kernel[grid](
170
+ k,
171
+ k_nope,
172
+ k_rope,
173
+ k.shape[1],
174
+ k.stride(0),
175
+ k.stride(1),
176
+ k_nope.stride(0),
177
+ k_nope.stride(1),
178
+ k_rope.stride(0),
179
+ nope_dim,
180
+ rope_dim,
181
+ )
@@ -337,6 +337,7 @@ class LayerCommunicator:
337
337
  static_conditions_met = (
338
338
  (not self.is_last_layer)
339
339
  and (self._context.tp_size > 1)
340
+ and not is_dp_attention_enabled()
340
341
  and get_global_server_args().enable_flashinfer_allreduce_fusion
341
342
  and _is_flashinfer_available
342
343
  )
@@ -26,7 +26,7 @@ _IN_PRECOMPILE_STAGE = get_bool_env_var("SGL_IN_DEEPGEMM_PRECOMPILE_STAGE", "fal
26
26
 
27
27
  # Force redirect deep_gemm cache_dir
28
28
  os.environ["DG_JIT_CACHE_DIR"] = os.getenv(
29
- "SGL_DG_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache", "deep_gemm")
29
+ "SGLANG_DG_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache", "deep_gemm")
30
30
  )
31
31
 
32
32
  # Refer to https://github.com/deepseek-ai/DeepGEMM/commit/d75b218b7b8f4a5dd5406ac87905039ead3ae42f