sglang 0.4.9__py3-none-any.whl → 0.4.9.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +2 -2
- sglang/srt/configs/model_config.py +36 -2
- sglang/srt/conversation.py +56 -3
- sglang/srt/disaggregation/ascend/__init__.py +6 -0
- sglang/srt/disaggregation/ascend/conn.py +44 -0
- sglang/srt/disaggregation/ascend/transfer_engine.py +58 -0
- sglang/srt/disaggregation/mooncake/conn.py +50 -18
- sglang/srt/disaggregation/mooncake/transfer_engine.py +17 -8
- sglang/srt/disaggregation/utils.py +25 -3
- sglang/srt/entrypoints/engine.py +1 -1
- sglang/srt/entrypoints/http_server.py +1 -0
- sglang/srt/entrypoints/http_server_engine.py +1 -1
- sglang/srt/entrypoints/openai/protocol.py +11 -0
- sglang/srt/entrypoints/openai/serving_chat.py +7 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/kimik2_detector.py +220 -0
- sglang/srt/hf_transformers_utils.py +18 -0
- sglang/srt/jinja_template_utils.py +8 -0
- sglang/srt/layers/communicator.py +20 -5
- sglang/srt/layers/flashinfer_comm_fusion.py +3 -3
- sglang/srt/layers/layernorm.py +2 -2
- sglang/srt/layers/linear.py +12 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +60 -1
- sglang/srt/layers/moe/ep_moe/layer.py +141 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +141 -59
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
- sglang/srt/layers/moe/topk.py +8 -2
- sglang/srt/layers/parameter.py +19 -3
- sglang/srt/layers/quantization/__init__.py +2 -0
- sglang/srt/layers/quantization/fp8.py +28 -7
- sglang/srt/layers/quantization/fp8_kernel.py +2 -2
- sglang/srt/layers/quantization/modelopt_quant.py +244 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -2
- sglang/srt/layers/quantization/w4afp8.py +264 -0
- sglang/srt/layers/quantization/w8a8_int8.py +738 -14
- sglang/srt/layers/vocab_parallel_embedding.py +9 -3
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
- sglang/srt/managers/cache_controller.py +41 -195
- sglang/srt/managers/io_struct.py +35 -3
- sglang/srt/managers/mm_utils.py +59 -96
- sglang/srt/managers/schedule_batch.py +17 -6
- sglang/srt/managers/scheduler.py +38 -6
- sglang/srt/managers/tokenizer_manager.py +16 -0
- sglang/srt/mem_cache/hiradix_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +176 -101
- sglang/srt/mem_cache/memory_pool_host.py +6 -109
- sglang/srt/mem_cache/radix_cache.py +8 -4
- sglang/srt/model_executor/forward_batch_info.py +13 -1
- sglang/srt/model_loader/loader.py +23 -12
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +78 -19
- sglang/srt/models/deepseek_vl2.py +1 -1
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +6 -3
- sglang/srt/models/internvl.py +8 -2
- sglang/srt/models/kimi_vl.py +8 -2
- sglang/srt/models/llama.py +2 -0
- sglang/srt/models/llava.py +3 -1
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpmo.py +1 -2
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mixtral_quant.py +4 -0
- sglang/srt/models/mllama4.py +372 -82
- sglang/srt/models/phi4mm.py +8 -2
- sglang/srt/models/phimoe.py +553 -0
- sglang/srt/models/qwen2.py +2 -0
- sglang/srt/models/qwen2_5_vl.py +10 -7
- sglang/srt/models/qwen2_vl.py +12 -1
- sglang/srt/models/vila.py +8 -2
- sglang/srt/multimodal/mm_utils.py +2 -2
- sglang/srt/multimodal/processors/base_processor.py +197 -137
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +1 -1
- sglang/srt/multimodal/processors/gemma3.py +4 -2
- sglang/srt/multimodal/processors/gemma3n.py +1 -1
- sglang/srt/multimodal/processors/internvl.py +1 -1
- sglang/srt/multimodal/processors/janus_pro.py +1 -1
- sglang/srt/multimodal/processors/kimi_vl.py +1 -1
- sglang/srt/multimodal/processors/minicpm.py +4 -3
- sglang/srt/multimodal/processors/mllama4.py +63 -61
- sglang/srt/multimodal/processors/phi4mm.py +1 -1
- sglang/srt/multimodal/processors/pixtral.py +1 -1
- sglang/srt/multimodal/processors/qwen_vl.py +203 -80
- sglang/srt/multimodal/processors/vila.py +1 -1
- sglang/srt/server_args.py +26 -4
- sglang/srt/two_batch_overlap.py +3 -0
- sglang/srt/utils.py +191 -48
- sglang/test/test_cutlass_w4a8_moe.py +281 -0
- sglang/utils.py +5 -5
- sglang/version.py +1 -1
- {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/METADATA +6 -4
- {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/RECORD +99 -90
- {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/top_level.txt +0 -0
@@ -82,6 +82,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
|
82
82
|
adapted_request = GenerateReqInput(
|
83
83
|
**prompt_kwargs,
|
84
84
|
image_data=processed_messages.image_data,
|
85
|
+
video_data=processed_messages.video_data,
|
85
86
|
audio_data=processed_messages.audio_data,
|
86
87
|
sampling_params=sampling_params,
|
87
88
|
return_logprob=request.logprobs,
|
@@ -143,6 +144,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
|
143
144
|
prompt_ids = []
|
144
145
|
openai_compatible_messages = []
|
145
146
|
image_data = []
|
147
|
+
video_data = []
|
146
148
|
audio_data = []
|
147
149
|
modalities = []
|
148
150
|
|
@@ -158,6 +160,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
|
158
160
|
msg_dict,
|
159
161
|
template_content_format,
|
160
162
|
image_data,
|
163
|
+
video_data,
|
161
164
|
audio_data,
|
162
165
|
modalities,
|
163
166
|
)
|
@@ -214,11 +217,13 @@ class OpenAIServingChat(OpenAIServingBase):
|
|
214
217
|
stop = request.stop
|
215
218
|
image_data = image_data if image_data else None
|
216
219
|
audio_data = audio_data if audio_data else None
|
220
|
+
video_data = video_data if video_data else None
|
217
221
|
modalities = modalities if modalities else []
|
218
222
|
return MessageProcessingResult(
|
219
223
|
prompt=prompt,
|
220
224
|
prompt_ids=prompt_ids,
|
221
225
|
image_data=image_data,
|
226
|
+
video_data=video_data,
|
222
227
|
audio_data=audio_data,
|
223
228
|
modalities=modalities,
|
224
229
|
stop=stop,
|
@@ -260,6 +265,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
|
260
265
|
prompt = conv.get_prompt()
|
261
266
|
|
262
267
|
image_data = conv.image_data if conv.image_data else None
|
268
|
+
video_data = conv.video_data if conv.video_data else None
|
263
269
|
audio_data = conv.audio_data if conv.audio_data else None
|
264
270
|
modalities = conv.modalities if conv.modalities else []
|
265
271
|
stop = copy.copy(conv.stop_str or [] if not request.ignore_eos else [])
|
@@ -277,6 +283,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
|
277
283
|
prompt=prompt,
|
278
284
|
prompt_ids=prompt_ids,
|
279
285
|
image_data=image_data,
|
286
|
+
video_data=video_data,
|
280
287
|
audio_data=audio_data,
|
281
288
|
modalities=modalities,
|
282
289
|
stop=stop,
|
@@ -10,6 +10,7 @@ from sglang.srt.entrypoints.openai.protocol import (
|
|
10
10
|
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
11
11
|
from sglang.srt.function_call.core_types import ToolCallItem
|
12
12
|
from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
|
13
|
+
from sglang.srt.function_call.kimik2_detector import KimiK2Detector
|
13
14
|
from sglang.srt.function_call.llama32_detector import Llama32Detector
|
14
15
|
from sglang.srt.function_call.mistral_detector import MistralDetector
|
15
16
|
from sglang.srt.function_call.pythonic_detector import PythonicDetector
|
@@ -33,6 +34,7 @@ class FunctionCallParser:
|
|
33
34
|
"mistral": MistralDetector,
|
34
35
|
"deepseekv3": DeepSeekV3Detector,
|
35
36
|
"pythonic": PythonicDetector,
|
37
|
+
"kimi_k2": KimiK2Detector,
|
36
38
|
}
|
37
39
|
|
38
40
|
def __init__(self, tools: List[Tool], tool_call_parser: str):
|
@@ -0,0 +1,220 @@
|
|
1
|
+
import json
|
2
|
+
import logging
|
3
|
+
import re
|
4
|
+
from typing import List
|
5
|
+
|
6
|
+
from sglang.srt.entrypoints.openai.protocol import Tool
|
7
|
+
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
8
|
+
from sglang.srt.function_call.core_types import (
|
9
|
+
StreamingParseResult,
|
10
|
+
StructureInfo,
|
11
|
+
ToolCallItem,
|
12
|
+
_GetInfoFunc,
|
13
|
+
)
|
14
|
+
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
15
|
+
from sglang.srt.function_call.utils import _is_complete_json
|
16
|
+
|
17
|
+
logger = logging.getLogger(__name__)
|
18
|
+
|
19
|
+
|
20
|
+
class KimiK2Detector(BaseFormatDetector):
|
21
|
+
|
22
|
+
def __init__(self):
|
23
|
+
super().__init__()
|
24
|
+
self._buffer = ""
|
25
|
+
self.current_tool_name_sent: bool = False
|
26
|
+
self.prev_tool_call_arr: list[dict] = []
|
27
|
+
self.current_tool_id: int = -1
|
28
|
+
self.streamed_args_for_tool: list[str] = (
|
29
|
+
[]
|
30
|
+
) # map what has been streamed for each tool so far to a list
|
31
|
+
|
32
|
+
self.bot_token: str = "<|tool_calls_section_begin|>"
|
33
|
+
self.eot_token: str = "<|tool_calls_section_end|>"
|
34
|
+
|
35
|
+
self.tool_call_start_token: str = "<|tool_call_begin|>"
|
36
|
+
self.tool_call_end_token: str = "<|tool_call_end|>"
|
37
|
+
|
38
|
+
self.tool_call_regex = re.compile(
|
39
|
+
r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>\{.*?\})\s*<\|tool_call_end\|>"
|
40
|
+
)
|
41
|
+
|
42
|
+
self.stream_tool_call_portion_regex = re.compile(
|
43
|
+
r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>\{.*)"
|
44
|
+
)
|
45
|
+
|
46
|
+
self._last_arguments = ""
|
47
|
+
|
48
|
+
def has_tool_call(self, text: str) -> bool:
|
49
|
+
"""Check if the text contains a KimiK2 format tool call."""
|
50
|
+
return self.bot_token in text
|
51
|
+
|
52
|
+
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
53
|
+
"""
|
54
|
+
One-time parsing: Detects and parses tool calls in the provided text.
|
55
|
+
|
56
|
+
:param text: The complete text to parse.
|
57
|
+
:param tools: List of available tools.
|
58
|
+
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
|
59
|
+
"""
|
60
|
+
if self.bot_token not in text:
|
61
|
+
return StreamingParseResult(normal_text=text, calls=[])
|
62
|
+
try:
|
63
|
+
# there are two possible captures - between tags, or between a
|
64
|
+
# tag and end-of-string so the result of
|
65
|
+
# findall is an array of tuples where one is a function call and
|
66
|
+
# the other is None
|
67
|
+
function_call_tuples = self.tool_call_regex.findall(text)
|
68
|
+
|
69
|
+
logger.debug("function_call_tuples: %s", function_call_tuples)
|
70
|
+
|
71
|
+
tool_calls = []
|
72
|
+
for match in function_call_tuples:
|
73
|
+
function_id, function_args = match
|
74
|
+
function_name = function_id.split(".")[1].split(":")[0]
|
75
|
+
function_idx = int(function_id.split(".")[1].split(":")[1])
|
76
|
+
|
77
|
+
logger.info(f"function_name {function_name}")
|
78
|
+
|
79
|
+
tool_calls.append(
|
80
|
+
ToolCallItem(
|
81
|
+
tool_index=function_idx, # Use the call index in the response, not tool position
|
82
|
+
name=function_name,
|
83
|
+
parameters=function_args,
|
84
|
+
)
|
85
|
+
)
|
86
|
+
|
87
|
+
content = text[: text.find(self.bot_token)]
|
88
|
+
return StreamingParseResult(normal_text=content, calls=tool_calls)
|
89
|
+
|
90
|
+
except Exception as e:
|
91
|
+
logger.error(f"Error in detect_and_parse: {e}")
|
92
|
+
# return the normal text if parsing fails
|
93
|
+
return StreamingParseResult(normal_text=text)
|
94
|
+
|
95
|
+
def parse_streaming_increment(
|
96
|
+
self, new_text: str, tools: List[Tool]
|
97
|
+
) -> StreamingParseResult:
|
98
|
+
"""
|
99
|
+
Streaming incremental parsing tool calls for KimiK2 format.
|
100
|
+
"""
|
101
|
+
self._buffer += new_text
|
102
|
+
current_text = self._buffer
|
103
|
+
|
104
|
+
# Check if we have a tool call (either the start token or individual tool call)
|
105
|
+
has_tool_call = (
|
106
|
+
self.bot_token in current_text or self.tool_call_start_token in current_text
|
107
|
+
)
|
108
|
+
|
109
|
+
if not has_tool_call:
|
110
|
+
self._buffer = ""
|
111
|
+
for e_token in [self.eot_token, self.tool_call_end_token]:
|
112
|
+
if e_token in new_text:
|
113
|
+
new_text = new_text.replace(e_token, "")
|
114
|
+
return StreamingParseResult(normal_text=new_text)
|
115
|
+
|
116
|
+
if not hasattr(self, "_tool_indices"):
|
117
|
+
self._tool_indices = {
|
118
|
+
tool.function.name: i
|
119
|
+
for i, tool in enumerate(tools)
|
120
|
+
if tool.function and tool.function.name
|
121
|
+
}
|
122
|
+
|
123
|
+
calls: list[ToolCallItem] = []
|
124
|
+
try:
|
125
|
+
match = self.stream_tool_call_portion_regex.search(current_text)
|
126
|
+
if match:
|
127
|
+
function_id = match.group("tool_call_id")
|
128
|
+
function_args = match.group("function_arguments")
|
129
|
+
|
130
|
+
function_name = function_id.split(".")[1].split(":")[0]
|
131
|
+
|
132
|
+
# Initialize state if this is the first tool call
|
133
|
+
if self.current_tool_id == -1:
|
134
|
+
self.current_tool_id = 0
|
135
|
+
self.prev_tool_call_arr = []
|
136
|
+
self.streamed_args_for_tool = [""]
|
137
|
+
|
138
|
+
# Ensure we have enough entries in our tracking arrays
|
139
|
+
while len(self.prev_tool_call_arr) <= self.current_tool_id:
|
140
|
+
self.prev_tool_call_arr.append({})
|
141
|
+
while len(self.streamed_args_for_tool) <= self.current_tool_id:
|
142
|
+
self.streamed_args_for_tool.append("")
|
143
|
+
|
144
|
+
if not self.current_tool_name_sent:
|
145
|
+
calls.append(
|
146
|
+
ToolCallItem(
|
147
|
+
tool_index=self.current_tool_id,
|
148
|
+
name=function_name,
|
149
|
+
parameters="",
|
150
|
+
)
|
151
|
+
)
|
152
|
+
self.current_tool_name_sent = True
|
153
|
+
# Store the tool call info for adapter.py
|
154
|
+
self.prev_tool_call_arr[self.current_tool_id] = {
|
155
|
+
"name": function_name,
|
156
|
+
"arguments": {},
|
157
|
+
}
|
158
|
+
else:
|
159
|
+
argument_diff = (
|
160
|
+
function_args[len(self._last_arguments) :]
|
161
|
+
if function_args.startswith(self._last_arguments)
|
162
|
+
else function_args
|
163
|
+
)
|
164
|
+
|
165
|
+
parsed_args_diff = argument_diff.split("<|tool_call_end|>", 1)[0]
|
166
|
+
|
167
|
+
if parsed_args_diff:
|
168
|
+
|
169
|
+
calls.append(
|
170
|
+
ToolCallItem(
|
171
|
+
tool_index=self.current_tool_id,
|
172
|
+
name=None,
|
173
|
+
parameters=parsed_args_diff,
|
174
|
+
)
|
175
|
+
)
|
176
|
+
self._last_arguments += argument_diff
|
177
|
+
self.streamed_args_for_tool[
|
178
|
+
self.current_tool_id
|
179
|
+
] += parsed_args_diff
|
180
|
+
|
181
|
+
parsed_args = function_args.split("<|tool_call_end|>", 1)[0]
|
182
|
+
if _is_complete_json(parsed_args):
|
183
|
+
try:
|
184
|
+
parsed_args = json.loads(parsed_args)
|
185
|
+
self.prev_tool_call_arr[self.current_tool_id][
|
186
|
+
"arguments"
|
187
|
+
] = parsed_args
|
188
|
+
except json.JSONDecodeError:
|
189
|
+
pass
|
190
|
+
|
191
|
+
# Find the end of the current tool call and remove only that part from buffer
|
192
|
+
tool_call_end_pattern = (
|
193
|
+
r"<\|tool_call_begin\|>.*?<\|tool_call_end\|>"
|
194
|
+
)
|
195
|
+
match = re.search(
|
196
|
+
tool_call_end_pattern, current_text, re.DOTALL
|
197
|
+
)
|
198
|
+
if match:
|
199
|
+
# Remove the completed tool call from buffer, keep any remaining content
|
200
|
+
self._buffer = current_text[match.end() :]
|
201
|
+
else:
|
202
|
+
self._buffer = ""
|
203
|
+
|
204
|
+
result = StreamingParseResult(normal_text="", calls=calls)
|
205
|
+
self.current_tool_id += 1
|
206
|
+
self._last_arguments = ""
|
207
|
+
self.current_tool_name_sent = False
|
208
|
+
return result
|
209
|
+
|
210
|
+
return StreamingParseResult(normal_text="", calls=calls)
|
211
|
+
|
212
|
+
except Exception as e:
|
213
|
+
logger.error(f"Error in parse_streaming_increment: {e}")
|
214
|
+
return StreamingParseResult(normal_text=current_text)
|
215
|
+
|
216
|
+
def structure_info(self) -> _GetInfoFunc:
|
217
|
+
raise NotImplementedError()
|
218
|
+
|
219
|
+
def build_ebnf(self, tools: List[Tool]):
|
220
|
+
raise NotImplementedError()
|
@@ -14,6 +14,7 @@
|
|
14
14
|
"""Utilities for Huggingface Transformers."""
|
15
15
|
|
16
16
|
import contextlib
|
17
|
+
import logging
|
17
18
|
import os
|
18
19
|
import warnings
|
19
20
|
from pathlib import Path
|
@@ -25,6 +26,7 @@ from transformers import (
|
|
25
26
|
AutoConfig,
|
26
27
|
AutoProcessor,
|
27
28
|
AutoTokenizer,
|
29
|
+
GenerationConfig,
|
28
30
|
PretrainedConfig,
|
29
31
|
PreTrainedTokenizer,
|
30
32
|
PreTrainedTokenizerBase,
|
@@ -153,6 +155,22 @@ def get_config(
|
|
153
155
|
return config
|
154
156
|
|
155
157
|
|
158
|
+
@lru_cache_frozenset(maxsize=32)
|
159
|
+
def get_generation_config(
|
160
|
+
model: str,
|
161
|
+
trust_remote_code: bool,
|
162
|
+
revision: Optional[str] = None,
|
163
|
+
**kwargs,
|
164
|
+
):
|
165
|
+
try:
|
166
|
+
return GenerationConfig.from_pretrained(
|
167
|
+
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
|
168
|
+
)
|
169
|
+
except OSError as e:
|
170
|
+
logging.info("model doesn't have generation_config.json")
|
171
|
+
return None
|
172
|
+
|
173
|
+
|
156
174
|
# Models don't use the same configuration key for determining the maximum
|
157
175
|
# context length. Store them here so we can sanely check them.
|
158
176
|
# NOTE: The ordering here is important. Some models have two of these and we
|
@@ -110,6 +110,7 @@ def process_content_for_template_format(
|
|
110
110
|
msg_dict: dict,
|
111
111
|
content_format: str,
|
112
112
|
image_data: list,
|
113
|
+
video_data: list,
|
113
114
|
audio_data: list,
|
114
115
|
modalities: list,
|
115
116
|
) -> dict:
|
@@ -120,6 +121,7 @@ def process_content_for_template_format(
|
|
120
121
|
msg_dict: Message dictionary with content
|
121
122
|
content_format: 'string' or 'openai' (detected via AST analysis)
|
122
123
|
image_data: List to append extracted image URLs
|
124
|
+
video_data: List to append extracted video URLs
|
123
125
|
audio_data: List to append extracted audio URLs
|
124
126
|
modalities: List to append modalities
|
125
127
|
|
@@ -143,6 +145,12 @@ def process_content_for_template_format(
|
|
143
145
|
modalities.append(chunk.get("modalities"))
|
144
146
|
# Normalize to simple 'image' type for template compatibility
|
145
147
|
processed_content_parts.append({"type": "image"})
|
148
|
+
elif chunk_type == "video_url":
|
149
|
+
video_data.append(chunk["video_url"]["url"])
|
150
|
+
if chunk.get("modalities"):
|
151
|
+
modalities.append(chunk.get("modalities"))
|
152
|
+
# Normalize to simple 'video' type for template compatibility
|
153
|
+
processed_content_parts.append({"type": "video"})
|
146
154
|
elif chunk_type == "audio_url":
|
147
155
|
audio_data.append(chunk["audio_url"]["url"])
|
148
156
|
# Normalize to simple 'audio' type
|
@@ -187,11 +187,24 @@ class LayerCommunicator:
|
|
187
187
|
if hidden_states.shape[0] == 0:
|
188
188
|
residual = hidden_states
|
189
189
|
else:
|
190
|
-
if
|
191
|
-
residual
|
192
|
-
|
190
|
+
if (
|
191
|
+
residual is not None
|
192
|
+
and hasattr(hidden_states, "_sglang_needs_allreduce_fusion")
|
193
|
+
and hidden_states._sglang_needs_allreduce_fusion
|
194
|
+
):
|
195
|
+
hidden_states, residual = (
|
196
|
+
self.input_layernorm.forward_with_allreduce_fusion(
|
197
|
+
hidden_states, residual
|
198
|
+
)
|
199
|
+
)
|
193
200
|
else:
|
194
|
-
|
201
|
+
if residual is None:
|
202
|
+
residual = hidden_states
|
203
|
+
hidden_states = self.input_layernorm(hidden_states)
|
204
|
+
else:
|
205
|
+
hidden_states, residual = self.input_layernorm(
|
206
|
+
hidden_states, residual
|
207
|
+
)
|
195
208
|
|
196
209
|
hidden_states = self._communicate_simple_fn(
|
197
210
|
hidden_states=hidden_states,
|
@@ -402,12 +415,14 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
|
402
415
|
if hidden_states.shape[0] != 0:
|
403
416
|
hidden_states = layernorm(hidden_states)
|
404
417
|
else:
|
418
|
+
# According to the discussion in https://github.com/flashinfer-ai/flashinfer/issues/1223#issuecomment-3047256465
|
419
|
+
# We set the max token num to 128 for allreduce fusion with min-latency case(use_oneshot=True).
|
405
420
|
if (
|
406
421
|
_is_sm100_supported
|
407
422
|
and _is_flashinfer_available
|
408
423
|
and hasattr(layernorm, "forward_with_allreduce_fusion")
|
409
424
|
and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
|
410
|
-
and hidden_states.shape[0] <=
|
425
|
+
and hidden_states.shape[0] <= 128
|
411
426
|
):
|
412
427
|
hidden_states, residual = layernorm.forward_with_allreduce_fusion(
|
413
428
|
hidden_states, residual
|
@@ -92,7 +92,7 @@ _workspace_manager = FlashInferWorkspaceManager()
|
|
92
92
|
|
93
93
|
|
94
94
|
def ensure_workspace_initialized(
|
95
|
-
max_token_num: int =
|
95
|
+
max_token_num: int = 128, hidden_dim: int = 4096, use_fp32_lamport: bool = False
|
96
96
|
):
|
97
97
|
"""Ensure workspace is initialized"""
|
98
98
|
if not is_flashinfer_available() or _flashinfer_comm is None:
|
@@ -119,12 +119,12 @@ def ensure_workspace_initialized(
|
|
119
119
|
return _workspace_manager.initialized
|
120
120
|
|
121
121
|
|
122
|
-
def
|
122
|
+
def flashinfer_allreduce_residual_rmsnorm(
|
123
123
|
input_tensor: torch.Tensor,
|
124
124
|
residual: torch.Tensor,
|
125
125
|
weight: torch.Tensor,
|
126
126
|
eps: float = 1e-6,
|
127
|
-
max_token_num: int =
|
127
|
+
max_token_num: int = 128,
|
128
128
|
use_oneshot: bool = True,
|
129
129
|
trigger_completion_at_end: bool = False,
|
130
130
|
fp32_acc: bool = False,
|
sglang/srt/layers/layernorm.py
CHANGED
@@ -174,11 +174,11 @@ class RMSNorm(CustomOp):
|
|
174
174
|
if residual is not None:
|
175
175
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
176
176
|
from sglang.srt.layers.flashinfer_comm_fusion import (
|
177
|
-
|
177
|
+
flashinfer_allreduce_residual_rmsnorm,
|
178
178
|
)
|
179
179
|
|
180
180
|
if get_tensor_model_parallel_world_size() > 1:
|
181
|
-
fused_result =
|
181
|
+
fused_result = flashinfer_allreduce_residual_rmsnorm(
|
182
182
|
input_tensor=x,
|
183
183
|
residual=residual,
|
184
184
|
weight=self.weight,
|
sglang/srt/layers/linear.py
CHANGED
@@ -34,6 +34,7 @@ from sglang.srt.layers.quantization.base_config import (
|
|
34
34
|
from sglang.srt.utils import (
|
35
35
|
cpu_has_amx_support,
|
36
36
|
is_cpu,
|
37
|
+
is_npu,
|
37
38
|
set_weight_attrs,
|
38
39
|
use_intel_amx_backend,
|
39
40
|
)
|
@@ -60,6 +61,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
|
|
60
61
|
|
61
62
|
_is_cpu_amx_available = cpu_has_amx_support()
|
62
63
|
_is_cpu = is_cpu()
|
64
|
+
_is_npu = is_npu()
|
63
65
|
|
64
66
|
|
65
67
|
def adjust_marlin_shard(param, shard_size, shard_offset):
|
@@ -297,6 +299,14 @@ class ReplicatedLinear(LinearBase):
|
|
297
299
|
if len(loaded_weight.shape) == 0:
|
298
300
|
loaded_weight = loaded_weight.reshape(1)
|
299
301
|
|
302
|
+
# The per-tensor quant-scale must be 1 dimension
|
303
|
+
if _is_npu:
|
304
|
+
if param.size() != loaded_weight.size() and param.size(0) == 1:
|
305
|
+
if torch.allclose(loaded_weight, loaded_weight[0]):
|
306
|
+
loaded_weight = loaded_weight[:1]
|
307
|
+
else:
|
308
|
+
raise ValueError(f"{loaded_weight} are not all equal")
|
309
|
+
|
300
310
|
assert param.size() == loaded_weight.size()
|
301
311
|
param.data.copy_(loaded_weight)
|
302
312
|
|
@@ -1357,7 +1367,7 @@ class RowParallelLinear(LinearBase):
|
|
1357
1367
|
# It does not support additional parameters.
|
1358
1368
|
param.load_row_parallel_weight(loaded_weight)
|
1359
1369
|
|
1360
|
-
def forward(self, input_):
|
1370
|
+
def forward(self, input_, can_fuse_mlp_allreduce=False):
|
1361
1371
|
if self.input_is_parallel:
|
1362
1372
|
input_parallel = input_
|
1363
1373
|
else:
|
@@ -1372,7 +1382,7 @@ class RowParallelLinear(LinearBase):
|
|
1372
1382
|
# bias will not get added more than once in TP>1 case)
|
1373
1383
|
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
1374
1384
|
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
|
1375
|
-
if self.reduce_results and self.tp_size > 1:
|
1385
|
+
if self.reduce_results and self.tp_size > 1 and not can_fuse_mlp_allreduce:
|
1376
1386
|
output = tensor_model_parallel_all_reduce(output_parallel)
|
1377
1387
|
else:
|
1378
1388
|
output = output_parallel
|