sglang 0.4.1.post7__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. sglang/bench_offline_throughput.py +17 -11
  2. sglang/bench_one_batch.py +14 -6
  3. sglang/bench_serving.py +47 -44
  4. sglang/lang/chat_template.py +31 -0
  5. sglang/srt/configs/load_config.py +1 -0
  6. sglang/srt/distributed/device_communicators/custom_all_reduce.py +5 -2
  7. sglang/srt/entrypoints/engine.py +5 -2
  8. sglang/srt/entrypoints/http_server.py +24 -0
  9. sglang/srt/function_call_parser.py +494 -0
  10. sglang/srt/layers/activation.py +5 -5
  11. sglang/srt/layers/dp_attention.py +3 -1
  12. sglang/srt/layers/layernorm.py +5 -5
  13. sglang/srt/layers/linear.py +24 -9
  14. sglang/srt/layers/logits_processor.py +1 -1
  15. sglang/srt/layers/moe/ep_moe/layer.py +20 -12
  16. sglang/srt/layers/moe/fused_moe_native.py +17 -3
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  18. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -1
  19. sglang/srt/layers/moe/fused_moe_triton/layer.py +9 -0
  20. sglang/srt/layers/parameter.py +16 -7
  21. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  22. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  23. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  24. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  25. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  26. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  27. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  28. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  29. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  30. sglang/srt/layers/quantization/fp8.py +4 -1
  31. sglang/srt/layers/rotary_embedding.py +6 -1
  32. sglang/srt/layers/sampler.py +28 -8
  33. sglang/srt/layers/torchao_utils.py +12 -6
  34. sglang/srt/managers/detokenizer_manager.py +1 -0
  35. sglang/srt/managers/io_struct.py +36 -5
  36. sglang/srt/managers/schedule_batch.py +31 -25
  37. sglang/srt/managers/scheduler.py +61 -35
  38. sglang/srt/managers/tokenizer_manager.py +4 -0
  39. sglang/srt/model_executor/cuda_graph_runner.py +23 -25
  40. sglang/srt/model_executor/forward_batch_info.py +5 -7
  41. sglang/srt/model_executor/model_runner.py +7 -4
  42. sglang/srt/model_loader/loader.py +75 -0
  43. sglang/srt/model_loader/weight_utils.py +91 -5
  44. sglang/srt/models/commandr.py +14 -2
  45. sglang/srt/models/dbrx.py +9 -1
  46. sglang/srt/models/deepseek_v2.py +3 -3
  47. sglang/srt/models/gemma2.py +9 -1
  48. sglang/srt/models/grok.py +1 -0
  49. sglang/srt/models/minicpm3.py +3 -3
  50. sglang/srt/models/torch_native_llama.py +17 -4
  51. sglang/srt/openai_api/adapter.py +139 -37
  52. sglang/srt/openai_api/protocol.py +5 -4
  53. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +11 -14
  54. sglang/srt/sampling/sampling_batch_info.py +4 -14
  55. sglang/srt/server.py +2 -2
  56. sglang/srt/server_args.py +20 -1
  57. sglang/srt/speculative/eagle_utils.py +37 -15
  58. sglang/srt/speculative/eagle_worker.py +11 -13
  59. sglang/srt/utils.py +62 -65
  60. sglang/test/test_programs.py +1 -0
  61. sglang/test/test_utils.py +81 -22
  62. sglang/version.py +1 -1
  63. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.dist-info}/METADATA +7 -7
  64. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.dist-info}/RECORD +67 -56
  65. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.dist-info}/LICENSE +0 -0
  66. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.dist-info}/WHEEL +0 -0
  67. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,494 @@
