sglang 0.4.8__py3-none-any.whl → 0.4.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +17 -2
- sglang/bench_serving.py +168 -22
- sglang/srt/configs/internvl.py +4 -2
- sglang/srt/configs/janus_pro.py +1 -1
- sglang/srt/configs/model_config.py +49 -0
- sglang/srt/configs/update_config.py +119 -0
- sglang/srt/conversation.py +35 -0
- sglang/srt/custom_op.py +7 -1
- sglang/srt/disaggregation/base/conn.py +2 -0
- sglang/srt/disaggregation/decode.py +22 -6
- sglang/srt/disaggregation/mooncake/conn.py +289 -48
- sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
- sglang/srt/disaggregation/nixl/conn.py +100 -52
- sglang/srt/disaggregation/prefill.py +5 -4
- sglang/srt/disaggregation/utils.py +13 -12
- sglang/srt/distributed/parallel_state.py +44 -17
- sglang/srt/entrypoints/EngineBase.py +8 -0
- sglang/srt/entrypoints/engine.py +45 -9
- sglang/srt/entrypoints/http_server.py +111 -24
- sglang/srt/entrypoints/openai/protocol.py +51 -6
- sglang/srt/entrypoints/openai/serving_chat.py +52 -76
- sglang/srt/entrypoints/openai/serving_completions.py +1 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/eplb/__init__.py +0 -0
- sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
- sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
- sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
- sglang/srt/{managers → eplb}/expert_distribution.py +18 -1
- sglang/srt/{managers → eplb}/expert_location.py +1 -1
- sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
- sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/activation.py +7 -0
- sglang/srt/layers/amx_utils.py +86 -0
- sglang/srt/layers/attention/ascend_backend.py +219 -0
- sglang/srt/layers/attention/flashattention_backend.py +56 -23
- sglang/srt/layers/attention/tbo_backend.py +37 -9
- sglang/srt/layers/communicator.py +18 -2
- sglang/srt/layers/dp_attention.py +9 -3
- sglang/srt/layers/elementwise.py +76 -12
- sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
- sglang/srt/layers/layernorm.py +41 -0
- sglang/srt/layers/linear.py +99 -12
- sglang/srt/layers/logits_processor.py +15 -6
- sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
- sglang/srt/layers/moe/ep_moe/layer.py +115 -25
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +42 -19
- sglang/srt/layers/moe/fused_moe_native.py +7 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +8 -4
- sglang/srt/layers/moe/fused_moe_triton/layer.py +129 -10
- sglang/srt/layers/moe/router.py +60 -22
- sglang/srt/layers/moe/topk.py +36 -28
- sglang/srt/layers/parameter.py +67 -7
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
- sglang/srt/layers/quantization/fp8.py +44 -0
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +6 -6
- sglang/srt/layers/quantization/gptq.py +5 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -1
- sglang/srt/layers/quantization/quant_utils.py +166 -0
- sglang/srt/layers/quantization/w8a8_int8.py +52 -1
- sglang/srt/layers/rotary_embedding.py +105 -13
- sglang/srt/layers/vocab_parallel_embedding.py +19 -2
- sglang/srt/lora/lora.py +4 -5
- sglang/srt/lora/lora_manager.py +73 -20
- sglang/srt/managers/configure_logging.py +1 -1
- sglang/srt/managers/io_struct.py +60 -15
- sglang/srt/managers/mm_utils.py +73 -59
- sglang/srt/managers/multimodal_processor.py +2 -6
- sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
- sglang/srt/managers/schedule_batch.py +80 -79
- sglang/srt/managers/scheduler.py +153 -63
- sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
- sglang/srt/managers/session_controller.py +12 -3
- sglang/srt/managers/tokenizer_manager.py +314 -103
- sglang/srt/managers/tp_worker.py +13 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
- sglang/srt/mem_cache/allocator.py +290 -0
- sglang/srt/mem_cache/chunk_cache.py +34 -2
- sglang/srt/mem_cache/memory_pool.py +289 -3
- sglang/srt/mem_cache/multimodal_cache.py +3 -0
- sglang/srt/model_executor/cuda_graph_runner.py +3 -2
- sglang/srt/model_executor/forward_batch_info.py +17 -4
- sglang/srt/model_executor/model_runner.py +302 -58
- sglang/srt/model_loader/loader.py +86 -10
- sglang/srt/model_loader/weight_utils.py +160 -3
- sglang/srt/models/deepseek_nextn.py +5 -4
- sglang/srt/models/deepseek_v2.py +305 -26
- sglang/srt/models/deepseek_vl2.py +3 -5
- sglang/srt/models/gemma3_causal.py +1 -2
- sglang/srt/models/gemma3n_audio.py +949 -0
- sglang/srt/models/gemma3n_causal.py +1010 -0
- sglang/srt/models/gemma3n_mm.py +495 -0
- sglang/srt/models/hunyuan.py +771 -0
- sglang/srt/models/kimi_vl.py +1 -2
- sglang/srt/models/llama.py +10 -4
- sglang/srt/models/llama4.py +32 -45
- sglang/srt/models/llama_eagle3.py +61 -11
- sglang/srt/models/llava.py +5 -5
- sglang/srt/models/minicpmo.py +2 -2
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mllama4.py +43 -11
- sglang/srt/models/phi4mm.py +1 -3
- sglang/srt/models/pixtral.py +3 -7
- sglang/srt/models/qwen2.py +31 -3
- sglang/srt/models/qwen2_5_vl.py +1 -3
- sglang/srt/models/qwen2_audio.py +200 -0
- sglang/srt/models/qwen2_moe.py +32 -6
- sglang/srt/models/qwen2_vl.py +1 -4
- sglang/srt/models/qwen3.py +94 -25
- sglang/srt/models/qwen3_moe.py +68 -21
- sglang/srt/models/vila.py +3 -8
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +150 -133
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
- sglang/srt/multimodal/processors/gemma3n.py +82 -0
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
- sglang/srt/operations_strategy.py +6 -2
- sglang/srt/reasoning_parser.py +26 -0
- sglang/srt/sampling/sampling_batch_info.py +39 -1
- sglang/srt/server_args.py +85 -24
- sglang/srt/speculative/build_eagle_tree.py +57 -18
- sglang/srt/speculative/eagle_worker.py +6 -4
- sglang/srt/two_batch_overlap.py +204 -28
- sglang/srt/utils.py +369 -138
- sglang/srt/warmup.py +12 -3
- sglang/test/runners.py +10 -1
- sglang/test/test_utils.py +15 -3
- sglang/version.py +1 -1
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/RECORD +149 -137
- sglang/math_utils.py +0 -8
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
- /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
- /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
@@ -22,6 +22,7 @@ from sglang.srt.entrypoints.openai.protocol import (
|
|
22
22
|
ErrorResponse,
|
23
23
|
FunctionResponse,
|
24
24
|
LogProbs,
|
25
|
+
MessageProcessingResult,
|
25
26
|
ToolCall,
|
26
27
|
TopLogprob,
|
27
28
|
)
|
@@ -62,120 +63,81 @@ class OpenAIServingChat(OpenAIServingBase):
|
|
62
63
|
is_multimodal = self.tokenizer_manager.model_config.is_multimodal
|
63
64
|
|
64
65
|
# Process messages and apply chat template
|
65
|
-
(
|
66
|
-
prompt,
|
67
|
-
prompt_ids,
|
68
|
-
image_data,
|
69
|
-
audio_data,
|
70
|
-
modalities,
|
71
|
-
stop,
|
72
|
-
tool_call_constraint,
|
73
|
-
) = self._process_messages(request, is_multimodal)
|
66
|
+
processed_messages = self._process_messages(request, is_multimodal)
|
74
67
|
|
75
68
|
# Build sampling parameters
|
76
69
|
sampling_params = self._build_sampling_params(
|
77
|
-
request, stop, tool_call_constraint
|
70
|
+
request, processed_messages.stop, processed_messages.tool_call_constraint
|
78
71
|
)
|
79
72
|
|
80
73
|
# Handle single vs multiple requests
|
81
74
|
if is_multimodal:
|
82
|
-
prompt_kwargs = {"text": prompt}
|
75
|
+
prompt_kwargs = {"text": processed_messages.prompt}
|
83
76
|
else:
|
84
|
-
if isinstance(prompt_ids, str):
|
85
|
-
prompt_kwargs = {"text": prompt_ids}
|
77
|
+
if isinstance(processed_messages.prompt_ids, str):
|
78
|
+
prompt_kwargs = {"text": processed_messages.prompt_ids}
|
86
79
|
else:
|
87
|
-
prompt_kwargs = {"input_ids": prompt_ids}
|
80
|
+
prompt_kwargs = {"input_ids": processed_messages.prompt_ids}
|
88
81
|
|
89
82
|
adapted_request = GenerateReqInput(
|
90
83
|
**prompt_kwargs,
|
91
|
-
image_data=image_data,
|
92
|
-
audio_data=audio_data,
|
84
|
+
image_data=processed_messages.image_data,
|
85
|
+
audio_data=processed_messages.audio_data,
|
93
86
|
sampling_params=sampling_params,
|
94
87
|
return_logprob=request.logprobs,
|
95
88
|
logprob_start_len=-1,
|
96
89
|
top_logprobs_num=request.top_logprobs or 0,
|
97
90
|
stream=request.stream,
|
98
91
|
return_text_in_logprobs=True,
|
99
|
-
modalities=modalities,
|
92
|
+
modalities=processed_messages.modalities,
|
100
93
|
lora_path=request.lora_path,
|
101
94
|
bootstrap_host=request.bootstrap_host,
|
102
95
|
bootstrap_port=request.bootstrap_port,
|
103
96
|
bootstrap_room=request.bootstrap_room,
|
104
97
|
return_hidden_states=request.return_hidden_states,
|
98
|
+
rid=request.rid,
|
105
99
|
)
|
106
100
|
|
107
101
|
return adapted_request, request
|
108
102
|
|
109
103
|
def _process_messages(
|
110
104
|
self, request: ChatCompletionRequest, is_multimodal: bool
|
111
|
-
) ->
|
112
|
-
str,
|
113
|
-
Union[str, List[int]],
|
114
|
-
Optional[Any],
|
115
|
-
Optional[Any],
|
116
|
-
List[str],
|
117
|
-
List[str],
|
118
|
-
Optional[Any],
|
119
|
-
]:
|
105
|
+
) -> MessageProcessingResult:
|
120
106
|
"""Process chat messages and apply chat template"""
|
121
107
|
tool_call_constraint = None
|
122
|
-
prompt = ""
|
123
|
-
prompt_ids = []
|
124
108
|
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
tools = [item.function.model_dump() for item in request.tools]
|
109
|
+
# Apply chat template and its stop strings
|
110
|
+
tools = None
|
111
|
+
if request.tools and request.tool_choice != "none":
|
112
|
+
request.skip_special_tokens = False
|
113
|
+
if not isinstance(request.tool_choice, str):
|
114
|
+
tools = [
|
115
|
+
item.function.model_dump()
|
116
|
+
for item in request.tools
|
117
|
+
if item.function.name == request.tool_choice.function.name
|
118
|
+
]
|
119
|
+
else:
|
120
|
+
tools = [item.function.model_dump() for item in request.tools]
|
138
121
|
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
request.tool_choice
|
143
|
-
)
|
122
|
+
tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser
|
123
|
+
parser = FunctionCallParser(request.tools, tool_call_parser)
|
124
|
+
tool_call_constraint = parser.get_structure_constraint(request.tool_choice)
|
144
125
|
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
self._apply_jinja_template(request, tools, is_multimodal)
|
149
|
-
)
|
150
|
-
else:
|
151
|
-
prompt, prompt_ids, image_data, audio_data, modalities, stop = (
|
152
|
-
self._apply_conversation_template(request, is_multimodal)
|
153
|
-
)
|
126
|
+
# Use chat template
|
127
|
+
if self.template_manager.chat_template_name is None:
|
128
|
+
result = self._apply_jinja_template(request, tools, is_multimodal)
|
154
129
|
else:
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
audio_data = None
|
160
|
-
modalities = []
|
161
|
-
prompt = request.messages
|
162
|
-
|
163
|
-
return (
|
164
|
-
prompt,
|
165
|
-
prompt_ids,
|
166
|
-
image_data,
|
167
|
-
audio_data,
|
168
|
-
modalities,
|
169
|
-
stop,
|
170
|
-
tool_call_constraint,
|
171
|
-
)
|
130
|
+
result = self._apply_conversation_template(request, is_multimodal)
|
131
|
+
|
132
|
+
result.tool_call_constraint = tool_call_constraint
|
133
|
+
return result
|
172
134
|
|
173
135
|
def _apply_jinja_template(
|
174
136
|
self,
|
175
137
|
request: ChatCompletionRequest,
|
176
138
|
tools: Optional[List[Dict]],
|
177
139
|
is_multimodal: bool,
|
178
|
-
) ->
|
140
|
+
) -> MessageProcessingResult:
|
179
141
|
"""Apply Jinja chat template"""
|
180
142
|
prompt = ""
|
181
143
|
prompt_ids = []
|
@@ -253,13 +215,20 @@ class OpenAIServingChat(OpenAIServingBase):
|
|
253
215
|
image_data = image_data if image_data else None
|
254
216
|
audio_data = audio_data if audio_data else None
|
255
217
|
modalities = modalities if modalities else []
|
256
|
-
return
|
218
|
+
return MessageProcessingResult(
|
219
|
+
prompt=prompt,
|
220
|
+
prompt_ids=prompt_ids,
|
221
|
+
image_data=image_data,
|
222
|
+
audio_data=audio_data,
|
223
|
+
modalities=modalities,
|
224
|
+
stop=stop,
|
225
|
+
)
|
257
226
|
|
258
227
|
def _apply_conversation_template(
|
259
228
|
self,
|
260
229
|
request: ChatCompletionRequest,
|
261
230
|
is_multimodal: bool,
|
262
|
-
) ->
|
231
|
+
) -> MessageProcessingResult:
|
263
232
|
"""Apply conversation template"""
|
264
233
|
prompt = ""
|
265
234
|
prompt_ids = []
|
@@ -304,7 +273,14 @@ class OpenAIServingChat(OpenAIServingBase):
|
|
304
273
|
if not is_multimodal:
|
305
274
|
prompt_ids = self.tokenizer_manager.tokenizer.encode(prompt)
|
306
275
|
|
307
|
-
return
|
276
|
+
return MessageProcessingResult(
|
277
|
+
prompt=prompt,
|
278
|
+
prompt_ids=prompt_ids,
|
279
|
+
image_data=image_data,
|
280
|
+
audio_data=audio_data,
|
281
|
+
modalities=modalities,
|
282
|
+
stop=stop,
|
283
|
+
)
|
308
284
|
|
309
285
|
def _build_sampling_params(
|
310
286
|
self,
|
File without changes
|
@@ -4,10 +4,8 @@ from typing import TYPE_CHECKING, List
|
|
4
4
|
|
5
5
|
import torch.cuda
|
6
6
|
|
7
|
-
from sglang.srt.
|
8
|
-
|
9
|
-
)
|
10
|
-
from sglang.srt.managers.expert_location import ExpertLocationMetadata
|
7
|
+
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
8
|
+
from sglang.srt.eplb.expert_location import ExpertLocationMetadata
|
11
9
|
|
12
10
|
if TYPE_CHECKING:
|
13
11
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
@@ -24,7 +24,7 @@ import einops
|
|
24
24
|
import torch
|
25
25
|
import torch.distributed
|
26
26
|
|
27
|
-
from sglang.srt.
|
27
|
+
from sglang.srt.eplb.expert_location import ExpertLocationMetadata
|
28
28
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
29
29
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
30
30
|
from sglang.srt.server_args import ServerArgs
|
@@ -61,6 +61,10 @@ class ExpertDistributionRecorder(ABC):
|
|
61
61
|
def with_debug_name(self, debug_name):
|
62
62
|
yield
|
63
63
|
|
64
|
+
@contextmanager
|
65
|
+
def disable_this_region(self):
|
66
|
+
yield
|
67
|
+
|
64
68
|
@contextmanager
|
65
69
|
def with_forward_pass(self, forward_pass_id: int, forward_batch: ForwardBatch):
|
66
70
|
yield
|
@@ -116,6 +120,7 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
|
|
116
120
|
self._expert_location_metadata = expert_location_metadata
|
117
121
|
|
118
122
|
self._recording = False
|
123
|
+
self._disable_all = False
|
119
124
|
self._current_forward_pass_id = Withable()
|
120
125
|
self._current_layer_idx = Withable()
|
121
126
|
self._current_debug_name = Withable()
|
@@ -148,6 +153,16 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
|
|
148
153
|
finally:
|
149
154
|
self._on_forward_pass_end(forward_pass_id)
|
150
155
|
|
156
|
+
@contextmanager
|
157
|
+
def disable_this_region(self):
|
158
|
+
"""Context manager to temporarily disable recording."""
|
159
|
+
previous_disable_all = self._disable_all
|
160
|
+
self._disable_all = True
|
161
|
+
try:
|
162
|
+
yield
|
163
|
+
finally:
|
164
|
+
self._disable_all = previous_disable_all
|
165
|
+
|
151
166
|
def _on_forward_pass_start(self, forward_batch: ForwardBatch):
|
152
167
|
if not self._recording:
|
153
168
|
return
|
@@ -189,6 +204,8 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
|
|
189
204
|
)
|
190
205
|
|
191
206
|
def _on_hook(self, hook_name: str, **kwargs):
|
207
|
+
if self._disable_all:
|
208
|
+
return
|
192
209
|
if not (self._recording or torch.cuda.is_current_stream_capturing()):
|
193
210
|
return
|
194
211
|
gatherer = self._single_pass_gatherers[
|
@@ -23,7 +23,7 @@ import torch.distributed
|
|
23
23
|
import torch.nn.functional as F
|
24
24
|
|
25
25
|
from sglang.srt.configs.model_config import ModelConfig
|
26
|
-
from sglang.srt.
|
26
|
+
from sglang.srt.eplb import eplb_algorithms
|
27
27
|
from sglang.srt.model_loader import get_model_architecture
|
28
28
|
from sglang.srt.server_args import ServerArgs
|
29
29
|
|
@@ -17,7 +17,7 @@ from typing import Literal, Optional
|
|
17
17
|
|
18
18
|
import torch
|
19
19
|
|
20
|
-
from sglang.srt.
|
20
|
+
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
|
21
21
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
22
22
|
|
23
23
|
|
@@ -20,7 +20,7 @@ import torch
|
|
20
20
|
import torch.distributed
|
21
21
|
from torch.distributed import P2POp
|
22
22
|
|
23
|
-
from sglang.srt.
|
23
|
+
from sglang.srt.eplb.expert_location import (
|
24
24
|
ExpertLocationMetadata,
|
25
25
|
get_global_expert_location_metadata,
|
26
26
|
)
|
@@ -30,6 +30,9 @@ from sglang.srt.utils import get_bool_env_var
|
|
30
30
|
logger = logging.getLogger(__name__)
|
31
31
|
|
32
32
|
|
33
|
+
_LOG_INPUT = get_bool_env_var("SGLANG_EXPERT_LOCATION_UPDATER_LOG_INPUT")
|
34
|
+
|
35
|
+
|
33
36
|
class ExpertLocationUpdater:
|
34
37
|
def __init__(self):
|
35
38
|
self._first_execution = True
|
@@ -175,6 +178,19 @@ def update_expert_weights_single_layer(
|
|
175
178
|
assert isinstance(old_physical_to_logical_map, list)
|
176
179
|
assert isinstance(new_physical_to_logical_map, list)
|
177
180
|
|
181
|
+
if _LOG_INPUT:
|
182
|
+
logger.info(
|
183
|
+
"update_expert_weights_single_layer "
|
184
|
+
f"{[x.shape for x in routed_experts_weights]=} "
|
185
|
+
f"{[x.shape for x in temp_buffers]=} "
|
186
|
+
f"{old_physical_to_logical_map=} "
|
187
|
+
f"{new_physical_to_logical_map=} "
|
188
|
+
f"{num_local_physical_experts=} "
|
189
|
+
f"{num_gpu_per_node=} "
|
190
|
+
f"{rank=} "
|
191
|
+
f"{world_size=} "
|
192
|
+
)
|
193
|
+
|
178
194
|
output_logs = [] if debug else None
|
179
195
|
|
180
196
|
num_physical_experts = len(old_physical_to_logical_map)
|
@@ -42,7 +42,7 @@ from sglang.srt.configs import (
|
|
42
42
|
)
|
43
43
|
from sglang.srt.configs.internvl import InternVLChatConfig
|
44
44
|
from sglang.srt.connector import create_remote_connector
|
45
|
-
from sglang.srt.utils import is_remote_url
|
45
|
+
from sglang.srt.utils import is_remote_url, lru_cache_frozenset
|
46
46
|
|
47
47
|
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
48
48
|
ChatGLMConfig.model_type: ChatGLMConfig,
|
@@ -103,6 +103,7 @@ def get_hf_text_config(config: PretrainedConfig):
|
|
103
103
|
return config
|
104
104
|
|
105
105
|
|
106
|
+
@lru_cache_frozenset(maxsize=32)
|
106
107
|
def get_config(
|
107
108
|
model: str,
|
108
109
|
trust_remote_code: bool,
|
sglang/srt/layers/activation.py
CHANGED
@@ -46,6 +46,9 @@ _is_cpu = is_cpu()
|
|
46
46
|
if _is_cuda:
|
47
47
|
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
48
48
|
|
49
|
+
if is_npu():
|
50
|
+
import torch_npu
|
51
|
+
|
49
52
|
logger = logging.getLogger(__name__)
|
50
53
|
|
51
54
|
|
@@ -70,6 +73,10 @@ class SiluAndMul(CustomOp):
|
|
70
73
|
else:
|
71
74
|
return self.forward_native(x)
|
72
75
|
|
76
|
+
def forward_npu(self, x: torch.Tensor) -> torch.Tensor:
|
77
|
+
out = torch_npu.npu_swiglu(x)
|
78
|
+
return out
|
79
|
+
|
73
80
|
|
74
81
|
class GeluAndMul(CustomOp):
|
75
82
|
def __init__(self, approximate="tanh"):
|
@@ -0,0 +1,86 @@
|
|
1
|
+
import logging
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
from sglang.srt.utils import cpu_has_amx_support
|
6
|
+
|
7
|
+
logger = logging.getLogger(__name__)
|
8
|
+
|
9
|
+
|
10
|
+
def amx_process_weight_after_loading(weight):
|
11
|
+
if weight.device != torch.device("cpu"):
|
12
|
+
return weight
|
13
|
+
if not cpu_has_amx_support():
|
14
|
+
return weight
|
15
|
+
|
16
|
+
return torch.ops.sgl_kernel.convert_weight_packed(weight)
|
17
|
+
|
18
|
+
|
19
|
+
# TODO: currently gemm kernel has the below requirements:
|
20
|
+
# OC % TILE_N == 0, where TILE_N = 16
|
21
|
+
# IC % TILE_K == 0, where TILE_K = 32
|
22
|
+
def dim_is_supported(weight):
|
23
|
+
TILE_N = 16
|
24
|
+
TILE_K = 32
|
25
|
+
ndim = weight.ndim
|
26
|
+
OC = weight.size(1) if ndim == 3 else weight.size(0)
|
27
|
+
IC = weight.size(2) if ndim == 3 else weight.size(1)
|
28
|
+
return OC % TILE_N == 0 and IC % TILE_K == 0
|
29
|
+
|
30
|
+
|
31
|
+
def _amx_process_weight_after_loading(
|
32
|
+
module, weight_names, transpose_dims=None
|
33
|
+
) -> None:
|
34
|
+
# Pack weight for get better performance on CPU
|
35
|
+
devices = {getattr(module, weight_name).device for weight_name in weight_names}
|
36
|
+
assert len(devices) == 1, f"Expects all weights to be on the same device"
|
37
|
+
device = devices.pop()
|
38
|
+
|
39
|
+
if transpose_dims:
|
40
|
+
assert len(weight_names) == len(
|
41
|
+
transpose_dims
|
42
|
+
), "len(weight_names) should be equal to len(transpose_dims)"
|
43
|
+
|
44
|
+
for i, weight_name in enumerate(weight_names):
|
45
|
+
weight_tensor = getattr(module, weight_name)
|
46
|
+
|
47
|
+
if transpose_dims and transpose_dims[i]:
|
48
|
+
weight_tensor = weight_tensor.transpose(*transpose_dims[i])
|
49
|
+
|
50
|
+
# We don't pack weight or use intel amx backend if any weight of this module has unsupported dim.
|
51
|
+
if not dim_is_supported(weight_tensor):
|
52
|
+
logger.warning(
|
53
|
+
f"Unsupported dimension for prepacking for weight '{weight_name}' with shape {weight_tensor.shape} in {module}. "
|
54
|
+
f"The derived (OC, IC) dimensions must be divisible by (16, 32). "
|
55
|
+
)
|
56
|
+
module.use_intel_amx_backend = False
|
57
|
+
return
|
58
|
+
|
59
|
+
packed_weight = torch.nn.Parameter(
|
60
|
+
amx_process_weight_after_loading(weight_tensor),
|
61
|
+
requires_grad=False,
|
62
|
+
)
|
63
|
+
packed_weight.__dict__ = weight_tensor.__dict__
|
64
|
+
setattr(module, weight_name, packed_weight)
|
65
|
+
|
66
|
+
module.use_intel_amx_backend = (
|
67
|
+
device == torch.device("cpu") and cpu_has_amx_support()
|
68
|
+
)
|
69
|
+
|
70
|
+
if (
|
71
|
+
module.use_intel_amx_backend
|
72
|
+
and hasattr(module, "bias")
|
73
|
+
and module.bias is not None
|
74
|
+
):
|
75
|
+
module.bias = torch.nn.Parameter(module.bias.data.float(), requires_grad=False)
|
76
|
+
|
77
|
+
|
78
|
+
class PackWeightMethod:
|
79
|
+
def __init__(self, weight_names, transpose_dims=None):
|
80
|
+
self.weight_names = weight_names
|
81
|
+
self.transpose_dims = transpose_dims
|
82
|
+
|
83
|
+
def process_weights_after_loading(self, module) -> None:
|
84
|
+
_amx_process_weight_after_loading(
|
85
|
+
module, self.weight_names, self.transpose_dims
|
86
|
+
)
|