sglang 0.5.4__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +56 -12
- sglang/launch_server.py +2 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +101 -4
- sglang/srt/compilation/backend.py +1 -1
- sglang/srt/configs/model_config.py +5 -5
- sglang/srt/distributed/parallel_state.py +0 -7
- sglang/srt/entrypoints/engine.py +18 -15
- sglang/srt/entrypoints/grpc_server.py +0 -1
- sglang/srt/entrypoints/http_server.py +75 -94
- sglang/srt/environ.py +16 -2
- sglang/srt/eplb/expert_distribution.py +30 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/minimax_m2.py +367 -0
- sglang/srt/layers/activation.py +6 -0
- sglang/srt/layers/attention/flashattention_backend.py +12 -2
- sglang/srt/layers/attention/flashinfer_backend.py +10 -1
- sglang/srt/layers/attention/flashinfer_mla_backend.py +18 -10
- sglang/srt/layers/attention/trtllm_mla_backend.py +1 -13
- sglang/srt/layers/attention/utils.py +78 -0
- sglang/srt/layers/communicator.py +1 -0
- sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/layernorm.py +19 -4
- sglang/srt/layers/logits_processor.py +5 -0
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
- sglang/srt/layers/moe/ep_moe/layer.py +79 -272
- sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +287 -22
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
- sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
- sglang/srt/layers/moe/token_dispatcher/deepep.py +18 -14
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +4 -4
- sglang/srt/layers/moe/utils.py +3 -4
- sglang/srt/layers/quantization/__init__.py +3 -5
- sglang/srt/layers/quantization/awq.py +0 -3
- sglang/srt/layers/quantization/base_config.py +7 -0
- sglang/srt/layers/quantization/fp8.py +68 -63
- sglang/srt/layers/quantization/gguf.py +566 -0
- sglang/srt/layers/quantization/mxfp4.py +30 -38
- sglang/srt/layers/quantization/unquant.py +23 -45
- sglang/srt/layers/quantization/w4afp8.py +38 -2
- sglang/srt/layers/radix_attention.py +5 -2
- sglang/srt/layers/rotary_embedding.py +13 -1
- sglang/srt/layers/sampler.py +12 -1
- sglang/srt/managers/io_struct.py +3 -0
- sglang/srt/managers/multi_tokenizer_mixin.py +17 -1
- sglang/srt/managers/scheduler.py +21 -15
- sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
- sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
- sglang/srt/managers/tokenizer_manager.py +11 -19
- sglang/srt/mem_cache/hicache_storage.py +7 -1
- sglang/srt/mem_cache/memory_pool.py +82 -0
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
- sglang/srt/model_executor/forward_batch_info.py +44 -3
- sglang/srt/model_executor/model_runner.py +1 -149
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
- sglang/srt/models/deepseek_v2.py +147 -44
- sglang/srt/models/glm4_moe.py +322 -354
- sglang/srt/models/glm4_moe_nextn.py +4 -14
- sglang/srt/models/glm4v_moe.py +29 -196
- sglang/srt/models/minimax_m2.py +922 -0
- sglang/srt/models/nvila.py +355 -0
- sglang/srt/models/nvila_lite.py +184 -0
- sglang/srt/models/qwen2.py +22 -1
- sglang/srt/models/qwen3.py +34 -4
- sglang/srt/models/qwen3_moe.py +2 -4
- sglang/srt/multimodal/processors/base_processor.py +1 -0
- sglang/srt/multimodal/processors/glm4v.py +1 -1
- sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
- sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
- sglang/srt/parser/reasoning_parser.py +28 -1
- sglang/srt/server_args.py +365 -186
- sglang/srt/single_batch_overlap.py +2 -7
- sglang/srt/utils/common.py +87 -42
- sglang/srt/utils/hf_transformers_utils.py +7 -3
- sglang/test/test_deterministic.py +235 -12
- sglang/test/test_deterministic_utils.py +2 -1
- sglang/version.py +1 -1
- {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +7 -6
- {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +87 -82
- sglang/srt/models/vila.py +0 -306
- {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,367 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
import html
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
import re
|
|
6
|
+
from typing import Any, Dict, List, Tuple
|
|
7
|
+
|
|
8
|
+
from sglang.srt.entrypoints.openai.protocol import Tool
|
|
9
|
+
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
|
10
|
+
from sglang.srt.function_call.core_types import (
|
|
11
|
+
StreamingParseResult,
|
|
12
|
+
ToolCallItem,
|
|
13
|
+
_GetInfoFunc,
|
|
14
|
+
)
|
|
15
|
+
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _safe_val(raw: str) -> Any:
|
|
21
|
+
raw = html.unescape(raw.strip())
|
|
22
|
+
try:
|
|
23
|
+
return json.loads(raw)
|
|
24
|
+
except Exception:
|
|
25
|
+
try:
|
|
26
|
+
return ast.literal_eval(raw)
|
|
27
|
+
except Exception:
|
|
28
|
+
return raw
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class MinimaxM2Detector(BaseFormatDetector):
|
|
32
|
+
"""
|
|
33
|
+
Detector for MiniMax M2 models.
|
|
34
|
+
Assumes function call format:
|
|
35
|
+
<minimax:tool_call>
|
|
36
|
+
<invoke name="func1">
|
|
37
|
+
<parameter name="param1">value1</parameter>
|
|
38
|
+
<parameter name="param2">value2</parameter>
|
|
39
|
+
</invoke>
|
|
40
|
+
</minimax:tool_call>
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(self):
|
|
44
|
+
super().__init__()
|
|
45
|
+
self.tool_call_start_token: str = "<minimax:tool_call>"
|
|
46
|
+
self.tool_call_end_token: str = "</minimax:tool_call>"
|
|
47
|
+
self.tool_call_prefix: str = '<invoke name="'
|
|
48
|
+
self.tool_call_function_end_token: str = "</invoke>"
|
|
49
|
+
self.tool_call_regex = re.compile(
|
|
50
|
+
r"<minimax:tool_call>(.*?)</minimax:tool_call>|<minimax:tool_call>(.*?)$",
|
|
51
|
+
re.DOTALL,
|
|
52
|
+
)
|
|
53
|
+
self.tool_call_function_regex = re.compile(
|
|
54
|
+
r"<invoke name=\"(.*?)</invoke>|<invoke name=\"(.*)$", re.DOTALL
|
|
55
|
+
)
|
|
56
|
+
self.tool_call_parameter_regex = re.compile(
|
|
57
|
+
r"<parameter name=\"(.*?)</parameter>|<parameter name=\"(.*?)$", re.DOTALL
|
|
58
|
+
)
|
|
59
|
+
self._buf: str = ""
|
|
60
|
+
|
|
61
|
+
# Streaming state variables
|
|
62
|
+
self._current_function_name: str = ""
|
|
63
|
+
self._current_parameters: Dict[str, Any] = {}
|
|
64
|
+
self._streamed_parameters: Dict[str, str] = (
|
|
65
|
+
{}
|
|
66
|
+
) # Track what parameter content we've streamed
|
|
67
|
+
self._in_tool_call: bool = False
|
|
68
|
+
self._function_name_sent: bool = False
|
|
69
|
+
|
|
70
|
+
def has_tool_call(self, text: str) -> bool:
|
|
71
|
+
return self.tool_call_start_token in text
|
|
72
|
+
|
|
73
|
+
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
|
74
|
+
normal, calls = self._extract(text, tools)
|
|
75
|
+
return StreamingParseResult(normal_text=normal, calls=calls)
|
|
76
|
+
|
|
77
|
+
def parse_streaming_increment(
|
|
78
|
+
self, new_text: str, tools: List[Tool]
|
|
79
|
+
) -> StreamingParseResult:
|
|
80
|
+
self._buf += new_text
|
|
81
|
+
normal = ""
|
|
82
|
+
calls: List[ToolCallItem] = []
|
|
83
|
+
|
|
84
|
+
# Build tool indices for validation
|
|
85
|
+
if not hasattr(self, "_tool_indices"):
|
|
86
|
+
self._tool_indices = self._get_tool_indices(tools)
|
|
87
|
+
|
|
88
|
+
while True:
|
|
89
|
+
# If we're not in a tool call and don't see a start token, return normal text
|
|
90
|
+
if not self._in_tool_call and self.tool_call_start_token not in self._buf:
|
|
91
|
+
normal += self._buf
|
|
92
|
+
self._buf = ""
|
|
93
|
+
break
|
|
94
|
+
|
|
95
|
+
# Look for tool call start
|
|
96
|
+
if not self._in_tool_call:
|
|
97
|
+
s = self._buf.find(self.tool_call_start_token)
|
|
98
|
+
if s == -1:
|
|
99
|
+
normal += self._buf
|
|
100
|
+
self._buf = ""
|
|
101
|
+
break
|
|
102
|
+
|
|
103
|
+
normal += self._buf[:s]
|
|
104
|
+
self._buf = self._buf[s:]
|
|
105
|
+
|
|
106
|
+
self._in_tool_call = True
|
|
107
|
+
self._function_name_sent = False
|
|
108
|
+
self._current_function_name = ""
|
|
109
|
+
self._current_parameters = {}
|
|
110
|
+
self._streamed_parameters = {}
|
|
111
|
+
|
|
112
|
+
# Remove the start token
|
|
113
|
+
self._buf = self._buf[len(self.tool_call_start_token) :]
|
|
114
|
+
continue
|
|
115
|
+
|
|
116
|
+
# We're in a tool call, try to parse function name if not sent yet
|
|
117
|
+
if not self._function_name_sent:
|
|
118
|
+
# Look for function name pattern: <invoke name=name>
|
|
119
|
+
function_match = re.search(r"<invoke name=\"([^>]+)\">", self._buf)
|
|
120
|
+
if function_match:
|
|
121
|
+
function_name = function_match.group(1).strip()
|
|
122
|
+
|
|
123
|
+
# Validate function name
|
|
124
|
+
if function_name in self._tool_indices:
|
|
125
|
+
self._current_function_name = function_name
|
|
126
|
+
self._function_name_sent = True
|
|
127
|
+
|
|
128
|
+
# Initialize tool call tracking
|
|
129
|
+
if self.current_tool_id == -1:
|
|
130
|
+
self.current_tool_id = 0
|
|
131
|
+
|
|
132
|
+
# Ensure tracking arrays are large enough
|
|
133
|
+
while len(self.prev_tool_call_arr) <= self.current_tool_id:
|
|
134
|
+
self.prev_tool_call_arr.append({})
|
|
135
|
+
while len(self.streamed_args_for_tool) <= self.current_tool_id:
|
|
136
|
+
self.streamed_args_for_tool.append("")
|
|
137
|
+
|
|
138
|
+
# Store tool call info
|
|
139
|
+
self.prev_tool_call_arr[self.current_tool_id] = {
|
|
140
|
+
"name": function_name,
|
|
141
|
+
"arguments": {},
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
# Send tool name with empty parameters
|
|
145
|
+
calls.append(
|
|
146
|
+
ToolCallItem(
|
|
147
|
+
tool_index=self.current_tool_id,
|
|
148
|
+
name=function_name,
|
|
149
|
+
parameters="",
|
|
150
|
+
)
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
# Remove the processed function declaration
|
|
154
|
+
self._buf = self._buf[function_match.end() :]
|
|
155
|
+
continue
|
|
156
|
+
else:
|
|
157
|
+
# Invalid function name, reset state
|
|
158
|
+
logger.warning(f"Invalid function name: {function_name}")
|
|
159
|
+
self._reset_streaming_state()
|
|
160
|
+
normal += self._buf
|
|
161
|
+
self._buf = ""
|
|
162
|
+
break
|
|
163
|
+
else:
|
|
164
|
+
# Function name not complete yet, wait for more text
|
|
165
|
+
break
|
|
166
|
+
|
|
167
|
+
# Parse parameters incrementally
|
|
168
|
+
if self._function_name_sent:
|
|
169
|
+
# Process parameters and get any calls to emit
|
|
170
|
+
parameter_calls = self._parse_and_stream_parameters(self._buf)
|
|
171
|
+
calls.extend(parameter_calls)
|
|
172
|
+
|
|
173
|
+
# Check if tool call is complete
|
|
174
|
+
if self.tool_call_function_end_token in self._buf:
|
|
175
|
+
end_pos = self._buf.find(self.tool_call_function_end_token)
|
|
176
|
+
|
|
177
|
+
# Add closing brace to complete the JSON object
|
|
178
|
+
current_streamed = self.streamed_args_for_tool[self.current_tool_id]
|
|
179
|
+
if current_streamed:
|
|
180
|
+
# Count opening and closing braces to check if JSON is complete
|
|
181
|
+
open_braces = current_streamed.count("{")
|
|
182
|
+
close_braces = current_streamed.count("}")
|
|
183
|
+
if open_braces > close_braces:
|
|
184
|
+
calls.append(
|
|
185
|
+
ToolCallItem(
|
|
186
|
+
tool_index=self.current_tool_id,
|
|
187
|
+
name=None,
|
|
188
|
+
parameters="}",
|
|
189
|
+
)
|
|
190
|
+
)
|
|
191
|
+
self.streamed_args_for_tool[self.current_tool_id] = (
|
|
192
|
+
current_streamed + "}"
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
# Complete the tool call
|
|
196
|
+
self._buf = self._buf[
|
|
197
|
+
end_pos + len(self.tool_call_function_end_token) :
|
|
198
|
+
]
|
|
199
|
+
self._reset_streaming_state(True)
|
|
200
|
+
self.current_tool_id += 1
|
|
201
|
+
continue
|
|
202
|
+
else:
|
|
203
|
+
# Tool call not complete yet, wait for more text
|
|
204
|
+
break
|
|
205
|
+
|
|
206
|
+
return StreamingParseResult(normal_text=normal, calls=calls)
|
|
207
|
+
|
|
208
|
+
def _parse_and_stream_parameters(self, text_to_parse: str) -> List[ToolCallItem]:
|
|
209
|
+
"""
|
|
210
|
+
Parse complete parameter blocks from text and return any tool call items to emit.
|
|
211
|
+
|
|
212
|
+
This method:
|
|
213
|
+
1. Finds all complete <parameter> blocks
|
|
214
|
+
2. Parses them into a dictionary
|
|
215
|
+
3. Compares with current parameters and generates diff if needed
|
|
216
|
+
4. Updates internal state
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
text_to_parse: The text to search for parameter blocks
|
|
220
|
+
|
|
221
|
+
Returns:
|
|
222
|
+
List of ToolCallItem objects to emit (may be empty)
|
|
223
|
+
"""
|
|
224
|
+
calls: List[ToolCallItem] = []
|
|
225
|
+
|
|
226
|
+
# Find all complete parameter patterns
|
|
227
|
+
param_matches = list(
|
|
228
|
+
re.finditer(
|
|
229
|
+
r"<parameter name=\"([^>]+)\">(.*?)</parameter>",
|
|
230
|
+
text_to_parse,
|
|
231
|
+
re.DOTALL,
|
|
232
|
+
)
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
# Build new parameters dictionary
|
|
236
|
+
new_params = {}
|
|
237
|
+
for match in param_matches:
|
|
238
|
+
param_name = match.group(1).strip()
|
|
239
|
+
param_value = match.group(2)
|
|
240
|
+
new_params[param_name] = _safe_val(param_value)
|
|
241
|
+
|
|
242
|
+
# Calculate parameter diff to stream with proper incremental JSON building
|
|
243
|
+
if new_params != self._current_parameters:
|
|
244
|
+
previous_args_json = self.streamed_args_for_tool[self.current_tool_id]
|
|
245
|
+
|
|
246
|
+
# Build incremental JSON properly
|
|
247
|
+
if not self._current_parameters:
|
|
248
|
+
# First parameter(s) - start JSON object but don't close it yet
|
|
249
|
+
items = []
|
|
250
|
+
for key, value in new_params.items():
|
|
251
|
+
items.append(
|
|
252
|
+
f"{json.dumps(key, ensure_ascii=False)}: {json.dumps(value, ensure_ascii=False)}"
|
|
253
|
+
)
|
|
254
|
+
json_fragment = "{" + ", ".join(items)
|
|
255
|
+
|
|
256
|
+
calls.append(
|
|
257
|
+
ToolCallItem(
|
|
258
|
+
tool_index=self.current_tool_id,
|
|
259
|
+
name=None,
|
|
260
|
+
parameters=json_fragment,
|
|
261
|
+
)
|
|
262
|
+
)
|
|
263
|
+
self.streamed_args_for_tool[self.current_tool_id] = json_fragment
|
|
264
|
+
|
|
265
|
+
else:
|
|
266
|
+
# Additional parameters - add them incrementally
|
|
267
|
+
new_keys = set(new_params.keys()) - set(self._current_parameters.keys())
|
|
268
|
+
if new_keys:
|
|
269
|
+
# Build the continuation part (no closing brace yet)
|
|
270
|
+
continuation_parts = []
|
|
271
|
+
for key in new_keys:
|
|
272
|
+
value = new_params[key]
|
|
273
|
+
continuation_parts.append(
|
|
274
|
+
f"{json.dumps(key, ensure_ascii=False)}: {json.dumps(value, ensure_ascii=False)}"
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
json_fragment = ", " + ", ".join(continuation_parts)
|
|
278
|
+
|
|
279
|
+
calls.append(
|
|
280
|
+
ToolCallItem(
|
|
281
|
+
tool_index=self.current_tool_id,
|
|
282
|
+
name=None,
|
|
283
|
+
parameters=json_fragment,
|
|
284
|
+
)
|
|
285
|
+
)
|
|
286
|
+
self.streamed_args_for_tool[self.current_tool_id] = (
|
|
287
|
+
previous_args_json + json_fragment
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
# Update current state
|
|
291
|
+
self._current_parameters = new_params
|
|
292
|
+
self.prev_tool_call_arr[self.current_tool_id]["arguments"] = new_params
|
|
293
|
+
|
|
294
|
+
return calls
|
|
295
|
+
|
|
296
|
+
def _reset_streaming_state(self, still_in_tool_call: bool = False):
|
|
297
|
+
"""Reset streaming state for the next tool call"""
|
|
298
|
+
self._in_tool_call = still_in_tool_call
|
|
299
|
+
self._function_name_sent = False
|
|
300
|
+
self._current_function_name = ""
|
|
301
|
+
self._current_parameters = {}
|
|
302
|
+
self._streamed_parameters = {}
|
|
303
|
+
self.current_tool_name_sent = False
|
|
304
|
+
|
|
305
|
+
def _extract(self, text: str, tools: List[Tool]) -> Tuple[str, List[ToolCallItem]]:
|
|
306
|
+
normal_parts: List[str] = []
|
|
307
|
+
calls: List[ToolCallItem] = []
|
|
308
|
+
cursor = 0
|
|
309
|
+
while True:
|
|
310
|
+
s = text.find(self.tool_call_start_token, cursor)
|
|
311
|
+
if s == -1:
|
|
312
|
+
normal_parts.append(text[cursor:])
|
|
313
|
+
break
|
|
314
|
+
normal_parts.append(text[cursor:s])
|
|
315
|
+
e = text.find(self.tool_call_end_token, s)
|
|
316
|
+
if e == -1:
|
|
317
|
+
normal_parts.append(text[s:])
|
|
318
|
+
break
|
|
319
|
+
block = text[s : e + len(self.tool_call_end_token)]
|
|
320
|
+
cursor = e + len(self.tool_call_end_token)
|
|
321
|
+
calls.extend(self._parse_block(block, tools))
|
|
322
|
+
return "".join(normal_parts), calls
|
|
323
|
+
|
|
324
|
+
def _parse_block(self, block: str, tools: List[Tool]) -> List[ToolCallItem]:
|
|
325
|
+
res: List[ToolCallItem] = []
|
|
326
|
+
for m in self.tool_call_function_regex.findall(block):
|
|
327
|
+
txt = m[0] if m[0] else m[1]
|
|
328
|
+
if '">' not in txt:
|
|
329
|
+
continue
|
|
330
|
+
idx = txt.index('">')
|
|
331
|
+
fname = txt[:idx].strip()
|
|
332
|
+
body = txt[idx + 2 :]
|
|
333
|
+
params: Dict[str, Any] = {}
|
|
334
|
+
for pm in self.tool_call_parameter_regex.findall(body):
|
|
335
|
+
ptxt = pm[0] if pm[0] else pm[1]
|
|
336
|
+
if '">' not in ptxt:
|
|
337
|
+
continue
|
|
338
|
+
pidx = ptxt.index('">')
|
|
339
|
+
pname = ptxt[:pidx].strip()
|
|
340
|
+
pval = ptxt[pidx + 2 :].lstrip("\n").rstrip("\n")
|
|
341
|
+
params[pname] = _safe_val(pval)
|
|
342
|
+
raw = {"name": fname, "arguments": params}
|
|
343
|
+
try:
|
|
344
|
+
# TODO: fix idx in function call, the index for a function
|
|
345
|
+
# call will always be -1 in parse_base_json
|
|
346
|
+
res.extend(self.parse_base_json(raw, tools))
|
|
347
|
+
except Exception:
|
|
348
|
+
logger.warning("invalid tool call for %s dropped", fname)
|
|
349
|
+
return res
|
|
350
|
+
|
|
351
|
+
def supports_structural_tag(self) -> bool:
|
|
352
|
+
return False
|
|
353
|
+
|
|
354
|
+
def structure_info(self) -> _GetInfoFunc:
|
|
355
|
+
raise NotImplementedError
|
|
356
|
+
|
|
357
|
+
def build_ebnf(self, tools: List[Tool]):
|
|
358
|
+
return EBNFComposer.build_ebnf(
|
|
359
|
+
tools,
|
|
360
|
+
individual_call_start_token=self.tool_call_start_token.replace("\n", "\\n"),
|
|
361
|
+
individual_call_end_token=self.tool_call_end_token.replace("\n", "\\n"),
|
|
362
|
+
tool_call_separator="\\n",
|
|
363
|
+
function_format="xml",
|
|
364
|
+
call_rule_fmt='"<invoke name=\\"{name}\\">\\n" {arguments_rule} "\\n</invoke>"',
|
|
365
|
+
key_value_rule_fmt='"<parameter name=\\"{key}\\">\\n" {valrule} "\\n</parameter>"',
|
|
366
|
+
key_value_separator='"\\n"',
|
|
367
|
+
)
|
sglang/srt/layers/activation.py
CHANGED
|
@@ -29,6 +29,7 @@ from sglang.srt.distributed import (
|
|
|
29
29
|
get_tensor_model_parallel_world_size,
|
|
30
30
|
)
|
|
31
31
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
32
|
+
from sglang.srt.server_args import get_global_server_args
|
|
32
33
|
from sglang.srt.utils import (
|
|
33
34
|
cpu_has_amx_support,
|
|
34
35
|
is_cpu,
|
|
@@ -59,6 +60,11 @@ logger = logging.getLogger(__name__)
|
|
|
59
60
|
|
|
60
61
|
|
|
61
62
|
class SiluAndMul(CustomOp):
|
|
63
|
+
def __init__(self, *args, **kwargs):
|
|
64
|
+
super().__init__(*args, **kwargs)
|
|
65
|
+
if get_global_server_args().rl_on_policy_target == "fsdp":
|
|
66
|
+
self._forward_method = self.forward_native
|
|
67
|
+
|
|
62
68
|
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
|
63
69
|
d = x.shape[-1] // 2
|
|
64
70
|
return F.silu(x[..., :d]) * x[..., d:]
|
|
@@ -855,14 +855,24 @@ class FlashAttentionBackend(AttentionBackend):
|
|
|
855
855
|
)
|
|
856
856
|
else:
|
|
857
857
|
# MHA for extend part of sequence without attending prefix kv cache
|
|
858
|
+
cu_seqlens_k = (
|
|
859
|
+
metadata.cu_seqlens_q
|
|
860
|
+
if not forward_batch.mha_one_shot
|
|
861
|
+
else metadata.cu_seqlens_k
|
|
862
|
+
)
|
|
863
|
+
max_seqlen_k = (
|
|
864
|
+
metadata.max_seq_len_q
|
|
865
|
+
if not forward_batch.mha_one_shot
|
|
866
|
+
else metadata.max_seq_len_k
|
|
867
|
+
)
|
|
858
868
|
output = flash_attn_varlen_func(
|
|
859
869
|
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
|
860
870
|
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
|
|
861
871
|
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
|
|
862
872
|
cu_seqlens_q=metadata.cu_seqlens_q,
|
|
863
|
-
cu_seqlens_k=
|
|
873
|
+
cu_seqlens_k=cu_seqlens_k,
|
|
864
874
|
max_seqlen_q=metadata.max_seq_len_q,
|
|
865
|
-
max_seqlen_k=
|
|
875
|
+
max_seqlen_k=max_seqlen_k,
|
|
866
876
|
softmax_scale=layer.scaling,
|
|
867
877
|
causal=True,
|
|
868
878
|
return_softmax_lse=forward_batch.mha_return_lse,
|
|
@@ -230,7 +230,16 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
|
230
230
|
|
|
231
231
|
fmha_backend = "auto"
|
|
232
232
|
if is_sm100_supported():
|
|
233
|
-
|
|
233
|
+
# Disable CUTLASS backend when piecewise cuda graph is enabled
|
|
234
|
+
# due to TMA descriptor initialization issues on B200
|
|
235
|
+
if model_runner.server_args.enable_piecewise_cuda_graph:
|
|
236
|
+
logger.warning(
|
|
237
|
+
"CUTLASS backend is disabled when piecewise cuda graph is enabled "
|
|
238
|
+
"due to TMA descriptor initialization issues on B200. "
|
|
239
|
+
"Using auto backend instead for stability."
|
|
240
|
+
)
|
|
241
|
+
else:
|
|
242
|
+
fmha_backend = "cutlass"
|
|
234
243
|
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
|
|
235
244
|
self.workspace_buffer, "NHD", backend=fmha_backend
|
|
236
245
|
)
|
|
@@ -82,6 +82,7 @@ class FlashInferMhaChunkKVRunner:
|
|
|
82
82
|
|
|
83
83
|
# Buffers and wrappers
|
|
84
84
|
self.qo_indptr = attn_backend.qo_indptr
|
|
85
|
+
self.kv_indptr = attn_backend.kv_indptr
|
|
85
86
|
self.workspace_buffer = attn_backend.workspace_buffer
|
|
86
87
|
self.fmha_backend = attn_backend.fmha_backend
|
|
87
88
|
|
|
@@ -132,9 +133,14 @@ class FlashInferMhaChunkKVRunner:
|
|
|
132
133
|
)
|
|
133
134
|
# ragged prefill
|
|
134
135
|
if not disable_flashinfer_ragged:
|
|
136
|
+
kv_indptr = (
|
|
137
|
+
qo_indptr
|
|
138
|
+
if not forward_batch.mha_one_shot
|
|
139
|
+
else self.kv_indptr[: bs + 1]
|
|
140
|
+
)
|
|
135
141
|
self.ragged_wrapper.begin_forward(
|
|
136
142
|
qo_indptr=qo_indptr,
|
|
137
|
-
kv_indptr=
|
|
143
|
+
kv_indptr=kv_indptr,
|
|
138
144
|
num_qo_heads=self.num_local_heads,
|
|
139
145
|
num_kv_heads=self.num_local_heads,
|
|
140
146
|
head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
|
|
@@ -156,7 +162,7 @@ class FlashInferMhaChunkKVRunner:
|
|
|
156
162
|
chunk_idx = forward_batch.prefix_chunk_idx
|
|
157
163
|
assert chunk_idx >= 0
|
|
158
164
|
wrapper = self.chunk_ragged_wrappers[chunk_idx]
|
|
159
|
-
|
|
165
|
+
o = wrapper.forward_return_lse(
|
|
160
166
|
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
|
161
167
|
k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
|
|
162
168
|
v.view(-1, layer.tp_v_head_num, layer.v_head_dim).to(q.dtype),
|
|
@@ -165,7 +171,12 @@ class FlashInferMhaChunkKVRunner:
|
|
|
165
171
|
logits_soft_cap=logits_soft_cap,
|
|
166
172
|
)
|
|
167
173
|
else:
|
|
168
|
-
|
|
174
|
+
forward = (
|
|
175
|
+
self.ragged_wrapper.forward_return_lse
|
|
176
|
+
if forward_batch.mha_return_lse
|
|
177
|
+
else self.ragged_wrapper.forward
|
|
178
|
+
)
|
|
179
|
+
o = forward(
|
|
169
180
|
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
|
170
181
|
k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
|
|
171
182
|
v.view(-1, layer.tp_v_head_num, layer.v_head_dim).to(q.dtype),
|
|
@@ -173,8 +184,7 @@ class FlashInferMhaChunkKVRunner:
|
|
|
173
184
|
sm_scale=layer.scaling,
|
|
174
185
|
logits_soft_cap=logits_soft_cap,
|
|
175
186
|
)
|
|
176
|
-
|
|
177
|
-
return o1, s1
|
|
187
|
+
return o
|
|
178
188
|
|
|
179
189
|
|
|
180
190
|
class FlashInferMLAAttnBackend(AttentionBackend):
|
|
@@ -512,15 +522,13 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
|
512
522
|
q_rope: Optional[torch.Tensor] = None,
|
|
513
523
|
k_rope: Optional[torch.Tensor] = None,
|
|
514
524
|
):
|
|
515
|
-
if (
|
|
516
|
-
forward_batch.
|
|
517
|
-
and forward_batch.mha_return_lse
|
|
525
|
+
if forward_batch.attn_attend_prefix_cache is not None and any(
|
|
526
|
+
forward_batch.extend_prefix_lens_cpu
|
|
518
527
|
): # MHA Chunk
|
|
519
528
|
assert self.enable_chunk_kv
|
|
520
529
|
assert q_rope is None
|
|
521
530
|
assert k_rope is None
|
|
522
|
-
|
|
523
|
-
return o1, s1
|
|
531
|
+
return self.mha_chunk_kv_cache.forward(q, k, v, layer, forward_batch)
|
|
524
532
|
|
|
525
533
|
cache_loc = forward_batch.out_cache_loc
|
|
526
534
|
logits_soft_cap = layer.logit_cap
|
|
@@ -423,14 +423,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
|
423
423
|
PAGED_SIZE=self.page_size,
|
|
424
424
|
)
|
|
425
425
|
|
|
426
|
-
# Record the true maximum sequence length for this capture batch so that
|
|
427
|
-
# the kernel launch path (which requires an int not a tensor) can reuse
|
|
428
|
-
# it safely during both capture and replay.
|
|
429
|
-
max_seq_len_val = int(seq_lens.max().item())
|
|
430
|
-
|
|
431
426
|
metadata = TRTLLMMLADecodeMetadata(
|
|
432
427
|
block_kv_indices,
|
|
433
|
-
|
|
428
|
+
self.max_context_len,
|
|
434
429
|
)
|
|
435
430
|
if forward_mode.is_draft_extend(include_v2=True):
|
|
436
431
|
num_tokens_per_bs = num_tokens // bs
|
|
@@ -509,13 +504,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
|
509
504
|
PAGED_SIZE=self.page_size,
|
|
510
505
|
)
|
|
511
506
|
|
|
512
|
-
# Update stored max_seq_len so subsequent kernel calls use the correct value
|
|
513
|
-
# Prefer CPU tensor to avoid GPU synchronization when available.
|
|
514
|
-
if seq_lens_cpu is not None:
|
|
515
|
-
metadata.max_seq_len = int(seq_lens_cpu.max().item())
|
|
516
|
-
else:
|
|
517
|
-
metadata.max_seq_len = int(seq_lens.max().item())
|
|
518
|
-
|
|
519
507
|
def get_cuda_graph_seq_len_fill_value(self) -> int:
|
|
520
508
|
"""Get the fill value for sequence lengths in CUDA graph."""
|
|
521
509
|
return 1
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import torch
|
|
1
2
|
import triton
|
|
2
3
|
import triton.language as tl
|
|
3
4
|
|
|
@@ -101,3 +102,80 @@ def create_flashmla_kv_indices_triton(
|
|
|
101
102
|
data // PAGED_SIZE,
|
|
102
103
|
mask=mask_out,
|
|
103
104
|
)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
@triton.jit
|
|
108
|
+
def concat_and_cast_mha_k_kernel(
|
|
109
|
+
k_ptr,
|
|
110
|
+
k_nope_ptr,
|
|
111
|
+
k_rope_ptr,
|
|
112
|
+
head_cnt: tl.constexpr,
|
|
113
|
+
k_stride0: tl.constexpr,
|
|
114
|
+
k_stride1: tl.constexpr,
|
|
115
|
+
nope_stride0: tl.constexpr,
|
|
116
|
+
nope_stride1: tl.constexpr,
|
|
117
|
+
rope_stride0: tl.constexpr,
|
|
118
|
+
nope_dim: tl.constexpr,
|
|
119
|
+
rope_dim: tl.constexpr,
|
|
120
|
+
):
|
|
121
|
+
pid_loc = tl.program_id(0)
|
|
122
|
+
head_range = tl.arange(0, head_cnt)
|
|
123
|
+
|
|
124
|
+
k_head_ptr = k_ptr + pid_loc * k_stride0 + head_range[:, None] * k_stride1
|
|
125
|
+
|
|
126
|
+
nope_offs = tl.arange(0, nope_dim)
|
|
127
|
+
|
|
128
|
+
src_nope_ptr = (
|
|
129
|
+
k_nope_ptr
|
|
130
|
+
+ pid_loc * nope_stride0
|
|
131
|
+
+ head_range[:, None] * nope_stride1
|
|
132
|
+
+ nope_offs[None, :]
|
|
133
|
+
)
|
|
134
|
+
dst_nope_ptr = k_head_ptr + nope_offs[None, :]
|
|
135
|
+
|
|
136
|
+
src_nope = tl.load(src_nope_ptr)
|
|
137
|
+
tl.store(dst_nope_ptr, src_nope)
|
|
138
|
+
|
|
139
|
+
rope_offs = tl.arange(0, rope_dim)
|
|
140
|
+
src_rope_ptr = k_rope_ptr + pid_loc * rope_stride0 + rope_offs[None, :]
|
|
141
|
+
dst_rope_ptr = k_head_ptr + nope_dim + rope_offs[None, :]
|
|
142
|
+
src_rope = tl.load(src_rope_ptr)
|
|
143
|
+
tl.store(dst_rope_ptr, src_rope)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def concat_and_cast_mha_k_triton(
|
|
147
|
+
k: torch.Tensor,
|
|
148
|
+
k_nope: torch.Tensor,
|
|
149
|
+
k_rope: torch.Tensor,
|
|
150
|
+
):
|
|
151
|
+
# The source data type will be implicitly converted to the target data type.
|
|
152
|
+
assert (
|
|
153
|
+
len(k.shape) == 3 and len(k_nope.shape) == 3 and len(k_rope.shape) == 3
|
|
154
|
+
), f"shape should be 3d, but got {k.shape=}, {k_nope.shape=}, {k_rope.shape=}"
|
|
155
|
+
assert (
|
|
156
|
+
k.shape[0] == k_nope.shape[0] and k.shape[0] == k_rope.shape[0]
|
|
157
|
+
), f"invalid shape, got {k.shape=}, {k_nope.shape=}, {k_rope.shape=}"
|
|
158
|
+
assert (
|
|
159
|
+
k.shape[1] == k_nope.shape[1] and 1 == k_rope.shape[1]
|
|
160
|
+
), f"invalid shape, got {k.shape=}, {k_nope.shape=}, {k_rope.shape=}"
|
|
161
|
+
assert (
|
|
162
|
+
k.shape[-1] == k_nope.shape[-1] + k_rope.shape[-1]
|
|
163
|
+
), f"invalid shape, got {k.shape=}, {k_nope.shape=}, {k_rope.shape=}"
|
|
164
|
+
|
|
165
|
+
nope_dim = k_nope.shape[-1]
|
|
166
|
+
rope_dim = k_rope.shape[-1]
|
|
167
|
+
grid = (k.shape[0],)
|
|
168
|
+
|
|
169
|
+
concat_and_cast_mha_k_kernel[grid](
|
|
170
|
+
k,
|
|
171
|
+
k_nope,
|
|
172
|
+
k_rope,
|
|
173
|
+
k.shape[1],
|
|
174
|
+
k.stride(0),
|
|
175
|
+
k.stride(1),
|
|
176
|
+
k_nope.stride(0),
|
|
177
|
+
k_nope.stride(1),
|
|
178
|
+
k_rope.stride(0),
|
|
179
|
+
nope_dim,
|
|
180
|
+
rope_dim,
|
|
181
|
+
)
|
|
@@ -337,6 +337,7 @@ class LayerCommunicator:
|
|
|
337
337
|
static_conditions_met = (
|
|
338
338
|
(not self.is_last_layer)
|
|
339
339
|
and (self._context.tp_size > 1)
|
|
340
|
+
and not is_dp_attention_enabled()
|
|
340
341
|
and get_global_server_args().enable_flashinfer_allreduce_fusion
|
|
341
342
|
and _is_flashinfer_available
|
|
342
343
|
)
|
|
@@ -26,7 +26,7 @@ _IN_PRECOMPILE_STAGE = get_bool_env_var("SGL_IN_DEEPGEMM_PRECOMPILE_STAGE", "fal
|
|
|
26
26
|
|
|
27
27
|
# Force redirect deep_gemm cache_dir
|
|
28
28
|
os.environ["DG_JIT_CACHE_DIR"] = os.getenv(
|
|
29
|
-
"
|
|
29
|
+
"SGLANG_DG_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache", "deep_gemm")
|
|
30
30
|
)
|
|
31
31
|
|
|
32
32
|
# Refer to https://github.com/deepseek-ai/DeepGEMM/commit/d75b218b7b8f4a5dd5406ac87905039ead3ae42f
|