sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +10 -8
- sglang/bench_one_batch.py +7 -6
- sglang/bench_one_batch_server.py +157 -21
- sglang/bench_serving.py +137 -59
- sglang/compile_deep_gemm.py +5 -5
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +78 -78
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +2 -2
- sglang/srt/configs/model_config.py +40 -28
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -43
- sglang/srt/conversation.py +49 -44
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +129 -135
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
- sglang/srt/disaggregation/fake/conn.py +3 -13
- sglang/srt/disaggregation/kv_events.py +357 -0
- sglang/srt/disaggregation/mini_lb.py +57 -24
- sglang/srt/disaggregation/mooncake/conn.py +238 -122
- sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
- sglang/srt/disaggregation/nixl/conn.py +10 -19
- sglang/srt/disaggregation/prefill.py +132 -47
- sglang/srt/disaggregation/utils.py +123 -6
- sglang/srt/distributed/utils.py +3 -3
- sglang/srt/entrypoints/EngineBase.py +5 -0
- sglang/srt/entrypoints/engine.py +44 -9
- sglang/srt/entrypoints/http_server.py +23 -6
- sglang/srt/entrypoints/http_server_engine.py +5 -2
- sglang/srt/function_call/base_format_detector.py +250 -0
- sglang/srt/function_call/core_types.py +34 -0
- sglang/srt/function_call/deepseekv3_detector.py +157 -0
- sglang/srt/function_call/ebnf_composer.py +234 -0
- sglang/srt/function_call/function_call_parser.py +175 -0
- sglang/srt/function_call/llama32_detector.py +74 -0
- sglang/srt/function_call/mistral_detector.py +84 -0
- sglang/srt/function_call/pythonic_detector.py +163 -0
- sglang/srt/function_call/qwen25_detector.py +67 -0
- sglang/srt/function_call/utils.py +35 -0
- sglang/srt/hf_transformers_utils.py +46 -7
- sglang/srt/layers/attention/aiter_backend.py +513 -0
- sglang/srt/layers/attention/flashattention_backend.py +64 -18
- sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
- sglang/srt/layers/attention/flashmla_backend.py +340 -78
- sglang/srt/layers/attention/triton_backend.py +3 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/utils.py +6 -4
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/communicator.py +451 -0
- sglang/srt/layers/dp_attention.py +61 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/cutlass_moe.py +207 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
- sglang/srt/layers/moe/ep_moe/layer.py +105 -51
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
- sglang/srt/layers/moe/topk.py +67 -10
- sglang/srt/layers/multimodal.py +70 -0
- sglang/srt/layers/quantization/__init__.py +8 -3
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/deep_gemm.py +77 -74
- sglang/srt/layers/quantization/fp8.py +92 -2
- sglang/srt/layers/quantization/fp8_kernel.py +3 -3
- sglang/srt/layers/quantization/fp8_utils.py +6 -0
- sglang/srt/layers/quantization/gptq.py +298 -6
- sglang/srt/layers/quantization/int8_kernel.py +20 -7
- sglang/srt/layers/quantization/qoq.py +244 -0
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +2 -4
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/deepseek_eplb.py +278 -0
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/eplb_manager.py +55 -0
- sglang/srt/managers/expert_distribution.py +704 -56
- sglang/srt/managers/expert_location.py +394 -0
- sglang/srt/managers/expert_location_dispatch.py +91 -0
- sglang/srt/managers/io_struct.py +19 -4
- sglang/srt/managers/mm_utils.py +294 -140
- sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
- sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
- sglang/srt/managers/multimodal_processors/internvl.py +14 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
- sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
- sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
- sglang/srt/managers/schedule_batch.py +122 -42
- sglang/srt/managers/schedule_policy.py +1 -5
- sglang/srt/managers/scheduler.py +205 -138
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +232 -58
- sglang/srt/managers/tp_worker.py +12 -9
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/base_prefix_cache.py +3 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -1
- sglang/srt/mem_cache/hiradix_cache.py +4 -4
- sglang/srt/mem_cache/memory_pool.py +76 -52
- sglang/srt/mem_cache/multimodal_cache.py +45 -0
- sglang/srt/mem_cache/radix_cache.py +58 -5
- sglang/srt/metrics/collector.py +314 -39
- sglang/srt/mm_utils.py +10 -0
- sglang/srt/model_executor/cuda_graph_runner.py +29 -19
- sglang/srt/model_executor/expert_location_updater.py +422 -0
- sglang/srt/model_executor/forward_batch_info.py +5 -1
- sglang/srt/model_executor/model_runner.py +163 -68
- sglang/srt/model_loader/loader.py +10 -6
- sglang/srt/models/clip.py +5 -1
- sglang/srt/models/deepseek_janus_pro.py +2 -2
- sglang/srt/models/deepseek_v2.py +308 -351
- sglang/srt/models/exaone.py +8 -3
- sglang/srt/models/gemma3_mm.py +70 -33
- sglang/srt/models/llama.py +2 -0
- sglang/srt/models/llama4.py +15 -8
- sglang/srt/models/llava.py +258 -7
- sglang/srt/models/mimo_mtp.py +220 -0
- sglang/srt/models/minicpmo.py +5 -12
- sglang/srt/models/mistral.py +71 -1
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/mllama.py +3 -3
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/qwen2.py +95 -26
- sglang/srt/models/qwen2_5_vl.py +8 -0
- sglang/srt/models/qwen2_moe.py +330 -60
- sglang/srt/models/qwen2_vl.py +6 -0
- sglang/srt/models/qwen3.py +52 -10
- sglang/srt/models/qwen3_moe.py +411 -48
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/siglip.py +294 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/openai_api/adapter.py +58 -20
- sglang/srt/openai_api/protocol.py +6 -8
- sglang/srt/operations.py +154 -0
- sglang/srt/operations_strategy.py +31 -0
- sglang/srt/reasoning_parser.py +3 -3
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +4 -56
- sglang/srt/sampling/sampling_params.py +2 -2
- sglang/srt/server_args.py +162 -22
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +138 -7
- sglang/srt/speculative/eagle_worker.py +69 -21
- sglang/srt/utils.py +74 -17
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_cutlass_moe.py +278 -0
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +55 -14
- sglang/utils.py +3 -3
- sglang/version.py +1 -1
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
- sglang/srt/function_call_parser.py +0 -858
- sglang/srt/platforms/interface.py +0 -371
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
sglang/srt/models/qwen2_moe.py
CHANGED
@@ -16,7 +16,10 @@
|
|
16
16
|
# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen2_moe.py
|
17
17
|
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
|
18
18
|
|
19
|
-
|
19
|
+
import logging
|
20
|
+
from dataclasses import dataclass
|
21
|
+
from enum import Enum, auto
|
22
|
+
from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
20
23
|
|
21
24
|
import torch
|
22
25
|
import torch.nn.functional as F
|
@@ -24,10 +27,20 @@ from torch import nn
|
|
24
27
|
from transformers import PretrainedConfig
|
25
28
|
|
26
29
|
from sglang.srt.distributed import (
|
30
|
+
get_pp_group,
|
27
31
|
get_tensor_model_parallel_world_size,
|
28
32
|
tensor_model_parallel_all_reduce,
|
29
33
|
)
|
30
34
|
from sglang.srt.layers.activation import SiluAndMul
|
35
|
+
from sglang.srt.layers.dp_attention import (
|
36
|
+
attn_tp_all_gather,
|
37
|
+
attn_tp_reduce_scatter,
|
38
|
+
dp_gather_partial,
|
39
|
+
dp_scatter,
|
40
|
+
get_attention_tp_rank,
|
41
|
+
get_attention_tp_size,
|
42
|
+
get_local_attention_dp_size,
|
43
|
+
)
|
31
44
|
from sglang.srt.layers.layernorm import RMSNorm
|
32
45
|
from sglang.srt.layers.linear import (
|
33
46
|
MergedColumnParallelLinear,
|
@@ -35,23 +48,28 @@ from sglang.srt.layers.linear import (
|
|
35
48
|
ReplicatedLinear,
|
36
49
|
RowParallelLinear,
|
37
50
|
)
|
38
|
-
from sglang.srt.layers.logits_processor import LogitsProcessor
|
51
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
39
52
|
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
40
53
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
41
54
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
42
55
|
from sglang.srt.layers.radix_attention import RadixAttention
|
43
56
|
from sglang.srt.layers.rotary_embedding import get_rope
|
57
|
+
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
|
44
58
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
45
59
|
ParallelLMHead,
|
46
60
|
VocabParallelEmbedding,
|
47
61
|
)
|
48
|
-
from sglang.srt.managers.expert_distribution import
|
62
|
+
from sglang.srt.managers.expert_distribution import (
|
63
|
+
ExpertDistributionRecorder,
|
64
|
+
get_global_expert_distribution_recorder,
|
65
|
+
)
|
66
|
+
from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
|
49
67
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
50
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
68
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
51
69
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
52
70
|
from sglang.srt.utils import add_prefix, make_layers
|
53
71
|
|
54
|
-
|
72
|
+
logger = logging.getLogger(__name__)
|
55
73
|
|
56
74
|
|
57
75
|
class Qwen2MoeMLP(nn.Module):
|
@@ -82,8 +100,7 @@ class Qwen2MoeMLP(nn.Module):
|
|
82
100
|
)
|
83
101
|
if hidden_act != "silu":
|
84
102
|
raise ValueError(
|
85
|
-
f"Unsupported activation: {hidden_act}. "
|
86
|
-
"Only silu is supported for now."
|
103
|
+
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
|
87
104
|
)
|
88
105
|
self.act_fn = SiluAndMul()
|
89
106
|
|
@@ -160,7 +177,6 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|
160
177
|
)
|
161
178
|
if shared_output is not None:
|
162
179
|
final_hidden_states = final_hidden_states + shared_output
|
163
|
-
if self.tp_size > 1:
|
164
180
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
165
181
|
|
166
182
|
return final_hidden_states.view(num_tokens, hidden_dim)
|
@@ -182,20 +198,23 @@ class Qwen2MoeAttention(nn.Module):
|
|
182
198
|
) -> None:
|
183
199
|
super().__init__()
|
184
200
|
self.hidden_size = hidden_size
|
185
|
-
|
201
|
+
|
202
|
+
attn_tp_rank = get_attention_tp_rank()
|
203
|
+
attn_tp_size = get_attention_tp_size()
|
204
|
+
|
186
205
|
self.total_num_heads = num_heads
|
187
|
-
assert self.total_num_heads %
|
188
|
-
self.num_heads = self.total_num_heads //
|
206
|
+
assert self.total_num_heads % attn_tp_size == 0
|
207
|
+
self.num_heads = self.total_num_heads // attn_tp_size
|
189
208
|
self.total_num_kv_heads = num_kv_heads
|
190
|
-
if self.total_num_kv_heads >=
|
209
|
+
if self.total_num_kv_heads >= attn_tp_size:
|
191
210
|
# Number of KV heads is greater than TP size, so we partition
|
192
211
|
# the KV heads across multiple tensor parallel GPUs.
|
193
|
-
assert self.total_num_kv_heads %
|
212
|
+
assert self.total_num_kv_heads % attn_tp_size == 0
|
194
213
|
else:
|
195
214
|
# Number of KV heads is less than TP size, so we replicate
|
196
215
|
# the KV heads across multiple tensor parallel GPUs.
|
197
|
-
assert
|
198
|
-
self.num_kv_heads = max(1, self.total_num_kv_heads //
|
216
|
+
assert attn_tp_size % self.total_num_kv_heads == 0
|
217
|
+
self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size)
|
199
218
|
self.head_dim = hidden_size // self.total_num_heads
|
200
219
|
self.q_size = self.num_heads * self.head_dim
|
201
220
|
self.kv_size = self.num_kv_heads * self.head_dim
|
@@ -210,6 +229,8 @@ class Qwen2MoeAttention(nn.Module):
|
|
210
229
|
self.total_num_kv_heads,
|
211
230
|
bias=qkv_bias,
|
212
231
|
quant_config=quant_config,
|
232
|
+
tp_rank=attn_tp_rank,
|
233
|
+
tp_size=attn_tp_size,
|
213
234
|
prefix=add_prefix("qkv_proj", prefix),
|
214
235
|
)
|
215
236
|
|
@@ -218,6 +239,9 @@ class Qwen2MoeAttention(nn.Module):
|
|
218
239
|
hidden_size,
|
219
240
|
bias=False,
|
220
241
|
quant_config=quant_config,
|
242
|
+
tp_rank=attn_tp_rank,
|
243
|
+
tp_size=attn_tp_size,
|
244
|
+
reduce_results=False,
|
221
245
|
prefix=add_prefix("o_proj", prefix),
|
222
246
|
)
|
223
247
|
|
@@ -252,6 +276,19 @@ class Qwen2MoeAttention(nn.Module):
|
|
252
276
|
return output
|
253
277
|
|
254
278
|
|
279
|
+
class _FFNInputMode(Enum):
|
280
|
+
# The MLP sublayer requires 1/tp_size tokens as input
|
281
|
+
SCATTERED = auto()
|
282
|
+
# The MLP sublayer requires all tokens as input
|
283
|
+
FULL = auto()
|
284
|
+
|
285
|
+
|
286
|
+
@dataclass
|
287
|
+
class _DecoderLayerInfo:
|
288
|
+
is_sparse: bool
|
289
|
+
ffn_input_mode: _FFNInputMode
|
290
|
+
|
291
|
+
|
255
292
|
class Qwen2MoeDecoderLayer(nn.Module):
|
256
293
|
def __init__(
|
257
294
|
self,
|
@@ -279,14 +316,21 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
|
279
316
|
prefix=add_prefix("self_attn", prefix),
|
280
317
|
)
|
281
318
|
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
319
|
+
self.layer_id = layer_id
|
320
|
+
|
321
|
+
self.attn_tp_size = get_attention_tp_size()
|
322
|
+
self.attn_tp_rank = get_attention_tp_rank()
|
323
|
+
self.local_dp_size = get_local_attention_dp_size()
|
324
|
+
|
325
|
+
self.info = self._compute_info(config, layer_id=layer_id)
|
326
|
+
previous_layer_info = self._compute_info(config, layer_id=layer_id - 1)
|
327
|
+
self.input_is_scattered = (
|
328
|
+
layer_id > 0
|
329
|
+
and previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
|
286
330
|
)
|
287
|
-
|
288
|
-
|
289
|
-
|
331
|
+
self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
|
332
|
+
|
333
|
+
if self.info.is_sparse:
|
290
334
|
self.mlp = Qwen2MoeSparseMoeBlock(
|
291
335
|
config=config,
|
292
336
|
quant_config=quant_config,
|
@@ -305,28 +349,185 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
|
305
349
|
config.hidden_size, eps=config.rms_norm_eps
|
306
350
|
)
|
307
351
|
|
352
|
+
@staticmethod
|
353
|
+
def _enable_moe_dense_fully_dp():
|
354
|
+
return global_server_args_dict["moe_dense_tp_size"] == 1
|
355
|
+
|
356
|
+
@staticmethod
|
357
|
+
def _compute_info(config: PretrainedConfig, layer_id: int):
|
358
|
+
# WARN: Qwen2MOE has no dense_layer, it is only for compatibility.
|
359
|
+
mlp_only_layers = (
|
360
|
+
[] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers
|
361
|
+
)
|
362
|
+
is_sparse = (layer_id not in mlp_only_layers) and (
|
363
|
+
config.num_experts > 0 and (layer_id + 1) % config.decoder_sparse_step == 0
|
364
|
+
)
|
365
|
+
ffn_input_mode = (
|
366
|
+
_FFNInputMode.SCATTERED
|
367
|
+
if (global_server_args_dict["enable_deepep_moe"] and is_sparse)
|
368
|
+
or (Qwen2MoeDecoderLayer._enable_moe_dense_fully_dp() and not is_sparse)
|
369
|
+
else _FFNInputMode.FULL
|
370
|
+
)
|
371
|
+
return _DecoderLayerInfo(is_sparse=is_sparse, ffn_input_mode=ffn_input_mode)
|
372
|
+
|
308
373
|
def forward(
|
309
374
|
self,
|
310
375
|
positions: torch.Tensor,
|
311
376
|
hidden_states: torch.Tensor,
|
312
377
|
forward_batch: ForwardBatch,
|
313
378
|
residual: Optional[torch.Tensor],
|
314
|
-
) -> torch.Tensor:
|
315
|
-
|
316
|
-
|
379
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
380
|
+
if self.info.ffn_input_mode == _FFNInputMode.SCATTERED:
|
381
|
+
return self.forward_ffn_with_scattered_input(
|
382
|
+
positions, hidden_states, forward_batch, residual
|
383
|
+
)
|
384
|
+
elif self.info.ffn_input_mode == _FFNInputMode.FULL:
|
385
|
+
return self.forward_ffn_with_full_input(
|
386
|
+
positions, hidden_states, forward_batch, residual
|
387
|
+
)
|
388
|
+
else:
|
389
|
+
raise NotImplementedError
|
390
|
+
|
391
|
+
def forward_ffn_with_full_input(
|
392
|
+
self,
|
393
|
+
positions: torch.Tensor,
|
394
|
+
hidden_states: torch.Tensor,
|
395
|
+
forward_batch: ForwardBatch,
|
396
|
+
residual: Optional[torch.Tensor],
|
397
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
398
|
+
if hidden_states.shape[0] == 0:
|
317
399
|
residual = hidden_states
|
318
|
-
hidden_states = self.input_layernorm(hidden_states)
|
319
400
|
else:
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
401
|
+
if residual is None:
|
402
|
+
residual = hidden_states
|
403
|
+
hidden_states = self.input_layernorm(hidden_states)
|
404
|
+
else:
|
405
|
+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
406
|
+
|
407
|
+
# Self Attention
|
408
|
+
hidden_states = self.self_attn(
|
409
|
+
positions=positions,
|
410
|
+
hidden_states=hidden_states,
|
411
|
+
forward_batch=forward_batch,
|
412
|
+
)
|
413
|
+
# Gather
|
414
|
+
if get_tensor_model_parallel_world_size() > 1:
|
415
|
+
# all gather and all reduce
|
416
|
+
if self.local_dp_size != 1:
|
417
|
+
if self.attn_tp_rank == 0:
|
418
|
+
hidden_states += residual
|
419
|
+
hidden_states, local_hidden_states = (
|
420
|
+
forward_batch.gathered_buffer,
|
421
|
+
hidden_states,
|
422
|
+
)
|
423
|
+
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
|
424
|
+
dp_scatter(residual, hidden_states, forward_batch)
|
425
|
+
# TODO extract this bugfix
|
426
|
+
if hidden_states.shape[0] != 0:
|
427
|
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
428
|
+
else:
|
429
|
+
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
430
|
+
# TODO extract this bugfix
|
431
|
+
if hidden_states.shape[0] != 0:
|
432
|
+
hidden_states, residual = self.post_attention_layernorm(
|
433
|
+
hidden_states, residual
|
434
|
+
)
|
435
|
+
elif hidden_states.shape[0] != 0:
|
436
|
+
hidden_states, residual = self.post_attention_layernorm(
|
437
|
+
hidden_states, residual
|
438
|
+
)
|
326
439
|
|
327
440
|
# Fully Connected
|
328
|
-
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
329
441
|
hidden_states = self.mlp(hidden_states)
|
442
|
+
|
443
|
+
# TODO: use reduce-scatter in MLP to avoid this scatter
|
444
|
+
# Scatter
|
445
|
+
if self.local_dp_size != 1:
|
446
|
+
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
447
|
+
# be careful about this!
|
448
|
+
hidden_states, global_hidden_states = (
|
449
|
+
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
450
|
+
hidden_states,
|
451
|
+
)
|
452
|
+
dp_scatter(hidden_states, global_hidden_states, forward_batch)
|
453
|
+
|
454
|
+
return hidden_states, residual
|
455
|
+
|
456
|
+
def forward_ffn_with_scattered_input(
|
457
|
+
self,
|
458
|
+
positions: torch.Tensor,
|
459
|
+
hidden_states: torch.Tensor,
|
460
|
+
forward_batch: ForwardBatch,
|
461
|
+
residual: Optional[torch.Tensor],
|
462
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
463
|
+
if hidden_states.shape[0] == 0:
|
464
|
+
residual = hidden_states
|
465
|
+
else:
|
466
|
+
if residual is None:
|
467
|
+
residual = hidden_states
|
468
|
+
hidden_states = self.input_layernorm(hidden_states)
|
469
|
+
else:
|
470
|
+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
471
|
+
|
472
|
+
if self.attn_tp_size != 1 and self.input_is_scattered:
|
473
|
+
hidden_states, local_hidden_states = (
|
474
|
+
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
475
|
+
hidden_states,
|
476
|
+
)
|
477
|
+
attn_tp_all_gather(
|
478
|
+
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
|
479
|
+
)
|
480
|
+
|
481
|
+
# Self Attention
|
482
|
+
if hidden_states.shape[0] != 0:
|
483
|
+
hidden_states = self.self_attn(
|
484
|
+
positions=positions,
|
485
|
+
hidden_states=hidden_states,
|
486
|
+
forward_batch=forward_batch,
|
487
|
+
)
|
488
|
+
|
489
|
+
if self.attn_tp_size != 1:
|
490
|
+
if self.input_is_scattered:
|
491
|
+
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
|
492
|
+
hidden_states = tensor_list[self.attn_tp_rank]
|
493
|
+
attn_tp_reduce_scatter(hidden_states, tensor_list)
|
494
|
+
if hidden_states.shape[0] != 0:
|
495
|
+
hidden_states, residual = self.post_attention_layernorm(
|
496
|
+
hidden_states, residual
|
497
|
+
)
|
498
|
+
else:
|
499
|
+
if self.attn_tp_rank == 0:
|
500
|
+
hidden_states += residual
|
501
|
+
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
|
502
|
+
hidden_states = tensor_list[self.attn_tp_rank]
|
503
|
+
attn_tp_reduce_scatter(hidden_states, tensor_list)
|
504
|
+
residual = hidden_states
|
505
|
+
if hidden_states.shape[0] != 0:
|
506
|
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
507
|
+
else:
|
508
|
+
if hidden_states.shape[0] != 0:
|
509
|
+
hidden_states, residual = self.post_attention_layernorm(
|
510
|
+
hidden_states, residual
|
511
|
+
)
|
512
|
+
|
513
|
+
if not (
|
514
|
+
self._enable_moe_dense_fully_dp()
|
515
|
+
and (not self.info.is_sparse)
|
516
|
+
and hidden_states.shape[0] == 0
|
517
|
+
):
|
518
|
+
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
|
519
|
+
|
520
|
+
if self.is_last_layer and self.attn_tp_size != 1:
|
521
|
+
hidden_states += residual
|
522
|
+
residual = None
|
523
|
+
hidden_states, local_hidden_states = (
|
524
|
+
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
525
|
+
hidden_states,
|
526
|
+
)
|
527
|
+
attn_tp_all_gather(
|
528
|
+
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
|
529
|
+
)
|
530
|
+
|
330
531
|
return hidden_states, residual
|
331
532
|
|
332
533
|
|
@@ -341,15 +542,21 @@ class Qwen2MoeModel(nn.Module):
|
|
341
542
|
super().__init__()
|
342
543
|
self.padding_idx = config.pad_token_id
|
343
544
|
self.vocab_size = config.vocab_size
|
545
|
+
self.pp_group = get_pp_group()
|
546
|
+
|
547
|
+
if self.pp_group.is_first_rank:
|
548
|
+
self.embed_tokens = VocabParallelEmbedding(
|
549
|
+
config.vocab_size,
|
550
|
+
config.hidden_size,
|
551
|
+
enable_tp=not global_server_args_dict["enable_dp_attention"],
|
552
|
+
prefix=add_prefix("embed_tokens", prefix),
|
553
|
+
)
|
554
|
+
else:
|
555
|
+
self.embed_tokens = PPMissingLayer()
|
344
556
|
|
345
|
-
self.embed_tokens = VocabParallelEmbedding(
|
346
|
-
config.vocab_size,
|
347
|
-
config.hidden_size,
|
348
|
-
prefix=add_prefix("embed_tokens", prefix),
|
349
|
-
)
|
350
557
|
# Use the provided decoder layer type or default to Qwen2MoeDecoderLayer
|
351
558
|
decoder_layer_type = decoder_layer_type or Qwen2MoeDecoderLayer
|
352
|
-
self.layers = make_layers(
|
559
|
+
self.layers, self.start_layer, self.end_layer = make_layers(
|
353
560
|
config.num_hidden_layers,
|
354
561
|
lambda idx, prefix: decoder_layer_type(
|
355
562
|
layer_id=idx,
|
@@ -357,9 +564,14 @@ class Qwen2MoeModel(nn.Module):
|
|
357
564
|
quant_config=quant_config,
|
358
565
|
prefix=prefix,
|
359
566
|
),
|
567
|
+
pp_rank=self.pp_group.rank_in_group,
|
568
|
+
pp_size=self.pp_group.world_size,
|
360
569
|
prefix=add_prefix("layers", prefix),
|
361
570
|
)
|
362
|
-
self.
|
571
|
+
if self.pp_group.is_last_rank:
|
572
|
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
573
|
+
else:
|
574
|
+
self.norm = PPMissingLayer(return_tuple=True)
|
363
575
|
|
364
576
|
def forward(
|
365
577
|
self,
|
@@ -367,24 +579,42 @@ class Qwen2MoeModel(nn.Module):
|
|
367
579
|
positions: torch.Tensor,
|
368
580
|
forward_batch: ForwardBatch,
|
369
581
|
input_embeds: torch.Tensor = None,
|
370
|
-
|
371
|
-
|
372
|
-
|
582
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
583
|
+
) -> Union[torch.Tensor, PPProxyTensors]:
|
584
|
+
if self.pp_group.is_first_rank:
|
585
|
+
if input_embeds is None:
|
586
|
+
hidden_states = self.embed_tokens(input_ids)
|
587
|
+
else:
|
588
|
+
hidden_states = input_embeds
|
589
|
+
residual = None
|
373
590
|
else:
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
591
|
+
assert pp_proxy_tensors is not None
|
592
|
+
hidden_states = pp_proxy_tensors["hidden_states"]
|
593
|
+
residual = pp_proxy_tensors["residual"]
|
594
|
+
|
595
|
+
for i in range(self.start_layer, self.end_layer):
|
596
|
+
with get_global_expert_distribution_recorder().with_current_layer(i):
|
597
|
+
layer = self.layers[i]
|
598
|
+
hidden_states, residual = layer(
|
599
|
+
positions, hidden_states, forward_batch, residual
|
600
|
+
)
|
601
|
+
if not self.pp_group.is_last_rank:
|
602
|
+
return PPProxyTensors(
|
603
|
+
{
|
604
|
+
"hidden_states": hidden_states,
|
605
|
+
"residual": residual,
|
606
|
+
}
|
381
607
|
)
|
382
|
-
|
608
|
+
else:
|
609
|
+
if hidden_states.shape[0] != 0:
|
610
|
+
if residual is None:
|
611
|
+
hidden_states = self.norm(hidden_states)
|
612
|
+
else:
|
613
|
+
hidden_states, _ = self.norm(hidden_states, residual)
|
383
614
|
return hidden_states
|
384
615
|
|
385
616
|
|
386
617
|
class Qwen2MoeForCausalLM(nn.Module):
|
387
|
-
|
388
618
|
fall_back_to_pt_during_load = False
|
389
619
|
|
390
620
|
def __init__(
|
@@ -394,6 +624,7 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|
394
624
|
prefix: str = "",
|
395
625
|
) -> None:
|
396
626
|
super().__init__()
|
627
|
+
self.pp_group = get_pp_group()
|
397
628
|
self.config = config
|
398
629
|
self.quant_config = quant_config
|
399
630
|
self.model = Qwen2MoeModel(
|
@@ -414,11 +645,29 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|
414
645
|
positions: torch.Tensor,
|
415
646
|
forward_batch: ForwardBatch,
|
416
647
|
input_embeds: torch.Tensor = None,
|
648
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
417
649
|
) -> torch.Tensor:
|
418
|
-
hidden_states = self.model(
|
419
|
-
|
420
|
-
|
650
|
+
hidden_states = self.model(
|
651
|
+
input_ids,
|
652
|
+
positions,
|
653
|
+
forward_batch,
|
654
|
+
input_embeds,
|
655
|
+
pp_proxy_tensors=pp_proxy_tensors,
|
421
656
|
)
|
657
|
+
if self.pp_group.is_last_rank:
|
658
|
+
return self.logits_processor(
|
659
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
660
|
+
)
|
661
|
+
else:
|
662
|
+
return hidden_states
|
663
|
+
|
664
|
+
@property
|
665
|
+
def start_layer(self):
|
666
|
+
return self.model.start_layer
|
667
|
+
|
668
|
+
@property
|
669
|
+
def end_layer(self):
|
670
|
+
return self.model.end_layer
|
422
671
|
|
423
672
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
424
673
|
stacked_params_mapping = [
|
@@ -441,6 +690,16 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|
441
690
|
|
442
691
|
params_dict = dict(self.named_parameters())
|
443
692
|
for name, loaded_weight in weights:
|
693
|
+
layer_id = get_layer_id(name)
|
694
|
+
if (
|
695
|
+
layer_id is not None
|
696
|
+
and hasattr(self.model, "start_layer")
|
697
|
+
and (
|
698
|
+
layer_id < self.model.start_layer
|
699
|
+
or layer_id >= self.model.end_layer
|
700
|
+
)
|
701
|
+
):
|
702
|
+
continue
|
444
703
|
if "rotary_emb.inv_freq" in name:
|
445
704
|
continue
|
446
705
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
@@ -489,11 +748,22 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|
489
748
|
if name not in params_dict:
|
490
749
|
continue
|
491
750
|
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
751
|
+
if name in params_dict.keys():
|
752
|
+
param = params_dict[name]
|
753
|
+
weight_loader = getattr(
|
754
|
+
param, "weight_loader", default_weight_loader
|
755
|
+
)
|
756
|
+
weight_loader(param, loaded_weight)
|
757
|
+
else:
|
758
|
+
logger.warning(f"Parameter {name} not found in params_dict")
|
759
|
+
|
760
|
+
@classmethod
|
761
|
+
def get_model_config_for_expert_location(cls, config):
|
762
|
+
return ModelConfigForExpertLocation(
|
763
|
+
num_layers=config.num_hidden_layers,
|
764
|
+
num_logical_experts=config.num_experts,
|
765
|
+
num_groups=None,
|
766
|
+
)
|
497
767
|
|
498
768
|
|
499
769
|
EntryClass = Qwen2MoeForCausalLM
|
sglang/srt/models/qwen2_vl.py
CHANGED
@@ -486,6 +486,12 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|
486
486
|
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
487
487
|
|
488
488
|
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
489
|
+
if any(item.precomputed_features is not None for item in items):
|
490
|
+
if not all(item.precomputed_features is not None for item in items):
|
491
|
+
raise NotImplementedError(
|
492
|
+
"MM inputs where only some items are precomputed."
|
493
|
+
)
|
494
|
+
return torch.concat([item.precomputed_features for item in items])
|
489
495
|
# in qwen-vl, last dim is the same
|
490
496
|
pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
|
491
497
|
self.visual.dtype
|
sglang/srt/models/qwen3.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
# Adapted from qwen2.py
|
2
2
|
|
3
|
+
import logging
|
3
4
|
from functools import partial
|
4
5
|
from typing import Any, Dict, Iterable, Optional, Tuple
|
5
6
|
|
@@ -7,6 +8,7 @@ import torch
|
|
7
8
|
from torch import nn
|
8
9
|
|
9
10
|
from sglang.srt.distributed import (
|
11
|
+
get_pp_group,
|
10
12
|
get_tensor_model_parallel_rank,
|
11
13
|
get_tensor_model_parallel_world_size,
|
12
14
|
split_tensor_along_last_dim,
|
@@ -19,8 +21,9 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
|
|
19
21
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
20
22
|
from sglang.srt.layers.radix_attention import RadixAttention
|
21
23
|
from sglang.srt.layers.rotary_embedding import get_rope
|
24
|
+
from sglang.srt.layers.utils import get_layer_id
|
22
25
|
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
23
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
26
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
24
27
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
25
28
|
from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
|
26
29
|
from sglang.srt.models.qwen2 import Qwen2Model
|
@@ -28,6 +31,8 @@ from sglang.srt.utils import add_prefix
|
|
28
31
|
|
29
32
|
Qwen3Config = None
|
30
33
|
|
34
|
+
logger = logging.getLogger(__name__)
|
35
|
+
|
31
36
|
|
32
37
|
class Qwen3Attention(nn.Module):
|
33
38
|
def __init__(
|
@@ -238,6 +243,7 @@ class Qwen3ForCausalLM(nn.Module):
|
|
238
243
|
prefix: str = "",
|
239
244
|
) -> None:
|
240
245
|
super().__init__()
|
246
|
+
self.pp_group = get_pp_group()
|
241
247
|
self.config = config
|
242
248
|
self.quant_config = quant_config
|
243
249
|
self.model = Qwen3Model(
|
@@ -266,14 +272,33 @@ class Qwen3ForCausalLM(nn.Module):
|
|
266
272
|
forward_batch: ForwardBatch,
|
267
273
|
input_embeds: torch.Tensor = None,
|
268
274
|
get_embedding: bool = False,
|
275
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
269
276
|
) -> torch.Tensor:
|
270
|
-
hidden_states = self.model(
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
277
|
+
hidden_states = self.model(
|
278
|
+
input_ids,
|
279
|
+
positions,
|
280
|
+
forward_batch,
|
281
|
+
input_embeds,
|
282
|
+
pp_proxy_tensors=pp_proxy_tensors,
|
283
|
+
)
|
284
|
+
|
285
|
+
if self.pp_group.is_last_rank:
|
286
|
+
if not get_embedding:
|
287
|
+
return self.logits_processor(
|
288
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
289
|
+
)
|
290
|
+
else:
|
291
|
+
return self.pooler(hidden_states, forward_batch)
|
275
292
|
else:
|
276
|
-
return
|
293
|
+
return hidden_states
|
294
|
+
|
295
|
+
@property
|
296
|
+
def start_layer(self):
|
297
|
+
return self.model.start_layer
|
298
|
+
|
299
|
+
@property
|
300
|
+
def end_layer(self):
|
301
|
+
return self.model.end_layer
|
277
302
|
|
278
303
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
279
304
|
stacked_params_mapping = [
|
@@ -287,6 +312,17 @@ class Qwen3ForCausalLM(nn.Module):
|
|
287
312
|
|
288
313
|
params_dict = dict(self.named_parameters())
|
289
314
|
for name, loaded_weight in weights:
|
315
|
+
layer_id = get_layer_id(name)
|
316
|
+
if (
|
317
|
+
layer_id is not None
|
318
|
+
and hasattr(self.model, "start_layer")
|
319
|
+
and (
|
320
|
+
layer_id < self.model.start_layer
|
321
|
+
or layer_id >= self.model.end_layer
|
322
|
+
)
|
323
|
+
):
|
324
|
+
continue
|
325
|
+
|
290
326
|
if "rotary_emb.inv_freq" in name or "projector" in name:
|
291
327
|
continue
|
292
328
|
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
@@ -313,9 +349,15 @@ class Qwen3ForCausalLM(nn.Module):
|
|
313
349
|
# Skip loading extra bias for GPTQ models.
|
314
350
|
if name.endswith(".bias") and name not in params_dict:
|
315
351
|
continue
|
316
|
-
|
317
|
-
|
318
|
-
|
352
|
+
|
353
|
+
if name in params_dict.keys():
|
354
|
+
param = params_dict[name]
|
355
|
+
weight_loader = getattr(
|
356
|
+
param, "weight_loader", default_weight_loader
|
357
|
+
)
|
358
|
+
weight_loader(param, loaded_weight)
|
359
|
+
else:
|
360
|
+
logger.warning(f"Parameter {name} not found in params_dict")
|
319
361
|
|
320
362
|
def get_embed_and_head(self):
|
321
363
|
return self.model.embed_tokens.weight, self.lm_head.weight
|