sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +10 -8
- sglang/bench_one_batch.py +7 -6
- sglang/bench_one_batch_server.py +157 -21
- sglang/bench_serving.py +137 -59
- sglang/compile_deep_gemm.py +5 -5
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +78 -78
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +2 -2
- sglang/srt/configs/model_config.py +40 -28
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -43
- sglang/srt/conversation.py +49 -44
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +129 -135
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
- sglang/srt/disaggregation/fake/conn.py +3 -13
- sglang/srt/disaggregation/kv_events.py +357 -0
- sglang/srt/disaggregation/mini_lb.py +57 -24
- sglang/srt/disaggregation/mooncake/conn.py +238 -122
- sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
- sglang/srt/disaggregation/nixl/conn.py +10 -19
- sglang/srt/disaggregation/prefill.py +132 -47
- sglang/srt/disaggregation/utils.py +123 -6
- sglang/srt/distributed/utils.py +3 -3
- sglang/srt/entrypoints/EngineBase.py +5 -0
- sglang/srt/entrypoints/engine.py +44 -9
- sglang/srt/entrypoints/http_server.py +23 -6
- sglang/srt/entrypoints/http_server_engine.py +5 -2
- sglang/srt/function_call/base_format_detector.py +250 -0
- sglang/srt/function_call/core_types.py +34 -0
- sglang/srt/function_call/deepseekv3_detector.py +157 -0
- sglang/srt/function_call/ebnf_composer.py +234 -0
- sglang/srt/function_call/function_call_parser.py +175 -0
- sglang/srt/function_call/llama32_detector.py +74 -0
- sglang/srt/function_call/mistral_detector.py +84 -0
- sglang/srt/function_call/pythonic_detector.py +163 -0
- sglang/srt/function_call/qwen25_detector.py +67 -0
- sglang/srt/function_call/utils.py +35 -0
- sglang/srt/hf_transformers_utils.py +46 -7
- sglang/srt/layers/attention/aiter_backend.py +513 -0
- sglang/srt/layers/attention/flashattention_backend.py +64 -18
- sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
- sglang/srt/layers/attention/flashmla_backend.py +340 -78
- sglang/srt/layers/attention/triton_backend.py +3 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/utils.py +6 -4
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/communicator.py +451 -0
- sglang/srt/layers/dp_attention.py +61 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/cutlass_moe.py +207 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
- sglang/srt/layers/moe/ep_moe/layer.py +105 -51
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
- sglang/srt/layers/moe/topk.py +67 -10
- sglang/srt/layers/multimodal.py +70 -0
- sglang/srt/layers/quantization/__init__.py +8 -3
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/deep_gemm.py +77 -74
- sglang/srt/layers/quantization/fp8.py +92 -2
- sglang/srt/layers/quantization/fp8_kernel.py +3 -3
- sglang/srt/layers/quantization/fp8_utils.py +6 -0
- sglang/srt/layers/quantization/gptq.py +298 -6
- sglang/srt/layers/quantization/int8_kernel.py +20 -7
- sglang/srt/layers/quantization/qoq.py +244 -0
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +2 -4
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/deepseek_eplb.py +278 -0
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/eplb_manager.py +55 -0
- sglang/srt/managers/expert_distribution.py +704 -56
- sglang/srt/managers/expert_location.py +394 -0
- sglang/srt/managers/expert_location_dispatch.py +91 -0
- sglang/srt/managers/io_struct.py +19 -4
- sglang/srt/managers/mm_utils.py +294 -140
- sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
- sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
- sglang/srt/managers/multimodal_processors/internvl.py +14 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
- sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
- sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
- sglang/srt/managers/schedule_batch.py +122 -42
- sglang/srt/managers/schedule_policy.py +1 -5
- sglang/srt/managers/scheduler.py +205 -138
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +232 -58
- sglang/srt/managers/tp_worker.py +12 -9
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/base_prefix_cache.py +3 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -1
- sglang/srt/mem_cache/hiradix_cache.py +4 -4
- sglang/srt/mem_cache/memory_pool.py +76 -52
- sglang/srt/mem_cache/multimodal_cache.py +45 -0
- sglang/srt/mem_cache/radix_cache.py +58 -5
- sglang/srt/metrics/collector.py +314 -39
- sglang/srt/mm_utils.py +10 -0
- sglang/srt/model_executor/cuda_graph_runner.py +29 -19
- sglang/srt/model_executor/expert_location_updater.py +422 -0
- sglang/srt/model_executor/forward_batch_info.py +5 -1
- sglang/srt/model_executor/model_runner.py +163 -68
- sglang/srt/model_loader/loader.py +10 -6
- sglang/srt/models/clip.py +5 -1
- sglang/srt/models/deepseek_janus_pro.py +2 -2
- sglang/srt/models/deepseek_v2.py +308 -351
- sglang/srt/models/exaone.py +8 -3
- sglang/srt/models/gemma3_mm.py +70 -33
- sglang/srt/models/llama.py +2 -0
- sglang/srt/models/llama4.py +15 -8
- sglang/srt/models/llava.py +258 -7
- sglang/srt/models/mimo_mtp.py +220 -0
- sglang/srt/models/minicpmo.py +5 -12
- sglang/srt/models/mistral.py +71 -1
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/mllama.py +3 -3
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/qwen2.py +95 -26
- sglang/srt/models/qwen2_5_vl.py +8 -0
- sglang/srt/models/qwen2_moe.py +330 -60
- sglang/srt/models/qwen2_vl.py +6 -0
- sglang/srt/models/qwen3.py +52 -10
- sglang/srt/models/qwen3_moe.py +411 -48
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/siglip.py +294 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/openai_api/adapter.py +58 -20
- sglang/srt/openai_api/protocol.py +6 -8
- sglang/srt/operations.py +154 -0
- sglang/srt/operations_strategy.py +31 -0
- sglang/srt/reasoning_parser.py +3 -3
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +4 -56
- sglang/srt/sampling/sampling_params.py +2 -2
- sglang/srt/server_args.py +162 -22
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +138 -7
- sglang/srt/speculative/eagle_worker.py +69 -21
- sglang/srt/utils.py +74 -17
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_cutlass_moe.py +278 -0
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +55 -14
- sglang/utils.py +3 -3
- sglang/version.py +1 -1
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
- sglang/srt/function_call_parser.py +0 -858
- sglang/srt/platforms/interface.py +0 -371
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
sglang/srt/models/qwen3_moe.py
CHANGED
@@ -17,21 +17,36 @@
|
|
17
17
|
|
18
18
|
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
|
19
19
|
|
20
|
+
import logging
|
21
|
+
from dataclasses import dataclass
|
22
|
+
from enum import Enum, auto
|
20
23
|
from functools import partial
|
21
24
|
from typing import Any, Dict, Iterable, Optional, Tuple
|
22
25
|
|
23
26
|
import torch
|
24
27
|
import torch.nn.functional as F
|
25
28
|
from torch import nn
|
29
|
+
from transformers.configuration_utils import PretrainedConfig
|
26
30
|
|
27
31
|
from sglang.srt.distributed import (
|
32
|
+
get_pp_group,
|
28
33
|
get_tensor_model_parallel_rank,
|
29
34
|
get_tensor_model_parallel_world_size,
|
35
|
+
parallel_state,
|
30
36
|
split_tensor_along_last_dim,
|
31
37
|
tensor_model_parallel_all_gather,
|
32
38
|
tensor_model_parallel_all_reduce,
|
33
39
|
)
|
34
40
|
from sglang.srt.layers.activation import SiluAndMul
|
41
|
+
from sglang.srt.layers.dp_attention import (
|
42
|
+
attn_tp_all_gather,
|
43
|
+
attn_tp_reduce_scatter,
|
44
|
+
dp_gather_partial,
|
45
|
+
dp_scatter,
|
46
|
+
get_attention_tp_rank,
|
47
|
+
get_attention_tp_size,
|
48
|
+
get_local_attention_dp_size,
|
49
|
+
)
|
35
50
|
from sglang.srt.layers.layernorm import RMSNorm
|
36
51
|
from sglang.srt.layers.linear import (
|
37
52
|
MergedColumnParallelLinear,
|
@@ -39,52 +54,69 @@ from sglang.srt.layers.linear import (
|
|
39
54
|
ReplicatedLinear,
|
40
55
|
RowParallelLinear,
|
41
56
|
)
|
42
|
-
from sglang.srt.layers.logits_processor import LogitsProcessor
|
43
|
-
from sglang.srt.layers.moe.ep_moe.layer import
|
57
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
58
|
+
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
59
|
+
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
44
60
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
61
|
+
from sglang.srt.layers.moe.topk import select_experts
|
45
62
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
46
63
|
from sglang.srt.layers.radix_attention import RadixAttention
|
47
64
|
from sglang.srt.layers.rotary_embedding import get_rope
|
65
|
+
from sglang.srt.layers.utils import get_layer_id
|
48
66
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
49
67
|
ParallelLMHead,
|
50
68
|
VocabParallelEmbedding,
|
51
69
|
)
|
70
|
+
from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
|
71
|
+
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
|
52
72
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
53
|
-
from sglang.srt.model_executor.forward_batch_info import
|
73
|
+
from sglang.srt.model_executor.forward_batch_info import (
|
74
|
+
ForwardBatch,
|
75
|
+
ForwardMode,
|
76
|
+
PPProxyTensors,
|
77
|
+
)
|
54
78
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
55
79
|
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
|
56
80
|
from sglang.srt.models.qwen2_moe import Qwen2MoeModel
|
57
|
-
from sglang.srt.utils import add_prefix
|
81
|
+
from sglang.srt.utils import DeepEPMode, add_prefix
|
58
82
|
|
59
83
|
Qwen3MoeConfig = None
|
60
84
|
|
85
|
+
logger = logging.getLogger(__name__)
|
86
|
+
|
61
87
|
|
62
88
|
class Qwen3MoeSparseMoeBlock(nn.Module):
|
63
89
|
def __init__(
|
64
90
|
self,
|
91
|
+
layer_id: int,
|
65
92
|
config: Qwen3MoeConfig,
|
66
93
|
quant_config: Optional[QuantizationConfig] = None,
|
67
94
|
prefix: str = "",
|
68
95
|
):
|
69
96
|
super().__init__()
|
70
97
|
self.tp_size = get_tensor_model_parallel_world_size()
|
71
|
-
|
98
|
+
self.layer_id = layer_id
|
72
99
|
if self.tp_size > config.num_experts:
|
73
100
|
raise ValueError(
|
74
101
|
f"Tensor parallel size {self.tp_size} is greater than "
|
75
102
|
f"the number of experts {config.num_experts}."
|
76
103
|
)
|
77
104
|
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
num_experts=config.num_experts,
|
105
|
+
self.experts = get_moe_impl_class()(
|
106
|
+
num_experts=config.num_experts
|
107
|
+
+ global_server_args_dict["ep_num_redundant_experts"],
|
82
108
|
top_k=config.num_experts_per_tok,
|
109
|
+
layer_id=layer_id,
|
83
110
|
hidden_size=config.hidden_size,
|
84
111
|
intermediate_size=config.moe_intermediate_size,
|
85
112
|
renormalize=config.norm_topk_prob,
|
86
113
|
quant_config=quant_config,
|
87
114
|
prefix=add_prefix("experts", prefix),
|
115
|
+
**(
|
116
|
+
dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
|
117
|
+
if global_server_args_dict["enable_deepep_moe"]
|
118
|
+
else {}
|
119
|
+
),
|
88
120
|
)
|
89
121
|
|
90
122
|
self.gate = ReplicatedLinear(
|
@@ -95,7 +127,45 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
95
127
|
prefix=add_prefix("gate", prefix),
|
96
128
|
)
|
97
129
|
|
98
|
-
|
130
|
+
if global_server_args_dict["enable_deepep_moe"]:
|
131
|
+
# TODO: we will support tp < ep in the future
|
132
|
+
self.ep_size = get_tensor_model_parallel_world_size()
|
133
|
+
self.num_experts = (
|
134
|
+
config.num_experts + global_server_args_dict["ep_num_redundant_experts"]
|
135
|
+
)
|
136
|
+
self.top_k = config.num_experts_per_tok
|
137
|
+
self.renormalize = config.norm_topk_prob
|
138
|
+
|
139
|
+
self.deepep_dispatcher = DeepEPDispatcher(
|
140
|
+
group=parallel_state.get_tp_group().device_group,
|
141
|
+
router_topk=self.top_k,
|
142
|
+
permute_fusion=True,
|
143
|
+
num_experts=self.num_experts,
|
144
|
+
num_local_experts=config.num_experts // self.tp_size,
|
145
|
+
hidden_size=config.hidden_size,
|
146
|
+
params_dtype=config.torch_dtype,
|
147
|
+
deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
|
148
|
+
async_finish=True, # TODO
|
149
|
+
return_recv_hook=True,
|
150
|
+
)
|
151
|
+
|
152
|
+
def forward(
|
153
|
+
self, hidden_states: torch.Tensor, forward_mode: Optional[ForwardMode] = None
|
154
|
+
) -> torch.Tensor:
|
155
|
+
|
156
|
+
if not global_server_args_dict["enable_deepep_moe"]:
|
157
|
+
return self.forward_normal(hidden_states)
|
158
|
+
else:
|
159
|
+
return self.forward_deepep(hidden_states, forward_mode)
|
160
|
+
|
161
|
+
def get_moe_weights(self):
|
162
|
+
return [
|
163
|
+
x.data
|
164
|
+
for name, x in self.experts.named_parameters()
|
165
|
+
if name not in ["correction_bias"]
|
166
|
+
]
|
167
|
+
|
168
|
+
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
99
169
|
num_tokens, hidden_dim = hidden_states.shape
|
100
170
|
hidden_states = hidden_states.view(-1, hidden_dim)
|
101
171
|
|
@@ -109,6 +179,71 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
109
179
|
|
110
180
|
return final_hidden_states.view(num_tokens, hidden_dim)
|
111
181
|
|
182
|
+
def forward_deepep(
|
183
|
+
self, hidden_states: torch.Tensor, forward_mode: ForwardMode
|
184
|
+
) -> torch.Tensor:
|
185
|
+
if (
|
186
|
+
forward_mode is not None
|
187
|
+
and not forward_mode.is_idle()
|
188
|
+
and hidden_states.shape[0] > 0
|
189
|
+
):
|
190
|
+
# router_logits: (num_tokens, n_experts)
|
191
|
+
router_logits, _ = self.gate(hidden_states)
|
192
|
+
|
193
|
+
topk_weights, topk_idx = select_experts(
|
194
|
+
hidden_states=hidden_states,
|
195
|
+
router_logits=router_logits,
|
196
|
+
top_k=self.top_k,
|
197
|
+
use_grouped_topk=False,
|
198
|
+
renormalize=self.renormalize,
|
199
|
+
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
200
|
+
layer_id=self.layer_id,
|
201
|
+
),
|
202
|
+
)
|
203
|
+
else:
|
204
|
+
topk_idx = torch.full(
|
205
|
+
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
|
206
|
+
)
|
207
|
+
topk_weights = torch.empty(
|
208
|
+
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
|
209
|
+
)
|
210
|
+
if self.ep_size > 1:
|
211
|
+
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
|
212
|
+
(
|
213
|
+
hidden_states,
|
214
|
+
topk_idx,
|
215
|
+
topk_weights,
|
216
|
+
reorder_topk_ids,
|
217
|
+
num_recv_tokens_per_expert,
|
218
|
+
seg_indptr,
|
219
|
+
masked_m,
|
220
|
+
expected_m,
|
221
|
+
) = self.deepep_dispatcher.dispatch(
|
222
|
+
hidden_states,
|
223
|
+
topk_idx,
|
224
|
+
topk_weights,
|
225
|
+
forward_mode=forward_mode,
|
226
|
+
)
|
227
|
+
final_hidden_states = self.experts(
|
228
|
+
hidden_states=hidden_states,
|
229
|
+
topk_idx=topk_idx,
|
230
|
+
topk_weights=topk_weights,
|
231
|
+
reorder_topk_ids=reorder_topk_ids,
|
232
|
+
seg_indptr=seg_indptr,
|
233
|
+
masked_m=masked_m,
|
234
|
+
expected_m=expected_m,
|
235
|
+
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
|
236
|
+
forward_mode=forward_mode,
|
237
|
+
)
|
238
|
+
if self.ep_size > 1:
|
239
|
+
final_hidden_states = self.deepep_dispatcher.combine(
|
240
|
+
final_hidden_states,
|
241
|
+
topk_idx,
|
242
|
+
topk_weights,
|
243
|
+
forward_mode,
|
244
|
+
)
|
245
|
+
return final_hidden_states
|
246
|
+
|
112
247
|
|
113
248
|
class Qwen3MoeAttention(nn.Module):
|
114
249
|
def __init__(
|
@@ -128,20 +263,23 @@ class Qwen3MoeAttention(nn.Module):
|
|
128
263
|
) -> None:
|
129
264
|
super().__init__()
|
130
265
|
self.hidden_size = hidden_size
|
131
|
-
|
266
|
+
|
267
|
+
attn_tp_rank = get_attention_tp_rank()
|
268
|
+
attn_tp_size = get_attention_tp_size()
|
269
|
+
|
132
270
|
self.total_num_heads = num_heads
|
133
|
-
assert self.total_num_heads %
|
134
|
-
self.num_heads = self.total_num_heads //
|
271
|
+
assert self.total_num_heads % attn_tp_size == 0
|
272
|
+
self.num_heads = self.total_num_heads // attn_tp_size
|
135
273
|
self.total_num_kv_heads = num_kv_heads
|
136
|
-
if self.total_num_kv_heads >=
|
274
|
+
if self.total_num_kv_heads >= attn_tp_size:
|
137
275
|
# Number of KV heads is greater than TP size, so we partition
|
138
276
|
# the KV heads across multiple tensor parallel GPUs.
|
139
|
-
assert self.total_num_kv_heads %
|
277
|
+
assert self.total_num_kv_heads % attn_tp_size == 0
|
140
278
|
else:
|
141
279
|
# Number of KV heads is less than TP size, so we replicate
|
142
280
|
# the KV heads across multiple tensor parallel GPUs.
|
143
|
-
assert
|
144
|
-
self.num_kv_heads = max(1, self.total_num_kv_heads //
|
281
|
+
assert attn_tp_size % self.total_num_kv_heads == 0
|
282
|
+
self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size)
|
145
283
|
self.head_dim = head_dim or hidden_size // self.total_num_heads
|
146
284
|
self.q_size = self.num_heads * self.head_dim
|
147
285
|
self.kv_size = self.num_kv_heads * self.head_dim
|
@@ -157,6 +295,8 @@ class Qwen3MoeAttention(nn.Module):
|
|
157
295
|
self.total_num_kv_heads,
|
158
296
|
bias=attention_bias,
|
159
297
|
quant_config=quant_config,
|
298
|
+
tp_rank=attn_tp_rank,
|
299
|
+
tp_size=attn_tp_size,
|
160
300
|
prefix=add_prefix("qkv_proj", prefix),
|
161
301
|
)
|
162
302
|
|
@@ -165,6 +305,9 @@ class Qwen3MoeAttention(nn.Module):
|
|
165
305
|
hidden_size,
|
166
306
|
bias=attention_bias,
|
167
307
|
quant_config=quant_config,
|
308
|
+
tp_rank=attn_tp_rank,
|
309
|
+
tp_size=attn_tp_size,
|
310
|
+
reduce_results=False,
|
168
311
|
prefix=add_prefix("o_proj", prefix),
|
169
312
|
)
|
170
313
|
|
@@ -213,6 +356,19 @@ class Qwen3MoeAttention(nn.Module):
|
|
213
356
|
return output
|
214
357
|
|
215
358
|
|
359
|
+
class _FFNInputMode(Enum):
|
360
|
+
# The MLP sublayer requires 1/tp_size tokens as input
|
361
|
+
SCATTERED = auto()
|
362
|
+
# The MLP sublayer requires all tokens as input
|
363
|
+
FULL = auto()
|
364
|
+
|
365
|
+
|
366
|
+
@dataclass
|
367
|
+
class _DecoderLayerInfo:
|
368
|
+
is_sparse: bool
|
369
|
+
ffn_input_mode: _FFNInputMode
|
370
|
+
|
371
|
+
|
216
372
|
class Qwen3MoeDecoderLayer(nn.Module):
|
217
373
|
def __init__(
|
218
374
|
self,
|
@@ -246,15 +402,23 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
|
246
402
|
prefix=add_prefix("self_attn", prefix),
|
247
403
|
)
|
248
404
|
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
405
|
+
self.layer_id = layer_id
|
406
|
+
|
407
|
+
self.attn_tp_size = get_attention_tp_size()
|
408
|
+
self.attn_tp_rank = get_attention_tp_rank()
|
409
|
+
self.local_dp_size = get_local_attention_dp_size()
|
410
|
+
|
411
|
+
self.info = self._compute_info(config, layer_id=layer_id)
|
412
|
+
previous_layer_info = self._compute_info(config, layer_id=layer_id - 1)
|
413
|
+
self.input_is_scattered = (
|
414
|
+
layer_id > 0
|
415
|
+
and previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
|
253
416
|
)
|
254
|
-
|
255
|
-
|
256
|
-
|
417
|
+
self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
|
418
|
+
|
419
|
+
if self.info.is_sparse:
|
257
420
|
self.mlp = Qwen3MoeSparseMoeBlock(
|
421
|
+
layer_id=self.layer_id,
|
258
422
|
config=config,
|
259
423
|
quant_config=quant_config,
|
260
424
|
prefix=add_prefix("mlp", prefix),
|
@@ -272,28 +436,182 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
|
272
436
|
config.hidden_size, eps=config.rms_norm_eps
|
273
437
|
)
|
274
438
|
|
439
|
+
@staticmethod
|
440
|
+
def _enable_moe_dense_fully_dp():
|
441
|
+
return global_server_args_dict["moe_dense_tp_size"] == 1
|
442
|
+
|
443
|
+
@staticmethod
|
444
|
+
def _compute_info(config: PretrainedConfig, layer_id: int):
|
445
|
+
# WARN: Qwen3MOE has no dense_layer, it is only for compatibility.
|
446
|
+
mlp_only_layers = (
|
447
|
+
[] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers
|
448
|
+
)
|
449
|
+
is_sparse = (layer_id not in mlp_only_layers) and (
|
450
|
+
config.num_experts > 0 and (layer_id + 1) % config.decoder_sparse_step == 0
|
451
|
+
)
|
452
|
+
ffn_input_mode = (
|
453
|
+
_FFNInputMode.SCATTERED
|
454
|
+
if (global_server_args_dict["enable_deepep_moe"] and is_sparse)
|
455
|
+
or (Qwen3MoeDecoderLayer._enable_moe_dense_fully_dp() and not is_sparse)
|
456
|
+
else _FFNInputMode.FULL
|
457
|
+
)
|
458
|
+
return _DecoderLayerInfo(is_sparse=is_sparse, ffn_input_mode=ffn_input_mode)
|
459
|
+
|
275
460
|
def forward(
|
276
461
|
self,
|
277
462
|
positions: torch.Tensor,
|
278
463
|
hidden_states: torch.Tensor,
|
279
464
|
forward_batch: ForwardBatch,
|
280
465
|
residual: Optional[torch.Tensor],
|
281
|
-
) -> torch.Tensor:
|
282
|
-
|
283
|
-
|
466
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
467
|
+
if self.info.ffn_input_mode == _FFNInputMode.SCATTERED:
|
468
|
+
return self.forward_ffn_with_scattered_input(
|
469
|
+
positions, hidden_states, forward_batch, residual
|
470
|
+
)
|
471
|
+
elif self.info.ffn_input_mode == _FFNInputMode.FULL:
|
472
|
+
return self.forward_ffn_with_full_input(
|
473
|
+
positions, hidden_states, forward_batch, residual
|
474
|
+
)
|
475
|
+
else:
|
476
|
+
raise NotImplementedError
|
477
|
+
|
478
|
+
def forward_ffn_with_full_input(
|
479
|
+
self,
|
480
|
+
positions: torch.Tensor,
|
481
|
+
hidden_states: torch.Tensor,
|
482
|
+
forward_batch: ForwardBatch,
|
483
|
+
residual: Optional[torch.Tensor],
|
484
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
485
|
+
if hidden_states.shape[0] == 0:
|
284
486
|
residual = hidden_states
|
285
|
-
hidden_states = self.input_layernorm(hidden_states)
|
286
487
|
else:
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
488
|
+
if residual is None:
|
489
|
+
residual = hidden_states
|
490
|
+
hidden_states = self.input_layernorm(hidden_states)
|
491
|
+
else:
|
492
|
+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
493
|
+
|
494
|
+
# Self Attention
|
495
|
+
hidden_states = self.self_attn(
|
496
|
+
positions=positions,
|
497
|
+
hidden_states=hidden_states,
|
498
|
+
forward_batch=forward_batch,
|
499
|
+
)
|
500
|
+
# Gather
|
501
|
+
if get_tensor_model_parallel_world_size() > 1:
|
502
|
+
if self.local_dp_size != 1:
|
503
|
+
if self.attn_tp_rank == 0:
|
504
|
+
hidden_states += residual
|
505
|
+
hidden_states, local_hidden_states = (
|
506
|
+
forward_batch.gathered_buffer,
|
507
|
+
hidden_states,
|
508
|
+
)
|
509
|
+
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
|
510
|
+
dp_scatter(residual, hidden_states, forward_batch)
|
511
|
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
512
|
+
else:
|
513
|
+
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
514
|
+
# TODO extract this bugfix
|
515
|
+
if hidden_states.shape[0] != 0:
|
516
|
+
hidden_states, residual = self.post_attention_layernorm(
|
517
|
+
hidden_states, residual
|
518
|
+
)
|
519
|
+
elif hidden_states.shape[0] != 0:
|
520
|
+
hidden_states, residual = self.post_attention_layernorm(
|
521
|
+
hidden_states, residual
|
522
|
+
)
|
293
523
|
|
294
524
|
# Fully Connected
|
295
|
-
hidden_states
|
296
|
-
|
525
|
+
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
|
526
|
+
|
527
|
+
# TODO: use reduce-scatter in MLP to avoid this scatter
|
528
|
+
# Scatter
|
529
|
+
if self.local_dp_size != 1:
|
530
|
+
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
531
|
+
# be careful about this!
|
532
|
+
hidden_states, global_hidden_states = (
|
533
|
+
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
534
|
+
hidden_states,
|
535
|
+
)
|
536
|
+
dp_scatter(hidden_states, global_hidden_states, forward_batch)
|
537
|
+
|
538
|
+
return hidden_states, residual
|
539
|
+
|
540
|
+
def forward_ffn_with_scattered_input(
|
541
|
+
self,
|
542
|
+
positions: torch.Tensor,
|
543
|
+
hidden_states: torch.Tensor,
|
544
|
+
forward_batch: ForwardBatch,
|
545
|
+
residual: Optional[torch.Tensor],
|
546
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
547
|
+
if hidden_states.shape[0] == 0:
|
548
|
+
residual = hidden_states
|
549
|
+
else:
|
550
|
+
if residual is None:
|
551
|
+
residual = hidden_states
|
552
|
+
hidden_states = self.input_layernorm(hidden_states)
|
553
|
+
else:
|
554
|
+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
555
|
+
|
556
|
+
if self.attn_tp_size != 1 and self.input_is_scattered:
|
557
|
+
hidden_states, local_hidden_states = (
|
558
|
+
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
559
|
+
hidden_states,
|
560
|
+
)
|
561
|
+
attn_tp_all_gather(
|
562
|
+
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
|
563
|
+
)
|
564
|
+
|
565
|
+
# Self Attention
|
566
|
+
if hidden_states.shape[0] != 0:
|
567
|
+
hidden_states = self.self_attn(
|
568
|
+
positions=positions,
|
569
|
+
hidden_states=hidden_states,
|
570
|
+
forward_batch=forward_batch,
|
571
|
+
)
|
572
|
+
|
573
|
+
if self.attn_tp_size != 1:
|
574
|
+
if self.input_is_scattered:
|
575
|
+
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
|
576
|
+
hidden_states = tensor_list[self.attn_tp_rank]
|
577
|
+
attn_tp_reduce_scatter(hidden_states, tensor_list)
|
578
|
+
if hidden_states.shape[0] != 0:
|
579
|
+
hidden_states, residual = self.post_attention_layernorm(
|
580
|
+
hidden_states, residual
|
581
|
+
)
|
582
|
+
else:
|
583
|
+
if self.attn_tp_rank == 0:
|
584
|
+
hidden_states += residual
|
585
|
+
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
|
586
|
+
hidden_states = tensor_list[self.attn_tp_rank]
|
587
|
+
attn_tp_reduce_scatter(hidden_states, tensor_list)
|
588
|
+
residual = hidden_states
|
589
|
+
if hidden_states.shape[0] != 0:
|
590
|
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
591
|
+
else:
|
592
|
+
if hidden_states.shape[0] != 0:
|
593
|
+
hidden_states, residual = self.post_attention_layernorm(
|
594
|
+
hidden_states, residual
|
595
|
+
)
|
596
|
+
|
597
|
+
if not (
|
598
|
+
self._enable_moe_dense_fully_dp()
|
599
|
+
and (not self.info.is_sparse)
|
600
|
+
and hidden_states.shape[0] == 0
|
601
|
+
):
|
602
|
+
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
|
603
|
+
|
604
|
+
if self.is_last_layer and self.attn_tp_size != 1:
|
605
|
+
hidden_states += residual
|
606
|
+
residual = None
|
607
|
+
hidden_states, local_hidden_states = (
|
608
|
+
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
609
|
+
hidden_states,
|
610
|
+
)
|
611
|
+
attn_tp_all_gather(
|
612
|
+
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
|
613
|
+
)
|
614
|
+
|
297
615
|
return hidden_states, residual
|
298
616
|
|
299
617
|
|
@@ -313,7 +631,6 @@ class Qwen3MoeModel(Qwen2MoeModel):
|
|
313
631
|
|
314
632
|
|
315
633
|
class Qwen3MoeForCausalLM(nn.Module):
|
316
|
-
|
317
634
|
fall_back_to_pt_during_load = False
|
318
635
|
|
319
636
|
def __init__(
|
@@ -323,6 +640,7 @@ class Qwen3MoeForCausalLM(nn.Module):
|
|
323
640
|
prefix: str = "",
|
324
641
|
) -> None:
|
325
642
|
super().__init__()
|
643
|
+
self.pp_group = get_pp_group()
|
326
644
|
self.config = config
|
327
645
|
self.quant_config = quant_config
|
328
646
|
self.model = Qwen3MoeModel(
|
@@ -343,12 +661,31 @@ class Qwen3MoeForCausalLM(nn.Module):
|
|
343
661
|
positions: torch.Tensor,
|
344
662
|
forward_batch: ForwardBatch,
|
345
663
|
input_embeds: torch.Tensor = None,
|
664
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
346
665
|
) -> torch.Tensor:
|
347
|
-
hidden_states = self.model(
|
348
|
-
|
349
|
-
|
666
|
+
hidden_states = self.model(
|
667
|
+
input_ids,
|
668
|
+
positions,
|
669
|
+
forward_batch,
|
670
|
+
input_embeds,
|
671
|
+
pp_proxy_tensors=pp_proxy_tensors,
|
350
672
|
)
|
351
673
|
|
674
|
+
if self.pp_group.is_last_rank:
|
675
|
+
return self.logits_processor(
|
676
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
677
|
+
)
|
678
|
+
else:
|
679
|
+
return hidden_states
|
680
|
+
|
681
|
+
@property
|
682
|
+
def start_layer(self):
|
683
|
+
return self.model.start_layer
|
684
|
+
|
685
|
+
@property
|
686
|
+
def end_layer(self):
|
687
|
+
return self.model.end_layer
|
688
|
+
|
352
689
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
353
690
|
stacked_params_mapping = [
|
354
691
|
# (param_name, shard_name, shard_id)
|
@@ -359,9 +696,7 @@ class Qwen3MoeForCausalLM(nn.Module):
|
|
359
696
|
("gate_up_proj", "up_proj", 1),
|
360
697
|
]
|
361
698
|
|
362
|
-
|
363
|
-
|
364
|
-
expert_params_mapping = MoEImpl.make_expert_params_mapping(
|
699
|
+
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
|
365
700
|
ckpt_gate_proj_name="gate_proj",
|
366
701
|
ckpt_down_proj_name="down_proj",
|
367
702
|
ckpt_up_proj_name="up_proj",
|
@@ -370,6 +705,17 @@ class Qwen3MoeForCausalLM(nn.Module):
|
|
370
705
|
|
371
706
|
params_dict = dict(self.named_parameters())
|
372
707
|
for name, loaded_weight in weights:
|
708
|
+
layer_id = get_layer_id(name)
|
709
|
+
if (
|
710
|
+
layer_id is not None
|
711
|
+
and hasattr(self.model, "start_layer")
|
712
|
+
and (
|
713
|
+
layer_id < self.model.start_layer
|
714
|
+
or layer_id >= self.model.end_layer
|
715
|
+
)
|
716
|
+
):
|
717
|
+
continue
|
718
|
+
|
373
719
|
if "rotary_emb.inv_freq" in name:
|
374
720
|
continue
|
375
721
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
@@ -418,11 +764,28 @@ class Qwen3MoeForCausalLM(nn.Module):
|
|
418
764
|
if name not in params_dict:
|
419
765
|
continue
|
420
766
|
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
767
|
+
if name in params_dict.keys():
|
768
|
+
param = params_dict[name]
|
769
|
+
weight_loader = getattr(
|
770
|
+
param, "weight_loader", default_weight_loader
|
771
|
+
)
|
772
|
+
weight_loader(param, loaded_weight)
|
773
|
+
else:
|
774
|
+
logger.warning(f"Parameter {name} not found in params_dict")
|
775
|
+
|
776
|
+
self.routed_experts_weights_of_layer = {
|
777
|
+
layer_id: layer.mlp.get_moe_weights()
|
778
|
+
for layer_id, layer in enumerate(self.model.layers)
|
779
|
+
if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock)
|
780
|
+
}
|
781
|
+
|
782
|
+
@classmethod
|
783
|
+
def get_model_config_for_expert_location(cls, config):
|
784
|
+
return ModelConfigForExpertLocation(
|
785
|
+
num_layers=config.num_hidden_layers,
|
786
|
+
num_logical_experts=config.num_experts,
|
787
|
+
num_groups=None,
|
788
|
+
)
|
426
789
|
|
427
790
|
|
428
791
|
EntryClass = Qwen3MoeForCausalLM
|
sglang/srt/models/roberta.py
CHANGED
@@ -57,7 +57,7 @@ class RobertaEmbedding(nn.Module):
|
|
57
57
|
input_shape = input_ids.size()
|
58
58
|
inputs_embeds = self.word_embeddings(input_ids)
|
59
59
|
|
60
|
-
#
|
60
|
+
# Adapted from vllm: https://github.com/vllm-project/vllm/commit/4a18fd14ba4a349291c798a16bf62fa8a9af0b6b/vllm/model_executor/models/roberta.py
|
61
61
|
|
62
62
|
pos_list = []
|
63
63
|
token_list = []
|