sglang 0.4.1.post6__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. sglang/__init__.py +21 -23
  2. sglang/api.py +2 -7
  3. sglang/bench_offline_throughput.py +41 -27
  4. sglang/bench_one_batch.py +60 -4
  5. sglang/bench_one_batch_server.py +1 -1
  6. sglang/bench_serving.py +83 -71
  7. sglang/lang/backend/runtime_endpoint.py +183 -4
  8. sglang/lang/chat_template.py +46 -4
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/_custom_ops.py +80 -42
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/load_config.py +1 -0
  13. sglang/srt/configs/model_config.py +1 -0
  14. sglang/srt/constrained/base_grammar_backend.py +21 -0
  15. sglang/srt/constrained/xgrammar_backend.py +8 -4
  16. sglang/srt/conversation.py +14 -1
  17. sglang/srt/distributed/__init__.py +3 -3
  18. sglang/srt/distributed/communication_op.py +2 -1
  19. sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
  20. sglang/srt/distributed/device_communicators/custom_all_reduce.py +112 -42
  21. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  22. sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
  23. sglang/srt/distributed/device_communicators/pynccl.py +80 -1
  24. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
  25. sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
  26. sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
  27. sglang/srt/distributed/parallel_state.py +1 -1
  28. sglang/srt/distributed/utils.py +2 -1
  29. sglang/srt/entrypoints/engine.py +452 -0
  30. sglang/srt/entrypoints/http_server.py +603 -0
  31. sglang/srt/function_call_parser.py +494 -0
  32. sglang/srt/layers/activation.py +8 -8
  33. sglang/srt/layers/attention/flashinfer_backend.py +10 -9
  34. sglang/srt/layers/attention/triton_backend.py +4 -6
  35. sglang/srt/layers/attention/vision.py +204 -0
  36. sglang/srt/layers/dp_attention.py +71 -0
  37. sglang/srt/layers/layernorm.py +5 -5
  38. sglang/srt/layers/linear.py +65 -14
  39. sglang/srt/layers/logits_processor.py +49 -64
  40. sglang/srt/layers/moe/ep_moe/layer.py +24 -16
  41. sglang/srt/layers/moe/fused_moe_native.py +84 -1
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  43. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +27 -7
  44. sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -5
  45. sglang/srt/layers/parameter.py +18 -8
  46. sglang/srt/layers/quantization/__init__.py +20 -23
  47. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  48. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  49. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  50. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  51. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  52. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  53. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  54. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  55. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  56. sglang/srt/layers/quantization/fp8.py +10 -4
  57. sglang/srt/layers/quantization/modelopt_quant.py +1 -2
  58. sglang/srt/layers/quantization/w8a8_int8.py +1 -1
  59. sglang/srt/layers/radix_attention.py +2 -2
  60. sglang/srt/layers/rotary_embedding.py +1184 -31
  61. sglang/srt/layers/sampler.py +64 -6
  62. sglang/srt/layers/torchao_utils.py +12 -6
  63. sglang/srt/layers/vocab_parallel_embedding.py +2 -2
  64. sglang/srt/lora/lora.py +1 -9
  65. sglang/srt/managers/configure_logging.py +3 -0
  66. sglang/srt/managers/data_parallel_controller.py +79 -72
  67. sglang/srt/managers/detokenizer_manager.py +24 -6
  68. sglang/srt/managers/image_processor.py +158 -2
  69. sglang/srt/managers/io_struct.py +57 -3
  70. sglang/srt/managers/schedule_batch.py +78 -45
  71. sglang/srt/managers/schedule_policy.py +26 -12
  72. sglang/srt/managers/scheduler.py +326 -201
  73. sglang/srt/managers/session_controller.py +1 -0
  74. sglang/srt/managers/tokenizer_manager.py +210 -121
  75. sglang/srt/managers/tp_worker.py +6 -4
  76. sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
  77. sglang/srt/managers/utils.py +44 -0
  78. sglang/srt/mem_cache/memory_pool.py +10 -32
  79. sglang/srt/metrics/collector.py +15 -6
  80. sglang/srt/model_executor/cuda_graph_runner.py +26 -30
  81. sglang/srt/model_executor/forward_batch_info.py +5 -7
  82. sglang/srt/model_executor/model_runner.py +44 -19
  83. sglang/srt/model_loader/loader.py +83 -6
  84. sglang/srt/model_loader/weight_utils.py +145 -6
  85. sglang/srt/models/baichuan.py +6 -6
  86. sglang/srt/models/chatglm.py +2 -2
  87. sglang/srt/models/commandr.py +17 -5
  88. sglang/srt/models/dbrx.py +13 -5
  89. sglang/srt/models/deepseek.py +3 -3
  90. sglang/srt/models/deepseek_v2.py +11 -11
  91. sglang/srt/models/exaone.py +2 -2
  92. sglang/srt/models/gemma.py +2 -2
  93. sglang/srt/models/gemma2.py +15 -25
  94. sglang/srt/models/gpt2.py +3 -5
  95. sglang/srt/models/gpt_bigcode.py +1 -1
  96. sglang/srt/models/granite.py +2 -2
  97. sglang/srt/models/grok.py +4 -3
  98. sglang/srt/models/internlm2.py +2 -2
  99. sglang/srt/models/llama.py +7 -5
  100. sglang/srt/models/minicpm.py +2 -2
  101. sglang/srt/models/minicpm3.py +9 -9
  102. sglang/srt/models/minicpmv.py +1238 -0
  103. sglang/srt/models/mixtral.py +3 -3
  104. sglang/srt/models/mixtral_quant.py +3 -3
  105. sglang/srt/models/mllama.py +2 -2
  106. sglang/srt/models/olmo.py +3 -3
  107. sglang/srt/models/olmo2.py +4 -4
  108. sglang/srt/models/olmoe.py +7 -13
  109. sglang/srt/models/phi3_small.py +2 -2
  110. sglang/srt/models/qwen.py +2 -2
  111. sglang/srt/models/qwen2.py +41 -4
  112. sglang/srt/models/qwen2_moe.py +3 -3
  113. sglang/srt/models/qwen2_vl.py +22 -122
  114. sglang/srt/models/stablelm.py +2 -2
  115. sglang/srt/models/torch_native_llama.py +20 -7
  116. sglang/srt/models/xverse.py +6 -6
  117. sglang/srt/models/xverse_moe.py +6 -6
  118. sglang/srt/openai_api/adapter.py +139 -37
  119. sglang/srt/openai_api/protocol.py +7 -4
  120. sglang/srt/sampling/custom_logit_processor.py +38 -0
  121. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +11 -14
  122. sglang/srt/sampling/sampling_batch_info.py +143 -18
  123. sglang/srt/sampling/sampling_params.py +3 -1
  124. sglang/srt/server.py +4 -1090
  125. sglang/srt/server_args.py +77 -15
  126. sglang/srt/speculative/eagle_utils.py +37 -15
  127. sglang/srt/speculative/eagle_worker.py +11 -13
  128. sglang/srt/utils.py +164 -129
  129. sglang/test/runners.py +8 -13
  130. sglang/test/test_programs.py +2 -1
  131. sglang/test/test_utils.py +83 -22
  132. sglang/utils.py +12 -2
  133. sglang/version.py +1 -1
  134. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/METADATA +21 -10
  135. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/RECORD +138 -123
  136. sglang/launch_server_llavavid.py +0 -25
  137. sglang/srt/constrained/__init__.py +0 -16
  138. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  139. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/LICENSE +0 -0
  140. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/WHEEL +0 -0
  141. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,494 @@