1
+ import json
2
+ import re
3
+ from abc import ABC, abstractmethod
4
+ from json import JSONDecodeError, JSONDecoder
5
+ from typing import Any, Dict, List, Optional, Tuple
6
+
7
+ import partial_json_parser
8
+ from partial_json_parser.core.options import Allow
9
+ from pydantic import BaseModel, Field
10
+
11
+ TOOLS_TAG_LIST = [
12
+ "<|plugin|>",
13
+ "<function=",
14
+ "<tool_call>",
15
+ "<|python_tag|>",
16
+ "[TOOL_CALLS]",
17
+ ]
18
+
19
+
20
+ class Function(BaseModel):
21
+ """Function Tool Template."""
22
+
23
+ description: Optional[str] = Field(default=None, examples=[None])
24
+ name: Optional[str] = None
25
+ parameters: Optional[object] = None
26
+
27
+
28
+ class ToolCallItem(BaseModel):
29
+ """Simple encapsulation of the parsed ToolCall result for easier usage in streaming contexts."""
30
+
31
+ tool_index: int
32
+ name: Optional[str] = None
33
+ parameters: str # JSON string
34
+
35
+
36
+ def _find_common_prefix(s1: str, s2: str) -> str:
37
+ prefix = ""
38
+ min_length = min(len(s1), len(s2))
39
+ for i in range(0, min_length):
40
+ if s1[i] == s2[i]:
41
+ prefix += s1[i]
42
+ else:
43
+ break
44
+ return prefix
45
+
46
+
47
+ def _partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]:
48
+ try:
49
+ return (partial_json_parser.loads(input_str, flags), len(input_str))
50
+ except JSONDecodeError as e:
51
+ if "Extra data" in e.msg:
52
+ dec = JSONDecoder()
53
+ return dec.raw_decode(input_str)
54
+ raise
55
+
56
+
57
+ def _is_complete_json(input_str: str) -> bool:
58
+ try:
59
+ json.loads(input_str)
60
+ return True
61
+ except JSONDecodeError:
62
+ return False
63
+
64
+
65
+ class StreamingParseResult:
66
+ """Result of streaming incremental parsing."""
67
+
68
+ def __init__(
69
+ self, normal_text: str = "", calls: Optional[List[ToolCallItem]] = None
70
+ ):
71
+ self.normal_text = normal_text
72
+ self.calls = calls or []
73
+
74
+
75
+ class BaseFormatDetector:
76
+ """Base class providing two sets of interfaces: one-time and streaming incremental."""
77
+
78
+ def __init__(self):
79
+ # initialize properties used for state when parsing tool calls in
80
+ self._buffer = ""
81
+ # streaming mode
82
+ self.prev_tool_call_arr: List[Dict] = []
83
+ self.current_tool_id: int = -1
84
+ self.current_tool_name_sent: bool = False
85
+ self.streamed_args_for_tool: List[str] = (
86
+ []
87
+ ) # map what has been streamed for each tool so far to a list
88
+ self.bot_token = ""
89
+ self.eot_token = ""
90
+
91
+ def parse_base_json(self, action: Dict, tools: List[Function]):
92
+ name, parameters = action["name"], json.dumps(
93
+ action.get("parameters", action.get("arguments", {})),
94
+ ensure_ascii=False,
95
+ )
96
+ tool_index = [tool.function.name for tool in tools].index(name)
97
+ tool_call_item = ToolCallItem(
98
+ tool_index=tool_index, name=name, parameters=parameters
99
+ )
100
+ calls = [tool_call_item]
101
+ return calls
102
+
103
+ def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
104
+ """
105
+ Parses the text in one go. Returns success=True if the format matches, otherwise False.
106
+ Note that leftover_text here represents "content that this parser will not consume further".
107
+ """
108
+ action = json.loads(text)
109
+ return self.parse_base_json(action, tools)
110
+
111
+ def parse_streaming_increment(
112
+ self, new_text: str, tools: List[Function]
113
+ ) -> StreamingParseResult:
114
+ """
115
+ Streaming incremental parsing, referencing the logic of Llama32Detector.
116
+ We partially parse JSON within <tool_call>...</tool_call>, and handle
117
+ incremental argument output.
118
+ """
119
+ # Append new text to buffer
120
+ self._buffer += new_text
121
+ current_text = self._buffer
122
+ if not (self.bot_token in current_text or current_text.startswith("{")):
123
+ self._buffer = ""
124
+ if self.eot_token in new_text:
125
+ new_text = new_text.replace(self.eot_token, "")
126
+ return StreamingParseResult(normal_text=new_text)
127
+
128
+ # bit mask flags for partial JSON parsing. If the name hasn't been
129
+ # sent yet, don't allow sending
130
+ # an incomplete string since OpenAI only ever (as far as I have
131
+ # seen) allows sending the entire tool/ function name at once.
132
+ flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
133
+ try:
134
+ tool_call_arr = []
135
+ is_complete = []
136
+ try:
137
+ # depending on the prompt format the Llama model may or may not
138
+ # prefix the output with the <|python_tag|> token
139
+ start_idx = (
140
+ len(self.bot_token)
141
+ if current_text.startswith(self.bot_token)
142
+ else 0
143
+ )
144
+ while start_idx < len(current_text):
145
+ (obj, end_idx) = _partial_json_loads(
146
+ current_text[start_idx:], flags
147
+ )
148
+ is_complete.append(
149
+ _is_complete_json(current_text[start_idx : start_idx + end_idx])
150
+ )
151
+ start_idx += end_idx + len("; ")
152
+ # depending on the prompt Llama can use
153
+ # either arguments or parameters
154
+ if "parameters" in obj:
155
+ assert (
156
+ "arguments" not in obj
157
+ ), "model generated both parameters and arguments"
158
+ obj["arguments"] = obj["parameters"]
159
+ tool_call_arr.append(obj)
160
+
161
+ except partial_json_parser.core.exceptions.MalformedJSON:
162
+ # not enough tokens to parse into JSON yet
163
+ return StreamingParseResult()
164
+
165
+ # select as the current tool call the one we're on the state at
166
+ current_tool_call: Dict = (
167
+ tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
168
+ )
169
+
170
+ # case -- if no tokens have been streamed for the tool, e.g.
171
+ # only the array brackets, stream nothing
172
+ if len(tool_call_arr) == 0:
173
+ return StreamingParseResult()
174
+
175
+ # case: we are starting a new tool in the array
176
+ # -> array has > 0 length AND length has moved past cursor
177
+ elif (
178
+ len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1
179
+ ):
180
+
181
+ # if we're moving on to a new call, first make sure we
182
+ # haven't missed anything in the previous one that was
183
+ # auto-generated due to JSON completions, but wasn't
184
+ # streamed to the client yet.
185
+ if self.current_tool_id >= 0:
186
+ cur_arguments = current_tool_call.get("arguments")
187
+ if cur_arguments:
188
+ cur_args_json = json.dumps(cur_arguments)
189
+ sent = len(self.streamed_args_for_tool[self.current_tool_id])
190
+ argument_diff = cur_args_json[sent:]
191
+
192
+ res = StreamingParseResult(
193
+ normal_text=None,
194
+ calls=[
195
+ ToolCallItem(
196
+ tool_index=self.current_tool_id,
197
+ name="",
198
+ parameters=argument_diff,
199
+ )
200
+ ],
201
+ )
202
+ self.streamed_args_for_tool[
203
+ self.current_tool_id
204
+ ] += argument_diff
205
+ else:
206
+ res = StreamingParseResult()
207
+ else:
208
+ res = StreamingParseResult()
209
+ # re-set stuff pertaining to progress in the current tool
210
+ self.current_tool_id = len(tool_call_arr) - 1
211
+ self.current_tool_name_sent = False
212
+ self.streamed_args_for_tool.append("")
213
+ print("starting on new tool %d", self.current_tool_id)
214
+ return res
215
+
216
+ # if the current tool name hasn't been sent, send if available
217
+ # - otherwise send nothing
218
+ elif not self.current_tool_name_sent:
219
+ function_name = current_tool_call.get("name")
220
+ if function_name:
221
+ res = StreamingParseResult(
222
+ normal_text=None,
223
+ calls=[
224
+ ToolCallItem(
225
+ tool_index=self.current_tool_id,
226
+ name=function_name,
227
+ parameters="",
228
+ )
229
+ ],
230
+ )
231
+ self.current_tool_name_sent = True
232
+ else:
233
+ res = StreamingParseResult()
234
+
235
+ # now we know we're on the same tool call and we're streaming
236
+ # arguments
237
+ else:
238
+ cur_arguments = current_tool_call.get("arguments")
239
+ res = StreamingParseResult()
240
+
241
+ if cur_arguments:
242
+ sent = len(self.streamed_args_for_tool[self.current_tool_id])
243
+ cur_args_json = json.dumps(cur_arguments)
244
+ prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
245
+ "arguments"
246
+ )
247
+
248
+ argument_diff = None
249
+ if is_complete[self.current_tool_id]:
250
+ argument_diff = cur_args_json[sent:]
251
+ self._buffer = ""
252
+ self.prev_tool_call_arr[self.current_tool_id].clear()
253
+ self.current_tool_name_sent: bool = False
254
+ self.streamed_args_for_tool[self.current_tool_id] = ""
255
+
256
+ elif prev_arguments:
257
+ prev_args_json = json.dumps(prev_arguments)
258
+ if cur_args_json != prev_args_json:
259
+
260
+ prefix = _find_common_prefix(prev_args_json, cur_args_json)
261
+ argument_diff = prefix[sent:]
262
+
263
+ if argument_diff is not None:
264
+ res = StreamingParseResult(
265
+ calls=[
266
+ ToolCallItem(
267
+ tool_index=self.current_tool_id,
268
+ name="",
269
+ parameters=argument_diff,
270
+ )
271
+ ],
272
+ )
273
+ if not is_complete[self.current_tool_id]:
274
+ self.streamed_args_for_tool[
275
+ self.current_tool_id
276
+ ] += argument_diff
277
+
278
+ self.prev_tool_call_arr = tool_call_arr
279
+ return res
280
+
281
+ except Exception as e:
282
+ print(e)
283
+ # Skipping chunk as a result of tool streaming extraction error
284
+ return StreamingParseResult()
285
+
286
+
287
+ class Qwen25Detector(BaseFormatDetector):
288
+ """
289
+ Detector for Qwen 2.5 models.
290
+ Assumes function call format:
291
+ <tool_call>{"name":"xxx", "arguments":{...}}</tool_call>
292
+ """
293
+
294
+ def __init__(self):
295
+ """
296
+ Initializes the detector with necessary state variables.
297
+ """
298
+ super().__init__()
299
+ self.bot_token = "<tool_call>"
300
+ self.eot_token = "</tool_call>"
301
+
302
+ def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
303
+ """
304
+ One-time parsing: Detects and parses tool calls in the provided text.
305
+
306
+ :param text: The complete text to parse.
307
+ :param tools: List of available tools.
308
+ :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
309
+ """
310
+ if "<tool_call>" not in text:
311
+ return []
312
+ pattern = r"<tool_call>(.*?)</tool_call>"
313
+ match_result_list = re.findall(pattern, text, re.DOTALL)
314
+ calls = []
315
+ for match_result in match_result_list:
316
+ match_result = json.loads(match_result)
317
+ calls.extend(self.parse_base_json(match_result, tools))
318
+ return calls
319
+
320
+
321
+ class MistralDetector(BaseFormatDetector):
322
+ """
323
+ Detector for Mistral models.
324
+ Assumes function call format:
325
+ <|action_start|><|plugin|>{"name":"xxx", "arguments":{...}}<|action_end|>
326
+ """
327
+
328
+ def __init__(self):
329
+ """
330
+ Initializes the detector with necessary state variables.
331
+ """
332
+ super().__init__()
333
+ self.bot_token = "[TOOL_CALLS] ["
334
+ self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
335
+
336
+ def _clean_text(self, text: str) -> str:
337
+ """
338
+ clean text to only leave ''[TOOL_CALLS] [{"name": xxx, "arguments": {xxx}}]'
339
+ for example,
340
+ text = '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]\n\nToday\'s weather in Boston is :{function call result} (in Fahrenheit)\n\nIf you prefer Celsius, please let me know.'
341
+ return '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]'
342
+ The key pattern is [TOOL_CALLS] [...]
343
+ """
344
+ find_results = re.findall(r"\[TOOL_CALLS\] \[.*?\]", text, re.DOTALL)
345
+ if len(find_results) > 0:
346
+ return find_results[0]
347
+ else:
348
+ return ""
349
+
350
+ def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
351
+ """
352
+ One-time parsing: Detects and parses tool calls in the provided text.
353
+
354
+ :param text: The complete text to parse.
355
+ :param tools: List of available tools.
356
+ :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
357
+ """
358
+ text = self._clean_text(text)
359
+ tool_content = text.replace("[TOOL_CALLS]", "").strip()
360
+ raw_tool_calls = self.tool_call_regex.findall(tool_content)
361
+ calls = []
362
+ if len(raw_tool_calls) > 0:
363
+ raw_tool_call = raw_tool_calls[0]
364
+ function_call_arr = json.loads(raw_tool_call)
365
+ for match_result in function_call_arr:
366
+ calls.extend(self.parse_base_json(match_result, tools))
367
+ return calls
368
+
369
+
370
+ class Llama32Detector(BaseFormatDetector):
371
+ """
372
+ Detector for Llama 3.2 models.
373
+ Assumes function call format:
374
+ <|python_tag|>{"name":"xxx", "arguments":{...}}
375
+ Does not require a closing tag "</python_tag|>",
376
+ relies on json.loads(...) success to determine if JSON is complete.
377
+ """
378
+
379
+ def __init__(self):
380
+ """
381
+ Initializes the detector with necessary state variables.
382
+ """
383
+ super().__init__()
384
+ self.bot_token = "<|python_tag|>"
385
+
386
+ def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
387
+ """
388
+ One-time parsing: Detects and parses tool calls in the provided text.
389
+
390
+ :param text: The complete text to parse.
391
+ :param tools: List of available tools.
392
+ :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
393
+ """
394
+
395
+ if "<|python_tag|>" not in text:
396
+ return []
397
+ _, action = text.split("<|python_tag|>")
398
+ action = json.loads(action)
399
+ return self.parse_base_json(action, tools)
400
+
401
+
402
+ class MultiFormatParser:
403
+ def __init__(self, detectors: List[BaseFormatDetector]):
404
+ """
405
+ :param detectors: A series of available Detector instances passed in
406
+ """
407
+ self.detectors = detectors
408
+
409
+ def parse_once(self, text: str, tools: List[Function]):
410
+ """
411
+ One-time parsing: Loop through detectors until there are no new matches or text is exhausted
412
+ Return: (final_text, all_calls)
413
+ - final_text: The remaining text after parsing that was not consumed by any Detector (can be treated as normal text)
414
+ - all_calls: All calls parsed by the Detectors
415
+ """
416
+ final_calls = []
417
+ final_normal_text = text
418
+ for detector in self.detectors:
419
+ tool_call_list = detector.detect_and_parse(text, tools)
420
+ if len(tool_call_list) > 0: # parsed successfully
421
+ final_calls = tool_call_list
422
+ break
423
+
424
+ # leftover_text is the normal text not consumed by any Detector
425
+ return final_normal_text, final_calls
426
+
427
+ def parse_streaming_increment(self, new_text: str, tools: List[Function]):
428
+ """
429
+ Streaming incremental parsing: Feed new_text to each detector's parse_streaming_increment
430
+ and merge their produced normal_text/calls to return.
431
+ (The logic here can be "priority-based" or "parallel parsing" based on your needs)
432
+ """
433
+ final_normal_text = ""
434
+ final_calls = []
435
+
436
+ for detector in self.detectors:
437
+ sp_result = detector.parse_streaming_increment(new_text, tools)
438
+ # Merge normal_text and calls
439
+ # If one sp_result contains result call, this should be a successful parse
440
+ # If one sp_result only contains normal_text, this can either be a successful
441
+ # parse or it is not using the desired parsing tool.
442
+ if sp_result.normal_text:
443
+ final_normal_text = sp_result.normal_text
444
+ if sp_result.calls:
445
+ final_calls.extend(sp_result.calls)
446
+ final_normal_text = sp_result.normal_text
447
+ break
448
+
449
+ return final_normal_text, final_calls
450
+
451
+
452
+ class FunctionCallParser:
453
+ """
454
+ In streaming scenarios, each time new_text is received, it calls multi_format_parser.parse_streaming_increment
455
+ and returns the resulting normal_text and calls to the upper layer (or SSE).
456
+ """
457
+
458
+ ToolCallParserEnum: Dict[str, BaseFormatDetector] = {
459
+ "llama3": Llama32Detector,
460
+ "qwen25": Qwen25Detector,
461
+ "mistral": MistralDetector,
462
+ }
463
+
464
+ def __init__(self, tools: List[Function], tool_call_parser: str = None):
465
+ detectors = []
466
+ if tool_call_parser:
467
+ detector_class = self.ToolCallParserEnum.get(tool_call_parser)
468
+ if detector_class:
469
+ detectors.append(detector_class())
470
+ else:
471
+ raise ValueError(f"Unsupported tool_call_parser: {tool_call_parser}")
472
+ else:
473
+ raise ValueError("Tool Call Parser Not Given!")
474
+
475
+ self.multi_format_parser = MultiFormatParser(detectors)
476
+ self.tools = tools
477
+
478
+ def parse_non_stream(self, full_text: str):
479
+ """
480
+ Non-streaming call: one-time parsing
481
+ """
482
+ full_normal_text, calls = self.multi_format_parser.parse_once(
483
+ full_text, self.tools
484
+ )
485
+ return full_normal_text, calls
486
+
487
+ def parse_stream_chunk(self, chunk_text: str):
488
+ """
489
+ Streaming call: incremental parsing
490
+ """
491
+ normal_text, calls = self.multi_format_parser.parse_streaming_increment(
492
+ chunk_text, self.tools
493
+ )
494
+ return normal_text, calls
@@ -20,10 +20,10 @@ import torch
20
20
  import torch.nn as nn
