sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +113 -17
- sglang/srt/configs/model_config.py +35 -0
- sglang/srt/conversation.py +9 -5
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +6 -1
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
- sglang/srt/disaggregation/mooncake/conn.py +243 -135
- sglang/srt/disaggregation/prefill.py +2 -0
- sglang/srt/distributed/parallel_state.py +11 -9
- sglang/srt/entrypoints/context.py +244 -0
- sglang/srt/entrypoints/engine.py +4 -3
- sglang/srt/entrypoints/harmony_utils.py +370 -0
- sglang/srt/entrypoints/http_server.py +71 -0
- sglang/srt/entrypoints/openai/protocol.py +227 -1
- sglang/srt/entrypoints/openai/serving_chat.py +278 -42
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +174 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/harmony_tool_parser.py +130 -0
- sglang/srt/hf_transformers_utils.py +30 -3
- sglang/srt/jinja_template_utils.py +8 -1
- sglang/srt/layers/attention/aiter_backend.py +5 -8
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/vision.py +13 -5
- sglang/srt/layers/communicator.py +21 -4
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/linear.py +2 -7
- sglang/srt/layers/moe/cutlass_moe.py +20 -6
- sglang/srt/layers/moe/ep_moe/layer.py +77 -73
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +416 -35
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
- sglang/srt/layers/moe/topk.py +12 -3
- sglang/srt/layers/moe/utils.py +16 -0
- sglang/srt/layers/quantization/__init__.py +22 -0
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +3 -6
- sglang/srt/layers/quantization/fp8_utils.py +29 -0
- sglang/srt/layers/quantization/modelopt_quant.py +259 -64
- sglang/srt/layers/quantization/mxfp4.py +651 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/__init__.py +0 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +1 -1
- sglang/srt/layers/rotary_embedding.py +225 -1
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +8 -3
- sglang/srt/lora/lora_manager.py +70 -14
- sglang/srt/lora/lora_registry.py +3 -2
- sglang/srt/lora/mem_pool.py +43 -5
- sglang/srt/managers/cache_controller.py +55 -30
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +15 -3
- sglang/srt/managers/mm_utils.py +5 -11
- sglang/srt/managers/schedule_batch.py +28 -7
- sglang/srt/managers/scheduler.py +26 -12
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +35 -1
- sglang/srt/managers/tokenizer_manager.py +24 -6
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/mem_cache/hiradix_cache.py +53 -5
- sglang/srt/mem_cache/memory_pool_host.py +1 -1
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +7 -6
- sglang/srt/model_executor/forward_batch_info.py +35 -14
- sglang/srt/model_executor/model_runner.py +19 -2
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +72 -33
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma3n_mm.py +39 -0
- sglang/srt/models/glm4_moe.py +24 -12
- sglang/srt/models/gpt_oss.py +1134 -0
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_moe.py +6 -0
- sglang/srt/models/qwen3_moe.py +32 -6
- sglang/srt/models/step3_vl.py +9 -0
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/reasoning_parser.py +18 -39
- sglang/srt/server_args.py +142 -7
- sglang/srt/two_batch_overlap.py +157 -5
- sglang/srt/utils.py +38 -2
- sglang/test/runners.py +2 -2
- sglang/test/test_utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +16 -14
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +105 -84
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,425 @@
|
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/bailing_moe.py
|
3
|
+
|
4
|
+
from collections.abc import Iterable
|
5
|
+
from typing import Optional, Tuple
|
6
|
+
|
7
|
+
import torch
|
8
|
+
import torch.nn.functional as F
|
9
|
+
from torch import nn
|
10
|
+
from transformers.configuration_utils import PretrainedConfig
|
11
|
+
|
12
|
+
from sglang.srt.distributed import (
|
13
|
+
get_tensor_model_parallel_world_size,
|
14
|
+
tensor_model_parallel_all_reduce,
|
15
|
+
)
|
16
|
+
from sglang.srt.layers.activation import SiluAndMul
|
17
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
18
|
+
from sglang.srt.layers.linear import (
|
19
|
+
MergedColumnParallelLinear,
|
20
|
+
QKVParallelLinear,
|
21
|
+
ReplicatedLinear,
|
22
|
+
RowParallelLinear,
|
23
|
+
)
|
24
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
25
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
26
|
+
from sglang.srt.layers.moe.topk import TopK
|
27
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
28
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
29
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
30
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
31
|
+
ParallelLMHead,
|
32
|
+
VocabParallelEmbedding,
|
33
|
+
)
|
34
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
35
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
36
|
+
from sglang.srt.utils import add_prefix, make_layers
|
37
|
+
|
38
|
+
|
39
|
+
class BailingAttention(nn.Module):
|
40
|
+
|
41
|
+
def __init__(
|
42
|
+
self,
|
43
|
+
config: PretrainedConfig,
|
44
|
+
layer_id: int = 0,
|
45
|
+
quant_config: Optional[QuantizationConfig] = None,
|
46
|
+
prefix: str = "",
|
47
|
+
):
|
48
|
+
super().__init__()
|
49
|
+
self.hidden_size = config.hidden_size
|
50
|
+
tp_size = get_tensor_model_parallel_world_size()
|
51
|
+
|
52
|
+
self.total_num_heads = config.num_attention_heads
|
53
|
+
self.total_num_kv_heads = config.num_key_value_heads
|
54
|
+
|
55
|
+
assert self.total_num_heads % tp_size == 0
|
56
|
+
assert self.total_num_kv_heads % tp_size == 0
|
57
|
+
|
58
|
+
self.num_heads = self.total_num_heads // tp_size
|
59
|
+
self.head_dim = config.head_dim or (self.hidden_size // self.total_num_heads)
|
60
|
+
self.q_size = self.num_heads * self.head_dim
|
61
|
+
|
62
|
+
self.num_kv_heads = self.total_num_kv_heads // tp_size
|
63
|
+
self.kv_size = self.num_kv_heads * self.head_dim
|
64
|
+
self.scale = self.head_dim**-0.5
|
65
|
+
|
66
|
+
self.query_key_value = QKVParallelLinear(
|
67
|
+
self.hidden_size,
|
68
|
+
self.head_dim,
|
69
|
+
self.total_num_heads,
|
70
|
+
self.total_num_kv_heads,
|
71
|
+
bias=(config.use_bias or config.use_qkv_bias),
|
72
|
+
quant_config=quant_config,
|
73
|
+
prefix=add_prefix("query_key_value", prefix),
|
74
|
+
)
|
75
|
+
|
76
|
+
self.dense = RowParallelLinear(
|
77
|
+
self.total_num_heads * self.head_dim,
|
78
|
+
self.hidden_size,
|
79
|
+
bias=config.use_bias,
|
80
|
+
quant_config=quant_config,
|
81
|
+
prefix=add_prefix("dense", prefix),
|
82
|
+
)
|
83
|
+
|
84
|
+
self.attn = RadixAttention(
|
85
|
+
self.num_heads,
|
86
|
+
self.head_dim,
|
87
|
+
self.scale,
|
88
|
+
num_kv_heads=self.num_kv_heads,
|
89
|
+
layer_id=layer_id,
|
90
|
+
quant_config=quant_config,
|
91
|
+
prefix=add_prefix("attn", prefix),
|
92
|
+
)
|
93
|
+
|
94
|
+
self.rotary_emb = get_rope(
|
95
|
+
self.head_dim,
|
96
|
+
rotary_dim=self.head_dim,
|
97
|
+
max_position=config.max_position_embeddings,
|
98
|
+
base=config.rope_theta,
|
99
|
+
is_neox_style=True,
|
100
|
+
rope_scaling=config.rope_scaling,
|
101
|
+
)
|
102
|
+
|
103
|
+
def forward(
|
104
|
+
self,
|
105
|
+
hidden_states: torch.Tensor,
|
106
|
+
position_ids: torch.Tensor,
|
107
|
+
forward_batch: ForwardBatch,
|
108
|
+
) -> torch.Tensor:
|
109
|
+
qkv, _ = self.query_key_value(hidden_states)
|
110
|
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
111
|
+
|
112
|
+
q, k = self.rotary_emb(position_ids, q, k)
|
113
|
+
context_layer = self.attn(q, k, v, forward_batch)
|
114
|
+
attn_output, _ = self.dense(context_layer)
|
115
|
+
return attn_output
|
116
|
+
|
117
|
+
|
118
|
+
class BailingMLP(nn.Module):
|
119
|
+
def __init__(
|
120
|
+
self,
|
121
|
+
intermediate_size: int,
|
122
|
+
config: PretrainedConfig,
|
123
|
+
quant_config: Optional[QuantizationConfig] = None,
|
124
|
+
reduce_results: Optional[bool] = True,
|
125
|
+
prefix: str = "",
|
126
|
+
) -> None:
|
127
|
+
super().__init__()
|
128
|
+
self.gate_up_proj = MergedColumnParallelLinear(
|
129
|
+
config.hidden_size,
|
130
|
+
[intermediate_size] * 2,
|
131
|
+
bias=config.use_bias,
|
132
|
+
quant_config=quant_config,
|
133
|
+
prefix=add_prefix("gate_up_proj", prefix),
|
134
|
+
)
|
135
|
+
self.down_proj = RowParallelLinear(
|
136
|
+
intermediate_size,
|
137
|
+
config.hidden_size,
|
138
|
+
bias=config.use_bias,
|
139
|
+
quant_config=quant_config,
|
140
|
+
reduce_results=reduce_results,
|
141
|
+
prefix=add_prefix("down_proj", prefix),
|
142
|
+
)
|
143
|
+
self.act_fn = SiluAndMul()
|
144
|
+
|
145
|
+
def forward(self, x):
|
146
|
+
x, _ = self.gate_up_proj(x)
|
147
|
+
x = self.act_fn(x)
|
148
|
+
x, _ = self.down_proj(x)
|
149
|
+
return x
|
150
|
+
|
151
|
+
|
152
|
+
class BailingMoE(nn.Module):
|
153
|
+
|
154
|
+
def __init__(
|
155
|
+
self,
|
156
|
+
config: PretrainedConfig,
|
157
|
+
layer_id: int,
|
158
|
+
quant_config: Optional[QuantizationConfig] = None,
|
159
|
+
prefix: str = "",
|
160
|
+
):
|
161
|
+
super().__init__()
|
162
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
163
|
+
self.num_experts = config.num_experts
|
164
|
+
self.top_k = config.num_experts_per_tok
|
165
|
+
self.hidden_size = config.hidden_size
|
166
|
+
self.num_shared_experts = config.num_shared_experts
|
167
|
+
self.norm_expert_prob = config.norm_topk_prob
|
168
|
+
self.moe_intermediate_size = config.moe_intermediate_size
|
169
|
+
|
170
|
+
self.gate = ReplicatedLinear(
|
171
|
+
self.hidden_size, self.num_experts, bias=False, quant_config=None
|
172
|
+
)
|
173
|
+
|
174
|
+
self.topk = TopK(top_k=self.top_k, renormalize=self.norm_expert_prob)
|
175
|
+
|
176
|
+
self.experts = FusedMoE(
|
177
|
+
num_experts=self.num_experts,
|
178
|
+
top_k=self.top_k,
|
179
|
+
layer_id=layer_id,
|
180
|
+
hidden_size=self.hidden_size,
|
181
|
+
intermediate_size=self.moe_intermediate_size,
|
182
|
+
reduce_results=False,
|
183
|
+
quant_config=quant_config,
|
184
|
+
prefix=add_prefix("experts", prefix),
|
185
|
+
)
|
186
|
+
|
187
|
+
if self.num_shared_experts > 0:
|
188
|
+
shared_intermediate_size = (
|
189
|
+
self.moe_intermediate_size * self.num_shared_experts
|
190
|
+
)
|
191
|
+
self.shared_experts = BailingMLP(
|
192
|
+
intermediate_size=shared_intermediate_size,
|
193
|
+
config=config,
|
194
|
+
quant_config=quant_config,
|
195
|
+
reduce_results=False,
|
196
|
+
prefix=add_prefix("shared_experts", prefix),
|
197
|
+
)
|
198
|
+
else:
|
199
|
+
self.shared_experts = None
|
200
|
+
|
201
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
202
|
+
orig_shape = hidden_states.shape
|
203
|
+
hidden_states_flat = hidden_states.view(-1, self.hidden_size)
|
204
|
+
|
205
|
+
shared_output = None
|
206
|
+
if self.shared_experts is not None:
|
207
|
+
shared_output = self.shared_experts(hidden_states_flat)
|
208
|
+
|
209
|
+
router_logits, _ = self.gate(hidden_states_flat)
|
210
|
+
topk_output = self.topk(hidden_states_flat, router_logits)
|
211
|
+
final_hidden_states = self.experts(hidden_states_flat, topk_output)
|
212
|
+
|
213
|
+
if shared_output is not None:
|
214
|
+
final_hidden_states = final_hidden_states + shared_output
|
215
|
+
|
216
|
+
if self.tp_size > 1:
|
217
|
+
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
218
|
+
|
219
|
+
return final_hidden_states.view(orig_shape)
|
220
|
+
|
221
|
+
|
222
|
+
class BailingMoeBlock(nn.Module):
|
223
|
+
|
224
|
+
def __init__(
|
225
|
+
self,
|
226
|
+
config: PretrainedConfig,
|
227
|
+
layer_id: int,
|
228
|
+
quant_config: Optional[QuantizationConfig] = None,
|
229
|
+
prefix: str = "",
|
230
|
+
):
|
231
|
+
super().__init__()
|
232
|
+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
233
|
+
self.attention = BailingAttention(
|
234
|
+
config, layer_id, quant_config, prefix=add_prefix("attention", prefix)
|
235
|
+
)
|
236
|
+
self.post_attention_layernorm = RMSNorm(
|
237
|
+
config.hidden_size, eps=config.rms_norm_eps
|
238
|
+
)
|
239
|
+
self.mlp = BailingMoE(
|
240
|
+
config=config,
|
241
|
+
layer_id=layer_id,
|
242
|
+
quant_config=quant_config,
|
243
|
+
prefix=add_prefix("mlp", prefix),
|
244
|
+
)
|
245
|
+
|
246
|
+
def forward(
|
247
|
+
self,
|
248
|
+
hidden_states: torch.Tensor,
|
249
|
+
position_ids: torch.Tensor,
|
250
|
+
residual: Optional[torch.Tensor],
|
251
|
+
forward_batch: ForwardBatch,
|
252
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
253
|
+
# Pre-normalization and residual connection for the attention block
|
254
|
+
if residual is None:
|
255
|
+
residual = hidden_states
|
256
|
+
normed_hidden_states = self.input_layernorm(hidden_states)
|
257
|
+
else:
|
258
|
+
normed_hidden_states, residual = self.input_layernorm(
|
259
|
+
hidden_states, residual
|
260
|
+
)
|
261
|
+
|
262
|
+
attn_output = self.attention(
|
263
|
+
hidden_states=normed_hidden_states,
|
264
|
+
position_ids=position_ids,
|
265
|
+
forward_batch=forward_batch,
|
266
|
+
)
|
267
|
+
|
268
|
+
# Pre-normalization and residual connection for the MLP block
|
269
|
+
normed_hidden_states, residual = self.post_attention_layernorm(
|
270
|
+
attn_output, residual
|
271
|
+
)
|
272
|
+
mlp_output = self.mlp(normed_hidden_states)
|
273
|
+
|
274
|
+
return mlp_output, residual
|
275
|
+
|
276
|
+
|
277
|
+
class BailingMoeModel(nn.Module):
|
278
|
+
|
279
|
+
def __init__(
|
280
|
+
self,
|
281
|
+
config: PretrainedConfig,
|
282
|
+
quant_config: Optional[QuantizationConfig] = None,
|
283
|
+
prefix: str = "",
|
284
|
+
):
|
285
|
+
super().__init__()
|
286
|
+
self.config = config
|
287
|
+
self.padding_idx = config.pad_token_id
|
288
|
+
self.vocab_size = config.vocab_size
|
289
|
+
self.embed_dim = config.hidden_size
|
290
|
+
|
291
|
+
self.embed_tokens = VocabParallelEmbedding(
|
292
|
+
config.vocab_size,
|
293
|
+
config.hidden_size,
|
294
|
+
prefix=add_prefix("embed_tokens", prefix),
|
295
|
+
)
|
296
|
+
self.embedding_dropout = torch.nn.Dropout(config.embedding_dropout)
|
297
|
+
|
298
|
+
self.layers = make_layers(
|
299
|
+
config.num_hidden_layers,
|
300
|
+
lambda idx, prefix: BailingMoeBlock(
|
301
|
+
config=config,
|
302
|
+
layer_id=idx,
|
303
|
+
quant_config=quant_config,
|
304
|
+
prefix=prefix,
|
305
|
+
),
|
306
|
+
prefix=add_prefix("layers", prefix),
|
307
|
+
)
|
308
|
+
|
309
|
+
self.norm = RMSNorm(self.embed_dim, eps=config.rms_norm_eps)
|
310
|
+
|
311
|
+
def forward(
|
312
|
+
self,
|
313
|
+
input_ids: torch.Tensor,
|
314
|
+
position_ids: torch.Tensor,
|
315
|
+
forward_batch: ForwardBatch,
|
316
|
+
input_embeds: Optional[torch.Tensor] = None,
|
317
|
+
) -> torch.Tensor:
|
318
|
+
if input_embeds is None:
|
319
|
+
hidden_states = self.embed_tokens(input_ids)
|
320
|
+
else:
|
321
|
+
hidden_states = input_embeds
|
322
|
+
|
323
|
+
residual = None
|
324
|
+
for layer in self.layers:
|
325
|
+
hidden_states, residual = layer(
|
326
|
+
hidden_states,
|
327
|
+
position_ids,
|
328
|
+
residual,
|
329
|
+
forward_batch,
|
330
|
+
)
|
331
|
+
|
332
|
+
hidden_states, _ = self.norm(hidden_states, residual)
|
333
|
+
return hidden_states
|
334
|
+
|
335
|
+
|
336
|
+
class BailingMoeForCausalLM(nn.Module):
|
337
|
+
|
338
|
+
def __init__(
|
339
|
+
self,
|
340
|
+
config: PretrainedConfig,
|
341
|
+
quant_config: Optional[QuantizationConfig] = None,
|
342
|
+
) -> None:
|
343
|
+
super().__init__()
|
344
|
+
self.config = config
|
345
|
+
self.model = BailingMoeModel(config=config, quant_config=quant_config)
|
346
|
+
self.lm_head = ParallelLMHead(
|
347
|
+
num_embeddings=config.vocab_size,
|
348
|
+
embedding_dim=config.hidden_size,
|
349
|
+
quant_config=quant_config,
|
350
|
+
)
|
351
|
+
if config.tie_word_embeddings:
|
352
|
+
self.lm_head.weight = self.model.embed_tokens.weight
|
353
|
+
|
354
|
+
self.logits_processor = LogitsProcessor(config)
|
355
|
+
|
356
|
+
def forward(
|
357
|
+
self,
|
358
|
+
input_ids: torch.Tensor,
|
359
|
+
positions: torch.Tensor,
|
360
|
+
forward_batch: ForwardBatch,
|
361
|
+
inputs_embeds: Optional[torch.Tensor] = None,
|
362
|
+
) -> torch.Tensor:
|
363
|
+
hidden_states = self.model(input_ids, positions, forward_batch, inputs_embeds)
|
364
|
+
return self.logits_processor(
|
365
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
366
|
+
)
|
367
|
+
|
368
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
369
|
+
|
370
|
+
stacked_params_mapping = [
|
371
|
+
("gate_up_proj", "gate_proj", 0),
|
372
|
+
("gate_up_proj", "up_proj", 1),
|
373
|
+
]
|
374
|
+
|
375
|
+
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
376
|
+
ckpt_gate_proj_name="gate_proj",
|
377
|
+
ckpt_down_proj_name="down_proj",
|
378
|
+
ckpt_up_proj_name="up_proj",
|
379
|
+
num_experts=self.config.num_experts,
|
380
|
+
)
|
381
|
+
|
382
|
+
params_dict = dict(self.named_parameters())
|
383
|
+
for name, loaded_weight in weights:
|
384
|
+
|
385
|
+
if (
|
386
|
+
hasattr(self.config, "norm_head")
|
387
|
+
and self.config.norm_head
|
388
|
+
and "lm_head.weight" in name
|
389
|
+
):
|
390
|
+
loaded_weight = F.normalize(loaded_weight, dim=0, p=2, eps=1e-7)
|
391
|
+
|
392
|
+
if "model.word_embeddings.weight" == name:
|
393
|
+
name = "model.embed_tokens.weight"
|
394
|
+
|
395
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
396
|
+
if weight_name in name and "mlp.experts" not in name:
|
397
|
+
full_param_name = name.replace(weight_name, param_name)
|
398
|
+
param = params_dict[full_param_name]
|
399
|
+
param.weight_loader(param, loaded_weight, shard_id)
|
400
|
+
break
|
401
|
+
else:
|
402
|
+
for p_name, w_name, e_id, s_id in expert_params_mapping:
|
403
|
+
if w_name in name and "mlp.experts" in name:
|
404
|
+
full_param_name = name.replace(w_name, p_name)
|
405
|
+
param = params_dict[full_param_name]
|
406
|
+
param.weight_loader(
|
407
|
+
param,
|
408
|
+
loaded_weight,
|
409
|
+
full_param_name,
|
410
|
+
shard_id=s_id,
|
411
|
+
expert_id=e_id,
|
412
|
+
)
|
413
|
+
break
|
414
|
+
else:
|
415
|
+
if name.endswith(".bias") and name not in params_dict:
|
416
|
+
continue
|
417
|
+
|
418
|
+
param = params_dict[name]
|
419
|
+
weight_loader = getattr(
|
420
|
+
param, "weight_loader", default_weight_loader
|
421
|
+
)
|
422
|
+
weight_loader(param, loaded_weight)
|
423
|
+
|
424
|
+
|
425
|
+
EntryClass = BailingMoeForCausalLM
|
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -60,12 +60,9 @@ from sglang.srt.layers.linear import (
|
|
60
60
|
RowParallelLinear,
|
61
61
|
)
|
62
62
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
63
|
-
from sglang.srt.layers.moe.ep_moe.layer import
|
64
|
-
DeepEPMoE,
|
65
|
-
get_moe_impl_class,
|
66
|
-
should_use_flashinfer_trtllm_moe,
|
67
|
-
)
|
63
|
+
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
|
68
64
|
from sglang.srt.layers.moe.topk import TopK
|
65
|
+
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
|
69
66
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
70
67
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
71
68
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
@@ -211,13 +208,21 @@ class DeepseekV2MLP(nn.Module):
|
|
211
208
|
)
|
212
209
|
self.act_fn = SiluAndMul()
|
213
210
|
|
214
|
-
def forward(
|
211
|
+
def forward(
|
212
|
+
self,
|
213
|
+
x,
|
214
|
+
forward_batch=None,
|
215
|
+
can_fuse_mlp_allreduce: bool = False,
|
216
|
+
use_reduce_scatter: bool = False,
|
217
|
+
):
|
215
218
|
if (self.tp_size == 1) and x.shape[0] == 0:
|
216
219
|
return x
|
217
220
|
|
218
221
|
gate_up, _ = self.gate_up_proj(x)
|
219
222
|
x = self.act_fn(gate_up)
|
220
|
-
x, _ = self.down_proj(
|
223
|
+
x, _ = self.down_proj(
|
224
|
+
x, skip_all_reduce=can_fuse_mlp_allreduce or use_reduce_scatter
|
225
|
+
)
|
221
226
|
return x
|
222
227
|
|
223
228
|
|
@@ -307,19 +312,15 @@ class DeepseekV2MoE(nn.Module):
|
|
307
312
|
config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
|
308
313
|
)
|
309
314
|
|
310
|
-
self.topk = (
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
routed_scaling_factor=self.routed_scaling_factor,
|
320
|
-
)
|
321
|
-
if not should_use_flashinfer_trtllm_moe()
|
322
|
-
else None
|
315
|
+
self.topk = TopK(
|
316
|
+
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
317
|
+
renormalize=config.norm_topk_prob,
|
318
|
+
use_grouped_topk=True,
|
319
|
+
num_expert_group=config.n_group,
|
320
|
+
num_fused_shared_experts=self.num_fused_shared_experts,
|
321
|
+
topk_group=config.topk_group,
|
322
|
+
correction_bias=self.gate.e_score_correction_bias,
|
323
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
323
324
|
)
|
324
325
|
|
325
326
|
self.experts = get_moe_impl_class()(
|
@@ -448,6 +449,7 @@ class DeepseekV2MoE(nn.Module):
|
|
448
449
|
hidden_states: torch.Tensor,
|
449
450
|
forward_batch: Optional[ForwardBatch] = None,
|
450
451
|
can_fuse_mlp_allreduce: bool = False,
|
452
|
+
use_reduce_scatter: bool = False,
|
451
453
|
) -> torch.Tensor:
|
452
454
|
if not self._enable_deepep_moe:
|
453
455
|
DUAL_STREAM_TOKEN_THRESHOLD = 1024
|
@@ -457,15 +459,20 @@ class DeepseekV2MoE(nn.Module):
|
|
457
459
|
and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
|
458
460
|
):
|
459
461
|
return self.forward_normal_dual_stream(
|
460
|
-
hidden_states, can_fuse_mlp_allreduce
|
462
|
+
hidden_states, can_fuse_mlp_allreduce, use_reduce_scatter
|
461
463
|
)
|
462
464
|
else:
|
463
|
-
return self.forward_normal(
|
465
|
+
return self.forward_normal(
|
466
|
+
hidden_states, can_fuse_mlp_allreduce, use_reduce_scatter
|
467
|
+
)
|
464
468
|
else:
|
465
469
|
return self.forward_deepep(hidden_states, forward_batch)
|
466
470
|
|
467
471
|
def forward_normal_dual_stream(
|
468
|
-
self,
|
472
|
+
self,
|
473
|
+
hidden_states: torch.Tensor,
|
474
|
+
can_fuse_mlp_allreduce: bool = False,
|
475
|
+
use_reduce_scatter: bool = False,
|
469
476
|
) -> torch.Tensor:
|
470
477
|
|
471
478
|
current_stream = torch.cuda.current_stream()
|
@@ -476,10 +483,14 @@ class DeepseekV2MoE(nn.Module):
|
|
476
483
|
# router_logits: (num_tokens, n_experts)
|
477
484
|
router_logits = self.gate(hidden_states)
|
478
485
|
kwargs = {"hidden_states": hidden_states}
|
479
|
-
|
480
|
-
|
486
|
+
|
487
|
+
# FlashInferFP4MoE (TRTLLM path) expects (TopK, router_logits) tuple
|
488
|
+
# Regular FusedMoE (CUTLASS path) expects StandardTopKOutput
|
489
|
+
if should_use_flashinfer_trtllm_moe():
|
490
|
+
kwargs["topk_output"] = (self.topk, router_logits)
|
481
491
|
else:
|
482
|
-
kwargs["
|
492
|
+
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
493
|
+
|
483
494
|
final_hidden_states = self.experts(**kwargs)
|
484
495
|
if not _is_cuda:
|
485
496
|
final_hidden_states *= self.routed_scaling_factor
|
@@ -489,12 +500,15 @@ class DeepseekV2MoE(nn.Module):
|
|
489
500
|
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
|
490
501
|
final_hidden_states = final_hidden_states_out
|
491
502
|
sm.tag(final_hidden_states)
|
492
|
-
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
|
503
|
+
if self.tp_size > 1 and not can_fuse_mlp_allreduce and not use_reduce_scatter:
|
493
504
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
494
505
|
return final_hidden_states
|
495
506
|
|
496
507
|
def forward_normal(
|
497
|
-
self,
|
508
|
+
self,
|
509
|
+
hidden_states: torch.Tensor,
|
510
|
+
can_fuse_mlp_allreduce: bool = False,
|
511
|
+
use_reduce_scatter: bool = False,
|
498
512
|
) -> torch.Tensor:
|
499
513
|
if hasattr(self, "shared_experts") and use_intel_amx_backend(
|
500
514
|
self.shared_experts.gate_up_proj
|
@@ -505,10 +519,14 @@ class DeepseekV2MoE(nn.Module):
|
|
505
519
|
# router_logits: (num_tokens, n_experts)
|
506
520
|
router_logits = self.gate(hidden_states)
|
507
521
|
kwargs = {"hidden_states": hidden_states}
|
508
|
-
|
509
|
-
|
522
|
+
|
523
|
+
# FlashInferFP4MoE (TRTLLM path) expects (TopK, router_logits) tuple
|
524
|
+
# Regular FusedMoE (CUTLASS path) expects StandardTopKOutput
|
525
|
+
if should_use_flashinfer_trtllm_moe():
|
526
|
+
kwargs["topk_output"] = (self.topk, router_logits)
|
510
527
|
else:
|
511
|
-
kwargs["
|
528
|
+
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
529
|
+
|
512
530
|
final_hidden_states = self.experts(**kwargs)
|
513
531
|
if not _is_cuda and not _use_aiter:
|
514
532
|
# fused in biased_grouped_topk so we can skip here
|
@@ -519,7 +537,7 @@ class DeepseekV2MoE(nn.Module):
|
|
519
537
|
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
|
520
538
|
final_hidden_states = final_hidden_states_out
|
521
539
|
sm.tag(final_hidden_states)
|
522
|
-
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
|
540
|
+
if self.tp_size > 1 and not can_fuse_mlp_allreduce and not use_reduce_scatter:
|
523
541
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
524
542
|
return final_hidden_states
|
525
543
|
|
@@ -1821,6 +1839,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1821
1839
|
layer_scatter_modes=self.layer_scatter_modes,
|
1822
1840
|
input_layernorm=self.input_layernorm,
|
1823
1841
|
post_attention_layernorm=self.post_attention_layernorm,
|
1842
|
+
allow_reduce_scatter=True,
|
1824
1843
|
)
|
1825
1844
|
|
1826
1845
|
def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool:
|
@@ -1883,7 +1902,13 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1883
1902
|
and not self.is_nextn
|
1884
1903
|
)
|
1885
1904
|
|
1886
|
-
|
1905
|
+
# For DP with padding, reduce scatter can be used instead of all-reduce.
|
1906
|
+
use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
|
1907
|
+
forward_batch
|
1908
|
+
)
|
1909
|
+
hidden_states = self.mlp(
|
1910
|
+
hidden_states, forward_batch, can_fuse_mlp_allreduce, use_reduce_scatter
|
1911
|
+
)
|
1887
1912
|
|
1888
1913
|
if can_fuse_mlp_allreduce:
|
1889
1914
|
hidden_states._sglang_needs_allreduce_fusion = True
|
@@ -2060,6 +2085,8 @@ class DeepseekV2Model(nn.Module):
|
|
2060
2085
|
|
2061
2086
|
|
2062
2087
|
class DeepseekV2ForCausalLM(nn.Module):
|
2088
|
+
# for quark model load
|
2089
|
+
packed_modules_mapping = {}
|
2063
2090
|
|
2064
2091
|
def __init__(
|
2065
2092
|
self,
|
@@ -2068,6 +2095,18 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2068
2095
|
prefix: str = "",
|
2069
2096
|
) -> None:
|
2070
2097
|
super().__init__()
|
2098
|
+
|
2099
|
+
# for quark model load
|
2100
|
+
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
|
2101
|
+
self.fuse_qkv_a_proj = (
|
2102
|
+
hasattr(config, "q_lora_rank") and config.q_lora_rank is not None
|
2103
|
+
)
|
2104
|
+
if self.fuse_qkv_a_proj:
|
2105
|
+
self.packed_modules_mapping["fused_qkv_a_proj_with_mqa"] = [
|
2106
|
+
"q_a_proj",
|
2107
|
+
"kv_a_proj_with_mqa",
|
2108
|
+
]
|
2109
|
+
|
2071
2110
|
self.config = config
|
2072
2111
|
self.tp_size = get_tensor_model_parallel_world_size()
|
2073
2112
|
self.quant_config = quant_config
|