sglang 0.4.4.post3__py3-none-any.whl → 0.4.4.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/bench_serving.py +49 -7
  2. sglang/srt/_custom_ops.py +59 -92
  3. sglang/srt/configs/model_config.py +1 -0
  4. sglang/srt/constrained/base_grammar_backend.py +5 -1
  5. sglang/srt/custom_op.py +5 -0
  6. sglang/srt/distributed/device_communicators/custom_all_reduce.py +27 -79
  7. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  8. sglang/srt/entrypoints/engine.py +0 -5
  9. sglang/srt/layers/attention/flashattention_backend.py +394 -76
  10. sglang/srt/layers/attention/flashinfer_backend.py +5 -7
  11. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
  12. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  13. sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
  14. sglang/srt/layers/moe/ep_moe/layer.py +79 -80
  15. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
  16. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  19. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +403 -47
  20. sglang/srt/layers/moe/topk.py +49 -3
  21. sglang/srt/layers/quantization/__init__.py +4 -1
  22. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
  23. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
  24. sglang/srt/layers/quantization/fp8_utils.py +1 -4
  25. sglang/srt/layers/quantization/moe_wna16.py +501 -0
  26. sglang/srt/layers/quantization/utils.py +1 -1
  27. sglang/srt/layers/rotary_embedding.py +0 -12
  28. sglang/srt/managers/cache_controller.py +34 -11
  29. sglang/srt/managers/mm_utils.py +202 -156
  30. sglang/srt/managers/multimodal_processor.py +0 -2
  31. sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
  32. sglang/srt/managers/multimodal_processors/clip.py +7 -26
  33. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
  34. sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
  35. sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
  36. sglang/srt/managers/multimodal_processors/llava.py +34 -14
  37. sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
  38. sglang/srt/managers/multimodal_processors/mlama.py +10 -23
  39. sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
  40. sglang/srt/managers/schedule_batch.py +185 -128
  41. sglang/srt/managers/scheduler.py +4 -4
  42. sglang/srt/managers/tokenizer_manager.py +1 -1
  43. sglang/srt/managers/utils.py +1 -6
  44. sglang/srt/mem_cache/hiradix_cache.py +62 -52
  45. sglang/srt/mem_cache/memory_pool.py +72 -6
  46. sglang/srt/mem_cache/paged_allocator.py +39 -0
  47. sglang/srt/metrics/collector.py +23 -53
  48. sglang/srt/model_executor/cuda_graph_runner.py +8 -6
  49. sglang/srt/model_executor/forward_batch_info.py +10 -10
  50. sglang/srt/model_executor/model_runner.py +59 -57
  51. sglang/srt/model_loader/loader.py +8 -0
  52. sglang/srt/models/clip.py +12 -7
  53. sglang/srt/models/deepseek_janus_pro.py +10 -15
  54. sglang/srt/models/deepseek_v2.py +212 -121
  55. sglang/srt/models/deepseek_vl2.py +105 -104
  56. sglang/srt/models/gemma3_mm.py +14 -80
  57. sglang/srt/models/llama.py +4 -1
  58. sglang/srt/models/llava.py +31 -19
  59. sglang/srt/models/llavavid.py +16 -7
  60. sglang/srt/models/minicpmo.py +63 -147
  61. sglang/srt/models/minicpmv.py +17 -27
  62. sglang/srt/models/mllama.py +29 -14
  63. sglang/srt/models/qwen2.py +9 -6
  64. sglang/srt/models/qwen2_5_vl.py +21 -31
  65. sglang/srt/models/qwen2_vl.py +20 -21
  66. sglang/srt/openai_api/adapter.py +18 -6
  67. sglang/srt/platforms/interface.py +371 -0
  68. sglang/srt/server_args.py +99 -14
  69. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
  70. sglang/srt/speculative/eagle_utils.py +140 -28
  71. sglang/srt/speculative/eagle_worker.py +93 -24
  72. sglang/srt/utils.py +104 -51
  73. sglang/test/test_custom_ops.py +55 -0
  74. sglang/test/test_utils.py +13 -26
  75. sglang/utils.py +2 -2
  76. sglang/version.py +1 -1
  77. {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/METADATA +4 -3
  78. {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/RECORD +81 -76
  79. {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/WHEEL +0 -0
  80. {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/licenses/LICENSE +0 -0
  81. {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/top_level.txt +0 -0
@@ -16,12 +16,14 @@
16
16
  # https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
17
17
  """Inference-only DeepseekV2 model."""
18
18
 
19
+ import logging
19
20
  import os
20
21
  from typing import Any, Dict, Iterable, Optional, Tuple
21
22
 
22
23
  import torch
23
24
  import torch.nn.functional as F
24
25
  from torch import nn
26
+ from tqdm import tqdm
25
27
  from transformers import PretrainedConfig
26
28
 
27
29
  from sglang.srt.distributed import (
@@ -30,9 +32,6 @@ from sglang.srt.distributed import (
30
32
  tensor_model_parallel_all_reduce,
31
33
  )
32
34
  from sglang.srt.layers.activation import SiluAndMul
33
- from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
34
- decode_attention_fwd_grouped_rope,
35
- )
36
35
  from sglang.srt.layers.dp_attention import (
37
36
  dp_gather_partial,
38
37
  dp_scatter,
@@ -73,7 +72,7 @@ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
73
72
  from sglang.srt.managers.schedule_batch import global_server_args_dict
74
73
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
75
74
  from sglang.srt.model_loader.weight_utils import default_weight_loader
76
- from sglang.srt.utils import add_prefix, is_cuda, is_hip
75
+ from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_hip
77
76
 
78
77
  _is_hip = is_hip()
79
78
  _is_cuda = is_cuda()
@@ -83,8 +82,15 @@ if _is_cuda:
83
82
  else:
84
83
  from vllm import _custom_ops as ops
85
84
 
85
+ if _is_hip:
86
+ from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
87
+ decode_attention_fwd_grouped_rope,
88
+ )
89
+
86
90
  expert_distribution_recorder = ExpertDistributionRecorder()
87
91
 
92
+ logger = logging.getLogger(__name__)
93
+
88
94
 
89
95
  class DeepseekV2MLP(nn.Module):
90
96
  def __init__(
@@ -166,6 +172,12 @@ class DeepseekV2MoE(nn.Module):
166
172
  self.tp_size = get_tensor_model_parallel_world_size()
167
173
  self.routed_scaling_factor = config.routed_scaling_factor
168
174
  self.n_shared_experts = config.n_shared_experts
175
+ self.n_share_experts_fusion = (
176
+ global_server_args_dict["n_share_experts_fusion"]
177
+ if global_server_args_dict["n_share_experts_fusion"] is not None
178
+ else 0
179
+ )
180
+
169
181
  self.routed_scaling_factor = config.routed_scaling_factor
170
182
  if self.tp_size > config.n_routed_experts:
171
183
  raise ValueError(
@@ -186,9 +198,10 @@ class DeepseekV2MoE(nn.Module):
186
198
  if global_server_args_dict["enable_deepep_moe"]
187
199
  else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
188
200
  )
201
+
189
202
  self.experts = MoEImpl(
190
- num_experts=config.n_routed_experts,
191
- top_k=config.num_experts_per_tok,
203
+ num_experts=config.n_routed_experts + self.n_share_experts_fusion,
204
+ top_k=config.num_experts_per_tok + min(self.n_share_experts_fusion, 1),
192
205
  hidden_size=config.hidden_size,
193
206
  intermediate_size=config.moe_intermediate_size,
194
207
  renormalize=config.norm_topk_prob,
@@ -198,9 +211,14 @@ class DeepseekV2MoE(nn.Module):
198
211
  topk_group=config.topk_group,
199
212
  correction_bias=self.gate.e_score_correction_bias,
200
213
  prefix=add_prefix("experts", prefix),
214
+ **(
215
+ dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
216
+ if global_server_args_dict["enable_deepep_moe"]
217
+ else {}
218
+ ),
201
219
  )
202
220
 
203
- if config.n_shared_experts is not None:
221
+ if config.n_shared_experts is not None and self.n_share_experts_fusion == 0:
204
222
  intermediate_size = config.moe_intermediate_size * config.n_shared_experts
205
223
  # disable tp for shared experts when enable deepep moe
206
224
  if not global_server_args_dict["enable_deepep_moe"]:
@@ -225,6 +243,8 @@ class DeepseekV2MoE(nn.Module):
225
243
  )
226
244
 
227
245
  if global_server_args_dict["enable_deepep_moe"]:
246
+ # TODO: we will support tp < ep in the future
247
+ self.ep_size = get_tensor_model_parallel_world_size()
228
248
  self.num_experts = config.n_routed_experts
229
249
  self.top_k = config.num_experts_per_tok
230
250
  self.renormalize = config.norm_topk_prob
@@ -244,7 +264,9 @@ class DeepseekV2MoE(nn.Module):
244
264
  num_local_experts=config.n_routed_experts // self.tp_size,
245
265
  hidden_size=config.hidden_size,
246
266
  params_dtype=config.torch_dtype,
267
+ deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
247
268
  async_finish=True, # TODO
269
+ return_recv_hook=True,
248
270
  )
249
271
 
250
272
  def forward(
@@ -256,8 +278,10 @@ class DeepseekV2MoE(nn.Module):
256
278
  return self.forward_deepep(hidden_states, forward_mode)
257
279
 
258
280
  def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
259
- if self.n_shared_experts is not None:
281
+ if self.n_shared_experts is not None and self.n_share_experts_fusion == 0:
260
282
  shared_output = self.shared_experts(hidden_states)
283
+ else:
284
+ shared_output = None
261
285
  # router_logits: (num_tokens, n_experts)
262
286
  router_logits = self.gate(hidden_states)
263
287
  final_hidden_states = (
@@ -299,28 +323,39 @@ class DeepseekV2MoE(nn.Module):
299
323
  num_expert_group=self.num_expert_group,
300
324
  correction_bias=self.correction_bias,
301
325
  )
302
- if self.tp_size > 1:
303
- recv_hidden_states, reorder_topk_ids, seg_indptr = (
304
- self.deepep_dispatcher.dispatch(
305
- hidden_states,
306
- topk_idx,
307
- topk_weights,
308
- self.num_experts,
309
- forward_mode,
310
- )
326
+ if self.ep_size > 1:
327
+ (
328
+ hidden_states,
329
+ topk_idx,
330
+ topk_weights,
331
+ reorder_topk_ids,
332
+ seg_indptr,
333
+ masked_m,
334
+ expected_m,
335
+ ) = self.deepep_dispatcher.dispatch(
336
+ hidden_states,
337
+ topk_idx,
338
+ topk_weights,
339
+ self.num_experts,
340
+ forward_mode=forward_mode,
311
341
  )
312
342
  final_hidden_states = (
313
343
  self.experts(
314
- hidden_states=recv_hidden_states,
344
+ hidden_states=hidden_states,
315
345
  reorder_topk_ids=reorder_topk_ids,
316
346
  seg_indptr=seg_indptr,
347
+ masked_m=masked_m,
348
+ expected_m=expected_m,
317
349
  forward_mode=forward_mode,
318
350
  )
319
351
  * self.routed_scaling_factor
320
352
  )
321
- if self.tp_size > 1:
353
+ if self.ep_size > 1:
322
354
  final_hidden_states = self.deepep_dispatcher.combine(
323
- final_hidden_states, forward_mode
355
+ final_hidden_states,
356
+ topk_idx,
357
+ topk_weights,
358
+ forward_mode,
324
359
  )
325
360
  if shared_output is not None:
326
361
  final_hidden_states = final_hidden_states + shared_output
@@ -651,7 +686,6 @@ class DeepseekV2AttentionMLA(nn.Module):
651
686
  self.w_vc = None
652
687
  self.w_scale = None
653
688
 
654
- self.enable_flashinfer_mla = global_server_args_dict["enable_flashinfer_mla"]
655
689
  self.flashinfer_mla_disable_ragged = global_server_args_dict[
656
690
  "flashinfer_mla_disable_ragged"
657
691
  ]
@@ -659,7 +693,7 @@ class DeepseekV2AttentionMLA(nn.Module):
659
693
  self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
660
694
 
661
695
  def no_absorb(self, forward_batch: ForwardBatch) -> bool:
662
- if self.enable_flashinfer_mla:
696
+ if self.attention_backend == "flashinfer":
663
697
  # Flashinfer MLA: Do not absorb when enabling ragged prefill
664
698
  return (
665
699
  not self.flashinfer_mla_disable_ragged
@@ -1100,6 +1134,10 @@ class DeepseekV2DecoderLayer(nn.Module):
1100
1134
  else:
1101
1135
  hidden_states, residual = self.input_layernorm(hidden_states, residual)
1102
1136
 
1137
+ assert not (
1138
+ self.attn_tp_size != 1 and self.input_is_scattered
1139
+ ), "moe_layer_freq > 1 is not supported when attn_tp_size > 1"
1140
+
1103
1141
  # Self Attention
1104
1142
  hidden_states = self.self_attn(
1105
1143
  positions=positions,
@@ -1107,22 +1145,6 @@ class DeepseekV2DecoderLayer(nn.Module):
1107
1145
  forward_batch=forward_batch,
1108
1146
  )
1109
1147
 
1110
- if self.attn_tp_size != 1 and self.input_is_scattered:
1111
- hidden_states, local_hidden_states = (
1112
- forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
1113
- hidden_states,
1114
- )
1115
- tp_all_gather(
1116
- list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
1117
- )
1118
- residual, local_residual = (
1119
- forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
1120
- residual,
1121
- )
1122
- tp_all_gather(
1123
- list(residual.tensor_split(self.attn_tp_size)), local_residual
1124
- )
1125
-
1126
1148
  # Gather
1127
1149
  if get_tensor_model_parallel_world_size() > 1:
1128
1150
  # all gather and all reduce
@@ -1221,6 +1243,8 @@ class DeepseekV2DecoderLayer(nn.Module):
1221
1243
  hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
1222
1244
 
1223
1245
  if self.is_last_layer and self.attn_tp_size != 1:
1246
+ hidden_states += residual
1247
+ residual = None
1224
1248
  hidden_states, local_hidden_states = (
1225
1249
  forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
1226
1250
  hidden_states,
@@ -1228,19 +1252,11 @@ class DeepseekV2DecoderLayer(nn.Module):
1228
1252
  tp_all_gather(
1229
1253
  list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
1230
1254
  )
1231
- residual, local_residual = (
1232
- forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
1233
- residual,
1234
- )
1235
- tp_all_gather(
1236
- list(residual.tensor_split(self.attn_tp_size)), local_residual
1237
- )
1238
1255
 
1239
1256
  return hidden_states, residual
1240
1257
 
1241
1258
 
1242
1259
  class DeepseekV2Model(nn.Module):
1243
-
1244
1260
  fall_back_to_pt_during_load = False
1245
1261
 
1246
1262
  def __init__(
@@ -1294,7 +1310,10 @@ class DeepseekV2Model(nn.Module):
1294
1310
  positions, hidden_states, forward_batch, residual
1295
1311
  )
1296
1312
  if not forward_batch.forward_mode.is_idle():
1297
- hidden_states, _ = self.norm(hidden_states, residual)
1313
+ if residual is None:
1314
+ hidden_states = self.norm(hidden_states)
1315
+ else:
1316
+ hidden_states, _ = self.norm(hidden_states, residual)
1298
1317
  return hidden_states
1299
1318
 
1300
1319
 
@@ -1308,7 +1327,28 @@ class DeepseekV2ForCausalLM(nn.Module):
1308
1327
  ) -> None:
1309
1328
  super().__init__()
1310
1329
  self.config = config
1330
+ self.tp_size = get_tensor_model_parallel_world_size()
1311
1331
  self.quant_config = quant_config
1332
+ self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
1333
+ # Only Deepseek V3/R1 can use shared experts fusion optimization now.
1334
+ if (
1335
+ global_server_args_dict.get("disable_shared_experts_fusion", False)
1336
+ or self.config.architectures[0] != "DeepseekV3ForCausalLM"
1337
+ or self.config.n_routed_experts != 256
1338
+ or self.config.routed_scaling_factor != 2.5
1339
+ ):
1340
+ self.n_share_experts_fusion = None
1341
+ global_server_args_dict["n_share_experts_fusion"] = None
1342
+ logger.info(
1343
+ "Only Deepseek V3/R1 can use shared experts fusion optimization. Shared experts fusion optimization is disabled."
1344
+ )
1345
+ elif self.n_share_experts_fusion is None:
1346
+ global_server_args_dict["n_share_experts_fusion"] = self.tp_size
1347
+ self.n_share_experts_fusion = self.tp_size
1348
+ logger.info(
1349
+ f"Shared experts fusion optimization is default enabled in DeepSeek V3/R1, and n_share_experts_fusion is set to {self.tp_size}. You can tune it by setting --n_share_experts_fusion or disable it by setting --disable_shared_experts_fusion."
1350
+ )
1351
+
1312
1352
  self.model = DeepseekV2Model(
1313
1353
  config, quant_config, prefix=add_prefix("model", prefix)
1314
1354
  )
@@ -1321,6 +1361,9 @@ class DeepseekV2ForCausalLM(nn.Module):
1321
1361
  self.logits_processor = LogitsProcessor(config)
1322
1362
  self.dp_size = get_attention_dp_size()
1323
1363
 
1364
+ def get_input_embeddings(self) -> nn.Embedding:
1365
+ return self.model.embed_tokens
1366
+
1324
1367
  @torch.no_grad()
1325
1368
  def forward(
1326
1369
  self,
@@ -1336,12 +1379,127 @@ class DeepseekV2ForCausalLM(nn.Module):
1336
1379
  input_ids, hidden_states, self.lm_head, forward_batch
1337
1380
  )
1338
1381
 
1382
+ def post_load_weights(self):
1383
+
1384
+ # Perform post-processing after loading weights
1385
+
1386
+ if not global_server_args_dict["disable_mla"]:
1387
+ for layer_id in range(self.config.num_hidden_layers):
1388
+ self_attn = self.model.layers[layer_id].self_attn
1389
+ if hasattr(self_attn.kv_b_proj, "qweight"):
1390
+ # AWQ compatible
1391
+ if _is_cuda:
1392
+ w = awq_dequantize(
1393
+ self_attn.kv_b_proj.qweight,
1394
+ self_attn.kv_b_proj.scales,
1395
+ self_attn.kv_b_proj.qzeros,
1396
+ ).T
1397
+ else:
1398
+ w = ops.awq_dequantize(
1399
+ self_attn.kv_b_proj.qweight,
1400
+ self_attn.kv_b_proj.scales,
1401
+ self_attn.kv_b_proj.qzeros,
1402
+ 0,
1403
+ 0,
1404
+ 0,
1405
+ ).T
1406
+ else:
1407
+ w = self_attn.kv_b_proj.weight
1408
+ # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
1409
+ # This may affect the accuracy of fp8 model.
1410
+ if hasattr(self.quant_config, "weight_block_size") and w.dtype in (
1411
+ torch.float8_e4m3fn,
1412
+ torch.float8_e4m3fnuz,
1413
+ ):
1414
+ weight_block_size = self.quant_config.weight_block_size
1415
+ if weight_block_size is not None:
1416
+ assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
1417
+ if _is_hip:
1418
+ weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
1419
+ weight=w,
1420
+ weight_scale=self_attn.kv_b_proj.weight_scale_inv,
1421
+ input_scale=None,
1422
+ )
1423
+ else:
1424
+ weight = w
1425
+ weight_scale = self_attn.kv_b_proj.weight_scale_inv
1426
+
1427
+ w, scale = block_quant_to_tensor_quant(
1428
+ weight, weight_scale, weight_block_size
1429
+ )
1430
+ self_attn.w_scale = scale
1431
+ if w.dtype == torch.int8:
1432
+ if hasattr(self.quant_config, "weight_block_size"):
1433
+ # block-wise int8 need it
1434
+ weight_block_size = self.quant_config.weight_block_size
1435
+ if weight_block_size is not None:
1436
+ assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
1437
+ weight = w
1438
+ weight_scale = self_attn.kv_b_proj.weight_scale_inv
1439
+ w = int8_block_dequant(
1440
+ weight, weight_scale, weight_block_size
1441
+ ).to(torch.bfloat16)
1442
+ else:
1443
+ # channel-wise int8 need it
1444
+ w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
1445
+ torch.bfloat16
1446
+ )
1447
+ w_kc, w_vc = w.unflatten(
1448
+ 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
1449
+ ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
1450
+ self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
1451
+ self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
1452
+ if (
1453
+ hasattr(self_attn.kv_b_proj, "weight_scale")
1454
+ and self_attn.w_scale is None
1455
+ ):
1456
+ self_attn.w_scale = self_attn.kv_b_proj.weight_scale
1457
+ if _is_hip:
1458
+ self_attn.w_scale *= 2.0
1459
+
1339
1460
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
1340
1461
  stacked_params_mapping = [
1341
1462
  # (param_name, shard_name, shard_id)
1342
1463
  ("gate_up_proj", "gate_proj", 0),
1343
1464
  ("gate_up_proj", "up_proj", 1),
1344
1465
  ]
1466
+ if self.n_share_experts_fusion is not None and self.n_share_experts_fusion > 0:
1467
+ weights_list = list(weights)
1468
+ weights_dict = dict(weights_list)
1469
+ suffix_list = [
1470
+ "down_proj.weight",
1471
+ "down_proj.weight_scale_inv",
1472
+ "gate_proj.weight",
1473
+ "gate_proj.weight_scale_inv",
1474
+ "up_proj.weight",
1475
+ "up_proj.weight_scale_inv",
1476
+ ]
1477
+ names_to_remove = []
1478
+ for moe_layer in tqdm(
1479
+ range(
1480
+ self.config.first_k_dense_replace,
1481
+ self.config.num_hidden_layers,
1482
+ self.config.moe_layer_freq,
1483
+ ),
1484
+ desc=f"Cloning {self.n_share_experts_fusion} "
1485
+ "replicas of the shared expert into MoE",
1486
+ ):
1487
+ for num_repeat in range(self.n_share_experts_fusion):
1488
+ for suffix in suffix_list:
1489
+ shared_expert_weight_name = (
1490
+ f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
1491
+ )
1492
+ weights_list.append(
1493
+ (
1494
+ f"model.layers.{moe_layer}."
1495
+ f"mlp.experts."
1496
+ f"{self.config.n_routed_experts + num_repeat}"
1497
+ f".{suffix}",
1498
+ weights_dict[shared_expert_weight_name].clone(),
1499
+ )
1500
+ )
1501
+ names_to_remove += [shared_expert_weight_name]
1502
+ weights = [w for w in weights_list if w[0] not in names_to_remove]
1345
1503
 
1346
1504
  # Params for weights, fp8 weight scales, fp8 activation scales
1347
1505
  # (param_name, weight_name, expert_id, shard_id)
@@ -1354,7 +1512,12 @@ class DeepseekV2ForCausalLM(nn.Module):
1354
1512
  ckpt_gate_proj_name="gate_proj",
1355
1513
  ckpt_down_proj_name="down_proj",
1356
1514
  ckpt_up_proj_name="up_proj",
1357
- num_experts=self.config.n_routed_experts,
1515
+ num_experts=self.config.n_routed_experts
1516
+ + (
1517
+ self.n_share_experts_fusion
1518
+ if self.n_share_experts_fusion is not None
1519
+ else 0
1520
+ ),
1358
1521
  )
1359
1522
 
1360
1523
  params_dict = dict(self.named_parameters())
@@ -1418,79 +1581,7 @@ class DeepseekV2ForCausalLM(nn.Module):
1418
1581
  )
1419
1582
  weight_loader(param, loaded_weight)
1420
1583
 
1421
- if not global_server_args_dict["disable_mla"]:
1422
- for layer_id in range(self.config.num_hidden_layers):
1423
- self_attn = self.model.layers[layer_id].self_attn
1424
- if hasattr(self_attn.kv_b_proj, "qweight"):
1425
- # AWQ compatible
1426
- if _is_cuda:
1427
- w = awq_dequantize(
1428
- self_attn.kv_b_proj.qweight,
1429
- self_attn.kv_b_proj.scales,
1430
- self_attn.kv_b_proj.qzeros,
1431
- ).T
1432
- else:
1433
- w = ops.awq_dequantize(
1434
- self_attn.kv_b_proj.qweight,
1435
- self_attn.kv_b_proj.scales,
1436
- self_attn.kv_b_proj.qzeros,
1437
- 0,
1438
- 0,
1439
- 0,
1440
- ).T
1441
- else:
1442
- w = self_attn.kv_b_proj.weight
1443
- # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
1444
- # This may affect the accuracy of fp8 model.
1445
- if hasattr(self.quant_config, "weight_block_size") and w.dtype in (
1446
- torch.float8_e4m3fn,
1447
- torch.float8_e4m3fnuz,
1448
- ):
1449
- weight_block_size = self.quant_config.weight_block_size
1450
- if weight_block_size is not None:
1451
- assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
1452
- if _is_hip:
1453
- weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
1454
- weight=w,
1455
- weight_scale=self_attn.kv_b_proj.weight_scale_inv,
1456
- input_scale=None,
1457
- )
1458
- else:
1459
- weight = w
1460
- weight_scale = self_attn.kv_b_proj.weight_scale_inv
1461
-
1462
- w, scale = block_quant_to_tensor_quant(
1463
- weight, weight_scale, weight_block_size
1464
- )
1465
- self_attn.w_scale = scale
1466
- if w.dtype == torch.int8:
1467
- if hasattr(self.quant_config, "weight_block_size"):
1468
- # block-wise int8 need it
1469
- weight_block_size = self.quant_config.weight_block_size
1470
- if weight_block_size is not None:
1471
- assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
1472
- weight = w
1473
- weight_scale = self_attn.kv_b_proj.weight_scale_inv
1474
- w = int8_block_dequant(
1475
- weight, weight_scale, weight_block_size
1476
- ).to(torch.bfloat16)
1477
- else:
1478
- # channel-wise int8 need it
1479
- w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
1480
- torch.bfloat16
1481
- )
1482
- w_kc, w_vc = w.unflatten(
1483
- 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
1484
- ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
1485
- self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
1486
- self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
1487
- if (
1488
- hasattr(self_attn.kv_b_proj, "weight_scale")
1489
- and self_attn.w_scale is None
1490
- ):
1491
- self_attn.w_scale = self_attn.kv_b_proj.weight_scale
1492
- if _is_hip:
1493
- self_attn.w_scale *= 2.0
1584
+ self.post_load_weights()
1494
1585
 
1495
1586
  def get_embed_and_head(self):
1496
1587
  return self.model.embed_tokens.weight, self.lm_head.weight