21
21
  import torch.nn.functional as F
22
22
 
23
- from sglang.srt.utils import is_flashinfer_available
23
+ from sglang.srt.utils import is_cuda_available
24
24
 
25
- if is_flashinfer_available():
26
- from flashinfer.activation import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
25
+ if is_cuda_available():
26
+ from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
27
27
 
28
28
  from vllm.model_executor.custom_op import CustomOp
29
29
 
@@ -149,8 +149,8 @@ def get_act_fn(
149
149
  return act_fn
150
150
 
151
151
 
152
- if not is_flashinfer_available():
152
+ if not is_cuda_available():
153
153
  logger.info(
154
- "FlashInfer is not available on Non-NV platforms. Fallback to other kernel libraries."
154
+ "sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
155
155
  )
156
156
  from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul
@@ -22,6 +22,8 @@ def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_si
22
22
  def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
23
23
  global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE
24
24
 
25
+ from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP
26
+
25
27
  _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK = compute_dp_attention_world_info(
26
28
  enable_dp_attention, tp_rank, tp_size, dp_size
27
29
  )
@@ -35,7 +37,7 @@ def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
35
37
  ],
36
38
  tp_rank,
37
39
  torch.distributed.get_backend(tp_group.device_group),
38
- False,
40
+ SYNC_TOKEN_IDS_ACROSS_TP,
39
41
  False,
