sglang 0.4.4.post3__py3-none-any.whl → 0.4.4.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +49 -7
- sglang/srt/_custom_ops.py +59 -92
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/custom_op.py +5 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +27 -79
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/entrypoints/engine.py +0 -5
- sglang/srt/layers/attention/flashattention_backend.py +394 -76
- sglang/srt/layers/attention/flashinfer_backend.py +5 -7
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
- sglang/srt/layers/attention/flashmla_backend.py +1 -1
- sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
- sglang/srt/layers/moe/ep_moe/layer.py +79 -80
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +403 -47
- sglang/srt/layers/moe/topk.py +49 -3
- sglang/srt/layers/quantization/__init__.py +4 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
- sglang/srt/layers/quantization/fp8_utils.py +1 -4
- sglang/srt/layers/quantization/moe_wna16.py +501 -0
- sglang/srt/layers/quantization/utils.py +1 -1
- sglang/srt/layers/rotary_embedding.py +0 -12
- sglang/srt/managers/cache_controller.py +34 -11
- sglang/srt/managers/mm_utils.py +202 -156
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
- sglang/srt/managers/multimodal_processors/clip.py +7 -26
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
- sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
- sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
- sglang/srt/managers/multimodal_processors/llava.py +34 -14
- sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
- sglang/srt/managers/multimodal_processors/mlama.py +10 -23
- sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
- sglang/srt/managers/schedule_batch.py +185 -128
- sglang/srt/managers/scheduler.py +4 -4
- sglang/srt/managers/tokenizer_manager.py +1 -1
- sglang/srt/managers/utils.py +1 -6
- sglang/srt/mem_cache/hiradix_cache.py +62 -52
- sglang/srt/mem_cache/memory_pool.py +72 -6
- sglang/srt/mem_cache/paged_allocator.py +39 -0
- sglang/srt/metrics/collector.py +23 -53
- sglang/srt/model_executor/cuda_graph_runner.py +8 -6
- sglang/srt/model_executor/forward_batch_info.py +10 -10
- sglang/srt/model_executor/model_runner.py +59 -57
- sglang/srt/model_loader/loader.py +8 -0
- sglang/srt/models/clip.py +12 -7
- sglang/srt/models/deepseek_janus_pro.py +10 -15
- sglang/srt/models/deepseek_v2.py +212 -121
- sglang/srt/models/deepseek_vl2.py +105 -104
- sglang/srt/models/gemma3_mm.py +14 -80
- sglang/srt/models/llama.py +4 -1
- sglang/srt/models/llava.py +31 -19
- sglang/srt/models/llavavid.py +16 -7
- sglang/srt/models/minicpmo.py +63 -147
- sglang/srt/models/minicpmv.py +17 -27
- sglang/srt/models/mllama.py +29 -14
- sglang/srt/models/qwen2.py +9 -6
- sglang/srt/models/qwen2_5_vl.py +21 -31
- sglang/srt/models/qwen2_vl.py +20 -21
- sglang/srt/openai_api/adapter.py +18 -6
- sglang/srt/platforms/interface.py +371 -0
- sglang/srt/server_args.py +99 -14
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
- sglang/srt/speculative/eagle_utils.py +140 -28
- sglang/srt/speculative/eagle_worker.py +93 -24
- sglang/srt/utils.py +104 -51
- sglang/test/test_custom_ops.py +55 -0
- sglang/test/test_utils.py +13 -26
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/METADATA +4 -3
- {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/RECORD +81 -76
- {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/top_level.txt +0 -0
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -16,12 +16,14 @@
|
|
16
16
|
# https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
|
17
17
|
"""Inference-only DeepseekV2 model."""
|
18
18
|
|
19
|
+
import logging
|
19
20
|
import os
|
20
21
|
from typing import Any, Dict, Iterable, Optional, Tuple
|
21
22
|
|
22
23
|
import torch
|
23
24
|
import torch.nn.functional as F
|
24
25
|
from torch import nn
|
26
|
+
from tqdm import tqdm
|
25
27
|
from transformers import PretrainedConfig
|
26
28
|
|
27
29
|
from sglang.srt.distributed import (
|
@@ -30,9 +32,6 @@ from sglang.srt.distributed import (
|
|
30
32
|
tensor_model_parallel_all_reduce,
|
31
33
|
)
|
32
34
|
from sglang.srt.layers.activation import SiluAndMul
|
33
|
-
from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
|
34
|
-
decode_attention_fwd_grouped_rope,
|
35
|
-
)
|
36
35
|
from sglang.srt.layers.dp_attention import (
|
37
36
|
dp_gather_partial,
|
38
37
|
dp_scatter,
|
@@ -73,7 +72,7 @@ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
|
73
72
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
74
73
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
75
74
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
76
|
-
from sglang.srt.utils import add_prefix, is_cuda, is_hip
|
75
|
+
from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_hip
|
77
76
|
|
78
77
|
_is_hip = is_hip()
|
79
78
|
_is_cuda = is_cuda()
|
@@ -83,8 +82,15 @@ if _is_cuda:
|
|
83
82
|
else:
|
84
83
|
from vllm import _custom_ops as ops
|
85
84
|
|
85
|
+
if _is_hip:
|
86
|
+
from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
|
87
|
+
decode_attention_fwd_grouped_rope,
|
88
|
+
)
|
89
|
+
|
86
90
|
expert_distribution_recorder = ExpertDistributionRecorder()
|
87
91
|
|
92
|
+
logger = logging.getLogger(__name__)
|
93
|
+
|
88
94
|
|
89
95
|
class DeepseekV2MLP(nn.Module):
|
90
96
|
def __init__(
|
@@ -166,6 +172,12 @@ class DeepseekV2MoE(nn.Module):
|
|
166
172
|
self.tp_size = get_tensor_model_parallel_world_size()
|
167
173
|
self.routed_scaling_factor = config.routed_scaling_factor
|
168
174
|
self.n_shared_experts = config.n_shared_experts
|
175
|
+
self.n_share_experts_fusion = (
|
176
|
+
global_server_args_dict["n_share_experts_fusion"]
|
177
|
+
if global_server_args_dict["n_share_experts_fusion"] is not None
|
178
|
+
else 0
|
179
|
+
)
|
180
|
+
|
169
181
|
self.routed_scaling_factor = config.routed_scaling_factor
|
170
182
|
if self.tp_size > config.n_routed_experts:
|
171
183
|
raise ValueError(
|
@@ -186,9 +198,10 @@ class DeepseekV2MoE(nn.Module):
|
|
186
198
|
if global_server_args_dict["enable_deepep_moe"]
|
187
199
|
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
|
188
200
|
)
|
201
|
+
|
189
202
|
self.experts = MoEImpl(
|
190
|
-
num_experts=config.n_routed_experts,
|
191
|
-
top_k=config.num_experts_per_tok,
|
203
|
+
num_experts=config.n_routed_experts + self.n_share_experts_fusion,
|
204
|
+
top_k=config.num_experts_per_tok + min(self.n_share_experts_fusion, 1),
|
192
205
|
hidden_size=config.hidden_size,
|
193
206
|
intermediate_size=config.moe_intermediate_size,
|
194
207
|
renormalize=config.norm_topk_prob,
|
@@ -198,9 +211,14 @@ class DeepseekV2MoE(nn.Module):
|
|
198
211
|
topk_group=config.topk_group,
|
199
212
|
correction_bias=self.gate.e_score_correction_bias,
|
200
213
|
prefix=add_prefix("experts", prefix),
|
214
|
+
**(
|
215
|
+
dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
|
216
|
+
if global_server_args_dict["enable_deepep_moe"]
|
217
|
+
else {}
|
218
|
+
),
|
201
219
|
)
|
202
220
|
|
203
|
-
if config.n_shared_experts is not None:
|
221
|
+
if config.n_shared_experts is not None and self.n_share_experts_fusion == 0:
|
204
222
|
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
205
223
|
# disable tp for shared experts when enable deepep moe
|
206
224
|
if not global_server_args_dict["enable_deepep_moe"]:
|
@@ -225,6 +243,8 @@ class DeepseekV2MoE(nn.Module):
|
|
225
243
|
)
|
226
244
|
|
227
245
|
if global_server_args_dict["enable_deepep_moe"]:
|
246
|
+
# TODO: we will support tp < ep in the future
|
247
|
+
self.ep_size = get_tensor_model_parallel_world_size()
|
228
248
|
self.num_experts = config.n_routed_experts
|
229
249
|
self.top_k = config.num_experts_per_tok
|
230
250
|
self.renormalize = config.norm_topk_prob
|
@@ -244,7 +264,9 @@ class DeepseekV2MoE(nn.Module):
|
|
244
264
|
num_local_experts=config.n_routed_experts // self.tp_size,
|
245
265
|
hidden_size=config.hidden_size,
|
246
266
|
params_dtype=config.torch_dtype,
|
267
|
+
deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
|
247
268
|
async_finish=True, # TODO
|
269
|
+
return_recv_hook=True,
|
248
270
|
)
|
249
271
|
|
250
272
|
def forward(
|
@@ -256,8 +278,10 @@ class DeepseekV2MoE(nn.Module):
|
|
256
278
|
return self.forward_deepep(hidden_states, forward_mode)
|
257
279
|
|
258
280
|
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
259
|
-
if self.n_shared_experts is not None:
|
281
|
+
if self.n_shared_experts is not None and self.n_share_experts_fusion == 0:
|
260
282
|
shared_output = self.shared_experts(hidden_states)
|
283
|
+
else:
|
284
|
+
shared_output = None
|
261
285
|
# router_logits: (num_tokens, n_experts)
|
262
286
|
router_logits = self.gate(hidden_states)
|
263
287
|
final_hidden_states = (
|
@@ -299,28 +323,39 @@ class DeepseekV2MoE(nn.Module):
|
|
299
323
|
num_expert_group=self.num_expert_group,
|
300
324
|
correction_bias=self.correction_bias,
|
301
325
|
)
|
302
|
-
if self.
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
326
|
+
if self.ep_size > 1:
|
327
|
+
(
|
328
|
+
hidden_states,
|
329
|
+
topk_idx,
|
330
|
+
topk_weights,
|
331
|
+
reorder_topk_ids,
|
332
|
+
seg_indptr,
|
333
|
+
masked_m,
|
334
|
+
expected_m,
|
335
|
+
) = self.deepep_dispatcher.dispatch(
|
336
|
+
hidden_states,
|
337
|
+
topk_idx,
|
338
|
+
topk_weights,
|
339
|
+
self.num_experts,
|
340
|
+
forward_mode=forward_mode,
|
311
341
|
)
|
312
342
|
final_hidden_states = (
|
313
343
|
self.experts(
|
314
|
-
hidden_states=
|
344
|
+
hidden_states=hidden_states,
|
315
345
|
reorder_topk_ids=reorder_topk_ids,
|
316
346
|
seg_indptr=seg_indptr,
|
347
|
+
masked_m=masked_m,
|
348
|
+
expected_m=expected_m,
|
317
349
|
forward_mode=forward_mode,
|
318
350
|
)
|
319
351
|
* self.routed_scaling_factor
|
320
352
|
)
|
321
|
-
if self.
|
353
|
+
if self.ep_size > 1:
|
322
354
|
final_hidden_states = self.deepep_dispatcher.combine(
|
323
|
-
final_hidden_states,
|
355
|
+
final_hidden_states,
|
356
|
+
topk_idx,
|
357
|
+
topk_weights,
|
358
|
+
forward_mode,
|
324
359
|
)
|
325
360
|
if shared_output is not None:
|
326
361
|
final_hidden_states = final_hidden_states + shared_output
|
@@ -651,7 +686,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
651
686
|
self.w_vc = None
|
652
687
|
self.w_scale = None
|
653
688
|
|
654
|
-
self.enable_flashinfer_mla = global_server_args_dict["enable_flashinfer_mla"]
|
655
689
|
self.flashinfer_mla_disable_ragged = global_server_args_dict[
|
656
690
|
"flashinfer_mla_disable_ragged"
|
657
691
|
]
|
@@ -659,7 +693,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
659
693
|
self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
|
660
694
|
|
661
695
|
def no_absorb(self, forward_batch: ForwardBatch) -> bool:
|
662
|
-
if self.
|
696
|
+
if self.attention_backend == "flashinfer":
|
663
697
|
# Flashinfer MLA: Do not absorb when enabling ragged prefill
|
664
698
|
return (
|
665
699
|
not self.flashinfer_mla_disable_ragged
|
@@ -1100,6 +1134,10 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1100
1134
|
else:
|
1101
1135
|
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
1102
1136
|
|
1137
|
+
assert not (
|
1138
|
+
self.attn_tp_size != 1 and self.input_is_scattered
|
1139
|
+
), "moe_layer_freq > 1 is not supported when attn_tp_size > 1"
|
1140
|
+
|
1103
1141
|
# Self Attention
|
1104
1142
|
hidden_states = self.self_attn(
|
1105
1143
|
positions=positions,
|
@@ -1107,22 +1145,6 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1107
1145
|
forward_batch=forward_batch,
|
1108
1146
|
)
|
1109
1147
|
|
1110
|
-
if self.attn_tp_size != 1 and self.input_is_scattered:
|
1111
|
-
hidden_states, local_hidden_states = (
|
1112
|
-
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
1113
|
-
hidden_states,
|
1114
|
-
)
|
1115
|
-
tp_all_gather(
|
1116
|
-
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
|
1117
|
-
)
|
1118
|
-
residual, local_residual = (
|
1119
|
-
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
1120
|
-
residual,
|
1121
|
-
)
|
1122
|
-
tp_all_gather(
|
1123
|
-
list(residual.tensor_split(self.attn_tp_size)), local_residual
|
1124
|
-
)
|
1125
|
-
|
1126
1148
|
# Gather
|
1127
1149
|
if get_tensor_model_parallel_world_size() > 1:
|
1128
1150
|
# all gather and all reduce
|
@@ -1221,6 +1243,8 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1221
1243
|
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
|
1222
1244
|
|
1223
1245
|
if self.is_last_layer and self.attn_tp_size != 1:
|
1246
|
+
hidden_states += residual
|
1247
|
+
residual = None
|
1224
1248
|
hidden_states, local_hidden_states = (
|
1225
1249
|
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
1226
1250
|
hidden_states,
|
@@ -1228,19 +1252,11 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1228
1252
|
tp_all_gather(
|
1229
1253
|
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
|
1230
1254
|
)
|
1231
|
-
residual, local_residual = (
|
1232
|
-
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
1233
|
-
residual,
|
1234
|
-
)
|
1235
|
-
tp_all_gather(
|
1236
|
-
list(residual.tensor_split(self.attn_tp_size)), local_residual
|
1237
|
-
)
|
1238
1255
|
|
1239
1256
|
return hidden_states, residual
|
1240
1257
|
|
1241
1258
|
|
1242
1259
|
class DeepseekV2Model(nn.Module):
|
1243
|
-
|
1244
1260
|
fall_back_to_pt_during_load = False
|
1245
1261
|
|
1246
1262
|
def __init__(
|
@@ -1294,7 +1310,10 @@ class DeepseekV2Model(nn.Module):
|
|
1294
1310
|
positions, hidden_states, forward_batch, residual
|
1295
1311
|
)
|
1296
1312
|
if not forward_batch.forward_mode.is_idle():
|
1297
|
-
|
1313
|
+
if residual is None:
|
1314
|
+
hidden_states = self.norm(hidden_states)
|
1315
|
+
else:
|
1316
|
+
hidden_states, _ = self.norm(hidden_states, residual)
|
1298
1317
|
return hidden_states
|
1299
1318
|
|
1300
1319
|
|
@@ -1308,7 +1327,28 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1308
1327
|
) -> None:
|
1309
1328
|
super().__init__()
|
1310
1329
|
self.config = config
|
1330
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
1311
1331
|
self.quant_config = quant_config
|
1332
|
+
self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
|
1333
|
+
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
|
1334
|
+
if (
|
1335
|
+
global_server_args_dict.get("disable_shared_experts_fusion", False)
|
1336
|
+
or self.config.architectures[0] != "DeepseekV3ForCausalLM"
|
1337
|
+
or self.config.n_routed_experts != 256
|
1338
|
+
or self.config.routed_scaling_factor != 2.5
|
1339
|
+
):
|
1340
|
+
self.n_share_experts_fusion = None
|
1341
|
+
global_server_args_dict["n_share_experts_fusion"] = None
|
1342
|
+
logger.info(
|
1343
|
+
"Only Deepseek V3/R1 can use shared experts fusion optimization. Shared experts fusion optimization is disabled."
|
1344
|
+
)
|
1345
|
+
elif self.n_share_experts_fusion is None:
|
1346
|
+
global_server_args_dict["n_share_experts_fusion"] = self.tp_size
|
1347
|
+
self.n_share_experts_fusion = self.tp_size
|
1348
|
+
logger.info(
|
1349
|
+
f"Shared experts fusion optimization is default enabled in DeepSeek V3/R1, and n_share_experts_fusion is set to {self.tp_size}. You can tune it by setting --n_share_experts_fusion or disable it by setting --disable_shared_experts_fusion."
|
1350
|
+
)
|
1351
|
+
|
1312
1352
|
self.model = DeepseekV2Model(
|
1313
1353
|
config, quant_config, prefix=add_prefix("model", prefix)
|
1314
1354
|
)
|
@@ -1321,6 +1361,9 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1321
1361
|
self.logits_processor = LogitsProcessor(config)
|
1322
1362
|
self.dp_size = get_attention_dp_size()
|
1323
1363
|
|
1364
|
+
def get_input_embeddings(self) -> nn.Embedding:
|
1365
|
+
return self.model.embed_tokens
|
1366
|
+
|
1324
1367
|
@torch.no_grad()
|
1325
1368
|
def forward(
|
1326
1369
|
self,
|
@@ -1336,12 +1379,127 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1336
1379
|
input_ids, hidden_states, self.lm_head, forward_batch
|
1337
1380
|
)
|
1338
1381
|
|
1382
|
+
def post_load_weights(self):
|
1383
|
+
|
1384
|
+
# Perform post-processing after loading weights
|
1385
|
+
|
1386
|
+
if not global_server_args_dict["disable_mla"]:
|
1387
|
+
for layer_id in range(self.config.num_hidden_layers):
|
1388
|
+
self_attn = self.model.layers[layer_id].self_attn
|
1389
|
+
if hasattr(self_attn.kv_b_proj, "qweight"):
|
1390
|
+
# AWQ compatible
|
1391
|
+
if _is_cuda:
|
1392
|
+
w = awq_dequantize(
|
1393
|
+
self_attn.kv_b_proj.qweight,
|
1394
|
+
self_attn.kv_b_proj.scales,
|
1395
|
+
self_attn.kv_b_proj.qzeros,
|
1396
|
+
).T
|
1397
|
+
else:
|
1398
|
+
w = ops.awq_dequantize(
|
1399
|
+
self_attn.kv_b_proj.qweight,
|
1400
|
+
self_attn.kv_b_proj.scales,
|
1401
|
+
self_attn.kv_b_proj.qzeros,
|
1402
|
+
0,
|
1403
|
+
0,
|
1404
|
+
0,
|
1405
|
+
).T
|
1406
|
+
else:
|
1407
|
+
w = self_attn.kv_b_proj.weight
|
1408
|
+
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
|
1409
|
+
# This may affect the accuracy of fp8 model.
|
1410
|
+
if hasattr(self.quant_config, "weight_block_size") and w.dtype in (
|
1411
|
+
torch.float8_e4m3fn,
|
1412
|
+
torch.float8_e4m3fnuz,
|
1413
|
+
):
|
1414
|
+
weight_block_size = self.quant_config.weight_block_size
|
1415
|
+
if weight_block_size is not None:
|
1416
|
+
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
1417
|
+
if _is_hip:
|
1418
|
+
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
1419
|
+
weight=w,
|
1420
|
+
weight_scale=self_attn.kv_b_proj.weight_scale_inv,
|
1421
|
+
input_scale=None,
|
1422
|
+
)
|
1423
|
+
else:
|
1424
|
+
weight = w
|
1425
|
+
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
1426
|
+
|
1427
|
+
w, scale = block_quant_to_tensor_quant(
|
1428
|
+
weight, weight_scale, weight_block_size
|
1429
|
+
)
|
1430
|
+
self_attn.w_scale = scale
|
1431
|
+
if w.dtype == torch.int8:
|
1432
|
+
if hasattr(self.quant_config, "weight_block_size"):
|
1433
|
+
# block-wise int8 need it
|
1434
|
+
weight_block_size = self.quant_config.weight_block_size
|
1435
|
+
if weight_block_size is not None:
|
1436
|
+
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
1437
|
+
weight = w
|
1438
|
+
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
1439
|
+
w = int8_block_dequant(
|
1440
|
+
weight, weight_scale, weight_block_size
|
1441
|
+
).to(torch.bfloat16)
|
1442
|
+
else:
|
1443
|
+
# channel-wise int8 need it
|
1444
|
+
w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
|
1445
|
+
torch.bfloat16
|
1446
|
+
)
|
1447
|
+
w_kc, w_vc = w.unflatten(
|
1448
|
+
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
1449
|
+
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
1450
|
+
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
1451
|
+
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
|
1452
|
+
if (
|
1453
|
+
hasattr(self_attn.kv_b_proj, "weight_scale")
|
1454
|
+
and self_attn.w_scale is None
|
1455
|
+
):
|
1456
|
+
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
|
1457
|
+
if _is_hip:
|
1458
|
+
self_attn.w_scale *= 2.0
|
1459
|
+
|
1339
1460
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
1340
1461
|
stacked_params_mapping = [
|
1341
1462
|
# (param_name, shard_name, shard_id)
|
1342
1463
|
("gate_up_proj", "gate_proj", 0),
|
1343
1464
|
("gate_up_proj", "up_proj", 1),
|
1344
1465
|
]
|
1466
|
+
if self.n_share_experts_fusion is not None and self.n_share_experts_fusion > 0:
|
1467
|
+
weights_list = list(weights)
|
1468
|
+
weights_dict = dict(weights_list)
|
1469
|
+
suffix_list = [
|
1470
|
+
"down_proj.weight",
|
1471
|
+
"down_proj.weight_scale_inv",
|
1472
|
+
"gate_proj.weight",
|
1473
|
+
"gate_proj.weight_scale_inv",
|
1474
|
+
"up_proj.weight",
|
1475
|
+
"up_proj.weight_scale_inv",
|
1476
|
+
]
|
1477
|
+
names_to_remove = []
|
1478
|
+
for moe_layer in tqdm(
|
1479
|
+
range(
|
1480
|
+
self.config.first_k_dense_replace,
|
1481
|
+
self.config.num_hidden_layers,
|
1482
|
+
self.config.moe_layer_freq,
|
1483
|
+
),
|
1484
|
+
desc=f"Cloning {self.n_share_experts_fusion} "
|
1485
|
+
"replicas of the shared expert into MoE",
|
1486
|
+
):
|
1487
|
+
for num_repeat in range(self.n_share_experts_fusion):
|
1488
|
+
for suffix in suffix_list:
|
1489
|
+
shared_expert_weight_name = (
|
1490
|
+
f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
|
1491
|
+
)
|
1492
|
+
weights_list.append(
|
1493
|
+
(
|
1494
|
+
f"model.layers.{moe_layer}."
|
1495
|
+
f"mlp.experts."
|
1496
|
+
f"{self.config.n_routed_experts + num_repeat}"
|
1497
|
+
f".{suffix}",
|
1498
|
+
weights_dict[shared_expert_weight_name].clone(),
|
1499
|
+
)
|
1500
|
+
)
|
1501
|
+
names_to_remove += [shared_expert_weight_name]
|
1502
|
+
weights = [w for w in weights_list if w[0] not in names_to_remove]
|
1345
1503
|
|
1346
1504
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
1347
1505
|
# (param_name, weight_name, expert_id, shard_id)
|
@@ -1354,7 +1512,12 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1354
1512
|
ckpt_gate_proj_name="gate_proj",
|
1355
1513
|
ckpt_down_proj_name="down_proj",
|
1356
1514
|
ckpt_up_proj_name="up_proj",
|
1357
|
-
num_experts=self.config.n_routed_experts
|
1515
|
+
num_experts=self.config.n_routed_experts
|
1516
|
+
+ (
|
1517
|
+
self.n_share_experts_fusion
|
1518
|
+
if self.n_share_experts_fusion is not None
|
1519
|
+
else 0
|
1520
|
+
),
|
1358
1521
|
)
|
1359
1522
|
|
1360
1523
|
params_dict = dict(self.named_parameters())
|
@@ -1418,79 +1581,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1418
1581
|
)
|
1419
1582
|
weight_loader(param, loaded_weight)
|
1420
1583
|
|
1421
|
-
|
1422
|
-
for layer_id in range(self.config.num_hidden_layers):
|
1423
|
-
self_attn = self.model.layers[layer_id].self_attn
|
1424
|
-
if hasattr(self_attn.kv_b_proj, "qweight"):
|
1425
|
-
# AWQ compatible
|
1426
|
-
if _is_cuda:
|
1427
|
-
w = awq_dequantize(
|
1428
|
-
self_attn.kv_b_proj.qweight,
|
1429
|
-
self_attn.kv_b_proj.scales,
|
1430
|
-
self_attn.kv_b_proj.qzeros,
|
1431
|
-
).T
|
1432
|
-
else:
|
1433
|
-
w = ops.awq_dequantize(
|
1434
|
-
self_attn.kv_b_proj.qweight,
|
1435
|
-
self_attn.kv_b_proj.scales,
|
1436
|
-
self_attn.kv_b_proj.qzeros,
|
1437
|
-
0,
|
1438
|
-
0,
|
1439
|
-
0,
|
1440
|
-
).T
|
1441
|
-
else:
|
1442
|
-
w = self_attn.kv_b_proj.weight
|
1443
|
-
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
|
1444
|
-
# This may affect the accuracy of fp8 model.
|
1445
|
-
if hasattr(self.quant_config, "weight_block_size") and w.dtype in (
|
1446
|
-
torch.float8_e4m3fn,
|
1447
|
-
torch.float8_e4m3fnuz,
|
1448
|
-
):
|
1449
|
-
weight_block_size = self.quant_config.weight_block_size
|
1450
|
-
if weight_block_size is not None:
|
1451
|
-
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
1452
|
-
if _is_hip:
|
1453
|
-
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
1454
|
-
weight=w,
|
1455
|
-
weight_scale=self_attn.kv_b_proj.weight_scale_inv,
|
1456
|
-
input_scale=None,
|
1457
|
-
)
|
1458
|
-
else:
|
1459
|
-
weight = w
|
1460
|
-
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
1461
|
-
|
1462
|
-
w, scale = block_quant_to_tensor_quant(
|
1463
|
-
weight, weight_scale, weight_block_size
|
1464
|
-
)
|
1465
|
-
self_attn.w_scale = scale
|
1466
|
-
if w.dtype == torch.int8:
|
1467
|
-
if hasattr(self.quant_config, "weight_block_size"):
|
1468
|
-
# block-wise int8 need it
|
1469
|
-
weight_block_size = self.quant_config.weight_block_size
|
1470
|
-
if weight_block_size is not None:
|
1471
|
-
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
1472
|
-
weight = w
|
1473
|
-
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
1474
|
-
w = int8_block_dequant(
|
1475
|
-
weight, weight_scale, weight_block_size
|
1476
|
-
).to(torch.bfloat16)
|
1477
|
-
else:
|
1478
|
-
# channel-wise int8 need it
|
1479
|
-
w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
|
1480
|
-
torch.bfloat16
|
1481
|
-
)
|
1482
|
-
w_kc, w_vc = w.unflatten(
|
1483
|
-
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
1484
|
-
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
1485
|
-
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
1486
|
-
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
|
1487
|
-
if (
|
1488
|
-
hasattr(self_attn.kv_b_proj, "weight_scale")
|
1489
|
-
and self_attn.w_scale is None
|
1490
|
-
):
|
1491
|
-
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
|
1492
|
-
if _is_hip:
|
1493
|
-
self_attn.w_scale *= 2.0
|
1584
|
+
self.post_load_weights()
|
1494
1585
|
|
1495
1586
|
def get_embed_and_head(self):
|
1496
1587
|
return self.model.embed_tokens.weight, self.lm_head.weight
|