sglang 0.4.9.post1__py3-none-any.whl → 0.4.9.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/srt/configs/model_config.py +24 -1
- sglang/srt/conversation.py +21 -2
- sglang/srt/disaggregation/ascend/__init__.py +6 -0
- sglang/srt/disaggregation/ascend/conn.py +44 -0
- sglang/srt/disaggregation/ascend/transfer_engine.py +58 -0
- sglang/srt/disaggregation/mooncake/conn.py +15 -14
- sglang/srt/disaggregation/mooncake/transfer_engine.py +17 -8
- sglang/srt/disaggregation/utils.py +25 -3
- sglang/srt/entrypoints/engine.py +1 -1
- sglang/srt/entrypoints/http_server.py +1 -0
- sglang/srt/entrypoints/openai/protocol.py +11 -0
- sglang/srt/entrypoints/openai/serving_chat.py +7 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/kimik2_detector.py +220 -0
- sglang/srt/hf_transformers_utils.py +18 -0
- sglang/srt/jinja_template_utils.py +8 -0
- sglang/srt/layers/communicator.py +17 -4
- sglang/srt/layers/linear.py +12 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
- sglang/srt/layers/moe/ep_moe/layer.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -2
- sglang/srt/layers/moe/topk.py +8 -2
- sglang/srt/layers/parameter.py +19 -3
- sglang/srt/layers/quantization/fp8_kernel.py +2 -2
- sglang/srt/layers/quantization/moe_wna16.py +1 -2
- sglang/srt/layers/quantization/w8a8_int8.py +738 -14
- sglang/srt/managers/io_struct.py +27 -2
- sglang/srt/managers/mm_utils.py +55 -94
- sglang/srt/managers/schedule_batch.py +16 -5
- sglang/srt/managers/scheduler.py +21 -1
- sglang/srt/managers/tokenizer_manager.py +16 -0
- sglang/srt/mem_cache/memory_pool.py +65 -40
- sglang/srt/model_executor/forward_batch_info.py +13 -1
- sglang/srt/model_loader/loader.py +23 -12
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +62 -17
- sglang/srt/models/deepseek_vl2.py +1 -1
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +6 -3
- sglang/srt/models/internvl.py +8 -2
- sglang/srt/models/kimi_vl.py +8 -2
- sglang/srt/models/llama.py +2 -0
- sglang/srt/models/llava.py +3 -1
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpmo.py +1 -2
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mixtral_quant.py +4 -0
- sglang/srt/models/mllama4.py +13 -4
- sglang/srt/models/phi4mm.py +8 -2
- sglang/srt/models/phimoe.py +553 -0
- sglang/srt/models/qwen2.py +2 -0
- sglang/srt/models/qwen2_5_vl.py +10 -7
- sglang/srt/models/qwen2_vl.py +12 -1
- sglang/srt/models/vila.py +8 -2
- sglang/srt/multimodal/processors/base_processor.py +197 -137
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +1 -1
- sglang/srt/multimodal/processors/gemma3.py +4 -2
- sglang/srt/multimodal/processors/gemma3n.py +1 -1
- sglang/srt/multimodal/processors/internvl.py +1 -1
- sglang/srt/multimodal/processors/janus_pro.py +1 -1
- sglang/srt/multimodal/processors/kimi_vl.py +1 -1
- sglang/srt/multimodal/processors/minicpm.py +4 -3
- sglang/srt/multimodal/processors/mllama4.py +1 -1
- sglang/srt/multimodal/processors/phi4mm.py +1 -1
- sglang/srt/multimodal/processors/pixtral.py +1 -1
- sglang/srt/multimodal/processors/qwen_vl.py +203 -80
- sglang/srt/multimodal/processors/vila.py +1 -1
- sglang/srt/server_args.py +11 -4
- sglang/srt/utils.py +154 -31
- sglang/version.py +1 -1
- {sglang-0.4.9.post1.dist-info → sglang-0.4.9.post2.dist-info}/METADATA +4 -3
- {sglang-0.4.9.post1.dist-info → sglang-0.4.9.post2.dist-info}/RECORD +75 -70
- {sglang-0.4.9.post1.dist-info → sglang-0.4.9.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post1.dist-info → sglang-0.4.9.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post1.dist-info → sglang-0.4.9.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,220 @@
|
|
1
|
+
import json
|
2
|
+
import logging
|
3
|
+
import re
|
4
|
+
from typing import List
|
5
|
+
|
6
|
+
from sglang.srt.entrypoints.openai.protocol import Tool
|
7
|
+
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
8
|
+
from sglang.srt.function_call.core_types import (
|
9
|
+
StreamingParseResult,
|
10
|
+
StructureInfo,
|
11
|
+
ToolCallItem,
|
12
|
+
_GetInfoFunc,
|
13
|
+
)
|
14
|
+
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
15
|
+
from sglang.srt.function_call.utils import _is_complete_json
|
16
|
+
|
17
|
+
logger = logging.getLogger(__name__)
|
18
|
+
|
19
|
+
|
20
|
+
class KimiK2Detector(BaseFormatDetector):
|
21
|
+
|
22
|
+
def __init__(self):
|
23
|
+
super().__init__()
|
24
|
+
self._buffer = ""
|
25
|
+
self.current_tool_name_sent: bool = False
|
26
|
+
self.prev_tool_call_arr: list[dict] = []
|
27
|
+
self.current_tool_id: int = -1
|
28
|
+
self.streamed_args_for_tool: list[str] = (
|
29
|
+
[]
|
30
|
+
) # map what has been streamed for each tool so far to a list
|
31
|
+
|
32
|
+
self.bot_token: str = "<|tool_calls_section_begin|>"
|
33
|
+
self.eot_token: str = "<|tool_calls_section_end|>"
|
34
|
+
|
35
|
+
self.tool_call_start_token: str = "<|tool_call_begin|>"
|
36
|
+
self.tool_call_end_token: str = "<|tool_call_end|>"
|
37
|
+
|
38
|
+
self.tool_call_regex = re.compile(
|
39
|
+
r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>\{.*?\})\s*<\|tool_call_end\|>"
|
40
|
+
)
|
41
|
+
|
42
|
+
self.stream_tool_call_portion_regex = re.compile(
|
43
|
+
r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>\{.*)"
|
44
|
+
)
|
45
|
+
|
46
|
+
self._last_arguments = ""
|
47
|
+
|
48
|
+
def has_tool_call(self, text: str) -> bool:
|
49
|
+
"""Check if the text contains a KimiK2 format tool call."""
|
50
|
+
return self.bot_token in text
|
51
|
+
|
52
|
+
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
53
|
+
"""
|
54
|
+
One-time parsing: Detects and parses tool calls in the provided text.
|
55
|
+
|
56
|
+
:param text: The complete text to parse.
|
57
|
+
:param tools: List of available tools.
|
58
|
+
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
|
59
|
+
"""
|
60
|
+
if self.bot_token not in text:
|
61
|
+
return StreamingParseResult(normal_text=text, calls=[])
|
62
|
+
try:
|
63
|
+
# there are two possible captures - between tags, or between a
|
64
|
+
# tag and end-of-string so the result of
|
65
|
+
# findall is an array of tuples where one is a function call and
|
66
|
+
# the other is None
|
67
|
+
function_call_tuples = self.tool_call_regex.findall(text)
|
68
|
+
|
69
|
+
logger.debug("function_call_tuples: %s", function_call_tuples)
|
70
|
+
|
71
|
+
tool_calls = []
|
72
|
+
for match in function_call_tuples:
|
73
|
+
function_id, function_args = match
|
74
|
+
function_name = function_id.split(".")[1].split(":")[0]
|
75
|
+
function_idx = int(function_id.split(".")[1].split(":")[1])
|
76
|
+
|
77
|
+
logger.info(f"function_name {function_name}")
|
78
|
+
|
79
|
+
tool_calls.append(
|
80
|
+
ToolCallItem(
|
81
|
+
tool_index=function_idx, # Use the call index in the response, not tool position
|
82
|
+
name=function_name,
|
83
|
+
parameters=function_args,
|
84
|
+
)
|
85
|
+
)
|
86
|
+
|
87
|
+
content = text[: text.find(self.bot_token)]
|
88
|
+
return StreamingParseResult(normal_text=content, calls=tool_calls)
|
89
|
+
|
90
|
+
except Exception as e:
|
91
|
+
logger.error(f"Error in detect_and_parse: {e}")
|
92
|
+
# return the normal text if parsing fails
|
93
|
+
return StreamingParseResult(normal_text=text)
|
94
|
+
|
95
|
+
def parse_streaming_increment(
|
96
|
+
self, new_text: str, tools: List[Tool]
|
97
|
+
) -> StreamingParseResult:
|
98
|
+
"""
|
99
|
+
Streaming incremental parsing tool calls for KimiK2 format.
|
100
|
+
"""
|
101
|
+
self._buffer += new_text
|
102
|
+
current_text = self._buffer
|
103
|
+
|
104
|
+
# Check if we have a tool call (either the start token or individual tool call)
|
105
|
+
has_tool_call = (
|
106
|
+
self.bot_token in current_text or self.tool_call_start_token in current_text
|
107
|
+
)
|
108
|
+
|
109
|
+
if not has_tool_call:
|
110
|
+
self._buffer = ""
|
111
|
+
for e_token in [self.eot_token, self.tool_call_end_token]:
|
112
|
+
if e_token in new_text:
|
113
|
+
new_text = new_text.replace(e_token, "")
|
114
|
+
return StreamingParseResult(normal_text=new_text)
|
115
|
+
|
116
|
+
if not hasattr(self, "_tool_indices"):
|
117
|
+
self._tool_indices = {
|
118
|
+
tool.function.name: i
|
119
|
+
for i, tool in enumerate(tools)
|
120
|
+
if tool.function and tool.function.name
|
121
|
+
}
|
122
|
+
|
123
|
+
calls: list[ToolCallItem] = []
|
124
|
+
try:
|
125
|
+
match = self.stream_tool_call_portion_regex.search(current_text)
|
126
|
+
if match:
|
127
|
+
function_id = match.group("tool_call_id")
|
128
|
+
function_args = match.group("function_arguments")
|
129
|
+
|
130
|
+
function_name = function_id.split(".")[1].split(":")[0]
|
131
|
+
|
132
|
+
# Initialize state if this is the first tool call
|
133
|
+
if self.current_tool_id == -1:
|
134
|
+
self.current_tool_id = 0
|
135
|
+
self.prev_tool_call_arr = []
|
136
|
+
self.streamed_args_for_tool = [""]
|
137
|
+
|
138
|
+
# Ensure we have enough entries in our tracking arrays
|
139
|
+
while len(self.prev_tool_call_arr) <= self.current_tool_id:
|
140
|
+
self.prev_tool_call_arr.append({})
|
141
|
+
while len(self.streamed_args_for_tool) <= self.current_tool_id:
|
142
|
+
self.streamed_args_for_tool.append("")
|
143
|
+
|
144
|
+
if not self.current_tool_name_sent:
|
145
|
+
calls.append(
|
146
|
+
ToolCallItem(
|
147
|
+
tool_index=self.current_tool_id,
|
148
|
+
name=function_name,
|
149
|
+
parameters="",
|
150
|
+
)
|
151
|
+
)
|
152
|
+
self.current_tool_name_sent = True
|
153
|
+
# Store the tool call info for adapter.py
|
154
|
+
self.prev_tool_call_arr[self.current_tool_id] = {
|
155
|
+
"name": function_name,
|
156
|
+
"arguments": {},
|
157
|
+
}
|
158
|
+
else:
|
159
|
+
argument_diff = (
|
160
|
+
function_args[len(self._last_arguments) :]
|
161
|
+
if function_args.startswith(self._last_arguments)
|
162
|
+
else function_args
|
163
|
+
)
|
164
|
+
|
165
|
+
parsed_args_diff = argument_diff.split("<|tool_call_end|>", 1)[0]
|
166
|
+
|
167
|
+
if parsed_args_diff:
|
168
|
+
|
169
|
+
calls.append(
|
170
|
+
ToolCallItem(
|
171
|
+
tool_index=self.current_tool_id,
|
172
|
+
name=None,
|
173
|
+
parameters=parsed_args_diff,
|
174
|
+
)
|
175
|
+
)
|
176
|
+
self._last_arguments += argument_diff
|
177
|
+
self.streamed_args_for_tool[
|
178
|
+
self.current_tool_id
|
179
|
+
] += parsed_args_diff
|
180
|
+
|
181
|
+
parsed_args = function_args.split("<|tool_call_end|>", 1)[0]
|
182
|
+
if _is_complete_json(parsed_args):
|
183
|
+
try:
|
184
|
+
parsed_args = json.loads(parsed_args)
|
185
|
+
self.prev_tool_call_arr[self.current_tool_id][
|
186
|
+
"arguments"
|
187
|
+
] = parsed_args
|
188
|
+
except json.JSONDecodeError:
|
189
|
+
pass
|
190
|
+
|
191
|
+
# Find the end of the current tool call and remove only that part from buffer
|
192
|
+
tool_call_end_pattern = (
|
193
|
+
r"<\|tool_call_begin\|>.*?<\|tool_call_end\|>"
|
194
|
+
)
|
195
|
+
match = re.search(
|
196
|
+
tool_call_end_pattern, current_text, re.DOTALL
|
197
|
+
)
|
198
|
+
if match:
|
199
|
+
# Remove the completed tool call from buffer, keep any remaining content
|
200
|
+
self._buffer = current_text[match.end() :]
|
201
|
+
else:
|
202
|
+
self._buffer = ""
|
203
|
+
|
204
|
+
result = StreamingParseResult(normal_text="", calls=calls)
|
205
|
+
self.current_tool_id += 1
|
206
|
+
self._last_arguments = ""
|
207
|
+
self.current_tool_name_sent = False
|
208
|
+
return result
|
209
|
+
|
210
|
+
return StreamingParseResult(normal_text="", calls=calls)
|
211
|
+
|
212
|
+
except Exception as e:
|
213
|
+
logger.error(f"Error in parse_streaming_increment: {e}")
|
214
|
+
return StreamingParseResult(normal_text=current_text)
|
215
|
+
|
216
|
+
def structure_info(self) -> _GetInfoFunc:
|
217
|
+
raise NotImplementedError()
|
218
|
+
|
219
|
+
def build_ebnf(self, tools: List[Tool]):
|
220
|
+
raise NotImplementedError()
|
@@ -14,6 +14,7 @@
|
|
14
14
|
"""Utilities for Huggingface Transformers."""
|
15
15
|
|
16
16
|
import contextlib
|
17
|
+
import logging
|
17
18
|
import os
|
18
19
|
import warnings
|
19
20
|
from pathlib import Path
|
@@ -25,6 +26,7 @@ from transformers import (
|
|
25
26
|
AutoConfig,
|
26
27
|
AutoProcessor,
|
27
28
|
AutoTokenizer,
|
29
|
+
GenerationConfig,
|
28
30
|
PretrainedConfig,
|
29
31
|
PreTrainedTokenizer,
|
30
32
|
PreTrainedTokenizerBase,
|
@@ -153,6 +155,22 @@ def get_config(
|
|
153
155
|
return config
|
154
156
|
|
155
157
|
|
158
|
+
@lru_cache_frozenset(maxsize=32)
|
159
|
+
def get_generation_config(
|
160
|
+
model: str,
|
161
|
+
trust_remote_code: bool,
|
162
|
+
revision: Optional[str] = None,
|
163
|
+
**kwargs,
|
164
|
+
):
|
165
|
+
try:
|
166
|
+
return GenerationConfig.from_pretrained(
|
167
|
+
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
|
168
|
+
)
|
169
|
+
except OSError as e:
|
170
|
+
logging.info("model doesn't have generation_config.json")
|
171
|
+
return None
|
172
|
+
|
173
|
+
|
156
174
|
# Models don't use the same configuration key for determining the maximum
|
157
175
|
# context length. Store them here so we can sanely check them.
|
158
176
|
# NOTE: The ordering here is important. Some models have two of these and we
|
@@ -110,6 +110,7 @@ def process_content_for_template_format(
|
|
110
110
|
msg_dict: dict,
|
111
111
|
content_format: str,
|
112
112
|
image_data: list,
|
113
|
+
video_data: list,
|
113
114
|
audio_data: list,
|
114
115
|
modalities: list,
|
115
116
|
) -> dict:
|
@@ -120,6 +121,7 @@ def process_content_for_template_format(
|
|
120
121
|
msg_dict: Message dictionary with content
|
121
122
|
content_format: 'string' or 'openai' (detected via AST analysis)
|
122
123
|
image_data: List to append extracted image URLs
|
124
|
+
video_data: List to append extracted video URLs
|
123
125
|
audio_data: List to append extracted audio URLs
|
124
126
|
modalities: List to append modalities
|
125
127
|
|
@@ -143,6 +145,12 @@ def process_content_for_template_format(
|
|
143
145
|
modalities.append(chunk.get("modalities"))
|
144
146
|
# Normalize to simple 'image' type for template compatibility
|
145
147
|
processed_content_parts.append({"type": "image"})
|
148
|
+
elif chunk_type == "video_url":
|
149
|
+
video_data.append(chunk["video_url"]["url"])
|
150
|
+
if chunk.get("modalities"):
|
151
|
+
modalities.append(chunk.get("modalities"))
|
152
|
+
# Normalize to simple 'video' type for template compatibility
|
153
|
+
processed_content_parts.append({"type": "video"})
|
146
154
|
elif chunk_type == "audio_url":
|
147
155
|
audio_data.append(chunk["audio_url"]["url"])
|
148
156
|
# Normalize to simple 'audio' type
|
@@ -187,11 +187,24 @@ class LayerCommunicator:
|
|
187
187
|
if hidden_states.shape[0] == 0:
|
188
188
|
residual = hidden_states
|
189
189
|
else:
|
190
|
-
if
|
191
|
-
residual
|
192
|
-
|
190
|
+
if (
|
191
|
+
residual is not None
|
192
|
+
and hasattr(hidden_states, "_sglang_needs_allreduce_fusion")
|
193
|
+
and hidden_states._sglang_needs_allreduce_fusion
|
194
|
+
):
|
195
|
+
hidden_states, residual = (
|
196
|
+
self.input_layernorm.forward_with_allreduce_fusion(
|
197
|
+
hidden_states, residual
|
198
|
+
)
|
199
|
+
)
|
193
200
|
else:
|
194
|
-
|
201
|
+
if residual is None:
|
202
|
+
residual = hidden_states
|
203
|
+
hidden_states = self.input_layernorm(hidden_states)
|
204
|
+
else:
|
205
|
+
hidden_states, residual = self.input_layernorm(
|
206
|
+
hidden_states, residual
|
207
|
+
)
|
195
208
|
|
196
209
|
hidden_states = self._communicate_simple_fn(
|
197
210
|
hidden_states=hidden_states,
|
sglang/srt/layers/linear.py
CHANGED
@@ -34,6 +34,7 @@ from sglang.srt.layers.quantization.base_config import (
|
|
34
34
|
from sglang.srt.utils import (
|
35
35
|
cpu_has_amx_support,
|
36
36
|
is_cpu,
|
37
|
+
is_npu,
|
37
38
|
set_weight_attrs,
|
38
39
|
use_intel_amx_backend,
|
39
40
|
)
|
@@ -60,6 +61,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
|
|
60
61
|
|
61
62
|
_is_cpu_amx_available = cpu_has_amx_support()
|
62
63
|
_is_cpu = is_cpu()
|
64
|
+
_is_npu = is_npu()
|
63
65
|
|
64
66
|
|
65
67
|
def adjust_marlin_shard(param, shard_size, shard_offset):
|
@@ -297,6 +299,14 @@ class ReplicatedLinear(LinearBase):
|
|
297
299
|
if len(loaded_weight.shape) == 0:
|
298
300
|
loaded_weight = loaded_weight.reshape(1)
|
299
301
|
|
302
|
+
# The per-tensor quant-scale must be 1 dimension
|
303
|
+
if _is_npu:
|
304
|
+
if param.size() != loaded_weight.size() and param.size(0) == 1:
|
305
|
+
if torch.allclose(loaded_weight, loaded_weight[0]):
|
306
|
+
loaded_weight = loaded_weight[:1]
|
307
|
+
else:
|
308
|
+
raise ValueError(f"{loaded_weight} are not all equal")
|
309
|
+
|
300
310
|
assert param.size() == loaded_weight.size()
|
301
311
|
param.data.copy_(loaded_weight)
|
302
312
|
|
@@ -1357,7 +1367,7 @@ class RowParallelLinear(LinearBase):
|
|
1357
1367
|
# It does not support additional parameters.
|
1358
1368
|
param.load_row_parallel_weight(loaded_weight)
|
1359
1369
|
|
1360
|
-
def forward(self, input_):
|
1370
|
+
def forward(self, input_, can_fuse_mlp_allreduce=False):
|
1361
1371
|
if self.input_is_parallel:
|
1362
1372
|
input_parallel = input_
|
1363
1373
|
else:
|
@@ -1372,7 +1382,7 @@ class RowParallelLinear(LinearBase):
|
|
1372
1382
|
# bias will not get added more than once in TP>1 case)
|
1373
1383
|
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
1374
1384
|
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
|
1375
|
-
if self.reduce_results and self.tp_size > 1:
|
1385
|
+
if self.reduce_results and self.tp_size > 1 and not can_fuse_mlp_allreduce:
|
1376
1386
|
output = tensor_model_parallel_all_reduce(output_parallel)
|
1377
1387
|
else:
|
1378
1388
|
output = output_parallel
|
@@ -6,6 +6,7 @@ import triton
|
|
6
6
|
|
7
7
|
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
8
8
|
from sglang.srt.utils import ceil_div, dispose_tensor, is_cuda
|
9
|
+
from sglang.utils import is_in_ci
|
9
10
|
|
10
11
|
logger = logging.getLogger(__name__)
|
11
12
|
|
@@ -1058,7 +1059,7 @@ def ep_gather(
|
|
1058
1059
|
input_index: torch.Tensor,
|
1059
1060
|
output_tensor: torch.Tensor,
|
1060
1061
|
):
|
1061
|
-
BLOCK_D = 1024 # block size of quantization
|
1062
|
+
BLOCK_D = 1024 if not is_in_ci() else 128 # block size of quantization
|
1062
1063
|
num_warps = 2
|
1063
1064
|
num_tokens = output_tensor.shape[0]
|
1064
1065
|
hidden_size = input_tensor.shape[1]
|
@@ -12,7 +12,6 @@ from sglang.srt.distributed import (
|
|
12
12
|
)
|
13
13
|
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
|
14
14
|
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
|
15
|
-
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
|
16
15
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
17
16
|
ep_gather,
|
18
17
|
ep_scatter,
|
@@ -65,6 +64,8 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
|
65
64
|
if not _is_npu:
|
66
65
|
from sgl_kernel import silu_and_mul
|
67
66
|
|
67
|
+
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
|
68
|
+
|
68
69
|
if _is_hip:
|
69
70
|
from vllm._custom_ops import scaled_fp8_quant
|
70
71
|
|
@@ -518,6 +518,7 @@ class FusedMoE(torch.nn.Module):
|
|
518
518
|
self.quant_method.enable_flashinfer_moe = self.enable_flashinfer_moe
|
519
519
|
assert self.quant_method is not None
|
520
520
|
|
521
|
+
self.quant_config = quant_config
|
521
522
|
self.quant_method.create_weights(
|
522
523
|
layer=self,
|
523
524
|
num_experts=self.local_num_experts,
|
@@ -661,7 +662,11 @@ class FusedMoE(torch.nn.Module):
|
|
661
662
|
):
|
662
663
|
raise ValueError("expert_data and loaded_weight must be torch.Tensor")
|
663
664
|
|
664
|
-
if
|
665
|
+
if (
|
666
|
+
self.quant_config is not None
|
667
|
+
and "modelopt" in self.quant_config.get_name()
|
668
|
+
and (expert_data.dim() != 2 or loaded_weight.dim() != 2)
|
669
|
+
):
|
665
670
|
raise ValueError(
|
666
671
|
f"Expected 2D tensors, got expert_data shape {expert_data.shape} and loaded_weight shape {loaded_weight.shape}"
|
667
672
|
)
|
@@ -850,7 +855,7 @@ class FusedMoE(torch.nn.Module):
|
|
850
855
|
return
|
851
856
|
|
852
857
|
# Case weight scales and zero_points
|
853
|
-
if "scale" in weight_name or "zero" in weight_name:
|
858
|
+
if "scale" in weight_name or "zero" in weight_name or "offset" in weight_name:
|
854
859
|
# load the weight scales and zp based on the quantization scheme
|
855
860
|
# supported weight scales/zp can be found in
|
856
861
|
# FusedMoeWeightScaleSupported
|
sglang/srt/layers/moe/topk.py
CHANGED
@@ -83,13 +83,18 @@ def fused_topk_cpu(
|
|
83
83
|
gating_output: torch.Tensor,
|
84
84
|
topk: int,
|
85
85
|
renormalize: bool,
|
86
|
+
num_token_non_padded: Optional[torch.Tensor] = None,
|
87
|
+
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
86
88
|
):
|
87
|
-
|
89
|
+
topk_weights, topk_ids = torch.ops.sgl_kernel.topk_softmax_cpu(
|
88
90
|
hidden_states=hidden_states,
|
89
91
|
gating_output=gating_output,
|
90
92
|
topk=topk,
|
91
93
|
renormalize=renormalize,
|
92
94
|
)
|
95
|
+
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
96
|
+
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
|
97
|
+
return topk_weights, topk_ids
|
93
98
|
|
94
99
|
|
95
100
|
def fused_topk(
|
@@ -303,7 +308,7 @@ def biased_grouped_topk_gpu(
|
|
303
308
|
renormalize: bool,
|
304
309
|
num_expert_group: int = 0,
|
305
310
|
topk_group: int = 0,
|
306
|
-
compiled: bool =
|
311
|
+
compiled: bool = not _is_npu,
|
307
312
|
num_fused_shared_experts: int = 0,
|
308
313
|
routed_scaling_factor: Optional[float] = None,
|
309
314
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
@@ -411,6 +416,7 @@ if _is_cpu and _is_cpu_amx_available:
|
|
411
416
|
biased_grouped_topk = biased_grouped_topk_cpu
|
412
417
|
grouped_topk = grouped_topk_cpu
|
413
418
|
fused_topk_native = fused_topk_cpu
|
419
|
+
fused_topk = fused_topk_cpu
|
414
420
|
else:
|
415
421
|
biased_grouped_topk = biased_grouped_topk_gpu
|
416
422
|
grouped_topk = grouped_topk_gpu
|
sglang/srt/layers/parameter.py
CHANGED
@@ -187,10 +187,26 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
|
187
187
|
param_data = self.data
|
188
188
|
shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
|
189
189
|
param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
|
190
|
-
|
191
|
-
|
192
|
-
|
190
|
+
|
191
|
+
if _is_cpu:
|
192
|
+
from sglang.srt.model_loader.weight_utils import (
|
193
|
+
narrow_padded_param_and_loaded_weight,
|
194
|
+
)
|
195
|
+
|
196
|
+
param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
|
197
|
+
param_data,
|
198
|
+
loaded_weight,
|
199
|
+
0, # param_data_start
|
200
|
+
shard_id * shard_size,
|
201
|
+
self.output_dim,
|
202
|
+
shard_size,
|
203
|
+
not use_presharded_weights,
|
193
204
|
)
|
205
|
+
else:
|
206
|
+
if not use_presharded_weights:
|
207
|
+
loaded_weight = loaded_weight.narrow(
|
208
|
+
self.output_dim, shard_id * shard_size, shard_size
|
209
|
+
)
|
194
210
|
|
195
211
|
assert (
|
196
212
|
param_data.shape == loaded_weight.shape
|
@@ -160,8 +160,8 @@ def _per_token_group_quant_fp8_colmajor(
|
|
160
160
|
"""
|
161
161
|
# Map the program id to the row of X and Y it should compute.
|
162
162
|
g_id = tl.program_id(0)
|
163
|
-
y_ptr += g_id * group_size
|
164
|
-
y_q_ptr += g_id * group_size
|
163
|
+
y_ptr += g_id.to(tl.int64) * group_size
|
164
|
+
y_q_ptr += g_id.to(tl.int64) * group_size
|
165
165
|
|
166
166
|
# Convert g_id the flattened block coordinate to 2D so we can index
|
167
167
|
# into the output y_scales matrix
|
@@ -116,8 +116,7 @@ class MoeWNA16Config(QuantizationConfig):
|
|
116
116
|
|
117
117
|
@classmethod
|
118
118
|
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
|
119
|
-
|
120
|
-
if can_convert and user_quant == "moe_wna16":
|
119
|
+
if user_quant == "moe_wna16" and cls.is_moe_wna16_compatible(hf_quant_cfg):
|
121
120
|
return cls.get_name()
|
122
121
|
return None
|
123
122
|
|