40
42
  False,
41
43
  False,
@@ -19,10 +19,10 @@ from typing import Optional, Tuple, Union
19
19
  import torch
20
20
  import torch.nn as nn
21
21
 
22
- from sglang.srt.utils import is_flashinfer_available
22
+ from sglang.srt.utils import is_cuda_available
23
23
 
24
- if is_flashinfer_available():
25
- from flashinfer.norm import (
24
+ if is_cuda_available():
25
+ from sgl_kernel import (
26
26
  fused_add_rmsnorm,
27
27
  gemma_fused_add_rmsnorm,
28
28
  gemma_rmsnorm,
@@ -121,8 +121,8 @@ class GemmaRMSNorm(CustomOp):
121
121
  return out
122
122
 
123
123
 
124
- if not is_flashinfer_available():
124
+ if not is_cuda_available():
125
125
  logger.info(
126
- "FlashInfer is not available on Non-NV platforms. Fallback to other kernel libraries."
126
+ "sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
127
127
  )
128
128
  from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
@@ -329,12 +329,14 @@ class ColumnParallelLinear(LinearBase):
329
329
  prefix: str = "",
330
330
  tp_rank: Optional[int] = None,
331
331
  tp_size: Optional[int] = None,
332
+ use_presharded_weights: bool = False,
332
333
  ):
333
334
  super().__init__(
334
335
  input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix
335
336
  )
336
337
 
337
338
  self.gather_output = gather_output
339
+ self.use_presharded_weights = use_presharded_weights
338
340
 
339
341
  # Divide the weight matrix along the last dimension.
340
342
  if tp_rank is None:
@@ -402,7 +404,8 @@ class ColumnParallelLinear(LinearBase):
402
404
  if output_dim is not None and not use_bitsandbytes_4bit:
403
405
  shard_size = param_data.shape[output_dim]
404
406
  start_idx = self.tp_rank * shard_size
405
- loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
407
+ if not self.use_presharded_weights:
408
+ loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
406
409
 
407
410
  # Special case for loading scales off disk, which often do not
408
411
  # have a shape (such as in the case of AutoFP8).
@@ -418,7 +421,11 @@ class ColumnParallelLinear(LinearBase):
418
421
  if len(loaded_weight.shape) == 0:
419
422
  assert loaded_weight.numel() == 1
420
423
  loaded_weight = loaded_weight.reshape(1)
421
- param.load_column_parallel_weight(loaded_weight, tp_rank=self.tp_rank)
424
+ param.load_column_parallel_weight(
425
+ loaded_weight,
426
+ tp_rank=self.tp_rank,
427
+ use_presharded_weights=self.use_presharded_weights,
428
+ )
422
429
 
423
430
  def forward(self, input_):
424
431
  bias = self.bias if not self.skip_bias_add else None
@@ -499,7 +506,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
499
506
  prefix=prefix,
500
507
  tp_rank=tp_rank,
501
508
  tp_size=tp_size,
509
+ use_presharded_weights=use_presharded_weights,
502
510
  )
