sglang 0.4.4.post2__py3-none-any.whl → 0.4.4.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +72 -10
- sglang/srt/_custom_ops.py +59 -92
- sglang/srt/configs/deepseekvl2.py +10 -1
- sglang/srt/configs/model_config.py +6 -16
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/custom_op.py +5 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +28 -80
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/distributed/parallel_state.py +32 -5
- sglang/srt/entrypoints/engine.py +0 -5
- sglang/srt/entrypoints/http_server.py +7 -1
- sglang/srt/entrypoints/verl_engine.py +2 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/attention/flashattention_backend.py +582 -125
- sglang/srt/layers/attention/flashinfer_backend.py +5 -7
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
- sglang/srt/layers/attention/flashmla_backend.py +1 -1
- sglang/srt/layers/dp_attention.py +12 -1
- sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
- sglang/srt/layers/moe/ep_moe/layer.py +79 -80
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +403 -47
- sglang/srt/layers/moe/topk.py +79 -6
- sglang/srt/layers/quantization/__init__.py +137 -165
- sglang/srt/layers/quantization/awq.py +200 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
- sglang/srt/layers/quantization/fp8_kernel.py +2 -1
- sglang/srt/layers/quantization/fp8_utils.py +1 -4
- sglang/srt/layers/quantization/gptq.py +30 -40
- sglang/srt/layers/quantization/moe_wna16.py +501 -0
- sglang/srt/layers/quantization/utils.py +1 -1
- sglang/srt/layers/quantization/w8a8_fp8.py +1 -1
- sglang/srt/lora/backend/base_backend.py +4 -4
- sglang/srt/lora/backend/flashinfer_backend.py +12 -9
- sglang/srt/lora/backend/triton_backend.py +5 -8
- sglang/srt/lora/layers.py +19 -33
- sglang/srt/lora/lora_manager.py +20 -7
- sglang/srt/lora/mem_pool.py +12 -6
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
- sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
- sglang/srt/lora/utils.py +6 -0
- sglang/srt/managers/cache_controller.py +34 -11
- sglang/srt/managers/io_struct.py +4 -2
- sglang/srt/managers/mm_utils.py +202 -156
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
- sglang/srt/managers/multimodal_processors/clip.py +44 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
- sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
- sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
- sglang/srt/managers/multimodal_processors/llava.py +34 -14
- sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
- sglang/srt/managers/multimodal_processors/mlama.py +10 -23
- sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
- sglang/srt/managers/schedule_batch.py +185 -127
- sglang/srt/managers/scheduler.py +29 -23
- sglang/srt/managers/tokenizer_manager.py +1 -2
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/utils.py +1 -6
- sglang/srt/mem_cache/hiradix_cache.py +62 -52
- sglang/srt/mem_cache/memory_pool.py +72 -6
- sglang/srt/mem_cache/paged_allocator.py +39 -0
- sglang/srt/metrics/collector.py +23 -53
- sglang/srt/model_executor/cuda_graph_runner.py +16 -13
- sglang/srt/model_executor/forward_batch_info.py +10 -10
- sglang/srt/model_executor/model_runner.py +64 -59
- sglang/srt/model_loader/loader.py +19 -1
- sglang/srt/model_loader/weight_utils.py +6 -3
- sglang/srt/models/clip.py +568 -0
- sglang/srt/models/deepseek_janus_pro.py +12 -17
- sglang/srt/models/deepseek_v2.py +339 -123
- sglang/srt/models/deepseek_vl2.py +105 -104
- sglang/srt/models/gemma3_causal.py +12 -2
- sglang/srt/models/gemma3_mm.py +20 -80
- sglang/srt/models/llama.py +4 -1
- sglang/srt/models/llava.py +31 -19
- sglang/srt/models/llavavid.py +16 -7
- sglang/srt/models/minicpmo.py +63 -147
- sglang/srt/models/minicpmv.py +17 -27
- sglang/srt/models/mllama.py +29 -14
- sglang/srt/models/qwen2.py +9 -6
- sglang/srt/models/qwen2_5_vl.py +21 -31
- sglang/srt/models/qwen2_vl.py +20 -21
- sglang/srt/openai_api/adapter.py +106 -93
- sglang/srt/openai_api/protocol.py +10 -5
- sglang/srt/patch_torch.py +71 -0
- sglang/srt/platforms/interface.py +371 -0
- sglang/srt/server_args.py +120 -25
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
- sglang/srt/speculative/eagle_utils.py +140 -28
- sglang/srt/speculative/eagle_worker.py +94 -25
- sglang/srt/utils.py +137 -51
- sglang/test/runners.py +27 -2
- sglang/test/test_custom_ops.py +55 -0
- sglang/test/test_utils.py +14 -27
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/METADATA +10 -5
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/RECORD +108 -99
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/top_level.txt +0 -0
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -16,12 +16,14 @@
|
|
16
16
|
# https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
|
17
17
|
"""Inference-only DeepseekV2 model."""
|
18
18
|
|
19
|
+
import logging
|
19
20
|
import os
|
20
21
|
from typing import Any, Dict, Iterable, Optional, Tuple
|
21
22
|
|
22
23
|
import torch
|
23
24
|
import torch.nn.functional as F
|
24
25
|
from torch import nn
|
26
|
+
from tqdm import tqdm
|
25
27
|
from transformers import PretrainedConfig
|
26
28
|
|
27
29
|
from sglang.srt.distributed import (
|
@@ -30,15 +32,14 @@ from sglang.srt.distributed import (
|
|
30
32
|
tensor_model_parallel_all_reduce,
|
31
33
|
)
|
32
34
|
from sglang.srt.layers.activation import SiluAndMul
|
33
|
-
from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
|
34
|
-
decode_attention_fwd_grouped_rope,
|
35
|
-
)
|
36
35
|
from sglang.srt.layers.dp_attention import (
|
37
36
|
dp_gather_partial,
|
38
37
|
dp_scatter,
|
39
38
|
get_attention_dp_size,
|
40
39
|
get_attention_tp_rank,
|
41
40
|
get_attention_tp_size,
|
41
|
+
tp_all_gather,
|
42
|
+
tp_reduce_scatter,
|
42
43
|
)
|
43
44
|
from sglang.srt.layers.layernorm import RMSNorm
|
44
45
|
from sglang.srt.layers.linear import (
|
@@ -71,7 +72,7 @@ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
|
71
72
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
72
73
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
73
74
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
74
|
-
from sglang.srt.utils import add_prefix, is_cuda,
|
75
|
+
from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_hip
|
75
76
|
|
76
77
|
_is_hip = is_hip()
|
77
78
|
_is_cuda = is_cuda()
|
@@ -81,8 +82,15 @@ if _is_cuda:
|
|
81
82
|
else:
|
82
83
|
from vllm import _custom_ops as ops
|
83
84
|
|
85
|
+
if _is_hip:
|
86
|
+
from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
|
87
|
+
decode_attention_fwd_grouped_rope,
|
88
|
+
)
|
89
|
+
|
84
90
|
expert_distribution_recorder = ExpertDistributionRecorder()
|
85
91
|
|
92
|
+
logger = logging.getLogger(__name__)
|
93
|
+
|
86
94
|
|
87
95
|
class DeepseekV2MLP(nn.Module):
|
88
96
|
def __init__(
|
@@ -164,6 +172,12 @@ class DeepseekV2MoE(nn.Module):
|
|
164
172
|
self.tp_size = get_tensor_model_parallel_world_size()
|
165
173
|
self.routed_scaling_factor = config.routed_scaling_factor
|
166
174
|
self.n_shared_experts = config.n_shared_experts
|
175
|
+
self.n_share_experts_fusion = (
|
176
|
+
global_server_args_dict["n_share_experts_fusion"]
|
177
|
+
if global_server_args_dict["n_share_experts_fusion"] is not None
|
178
|
+
else 0
|
179
|
+
)
|
180
|
+
|
167
181
|
self.routed_scaling_factor = config.routed_scaling_factor
|
168
182
|
if self.tp_size > config.n_routed_experts:
|
169
183
|
raise ValueError(
|
@@ -184,9 +198,10 @@ class DeepseekV2MoE(nn.Module):
|
|
184
198
|
if global_server_args_dict["enable_deepep_moe"]
|
185
199
|
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
|
186
200
|
)
|
201
|
+
|
187
202
|
self.experts = MoEImpl(
|
188
|
-
num_experts=config.n_routed_experts,
|
189
|
-
top_k=config.num_experts_per_tok,
|
203
|
+
num_experts=config.n_routed_experts + self.n_share_experts_fusion,
|
204
|
+
top_k=config.num_experts_per_tok + min(self.n_share_experts_fusion, 1),
|
190
205
|
hidden_size=config.hidden_size,
|
191
206
|
intermediate_size=config.moe_intermediate_size,
|
192
207
|
renormalize=config.norm_topk_prob,
|
@@ -196,9 +211,14 @@ class DeepseekV2MoE(nn.Module):
|
|
196
211
|
topk_group=config.topk_group,
|
197
212
|
correction_bias=self.gate.e_score_correction_bias,
|
198
213
|
prefix=add_prefix("experts", prefix),
|
214
|
+
**(
|
215
|
+
dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
|
216
|
+
if global_server_args_dict["enable_deepep_moe"]
|
217
|
+
else {}
|
218
|
+
),
|
199
219
|
)
|
200
220
|
|
201
|
-
if config.n_shared_experts is not None:
|
221
|
+
if config.n_shared_experts is not None and self.n_share_experts_fusion == 0:
|
202
222
|
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
203
223
|
# disable tp for shared experts when enable deepep moe
|
204
224
|
if not global_server_args_dict["enable_deepep_moe"]:
|
@@ -223,6 +243,8 @@ class DeepseekV2MoE(nn.Module):
|
|
223
243
|
)
|
224
244
|
|
225
245
|
if global_server_args_dict["enable_deepep_moe"]:
|
246
|
+
# TODO: we will support tp < ep in the future
|
247
|
+
self.ep_size = get_tensor_model_parallel_world_size()
|
226
248
|
self.num_experts = config.n_routed_experts
|
227
249
|
self.top_k = config.num_experts_per_tok
|
228
250
|
self.renormalize = config.norm_topk_prob
|
@@ -242,7 +264,9 @@ class DeepseekV2MoE(nn.Module):
|
|
242
264
|
num_local_experts=config.n_routed_experts // self.tp_size,
|
243
265
|
hidden_size=config.hidden_size,
|
244
266
|
params_dtype=config.torch_dtype,
|
267
|
+
deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
|
245
268
|
async_finish=True, # TODO
|
269
|
+
return_recv_hook=True,
|
246
270
|
)
|
247
271
|
|
248
272
|
def forward(
|
@@ -254,8 +278,10 @@ class DeepseekV2MoE(nn.Module):
|
|
254
278
|
return self.forward_deepep(hidden_states, forward_mode)
|
255
279
|
|
256
280
|
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
257
|
-
if self.n_shared_experts is not None:
|
281
|
+
if self.n_shared_experts is not None and self.n_share_experts_fusion == 0:
|
258
282
|
shared_output = self.shared_experts(hidden_states)
|
283
|
+
else:
|
284
|
+
shared_output = None
|
259
285
|
# router_logits: (num_tokens, n_experts)
|
260
286
|
router_logits = self.gate(hidden_states)
|
261
287
|
final_hidden_states = (
|
@@ -278,7 +304,11 @@ class DeepseekV2MoE(nn.Module):
|
|
278
304
|
topk_weights = torch.empty(
|
279
305
|
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
|
280
306
|
)
|
281
|
-
if
|
307
|
+
if (
|
308
|
+
forward_mode is not None
|
309
|
+
and not forward_mode.is_idle()
|
310
|
+
and hidden_states.shape[0] > 0
|
311
|
+
):
|
282
312
|
# router_logits: (num_tokens, n_experts)
|
283
313
|
router_logits = self.gate(hidden_states)
|
284
314
|
if self.n_shared_experts is not None:
|
@@ -293,28 +323,39 @@ class DeepseekV2MoE(nn.Module):
|
|
293
323
|
num_expert_group=self.num_expert_group,
|
294
324
|
correction_bias=self.correction_bias,
|
295
325
|
)
|
296
|
-
if self.
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
326
|
+
if self.ep_size > 1:
|
327
|
+
(
|
328
|
+
hidden_states,
|
329
|
+
topk_idx,
|
330
|
+
topk_weights,
|
331
|
+
reorder_topk_ids,
|
332
|
+
seg_indptr,
|
333
|
+
masked_m,
|
334
|
+
expected_m,
|
335
|
+
) = self.deepep_dispatcher.dispatch(
|
336
|
+
hidden_states,
|
337
|
+
topk_idx,
|
338
|
+
topk_weights,
|
339
|
+
self.num_experts,
|
340
|
+
forward_mode=forward_mode,
|
305
341
|
)
|
306
342
|
final_hidden_states = (
|
307
343
|
self.experts(
|
308
|
-
hidden_states=
|
344
|
+
hidden_states=hidden_states,
|
309
345
|
reorder_topk_ids=reorder_topk_ids,
|
310
346
|
seg_indptr=seg_indptr,
|
347
|
+
masked_m=masked_m,
|
348
|
+
expected_m=expected_m,
|
311
349
|
forward_mode=forward_mode,
|
312
350
|
)
|
313
351
|
* self.routed_scaling_factor
|
314
352
|
)
|
315
|
-
if self.
|
353
|
+
if self.ep_size > 1:
|
316
354
|
final_hidden_states = self.deepep_dispatcher.combine(
|
317
|
-
final_hidden_states,
|
355
|
+
final_hidden_states,
|
356
|
+
topk_idx,
|
357
|
+
topk_weights,
|
358
|
+
forward_mode,
|
318
359
|
)
|
319
360
|
if shared_output is not None:
|
320
361
|
final_hidden_states = final_hidden_states + shared_output
|
@@ -645,14 +686,14 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
645
686
|
self.w_vc = None
|
646
687
|
self.w_scale = None
|
647
688
|
|
648
|
-
self.enable_flashinfer_mla = global_server_args_dict["enable_flashinfer_mla"]
|
649
689
|
self.flashinfer_mla_disable_ragged = global_server_args_dict[
|
650
690
|
"flashinfer_mla_disable_ragged"
|
651
691
|
]
|
692
|
+
self.attention_backend = global_server_args_dict["attention_backend"]
|
652
693
|
self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
|
653
694
|
|
654
695
|
def no_absorb(self, forward_batch: ForwardBatch) -> bool:
|
655
|
-
if self.
|
696
|
+
if self.attention_backend == "flashinfer":
|
656
697
|
# Flashinfer MLA: Do not absorb when enabling ragged prefill
|
657
698
|
return (
|
658
699
|
not self.flashinfer_mla_disable_ragged
|
@@ -661,6 +702,9 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
661
702
|
and not forward_batch.forward_mode.is_draft_extend()
|
662
703
|
and sum(forward_batch.extend_prefix_lens_cpu) == 0
|
663
704
|
)
|
705
|
+
elif self.attention_backend == "fa3":
|
706
|
+
# Flash Attention: Keep absorbing for all extend/decode
|
707
|
+
return False
|
664
708
|
else:
|
665
709
|
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
|
666
710
|
return (
|
@@ -969,6 +1013,14 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
969
1013
|
is_nextn: bool = False,
|
970
1014
|
prefix: str = "",
|
971
1015
|
) -> None:
|
1016
|
+
|
1017
|
+
def is_sparse_layer(l: int):
|
1018
|
+
return (
|
1019
|
+
config.n_routed_experts is not None
|
1020
|
+
and l >= config.first_k_dense_replace
|
1021
|
+
and l % config.moe_layer_freq == 0
|
1022
|
+
)
|
1023
|
+
|
972
1024
|
super().__init__()
|
973
1025
|
self.hidden_size = config.hidden_size
|
974
1026
|
rope_theta = getattr(config, "rope_theta", 10000)
|
@@ -977,6 +1029,8 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
977
1029
|
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
|
978
1030
|
self.layer_id = layer_id
|
979
1031
|
self.dp_size = get_attention_dp_size()
|
1032
|
+
self.attn_tp_size = get_attention_tp_size()
|
1033
|
+
self.attn_tp_rank = get_attention_tp_rank()
|
980
1034
|
|
981
1035
|
if not global_server_args_dict["disable_mla"]:
|
982
1036
|
self.self_attn = DeepseekV2AttentionMLA(
|
@@ -1019,16 +1073,13 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1019
1073
|
prefix=add_prefix("self_attn", prefix),
|
1020
1074
|
)
|
1021
1075
|
|
1022
|
-
if is_nextn or (
|
1023
|
-
config.n_routed_experts is not None
|
1024
|
-
and layer_id >= config.first_k_dense_replace
|
1025
|
-
and layer_id % config.moe_layer_freq == 0
|
1026
|
-
):
|
1076
|
+
if is_nextn or is_sparse_layer(layer_id):
|
1027
1077
|
self.mlp = DeepseekV2MoE(
|
1028
1078
|
config=config,
|
1029
1079
|
quant_config=quant_config,
|
1030
1080
|
prefix=add_prefix("mlp", prefix),
|
1031
1081
|
)
|
1082
|
+
self.is_sparse = True
|
1032
1083
|
else:
|
1033
1084
|
self.mlp = DeepseekV2MLP(
|
1034
1085
|
hidden_size=config.hidden_size,
|
@@ -1037,6 +1088,14 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1037
1088
|
quant_config=quant_config,
|
1038
1089
|
prefix=add_prefix("mlp", prefix),
|
1039
1090
|
)
|
1091
|
+
self.is_sparse = False
|
1092
|
+
|
1093
|
+
self.input_is_scattered = (
|
1094
|
+
is_sparse_layer(layer_id - 1)
|
1095
|
+
and global_server_args_dict["enable_deepep_moe"]
|
1096
|
+
)
|
1097
|
+
self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
|
1098
|
+
|
1040
1099
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
1041
1100
|
self.post_attention_layernorm = RMSNorm(
|
1042
1101
|
config.hidden_size, eps=config.rms_norm_eps
|
@@ -1049,6 +1108,23 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1049
1108
|
forward_batch: ForwardBatch,
|
1050
1109
|
residual: Optional[torch.Tensor],
|
1051
1110
|
) -> torch.Tensor:
|
1111
|
+
if global_server_args_dict["enable_deepep_moe"] and self.is_sparse:
|
1112
|
+
return self.forward_deepep(
|
1113
|
+
positions, hidden_states, forward_batch, residual
|
1114
|
+
)
|
1115
|
+
else:
|
1116
|
+
return self.forward_normal(
|
1117
|
+
positions, hidden_states, forward_batch, residual
|
1118
|
+
)
|
1119
|
+
|
1120
|
+
def forward_normal(
|
1121
|
+
self,
|
1122
|
+
positions: torch.Tensor,
|
1123
|
+
hidden_states: torch.Tensor,
|
1124
|
+
forward_batch: ForwardBatch,
|
1125
|
+
residual: Optional[torch.Tensor],
|
1126
|
+
) -> torch.Tensor:
|
1127
|
+
|
1052
1128
|
if hidden_states.shape[0] == 0:
|
1053
1129
|
residual = hidden_states
|
1054
1130
|
else:
|
@@ -1058,6 +1134,10 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1058
1134
|
else:
|
1059
1135
|
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
1060
1136
|
|
1137
|
+
assert not (
|
1138
|
+
self.attn_tp_size != 1 and self.input_is_scattered
|
1139
|
+
), "moe_layer_freq > 1 is not supported when attn_tp_size > 1"
|
1140
|
+
|
1061
1141
|
# Self Attention
|
1062
1142
|
hidden_states = self.self_attn(
|
1063
1143
|
positions=positions,
|
@@ -1069,25 +1149,15 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1069
1149
|
if get_tensor_model_parallel_world_size() > 1:
|
1070
1150
|
# all gather and all reduce
|
1071
1151
|
if self.dp_size != 1:
|
1072
|
-
if
|
1073
|
-
|
1074
|
-
|
1075
|
-
|
1076
|
-
|
1077
|
-
|
1078
|
-
|
1079
|
-
|
1080
|
-
|
1081
|
-
else:
|
1082
|
-
if get_attention_tp_rank() == 0:
|
1083
|
-
hidden_states += residual
|
1084
|
-
hidden_states, local_hidden_states = (
|
1085
|
-
forward_batch.gathered_buffer,
|
1086
|
-
hidden_states,
|
1087
|
-
)
|
1088
|
-
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
|
1089
|
-
dp_scatter(residual, hidden_states, forward_batch)
|
1090
|
-
hidden_states = self.post_attention_layernorm(hidden_states)
|
1152
|
+
if self.attn_tp_rank == 0:
|
1153
|
+
hidden_states += residual
|
1154
|
+
hidden_states, local_hidden_states = (
|
1155
|
+
forward_batch.gathered_buffer,
|
1156
|
+
hidden_states,
|
1157
|
+
)
|
1158
|
+
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
|
1159
|
+
dp_scatter(residual, hidden_states, forward_batch)
|
1160
|
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
1091
1161
|
else:
|
1092
1162
|
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
1093
1163
|
hidden_states, residual = self.post_attention_layernorm(
|
@@ -1101,6 +1171,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1101
1171
|
# Fully Connected
|
1102
1172
|
hidden_states = self.mlp(hidden_states)
|
1103
1173
|
|
1174
|
+
# TODO(ch-wan): ues reduce-scatter in MLP to avoid this scatter
|
1104
1175
|
# Scatter
|
1105
1176
|
if self.dp_size != 1:
|
1106
1177
|
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
@@ -1113,9 +1184,79 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1113
1184
|
|
1114
1185
|
return hidden_states, residual
|
1115
1186
|
|
1187
|
+
def forward_deepep(
|
1188
|
+
self,
|
1189
|
+
positions: torch.Tensor,
|
1190
|
+
hidden_states: torch.Tensor,
|
1191
|
+
forward_batch: ForwardBatch,
|
1192
|
+
residual: Optional[torch.Tensor],
|
1193
|
+
) -> torch.Tensor:
|
1116
1194
|
|
1117
|
-
|
1195
|
+
if hidden_states.shape[0] == 0:
|
1196
|
+
residual = hidden_states
|
1197
|
+
else:
|
1198
|
+
if residual is None:
|
1199
|
+
residual = hidden_states
|
1200
|
+
hidden_states = self.input_layernorm(hidden_states)
|
1201
|
+
else:
|
1202
|
+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
1118
1203
|
|
1204
|
+
if self.attn_tp_size != 1 and self.input_is_scattered:
|
1205
|
+
hidden_states, local_hidden_states = (
|
1206
|
+
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
1207
|
+
hidden_states,
|
1208
|
+
)
|
1209
|
+
tp_all_gather(
|
1210
|
+
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
|
1211
|
+
)
|
1212
|
+
|
1213
|
+
# Self Attention
|
1214
|
+
hidden_states = self.self_attn(
|
1215
|
+
positions=positions,
|
1216
|
+
hidden_states=hidden_states,
|
1217
|
+
forward_batch=forward_batch,
|
1218
|
+
)
|
1219
|
+
|
1220
|
+
if self.attn_tp_size != 1:
|
1221
|
+
if self.input_is_scattered:
|
1222
|
+
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
|
1223
|
+
hidden_states = tensor_list[self.attn_tp_rank]
|
1224
|
+
tp_reduce_scatter(hidden_states, tensor_list)
|
1225
|
+
if hidden_states.shape[0] != 0:
|
1226
|
+
hidden_states, residual = self.post_attention_layernorm(
|
1227
|
+
hidden_states, residual
|
1228
|
+
)
|
1229
|
+
else:
|
1230
|
+
if self.attn_tp_rank == 0:
|
1231
|
+
hidden_states += residual
|
1232
|
+
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
|
1233
|
+
hidden_states = tensor_list[self.attn_tp_rank]
|
1234
|
+
tp_reduce_scatter(hidden_states, tensor_list)
|
1235
|
+
residual = hidden_states
|
1236
|
+
if hidden_states.shape[0] != 0:
|
1237
|
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
1238
|
+
else:
|
1239
|
+
if hidden_states.shape[0] != 0:
|
1240
|
+
hidden_states, residual = self.post_attention_layernorm(
|
1241
|
+
hidden_states, residual
|
1242
|
+
)
|
1243
|
+
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
|
1244
|
+
|
1245
|
+
if self.is_last_layer and self.attn_tp_size != 1:
|
1246
|
+
hidden_states += residual
|
1247
|
+
residual = None
|
1248
|
+
hidden_states, local_hidden_states = (
|
1249
|
+
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
1250
|
+
hidden_states,
|
1251
|
+
)
|
1252
|
+
tp_all_gather(
|
1253
|
+
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
|
1254
|
+
)
|
1255
|
+
|
1256
|
+
return hidden_states, residual
|
1257
|
+
|
1258
|
+
|
1259
|
+
class DeepseekV2Model(nn.Module):
|
1119
1260
|
fall_back_to_pt_during_load = False
|
1120
1261
|
|
1121
1262
|
def __init__(
|
@@ -1169,7 +1310,10 @@ class DeepseekV2Model(nn.Module):
|
|
1169
1310
|
positions, hidden_states, forward_batch, residual
|
1170
1311
|
)
|
1171
1312
|
if not forward_batch.forward_mode.is_idle():
|
1172
|
-
|
1313
|
+
if residual is None:
|
1314
|
+
hidden_states = self.norm(hidden_states)
|
1315
|
+
else:
|
1316
|
+
hidden_states, _ = self.norm(hidden_states, residual)
|
1173
1317
|
return hidden_states
|
1174
1318
|
|
1175
1319
|
|
@@ -1183,7 +1327,28 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1183
1327
|
) -> None:
|
1184
1328
|
super().__init__()
|
1185
1329
|
self.config = config
|
1330
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
1186
1331
|
self.quant_config = quant_config
|
1332
|
+
self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
|
1333
|
+
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
|
1334
|
+
if (
|
1335
|
+
global_server_args_dict.get("disable_shared_experts_fusion", False)
|
1336
|
+
or self.config.architectures[0] != "DeepseekV3ForCausalLM"
|
1337
|
+
or self.config.n_routed_experts != 256
|
1338
|
+
or self.config.routed_scaling_factor != 2.5
|
1339
|
+
):
|
1340
|
+
self.n_share_experts_fusion = None
|
1341
|
+
global_server_args_dict["n_share_experts_fusion"] = None
|
1342
|
+
logger.info(
|
1343
|
+
"Only Deepseek V3/R1 can use shared experts fusion optimization. Shared experts fusion optimization is disabled."
|
1344
|
+
)
|
1345
|
+
elif self.n_share_experts_fusion is None:
|
1346
|
+
global_server_args_dict["n_share_experts_fusion"] = self.tp_size
|
1347
|
+
self.n_share_experts_fusion = self.tp_size
|
1348
|
+
logger.info(
|
1349
|
+
f"Shared experts fusion optimization is default enabled in DeepSeek V3/R1, and n_share_experts_fusion is set to {self.tp_size}. You can tune it by setting --n_share_experts_fusion or disable it by setting --disable_shared_experts_fusion."
|
1350
|
+
)
|
1351
|
+
|
1187
1352
|
self.model = DeepseekV2Model(
|
1188
1353
|
config, quant_config, prefix=add_prefix("model", prefix)
|
1189
1354
|
)
|
@@ -1196,6 +1361,9 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1196
1361
|
self.logits_processor = LogitsProcessor(config)
|
1197
1362
|
self.dp_size = get_attention_dp_size()
|
1198
1363
|
|
1364
|
+
def get_input_embeddings(self) -> nn.Embedding:
|
1365
|
+
return self.model.embed_tokens
|
1366
|
+
|
1199
1367
|
@torch.no_grad()
|
1200
1368
|
def forward(
|
1201
1369
|
self,
|
@@ -1211,12 +1379,127 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1211
1379
|
input_ids, hidden_states, self.lm_head, forward_batch
|
1212
1380
|
)
|
1213
1381
|
|
1382
|
+
def post_load_weights(self):
|
1383
|
+
|
1384
|
+
# Perform post-processing after loading weights
|
1385
|
+
|
1386
|
+
if not global_server_args_dict["disable_mla"]:
|
1387
|
+
for layer_id in range(self.config.num_hidden_layers):
|
1388
|
+
self_attn = self.model.layers[layer_id].self_attn
|
1389
|
+
if hasattr(self_attn.kv_b_proj, "qweight"):
|
1390
|
+
# AWQ compatible
|
1391
|
+
if _is_cuda:
|
1392
|
+
w = awq_dequantize(
|
1393
|
+
self_attn.kv_b_proj.qweight,
|
1394
|
+
self_attn.kv_b_proj.scales,
|
1395
|
+
self_attn.kv_b_proj.qzeros,
|
1396
|
+
).T
|
1397
|
+
else:
|
1398
|
+
w = ops.awq_dequantize(
|
1399
|
+
self_attn.kv_b_proj.qweight,
|
1400
|
+
self_attn.kv_b_proj.scales,
|
1401
|
+
self_attn.kv_b_proj.qzeros,
|
1402
|
+
0,
|
1403
|
+
0,
|
1404
|
+
0,
|
1405
|
+
).T
|
1406
|
+
else:
|
1407
|
+
w = self_attn.kv_b_proj.weight
|
1408
|
+
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
|
1409
|
+
# This may affect the accuracy of fp8 model.
|
1410
|
+
if hasattr(self.quant_config, "weight_block_size") and w.dtype in (
|
1411
|
+
torch.float8_e4m3fn,
|
1412
|
+
torch.float8_e4m3fnuz,
|
1413
|
+
):
|
1414
|
+
weight_block_size = self.quant_config.weight_block_size
|
1415
|
+
if weight_block_size is not None:
|
1416
|
+
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
1417
|
+
if _is_hip:
|
1418
|
+
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
1419
|
+
weight=w,
|
1420
|
+
weight_scale=self_attn.kv_b_proj.weight_scale_inv,
|
1421
|
+
input_scale=None,
|
1422
|
+
)
|
1423
|
+
else:
|
1424
|
+
weight = w
|
1425
|
+
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
1426
|
+
|
1427
|
+
w, scale = block_quant_to_tensor_quant(
|
1428
|
+
weight, weight_scale, weight_block_size
|
1429
|
+
)
|
1430
|
+
self_attn.w_scale = scale
|
1431
|
+
if w.dtype == torch.int8:
|
1432
|
+
if hasattr(self.quant_config, "weight_block_size"):
|
1433
|
+
# block-wise int8 need it
|
1434
|
+
weight_block_size = self.quant_config.weight_block_size
|
1435
|
+
if weight_block_size is not None:
|
1436
|
+
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
1437
|
+
weight = w
|
1438
|
+
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
1439
|
+
w = int8_block_dequant(
|
1440
|
+
weight, weight_scale, weight_block_size
|
1441
|
+
).to(torch.bfloat16)
|
1442
|
+
else:
|
1443
|
+
# channel-wise int8 need it
|
1444
|
+
w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
|
1445
|
+
torch.bfloat16
|
1446
|
+
)
|
1447
|
+
w_kc, w_vc = w.unflatten(
|
1448
|
+
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
1449
|
+
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
1450
|
+
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
1451
|
+
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
|
1452
|
+
if (
|
1453
|
+
hasattr(self_attn.kv_b_proj, "weight_scale")
|
1454
|
+
and self_attn.w_scale is None
|
1455
|
+
):
|
1456
|
+
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
|
1457
|
+
if _is_hip:
|
1458
|
+
self_attn.w_scale *= 2.0
|
1459
|
+
|
1214
1460
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
1215
1461
|
stacked_params_mapping = [
|
1216
1462
|
# (param_name, shard_name, shard_id)
|
1217
1463
|
("gate_up_proj", "gate_proj", 0),
|
1218
1464
|
("gate_up_proj", "up_proj", 1),
|
1219
1465
|
]
|
1466
|
+
if self.n_share_experts_fusion is not None and self.n_share_experts_fusion > 0:
|
1467
|
+
weights_list = list(weights)
|
1468
|
+
weights_dict = dict(weights_list)
|
1469
|
+
suffix_list = [
|
1470
|
+
"down_proj.weight",
|
1471
|
+
"down_proj.weight_scale_inv",
|
1472
|
+
"gate_proj.weight",
|
1473
|
+
"gate_proj.weight_scale_inv",
|
1474
|
+
"up_proj.weight",
|
1475
|
+
"up_proj.weight_scale_inv",
|
1476
|
+
]
|
1477
|
+
names_to_remove = []
|
1478
|
+
for moe_layer in tqdm(
|
1479
|
+
range(
|
1480
|
+
self.config.first_k_dense_replace,
|
1481
|
+
self.config.num_hidden_layers,
|
1482
|
+
self.config.moe_layer_freq,
|
1483
|
+
),
|
1484
|
+
desc=f"Cloning {self.n_share_experts_fusion} "
|
1485
|
+
"replicas of the shared expert into MoE",
|
1486
|
+
):
|
1487
|
+
for num_repeat in range(self.n_share_experts_fusion):
|
1488
|
+
for suffix in suffix_list:
|
1489
|
+
shared_expert_weight_name = (
|
1490
|
+
f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
|
1491
|
+
)
|
1492
|
+
weights_list.append(
|
1493
|
+
(
|
1494
|
+
f"model.layers.{moe_layer}."
|
1495
|
+
f"mlp.experts."
|
1496
|
+
f"{self.config.n_routed_experts + num_repeat}"
|
1497
|
+
f".{suffix}",
|
1498
|
+
weights_dict[shared_expert_weight_name].clone(),
|
1499
|
+
)
|
1500
|
+
)
|
1501
|
+
names_to_remove += [shared_expert_weight_name]
|
1502
|
+
weights = [w for w in weights_list if w[0] not in names_to_remove]
|
1220
1503
|
|
1221
1504
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
1222
1505
|
# (param_name, weight_name, expert_id, shard_id)
|
@@ -1229,7 +1512,12 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1229
1512
|
ckpt_gate_proj_name="gate_proj",
|
1230
1513
|
ckpt_down_proj_name="down_proj",
|
1231
1514
|
ckpt_up_proj_name="up_proj",
|
1232
|
-
num_experts=self.config.n_routed_experts
|
1515
|
+
num_experts=self.config.n_routed_experts
|
1516
|
+
+ (
|
1517
|
+
self.n_share_experts_fusion
|
1518
|
+
if self.n_share_experts_fusion is not None
|
1519
|
+
else 0
|
1520
|
+
),
|
1233
1521
|
)
|
1234
1522
|
|
1235
1523
|
params_dict = dict(self.named_parameters())
|
@@ -1293,79 +1581,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1293
1581
|
)
|
1294
1582
|
weight_loader(param, loaded_weight)
|
1295
1583
|
|
1296
|
-
|
1297
|
-
for layer_id in range(self.config.num_hidden_layers):
|
1298
|
-
self_attn = self.model.layers[layer_id].self_attn
|
1299
|
-
if hasattr(self_attn.kv_b_proj, "qweight"):
|
1300
|
-
# AWQ compatible
|
1301
|
-
if _is_cuda:
|
1302
|
-
w = awq_dequantize(
|
1303
|
-
self_attn.kv_b_proj.qweight,
|
1304
|
-
self_attn.kv_b_proj.scales,
|
1305
|
-
self_attn.kv_b_proj.qzeros,
|
1306
|
-
).T
|
1307
|
-
else:
|
1308
|
-
w = ops.awq_dequantize(
|
1309
|
-
self_attn.kv_b_proj.qweight,
|
1310
|
-
self_attn.kv_b_proj.scales,
|
1311
|
-
self_attn.kv_b_proj.qzeros,
|
1312
|
-
0,
|
1313
|
-
0,
|
1314
|
-
0,
|
1315
|
-
).T
|
1316
|
-
else:
|
1317
|
-
w = self_attn.kv_b_proj.weight
|
1318
|
-
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
|
1319
|
-
# This may affect the accuracy of fp8 model.
|
1320
|
-
if hasattr(self.quant_config, "weight_block_size") and w.dtype in (
|
1321
|
-
torch.float8_e4m3fn,
|
1322
|
-
torch.float8_e4m3fnuz,
|
1323
|
-
):
|
1324
|
-
weight_block_size = self.quant_config.weight_block_size
|
1325
|
-
if weight_block_size is not None:
|
1326
|
-
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
1327
|
-
if _is_hip:
|
1328
|
-
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
1329
|
-
weight=w,
|
1330
|
-
weight_scale=self_attn.kv_b_proj.weight_scale_inv,
|
1331
|
-
input_scale=None,
|
1332
|
-
)
|
1333
|
-
else:
|
1334
|
-
weight = w
|
1335
|
-
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
1336
|
-
|
1337
|
-
w, scale = block_quant_to_tensor_quant(
|
1338
|
-
weight, weight_scale, weight_block_size
|
1339
|
-
)
|
1340
|
-
self_attn.w_scale = scale
|
1341
|
-
if w.dtype == torch.int8:
|
1342
|
-
if hasattr(self.quant_config, "weight_block_size"):
|
1343
|
-
# block-wise int8 need it
|
1344
|
-
weight_block_size = self.quant_config.weight_block_size
|
1345
|
-
if weight_block_size is not None:
|
1346
|
-
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
1347
|
-
weight = w
|
1348
|
-
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
1349
|
-
w = int8_block_dequant(
|
1350
|
-
weight, weight_scale, weight_block_size
|
1351
|
-
).to(torch.bfloat16)
|
1352
|
-
else:
|
1353
|
-
# channel-wise int8 need it
|
1354
|
-
w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
|
1355
|
-
torch.bfloat16
|
1356
|
-
)
|
1357
|
-
w_kc, w_vc = w.unflatten(
|
1358
|
-
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
1359
|
-
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
1360
|
-
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
1361
|
-
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
|
1362
|
-
if (
|
1363
|
-
hasattr(self_attn.kv_b_proj, "weight_scale")
|
1364
|
-
and self_attn.w_scale is None
|
1365
|
-
):
|
1366
|
-
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
|
1367
|
-
if _is_hip:
|
1368
|
-
self_attn.w_scale *= 2.0
|
1584
|
+
self.post_load_weights()
|
1369
1585
|
|
1370
1586
|
def get_embed_and_head(self):
|
1371
1587
|
return self.model.embed_tokens.weight, self.lm_head.weight
|