1
+ import json
2
+ import re
3
+ from abc import ABC, abstractmethod
4
+ from json import JSONDecodeError, JSONDecoder
5
+ from typing import Any, Dict, List, Optional, Tuple
6
+
7
+ import partial_json_parser
8
+ from partial_json_parser.core.options import Allow
9
+ from pydantic import BaseModel, Field
10
+
11
+ TOOLS_TAG_LIST = [
12
+ "<|plugin|>",
13
+ "<function=",
14
+ "<tool_call>",
15
+ "<|python_tag|>",
16
+ "[TOOL_CALLS]",
17
+ ]
18
+
19
+
20
+ class Function(BaseModel):
21
+ """Function Tool Template."""
22
+
23
+ description: Optional[str] = Field(default=None, examples=[None])
24
+ name: Optional[str] = None
25
+ parameters: Optional[object] = None
26
+
27
+
28
+ class ToolCallItem(BaseModel):
29
+ """Simple encapsulation of the parsed ToolCall result for easier usage in streaming contexts."""
30
+
31
+ tool_index: int
32
+ name: Optional[str] = None
33
+ parameters: str # JSON string
34
+
35
+
36
+ def _find_common_prefix(s1: str, s2: str) -> str:
37
+ prefix = ""
38
+ min_length = min(len(s1), len(s2))
39
+ for i in range(0, min_length):
40
+ if s1[i] == s2[i]:
41
+ prefix += s1[i]
42
+ else:
43
+ break
44
+ return prefix
45
+
46
+
47
+ def _partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]:
48
+ try:
49
+ return (partial_json_parser.loads(input_str, flags), len(input_str))
50
+ except JSONDecodeError as e:
51
+ if "Extra data" in e.msg:
52
+ dec = JSONDecoder()
53
+ return dec.raw_decode(input_str)
54
+ raise
55
+
56
+
57
+ def _is_complete_json(input_str: str) -> bool:
58
+ try:
59
+ json.loads(input_str)
60
+ return True
61
+ except JSONDecodeError:
62
+ return False
63
+
64
+
65
+ class StreamingParseResult:
66
+ """Result of streaming incremental parsing."""
67
+
68
+ def __init__(
69
+ self, normal_text: str = "", calls: Optional[List[ToolCallItem]] = None
70
+ ):
71
+ self.normal_text = normal_text
72
+ self.calls = calls or []
73
+
74
+
75
+ class BaseFormatDetector:
76
+ """Base class providing two sets of interfaces: one-time and streaming incremental."""
77
+
78
+ def __init__(self):
79
+ # initialize properties used for state when parsing tool calls in
80
+ self._buffer = ""
81
+ # streaming mode
82
+ self.prev_tool_call_arr: List[Dict] = []
83
+ self.current_tool_id: int = -1
84
+ self.current_tool_name_sent: bool = False
85
+ self.streamed_args_for_tool: List[str] = (
86
+ []
87
+ ) # map what has been streamed for each tool so far to a list
88
+ self.bot_token = ""
89
+ self.eot_token = ""
90
+
91
+ def parse_base_json(self, action: Dict, tools: List[Function]):
92
+ name, parameters = action["name"], json.dumps(
93
+ action.get("parameters", action.get("arguments", {})),
94
+ ensure_ascii=False,
95
+ )
96
+ tool_index = [tool.function.name for tool in tools].index(name)
97
+ tool_call_item = ToolCallItem(
98
+ tool_index=tool_index, name=name, parameters=parameters
99
+ )
100
+ calls = [tool_call_item]
101
+ return calls
102
+
103
+ def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
104
+ """
105
+ Parses the text in one go. Returns success=True if the format matches, otherwise False.
106
+ Note that leftover_text here represents "content that this parser will not consume further".
107
+ """
108
+ action = json.loads(text)
109
+ return self.parse_base_json(action, tools)
110
+
111
+ def parse_streaming_increment(
112
+ self, new_text: str, tools: List[Function]
113
+ ) -> StreamingParseResult:
114
+ """
115
+ Streaming incremental parsing, referencing the logic of Llama32Detector.
116
+ We partially parse JSON within <tool_call>...</tool_call>, and handle
117
+ incremental argument output.
118
+ """
119
+ # Append new text to buffer
120
+ self._buffer += new_text
121
+ current_text = self._buffer
122
+ if not (self.bot_token in current_text or current_text.startswith("{")):
123
+ self._buffer = ""
124
+ if self.eot_token in new_text:
125
+ new_text = new_text.replace(self.eot_token, "")
126
+ return StreamingParseResult(normal_text=new_text)
127
+
128
+ # bit mask flags for partial JSON parsing. If the name hasn't been
129
+ # sent yet, don't allow sending
130
+ # an incomplete string since OpenAI only ever (as far as I have
131
+ # seen) allows sending the entire tool/ function name at once.
132
+ flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
133
+ try:
134
+ tool_call_arr = []
135
+ is_complete = []
136
+ try:
137
+ # depending on the prompt format the Llama model may or may not
138
+ # prefix the output with the <|python_tag|> token
139
+ start_idx = (
140
+ len(self.bot_token)
141
+ if current_text.startswith(self.bot_token)
142
+ else 0
143
+ )
144
+ while start_idx < len(current_text):
145
+ (obj, end_idx) = _partial_json_loads(
146
+ current_text[start_idx:], flags
147
+ )
148
+ is_complete.append(
149
+ _is_complete_json(current_text[start_idx : start_idx + end_idx])
150
+ )
151
+ start_idx += end_idx + len("; ")
152
+ # depending on the prompt Llama can use
153
+ # either arguments or parameters
154
+ if "parameters" in obj:
155
+ assert (
156
+ "arguments" not in obj
157
+ ), "model generated both parameters and arguments"
158
+ obj["arguments"] = obj["parameters"]
159
+ tool_call_arr.append(obj)
160
+
161
+ except partial_json_parser.core.exceptions.MalformedJSON:
162
+ # not enough tokens to parse into JSON yet
163
+ return StreamingParseResult()
164
+
165
+ # select as the current tool call the one we're on the state at
166
+ current_tool_call: Dict = (
167
+ tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
168
+ )
169
+
170
+ # case -- if no tokens have been streamed for the tool, e.g.
171
+ # only the array brackets, stream nothing
172
+ if len(tool_call_arr) == 0:
173
+ return StreamingParseResult()
174
+
175
+ # case: we are starting a new tool in the array
176
+ # -> array has > 0 length AND length has moved past cursor
177
+ elif (
178
+ len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1
179
+ ):
180
+
181
+ # if we're moving on to a new call, first make sure we
182
+ # haven't missed anything in the previous one that was
183
+ # auto-generated due to JSON completions, but wasn't
184
+ # streamed to the client yet.
185
+ if self.current_tool_id >= 0:
186
+ cur_arguments = current_tool_call.get("arguments")
187
+ if cur_arguments:
188
+ cur_args_json = json.dumps(cur_arguments)
189
+ sent = len(self.streamed_args_for_tool[self.current_tool_id])
190
+ argument_diff = cur_args_json[sent:]
191
+
192
+ res = StreamingParseResult(
193
+ normal_text=None,
194
+ calls=[
195
+ ToolCallItem(
196
+ tool_index=self.current_tool_id,
197
+ name="",
198
+ parameters=argument_diff,
199
+ )
200
+ ],
201
+ )
202
+ self.streamed_args_for_tool[
203
+ self.current_tool_id
204
+ ] += argument_diff
205
+ else:
206
+ res = StreamingParseResult()
207
+ else:
208
+ res = StreamingParseResult()
209
+ # re-set stuff pertaining to progress in the current tool
210
+ self.current_tool_id = len(tool_call_arr) - 1
211
+ self.current_tool_name_sent = False
212
+ self.streamed_args_for_tool.append("")
213
+ print("starting on new tool %d", self.current_tool_id)
214
+ return res
215
+
216
+ # if the current tool name hasn't been sent, send if available
217
+ # - otherwise send nothing
218
+ elif not self.current_tool_name_sent:
219
+ function_name = current_tool_call.get("name")
220
+ if function_name:
221
+ res = StreamingParseResult(
222
+ normal_text=None,
223
+ calls=[
224
+ ToolCallItem(
225
+ tool_index=self.current_tool_id,
226
+ name=function_name,
227
+ parameters="",
228
+ )
229
+ ],
230
+ )
231
+ self.current_tool_name_sent = True
232
+ else:
233
+ res = StreamingParseResult()
234
+
235
+ # now we know we're on the same tool call and we're streaming
236
+ # arguments
237
+ else:
238
+ cur_arguments = current_tool_call.get("arguments")
239
+ res = StreamingParseResult()
240
+
241
+ if cur_arguments:
242
+ sent = len(self.streamed_args_for_tool[self.current_tool_id])
243
+ cur_args_json = json.dumps(cur_arguments)
244
+ prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
245
+ "arguments"
246
+ )
247
+
248
+ argument_diff = None
249
+ if is_complete[self.current_tool_id]:
250
+ argument_diff = cur_args_json[sent:]
251
+ self._buffer = ""
252
+ self.prev_tool_call_arr[self.current_tool_id].clear()
253
+ self.current_tool_name_sent: bool = False
254
+ self.streamed_args_for_tool[self.current_tool_id] = ""
255
+
256
+ elif prev_arguments:
257
+ prev_args_json = json.dumps(prev_arguments)
258
+ if cur_args_json != prev_args_json:
259
+
260
+ prefix = _find_common_prefix(prev_args_json, cur_args_json)
261
+ argument_diff = prefix[sent:]
262
+
263
+ if argument_diff is not None:
264
+ res = StreamingParseResult(
265
+ calls=[
266
+ ToolCallItem(
267
+ tool_index=self.current_tool_id,
268
+ name="",
269
+ parameters=argument_diff,
270
+ )
271
+ ],
272
+ )
273
+ if not is_complete[self.current_tool_id]:
274
+ self.streamed_args_for_tool[
275
+ self.current_tool_id
276
+ ] += argument_diff
277
+
278
+ self.prev_tool_call_arr = tool_call_arr
279
+ return res
280
+
281
+ except Exception as e:
282
+ print(e)
283
+ # Skipping chunk as a result of tool streaming extraction error
284
+ return StreamingParseResult()
285
+
286
+
287
+ class Qwen25Detector(BaseFormatDetector):
288
+ """
289
+ Detector for Qwen 2.5 models.
290
+ Assumes function call format:
291
+ <tool_call>{"name":"xxx", "arguments":{...}}</tool_call>
292
+ """
293
+
294
+ def __init__(self):
295
+ """
296
+ Initializes the detector with necessary state variables.
297
+ """
298
+ super().__init__()
299
+ self.bot_token = "<tool_call>"
300
+ self.eot_token = "</tool_call>"
301
+
302
+ def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
303
+ """
304
+ One-time parsing: Detects and parses tool calls in the provided text.
305
+
306
+ :param text: The complete text to parse.
307
+ :param tools: List of available tools.
308
+ :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
309
+ """
310
+ if "<tool_call>" not in text:
311
+ return []
312
+ pattern = r"<tool_call>(.*?)</tool_call>"
313
+ match_result_list = re.findall(pattern, text, re.DOTALL)
314
+ calls = []
315
+ for match_result in match_result_list:
316
+ match_result = json.loads(match_result)
317
+ calls.extend(self.parse_base_json(match_result, tools))
318
+ return calls
319
+
320
+
321
+ class MistralDetector(BaseFormatDetector):
322
+ """
323
+ Detector for Mistral models.
324
+ Assumes function call format:
325
+ <|action_start|><|plugin|>{"name":"xxx", "arguments":{...}}<|action_end|>
326
+ """
327
+
328
+ def __init__(self):
329
+ """
330
+ Initializes the detector with necessary state variables.
331
+ """
332
+ super().__init__()
333
+ self.bot_token = "[TOOL_CALLS] ["
334
+ self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
335
+
336
+ def _clean_text(self, text: str) -> str:
337
+ """
338
+ clean text to only leave ''[TOOL_CALLS] [{"name": xxx, "arguments": {xxx}}]'
339
+ for example,
340
+ text = '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]\n\nToday\'s weather in Boston is :{function call result} (in Fahrenheit)\n\nIf you prefer Celsius, please let me know.'
341
+ return '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]'
342
+ The key pattern is [TOOL_CALLS] [...]
343
+ """
344
+ find_results = re.findall(r"\[TOOL_CALLS\] \[.*?\]", text, re.DOTALL)
345
+ if len(find_results) > 0:
346
+ return find_results[0]
347
+ else:
348
+ return ""
349
+
350
+ def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
351
+ """
352
+ One-time parsing: Detects and parses tool calls in the provided text.
353
+
354
+ :param text: The complete text to parse.
355
+ :param tools: List of available tools.
356
+ :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
357
+ """
358
+ text = self._clean_text(text)
359
+ tool_content = text.replace("[TOOL_CALLS]", "").strip()
360
+ raw_tool_calls = self.tool_call_regex.findall(tool_content)
361
+ calls = []
362
+ if len(raw_tool_calls) > 0:
363
+ raw_tool_call = raw_tool_calls[0]
364
+ function_call_arr = json.loads(raw_tool_call)
365
+ for match_result in function_call_arr:
366
+ calls.extend(self.parse_base_json(match_result, tools))
367
+ return calls
368
+
369
+
370
+ class Llama32Detector(BaseFormatDetector):
371
+ """
372
+ Detector for Llama 3.2 models.
373
+ Assumes function call format:
374
+ <|python_tag|>{"name":"xxx", "arguments":{...}}
375
+ Does not require a closing tag "</python_tag|>",
376
+ relies on json.loads(...) success to determine if JSON is complete.
377
+ """
378
+
379
+ def __init__(self):
380
+ """
381
+ Initializes the detector with necessary state variables.
382
+ """
383
+ super().__init__()
384
+ self.bot_token = "<|python_tag|>"
385
+
386
+ def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
387
+ """
388
+ One-time parsing: Detects and parses tool calls in the provided text.
389
+
390
+ :param text: The complete text to parse.
391
+ :param tools: List of available tools.
392
+ :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
393
+ """
394
+
395
+ if "<|python_tag|>" not in text:
396
+ return []
397
+ _, action = text.split("<|python_tag|>")
398
+ action = json.loads(action)
399
+ return self.parse_base_json(action, tools)
400
+
401
+
402
+ class MultiFormatParser:
403
+ def __init__(self, detectors: List[BaseFormatDetector]):
404
+ """
405
+ :param detectors: A series of available Detector instances passed in
406
+ """
407
+ self.detectors = detectors
408
+
409
+ def parse_once(self, text: str, tools: List[Function]):
410
+ """
411
+ One-time parsing: Loop through detectors until there are no new matches or text is exhausted
412
+ Return: (final_text, all_calls)
413
+ - final_text: The remaining text after parsing that was not consumed by any Detector (can be treated as normal text)
414
+ - all_calls: All calls parsed by the Detectors
415
+ """
416
+ final_calls = []
417
+ final_normal_text = text
418
+ for detector in self.detectors:
419
+ tool_call_list = detector.detect_and_parse(text, tools)
420
+ if len(tool_call_list) > 0: # parsed successfully
421
+ final_calls = tool_call_list
422
+ break
423
+
424
+ # leftover_text is the normal text not consumed by any Detector
425
+ return final_normal_text, final_calls
426
+
427
+ def parse_streaming_increment(self, new_text: str, tools: List[Function]):
428
+ """
429
+ Streaming incremental parsing: Feed new_text to each detector's parse_streaming_increment
430
+ and merge their produced normal_text/calls to return.
431
+ (The logic here can be "priority-based" or "parallel parsing" based on your needs)
432
+ """
433
+ final_normal_text = ""
434
+ final_calls = []
435
+
436
+ for detector in self.detectors:
437
+ sp_result = detector.parse_streaming_increment(new_text, tools)
438
+ # Merge normal_text and calls
439
+ # If one sp_result contains result call, this should be a successful parse
440
+ # If one sp_result only contains normal_text, this can either be a successful
441
+ # parse or it is not using the desired parsing tool.
442
+ if sp_result.normal_text:
443
+ final_normal_text = sp_result.normal_text
444
+ if sp_result.calls:
445
+ final_calls.extend(sp_result.calls)
446
+ final_normal_text = sp_result.normal_text
447
+ break
448
+
449
+ return final_normal_text, final_calls
450
+
451
+
452
+ class FunctionCallParser:
453
+ """
454
+ In streaming scenarios, each time new_text is received, it calls multi_format_parser.parse_streaming_increment
455
+ and returns the resulting normal_text and calls to the upper layer (or SSE).
456
+ """
457
+
458
+ ToolCallParserEnum: Dict[str, BaseFormatDetector] = {
459
+ "llama3": Llama32Detector,
460
+ "qwen25": Qwen25Detector,
461
+ "mistral": MistralDetector,
462
+ }
463
+
464
+ def __init__(self, tools: List[Function], tool_call_parser: str = None):
465
+ detectors = []
466
+ if tool_call_parser:
467
+ detector_class = self.ToolCallParserEnum.get(tool_call_parser)
468
+ if detector_class:
469
+ detectors.append(detector_class())
470
+ else:
471
+ raise ValueError(f"Unsupported tool_call_parser: {tool_call_parser}")
472
+ else:
473
+ raise ValueError("Tool Call Parser Not Given!")
474
+
475
+ self.multi_format_parser = MultiFormatParser(detectors)
476
+ self.tools = tools
477
+
478
+ def parse_non_stream(self, full_text: str):
479
+ """
480
+ Non-streaming call: one-time parsing
481
+ """
482
+ full_normal_text, calls = self.multi_format_parser.parse_once(
483
+ full_text, self.tools
484
+ )
485
+ return full_normal_text, calls
486
+
487
+ def parse_stream_chunk(self, chunk_text: str):
488
+ """
489
+ Streaming call: incremental parsing
490
+ """
491
+ normal_text, calls = self.multi_format_parser.parse_streaming_increment(
492
+ chunk_text, self.tools
493
+ )
494
+ return normal_text, calls
@@ -20,18 +20,18 @@ import torch
20
20
  import torch.nn as nn