511
+ self.prefix = prefix
503
512
 
504
513
  def weight_loader(
505
514
  self,
@@ -743,6 +752,7 @@ class QKVParallelLinear(ColumnParallelLinear):
743
752
  prefix: str = "",
744
753
  tp_rank: Optional[int] = None,
745
754
  tp_size: Optional[int] = None,
755
+ load_presharded_attn: bool = False,
746
756
  ):
747
757
  self.hidden_size = hidden_size
748
758
  self.head_size = head_size
@@ -772,6 +782,7 @@ class QKVParallelLinear(ColumnParallelLinear):
772
782
  self.num_kv_heads * self.head_size * tp_size, # k_proj
773
783
  self.num_kv_heads * self.head_size * tp_size, # v_proj
774
784
  ]
785
+ self.use_presharded_weights = load_presharded_attn
775
786
 
776
787
  super().__init__(
777
788
  input_size=input_size,
@@ -784,6 +795,7 @@ class QKVParallelLinear(ColumnParallelLinear):
784
795
  prefix=prefix,
785
796
  tp_rank=tp_rank,
786
797
  tp_size=tp_size,
798
+ use_presharded_weights=self.use_presharded_weights,
787
799
  )
788
800
 
789
801
  def _get_shard_offset_mapping(self, loaded_shard_id: str):
@@ -842,9 +854,10 @@ class QKVParallelLinear(ColumnParallelLinear):
842
854
  shard_size=shard_size, shard_offset=shard_offset
843
855
  )
844
856
 
845
- loaded_weight_shard = loaded_weight.narrow(
846
- param.output_dim, shard_offset, shard_size
847
- )
857
+ if not self.use_presharded_weights:
858
+ loaded_weight_shard = loaded_weight.narrow(
859
+ param.output_dim, shard_offset, shard_size
860
+ )
848
861
  self.weight_loader_v2(param, loaded_weight_shard, shard_id)
849
862
 
850
863
  def weight_loader_v2(
@@ -882,6 +895,7 @@ class QKVParallelLinear(ColumnParallelLinear):
882
895
  shard_offset=shard_offset,
883
896
  shard_size=shard_size,
884
897
  tp_rank=self.tp_rank,
898
+ use_presharded_weights=self.use_presharded_weights,
885
899
  )
886
900
 
887
901
  def weight_loader(
@@ -987,9 +1001,10 @@ class QKVParallelLinear(ColumnParallelLinear):
987
1001
  param, orig_qkv_offsets, shard_id
988
1002
  )
989
1003
 
990
- loaded_weight_shard = loaded_weight.narrow(
991
- output_dim, shard_offset, shard_size
992
- )
1004
+ if not self.use_presharded_weights:
1005
+ loaded_weight_shard = loaded_weight.narrow(
1006
+ output_dim, shard_offset, shard_size
1007
+ )
993
1008
  self.weight_loader(param, loaded_weight_shard, shard_id)
994
1009
  return
995
1010
 
@@ -1049,7 +1064,7 @@ class QKVParallelLinear(ColumnParallelLinear):
1049
1064
 
1050
1065
  # bitsandbytes loads the weights of the specific portion
1051
1066
  # no need to narrow here
1052
- if not use_bitsandbytes_4bit:
1067
+ if not use_bitsandbytes_4bit and not self.use_presharded_weights:
1053
1068
  loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
1054
1069
 
1055
1070
  # Special case for for AQLM codebooks.
@@ -296,7 +296,7 @@ def fused_softcap_kernel(
296
296
  n_elements,
297
297
  BLOCK_SIZE: tl.constexpr,
298
298
  ):
299
- pid = tl.program_id(0)
299
+ pid = tl.program_id(0).to(tl.int64)
300
300
  block_start = pid * BLOCK_SIZE
301
301
  offsets = block_start + tl.arange(0, BLOCK_SIZE)
302
302
  mask = offsets < n_elements