21
21
  import torch.nn.functional as F
22
22
 
23
- from sglang.srt.utils import is_flashinfer_available
23
+ from sglang.srt.utils import is_cuda_available
24
24
 
25
- if is_flashinfer_available():
26
- from flashinfer.activation import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
25
+ if is_cuda_available():
26
+ from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
27
27
 
28
- from vllm.distributed import (
28
+ from vllm.model_executor.custom_op import CustomOp
29
+
30
+ from sglang.srt.distributed import (
29
31
  divide,
30
32
  get_tensor_model_parallel_rank,
31
33
  get_tensor_model_parallel_world_size,
32
34
  )
33
- from vllm.model_executor.custom_op import CustomOp
34
-
35
35
  from sglang.srt.layers.custom_op_util import register_custom_op
36
36
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
37
37
  from sglang.srt.utils import set_weight_attrs
@@ -149,8 +149,8 @@ def get_act_fn(
149
149
  return act_fn
150
150
 
151
151
 
152
- if not is_flashinfer_available():
152
+ if not is_cuda_available():
153
153
  logger.info(
154
- "FlashInfer is not available on Non-NV platforms. Fallback to other kernel libraries."
154
+ "sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
155
155
  )
156
156
  from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul
@@ -18,6 +18,7 @@ import triton.language as tl
18
18
 
19
19
  from sglang.global_config import global_config
20
20
  from sglang.srt.layers.attention import AttentionBackend
21
+ from sglang.srt.layers.dp_attention import get_attention_tp_size
21
22
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
22
23
  from sglang.srt.utils import is_flashinfer_available
23
24
 
@@ -62,9 +63,9 @@ class FlashInferAttnBackend(AttentionBackend):
62
63
  self.decode_use_tensor_cores = should_use_tensor_core(
63
64
  kv_cache_dtype=model_runner.kv_cache_dtype,
64
65
  num_attention_heads=model_runner.model_config.num_attention_heads
65
- // model_runner.tp_size,
66
+ // get_attention_tp_size(),
66
67
  num_kv_heads=model_runner.model_config.get_num_kv_heads(
67
- model_runner.tp_size
68
+ get_attention_tp_size()
68
69
  ),
69
70
  )
70
71
  self.max_context_len = model_runner.model_config.context_len
@@ -147,7 +148,7 @@ class FlashInferAttnBackend(AttentionBackend):
147
148
  self.prefill_cuda_graph_metadata = {}
148
149
 
149
150
  def init_forward_metadata(self, forward_batch: ForwardBatch):
150
- if forward_batch.forward_mode.is_decode():
151
+ if forward_batch.forward_mode.is_decode_or_idle():
151
152
  self.indices_updater_decode.update(
152
153
  forward_batch.req_pool_indices,
153
154
  forward_batch.seq_lens,
@@ -238,7 +239,7 @@ class FlashInferAttnBackend(AttentionBackend):
238
239
  forward_mode: ForwardMode,
239
240
  spec_info: Optional[SpecInfo],
240
241
  ):
241
- if forward_mode.is_decode():
242
+ if forward_mode.is_decode_or_idle():
242
243
  decode_wrappers = []
243
244
  for i in range(self.num_wrappers):
244
245
  decode_wrappers.append(
@@ -307,7 +308,7 @@ class FlashInferAttnBackend(AttentionBackend):
307
308
  forward_mode: ForwardMode,
308
309
  spec_info: Optional[SpecInfo],
309
310
  ):
310
- if forward_mode.is_decode():
311
+ if forward_mode.is_decode_or_idle():
311
312
  self.indices_updater_decode.update(
312
313
  req_pool_indices[:bs],
313
314
  seq_lens[:bs],
@@ -453,10 +454,10 @@ class FlashInferIndicesUpdaterDecode:
453
454
  def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
454
455
  # Parse Constants
455
456
  self.num_qo_heads = (
456
- model_runner.model_config.num_attention_heads // model_runner.tp_size
457
+ model_runner.model_config.num_attention_heads // get_attention_tp_size()
457
458
  )
458
459
  self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
459
- model_runner.tp_size
460
+ get_attention_tp_size()
460
461
  )
461
462
  self.head_dim = model_runner.model_config.head_dim
462
463
  self.data_type = model_runner.kv_cache_dtype
@@ -625,10 +626,10 @@ class FlashInferIndicesUpdaterPrefill:
625
626
  def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
626
627
  # Parse Constants
627
628
  self.num_qo_heads = (
628
- model_runner.model_config.num_attention_heads // model_runner.tp_size
629
+ model_runner.model_config.num_attention_heads // get_attention_tp_size()
629
630
  )
630
631
  self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
631
- model_runner.tp_size
632
+ get_attention_tp_size()
632
633
  )
633
634
  self.head_dim = model_runner.model_config.head_dim
634
635
  self.data_type = model_runner.kv_cache_dtype
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Optional
5
5
  import torch
6
6
 
7
7
  from sglang.srt.layers.attention import AttentionBackend
8
+ from sglang.srt.layers.dp_attention import get_attention_tp_size
8
9
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
9
10
 
10
11
  if TYPE_CHECKING:
@@ -28,12 +29,9 @@ class TritonAttnBackend(AttentionBackend):
28
29
  self.decode_attention_fwd = decode_attention_fwd
29
30
  self.extend_attention_fwd = extend_attention_fwd
30
31
 
31
- if model_runner.server_args.enable_dp_attention:
32
- self.num_head = model_runner.model_config.num_attention_heads
33
- else:
34
- self.num_head = (
35
- model_runner.model_config.num_attention_heads // model_runner.tp_size
36
- )
32
+ self.num_head = (
33
+ model_runner.model_config.num_attention_heads // get_attention_tp_size()
34
+ )
37
35
 
38
36
  self.num_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
39
37
  self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]