sglang 0.4.4.post2__py3-none-any.whl → 0.4.4.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. sglang/bench_serving.py +72 -10
  2. sglang/srt/_custom_ops.py +59 -92
  3. sglang/srt/configs/deepseekvl2.py +10 -1
  4. sglang/srt/configs/model_config.py +6 -16
  5. sglang/srt/constrained/base_grammar_backend.py +5 -1
  6. sglang/srt/custom_op.py +5 -0
  7. sglang/srt/distributed/device_communicators/custom_all_reduce.py +28 -80
  8. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  9. sglang/srt/distributed/parallel_state.py +32 -5
  10. sglang/srt/entrypoints/engine.py +0 -5
  11. sglang/srt/entrypoints/http_server.py +7 -1
  12. sglang/srt/entrypoints/verl_engine.py +2 -0
  13. sglang/srt/function_call_parser.py +0 -1
  14. sglang/srt/layers/attention/flashattention_backend.py +582 -125
  15. sglang/srt/layers/attention/flashinfer_backend.py +5 -7
  16. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
  17. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  18. sglang/srt/layers/dp_attention.py +12 -1
  19. sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
  20. sglang/srt/layers/moe/ep_moe/layer.py +79 -80
  21. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +403 -47
  26. sglang/srt/layers/moe/topk.py +79 -6
  27. sglang/srt/layers/quantization/__init__.py +137 -165
  28. sglang/srt/layers/quantization/awq.py +200 -0
  29. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
  30. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
  31. sglang/srt/layers/quantization/fp8_kernel.py +2 -1
  32. sglang/srt/layers/quantization/fp8_utils.py +1 -4
  33. sglang/srt/layers/quantization/gptq.py +30 -40
  34. sglang/srt/layers/quantization/moe_wna16.py +501 -0
  35. sglang/srt/layers/quantization/utils.py +1 -1
  36. sglang/srt/layers/quantization/w8a8_fp8.py +1 -1
  37. sglang/srt/lora/backend/base_backend.py +4 -4
  38. sglang/srt/lora/backend/flashinfer_backend.py +12 -9
  39. sglang/srt/lora/backend/triton_backend.py +5 -8
  40. sglang/srt/lora/layers.py +19 -33
  41. sglang/srt/lora/lora_manager.py +20 -7
  42. sglang/srt/lora/mem_pool.py +12 -6
  43. sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
  44. sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
  45. sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
  46. sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
  47. sglang/srt/lora/utils.py +6 -0
  48. sglang/srt/managers/cache_controller.py +34 -11
  49. sglang/srt/managers/io_struct.py +4 -2
  50. sglang/srt/managers/mm_utils.py +202 -156
  51. sglang/srt/managers/multimodal_processor.py +0 -2
  52. sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
  53. sglang/srt/managers/multimodal_processors/clip.py +44 -0
  54. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
  55. sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
  56. sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
  57. sglang/srt/managers/multimodal_processors/llava.py +34 -14
  58. sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
  59. sglang/srt/managers/multimodal_processors/mlama.py +10 -23
  60. sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
  61. sglang/srt/managers/schedule_batch.py +185 -127
  62. sglang/srt/managers/scheduler.py +29 -23
  63. sglang/srt/managers/tokenizer_manager.py +1 -2
  64. sglang/srt/managers/tp_worker.py +3 -0
  65. sglang/srt/managers/utils.py +1 -6
  66. sglang/srt/mem_cache/hiradix_cache.py +62 -52
  67. sglang/srt/mem_cache/memory_pool.py +72 -6
  68. sglang/srt/mem_cache/paged_allocator.py +39 -0
  69. sglang/srt/metrics/collector.py +23 -53
  70. sglang/srt/model_executor/cuda_graph_runner.py +16 -13
  71. sglang/srt/model_executor/forward_batch_info.py +10 -10
  72. sglang/srt/model_executor/model_runner.py +64 -59
  73. sglang/srt/model_loader/loader.py +19 -1
  74. sglang/srt/model_loader/weight_utils.py +6 -3
  75. sglang/srt/models/clip.py +568 -0
  76. sglang/srt/models/deepseek_janus_pro.py +12 -17
  77. sglang/srt/models/deepseek_v2.py +339 -123
  78. sglang/srt/models/deepseek_vl2.py +105 -104
  79. sglang/srt/models/gemma3_causal.py +12 -2
  80. sglang/srt/models/gemma3_mm.py +20 -80
  81. sglang/srt/models/llama.py +4 -1
  82. sglang/srt/models/llava.py +31 -19
  83. sglang/srt/models/llavavid.py +16 -7
  84. sglang/srt/models/minicpmo.py +63 -147
  85. sglang/srt/models/minicpmv.py +17 -27
  86. sglang/srt/models/mllama.py +29 -14
  87. sglang/srt/models/qwen2.py +9 -6
  88. sglang/srt/models/qwen2_5_vl.py +21 -31
  89. sglang/srt/models/qwen2_vl.py +20 -21
  90. sglang/srt/openai_api/adapter.py +106 -93
  91. sglang/srt/openai_api/protocol.py +10 -5
  92. sglang/srt/patch_torch.py +71 -0
  93. sglang/srt/platforms/interface.py +371 -0
  94. sglang/srt/server_args.py +120 -25
  95. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
  96. sglang/srt/speculative/eagle_utils.py +140 -28
  97. sglang/srt/speculative/eagle_worker.py +94 -25
  98. sglang/srt/utils.py +137 -51
  99. sglang/test/runners.py +27 -2
  100. sglang/test/test_custom_ops.py +55 -0
  101. sglang/test/test_utils.py +14 -27
  102. sglang/utils.py +2 -2
  103. sglang/version.py +1 -1
  104. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/METADATA +10 -5
  105. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/RECORD +108 -99
  106. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/WHEEL +0 -0
  107. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/licenses/LICENSE +0 -0
  108. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/top_level.txt +0 -0
@@ -16,12 +16,14 @@
16
16
  # https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
17
17
  """Inference-only DeepseekV2 model."""
18
18
 
19
+ import logging
19
20
  import os
20
21
  from typing import Any, Dict, Iterable, Optional, Tuple
21
22
 
22
23
  import torch
23
24
  import torch.nn.functional as F
24
25
  from torch import nn
26
+ from tqdm import tqdm
25
27
  from transformers import PretrainedConfig
26
28
 
27
29
  from sglang.srt.distributed import (
@@ -30,15 +32,14 @@ from sglang.srt.distributed import (
30
32
  tensor_model_parallel_all_reduce,
31
33
  )
32
34
  from sglang.srt.layers.activation import SiluAndMul
33
- from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
34
- decode_attention_fwd_grouped_rope,
35
- )
36
35
  from sglang.srt.layers.dp_attention import (
37
36
  dp_gather_partial,
38
37
  dp_scatter,
39
38
  get_attention_dp_size,
40
39
  get_attention_tp_rank,
41
40
  get_attention_tp_size,
41
+ tp_all_gather,
42
+ tp_reduce_scatter,
42
43
  )
43
44
  from sglang.srt.layers.layernorm import RMSNorm
44
45
  from sglang.srt.layers.linear import (
@@ -71,7 +72,7 @@ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
71
72
  from sglang.srt.managers.schedule_batch import global_server_args_dict
72
73
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
73
74
  from sglang.srt.model_loader.weight_utils import default_weight_loader
74
- from sglang.srt.utils import add_prefix, is_cuda, is_cuda_available, is_hip
75
+ from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_hip
75
76
 
76
77
  _is_hip = is_hip()
77
78
  _is_cuda = is_cuda()
@@ -81,8 +82,15 @@ if _is_cuda:
81
82
  else:
82
83
  from vllm import _custom_ops as ops
83
84
 
85
+ if _is_hip:
86
+ from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
87
+ decode_attention_fwd_grouped_rope,
88
+ )
89
+
84
90
  expert_distribution_recorder = ExpertDistributionRecorder()
85
91
 
92
+ logger = logging.getLogger(__name__)
93
+
86
94
 
87
95
  class DeepseekV2MLP(nn.Module):
88
96
  def __init__(
@@ -164,6 +172,12 @@ class DeepseekV2MoE(nn.Module):
164
172
  self.tp_size = get_tensor_model_parallel_world_size()
165
173
  self.routed_scaling_factor = config.routed_scaling_factor
166
174
  self.n_shared_experts = config.n_shared_experts
175
+ self.n_share_experts_fusion = (
176
+ global_server_args_dict["n_share_experts_fusion"]
177
+ if global_server_args_dict["n_share_experts_fusion"] is not None
178
+ else 0
179
+ )
180
+
167
181
  self.routed_scaling_factor = config.routed_scaling_factor
168
182
  if self.tp_size > config.n_routed_experts:
169
183
  raise ValueError(
@@ -184,9 +198,10 @@ class DeepseekV2MoE(nn.Module):
184
198
  if global_server_args_dict["enable_deepep_moe"]
185
199
  else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
186
200
  )
201
+
187
202
  self.experts = MoEImpl(
188
- num_experts=config.n_routed_experts,
189
- top_k=config.num_experts_per_tok,
203
+ num_experts=config.n_routed_experts + self.n_share_experts_fusion,
204
+ top_k=config.num_experts_per_tok + min(self.n_share_experts_fusion, 1),
190
205
  hidden_size=config.hidden_size,
191
206
  intermediate_size=config.moe_intermediate_size,
192
207
  renormalize=config.norm_topk_prob,
@@ -196,9 +211,14 @@ class DeepseekV2MoE(nn.Module):
196
211
  topk_group=config.topk_group,
197
212
  correction_bias=self.gate.e_score_correction_bias,
198
213
  prefix=add_prefix("experts", prefix),
214
+ **(
215
+ dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
216
+ if global_server_args_dict["enable_deepep_moe"]
217
+ else {}
218
+ ),
199
219
  )
200
220
 
201
- if config.n_shared_experts is not None:
221
+ if config.n_shared_experts is not None and self.n_share_experts_fusion == 0:
202
222
  intermediate_size = config.moe_intermediate_size * config.n_shared_experts
203
223
  # disable tp for shared experts when enable deepep moe
204
224
  if not global_server_args_dict["enable_deepep_moe"]:
@@ -223,6 +243,8 @@ class DeepseekV2MoE(nn.Module):
223
243
  )
224
244
 
225
245
  if global_server_args_dict["enable_deepep_moe"]:
246
+ # TODO: we will support tp < ep in the future
247
+ self.ep_size = get_tensor_model_parallel_world_size()
226
248
  self.num_experts = config.n_routed_experts
227
249
  self.top_k = config.num_experts_per_tok
228
250
  self.renormalize = config.norm_topk_prob
@@ -242,7 +264,9 @@ class DeepseekV2MoE(nn.Module):
242
264
  num_local_experts=config.n_routed_experts // self.tp_size,
243
265
  hidden_size=config.hidden_size,
244
266
  params_dtype=config.torch_dtype,
267
+ deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
245
268
  async_finish=True, # TODO
269
+ return_recv_hook=True,
246
270
  )
247
271
 
248
272
  def forward(
@@ -254,8 +278,10 @@ class DeepseekV2MoE(nn.Module):
254
278
  return self.forward_deepep(hidden_states, forward_mode)
255
279
 
256
280
  def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
257
- if self.n_shared_experts is not None:
281
+ if self.n_shared_experts is not None and self.n_share_experts_fusion == 0:
258
282
  shared_output = self.shared_experts(hidden_states)
283
+ else:
284
+ shared_output = None
259
285
  # router_logits: (num_tokens, n_experts)
260
286
  router_logits = self.gate(hidden_states)
261
287
  final_hidden_states = (
@@ -278,7 +304,11 @@ class DeepseekV2MoE(nn.Module):
278
304
  topk_weights = torch.empty(
279
305
  (0, self.top_k), dtype=torch.float32, device=hidden_states.device
280
306
  )
281
- if forward_mode is not None and not forward_mode.is_idle():
307
+ if (
308
+ forward_mode is not None
309
+ and not forward_mode.is_idle()
310
+ and hidden_states.shape[0] > 0
311
+ ):
282
312
  # router_logits: (num_tokens, n_experts)
283
313
  router_logits = self.gate(hidden_states)
284
314
  if self.n_shared_experts is not None:
@@ -293,28 +323,39 @@ class DeepseekV2MoE(nn.Module):
293
323
  num_expert_group=self.num_expert_group,
294
324
  correction_bias=self.correction_bias,
295
325
  )
296
- if self.tp_size > 1:
297
- recv_hidden_states, reorder_topk_ids, seg_indptr = (
298
- self.deepep_dispatcher.dispatch(
299
- hidden_states,
300
- topk_idx,
301
- topk_weights,
302
- self.num_experts,
303
- forward_mode,
304
- )
326
+ if self.ep_size > 1:
327
+ (
328
+ hidden_states,
329
+ topk_idx,
330
+ topk_weights,
331
+ reorder_topk_ids,
332
+ seg_indptr,
333
+ masked_m,
334
+ expected_m,
335
+ ) = self.deepep_dispatcher.dispatch(
336
+ hidden_states,
337
+ topk_idx,
338
+ topk_weights,
339
+ self.num_experts,
340
+ forward_mode=forward_mode,
305
341
  )
306
342
  final_hidden_states = (
307
343
  self.experts(
308
- hidden_states=recv_hidden_states,
344
+ hidden_states=hidden_states,
309
345
  reorder_topk_ids=reorder_topk_ids,
310
346
  seg_indptr=seg_indptr,
347
+ masked_m=masked_m,
348
+ expected_m=expected_m,
311
349
  forward_mode=forward_mode,
312
350
  )
313
351
  * self.routed_scaling_factor
314
352
  )
315
- if self.tp_size > 1:
353
+ if self.ep_size > 1:
316
354
  final_hidden_states = self.deepep_dispatcher.combine(
317
- final_hidden_states, forward_mode
355
+ final_hidden_states,
356
+ topk_idx,
357
+ topk_weights,
358
+ forward_mode,
318
359
  )
319
360
  if shared_output is not None:
320
361
  final_hidden_states = final_hidden_states + shared_output
@@ -645,14 +686,14 @@ class DeepseekV2AttentionMLA(nn.Module):
645
686
  self.w_vc = None
646
687
  self.w_scale = None
647
688
 
648
- self.enable_flashinfer_mla = global_server_args_dict["enable_flashinfer_mla"]
649
689
  self.flashinfer_mla_disable_ragged = global_server_args_dict[
650
690
  "flashinfer_mla_disable_ragged"
651
691
  ]
692
+ self.attention_backend = global_server_args_dict["attention_backend"]
652
693
  self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
653
694
 
654
695
  def no_absorb(self, forward_batch: ForwardBatch) -> bool:
655
- if self.enable_flashinfer_mla:
696
+ if self.attention_backend == "flashinfer":
656
697
  # Flashinfer MLA: Do not absorb when enabling ragged prefill
657
698
  return (
658
699
  not self.flashinfer_mla_disable_ragged
@@ -661,6 +702,9 @@ class DeepseekV2AttentionMLA(nn.Module):
661
702
  and not forward_batch.forward_mode.is_draft_extend()
662
703
  and sum(forward_batch.extend_prefix_lens_cpu) == 0
663
704
  )
705
+ elif self.attention_backend == "fa3":
706
+ # Flash Attention: Keep absorbing for all extend/decode
707
+ return False
664
708
  else:
665
709
  # Triton: Use normal computation for prefill and use weight absorption for extend/decode
666
710
  return (
@@ -969,6 +1013,14 @@ class DeepseekV2DecoderLayer(nn.Module):
969
1013
  is_nextn: bool = False,
970
1014
  prefix: str = "",
971
1015
  ) -> None:
1016
+
1017
+ def is_sparse_layer(l: int):
1018
+ return (
1019
+ config.n_routed_experts is not None
1020
+ and l >= config.first_k_dense_replace
1021
+ and l % config.moe_layer_freq == 0
1022
+ )
1023
+
972
1024
  super().__init__()
973
1025
  self.hidden_size = config.hidden_size
974
1026
  rope_theta = getattr(config, "rope_theta", 10000)
@@ -977,6 +1029,8 @@ class DeepseekV2DecoderLayer(nn.Module):
977
1029
  self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
978
1030
  self.layer_id = layer_id
979
1031
  self.dp_size = get_attention_dp_size()
1032
+ self.attn_tp_size = get_attention_tp_size()
1033
+ self.attn_tp_rank = get_attention_tp_rank()
980
1034
 
981
1035
  if not global_server_args_dict["disable_mla"]:
982
1036
  self.self_attn = DeepseekV2AttentionMLA(
@@ -1019,16 +1073,13 @@ class DeepseekV2DecoderLayer(nn.Module):
1019
1073
  prefix=add_prefix("self_attn", prefix),
1020
1074
  )
1021
1075
 
1022
- if is_nextn or (
1023
- config.n_routed_experts is not None
1024
- and layer_id >= config.first_k_dense_replace
1025
- and layer_id % config.moe_layer_freq == 0
1026
- ):
1076
+ if is_nextn or is_sparse_layer(layer_id):
1027
1077
  self.mlp = DeepseekV2MoE(
1028
1078
  config=config,
1029
1079
  quant_config=quant_config,
1030
1080
  prefix=add_prefix("mlp", prefix),
1031
1081
  )
1082
+ self.is_sparse = True
1032
1083
  else:
1033
1084
  self.mlp = DeepseekV2MLP(
1034
1085
  hidden_size=config.hidden_size,
@@ -1037,6 +1088,14 @@ class DeepseekV2DecoderLayer(nn.Module):
1037
1088
  quant_config=quant_config,
1038
1089
  prefix=add_prefix("mlp", prefix),
1039
1090
  )
1091
+ self.is_sparse = False
1092
+
1093
+ self.input_is_scattered = (
1094
+ is_sparse_layer(layer_id - 1)
1095
+ and global_server_args_dict["enable_deepep_moe"]
1096
+ )
1097
+ self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
1098
+
1040
1099
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1041
1100
  self.post_attention_layernorm = RMSNorm(
1042
1101
  config.hidden_size, eps=config.rms_norm_eps
@@ -1049,6 +1108,23 @@ class DeepseekV2DecoderLayer(nn.Module):
1049
1108
  forward_batch: ForwardBatch,
1050
1109
  residual: Optional[torch.Tensor],
1051
1110
  ) -> torch.Tensor:
1111
+ if global_server_args_dict["enable_deepep_moe"] and self.is_sparse:
1112
+ return self.forward_deepep(
1113
+ positions, hidden_states, forward_batch, residual
1114
+ )
1115
+ else:
1116
+ return self.forward_normal(
1117
+ positions, hidden_states, forward_batch, residual
1118
+ )
1119
+
1120
+ def forward_normal(
1121
+ self,
1122
+ positions: torch.Tensor,
1123
+ hidden_states: torch.Tensor,
1124
+ forward_batch: ForwardBatch,
1125
+ residual: Optional[torch.Tensor],
1126
+ ) -> torch.Tensor:
1127
+
1052
1128
  if hidden_states.shape[0] == 0:
1053
1129
  residual = hidden_states
1054
1130
  else:
@@ -1058,6 +1134,10 @@ class DeepseekV2DecoderLayer(nn.Module):
1058
1134
  else:
1059
1135
  hidden_states, residual = self.input_layernorm(hidden_states, residual)
1060
1136
 
1137
+ assert not (
1138
+ self.attn_tp_size != 1 and self.input_is_scattered
1139
+ ), "moe_layer_freq > 1 is not supported when attn_tp_size > 1"
1140
+
1061
1141
  # Self Attention
1062
1142
  hidden_states = self.self_attn(
1063
1143
  positions=positions,
@@ -1069,25 +1149,15 @@ class DeepseekV2DecoderLayer(nn.Module):
1069
1149
  if get_tensor_model_parallel_world_size() > 1:
1070
1150
  # all gather and all reduce
1071
1151
  if self.dp_size != 1:
1072
- if global_server_args_dict["enable_deepep_moe"] and isinstance(
1073
- self.mlp, DeepseekV2MoE
1074
- ):
1075
- if hidden_states.shape[0] != 0:
1076
- hidden_states, residual = self.post_attention_layernorm(
1077
- hidden_states, residual
1078
- )
1079
- hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
1080
- return hidden_states, residual
1081
- else:
1082
- if get_attention_tp_rank() == 0:
1083
- hidden_states += residual
1084
- hidden_states, local_hidden_states = (
1085
- forward_batch.gathered_buffer,
1086
- hidden_states,
1087
- )
1088
- dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
1089
- dp_scatter(residual, hidden_states, forward_batch)
1090
- hidden_states = self.post_attention_layernorm(hidden_states)
1152
+ if self.attn_tp_rank == 0:
1153
+ hidden_states += residual
1154
+ hidden_states, local_hidden_states = (
1155
+ forward_batch.gathered_buffer,
1156
+ hidden_states,
1157
+ )
1158
+ dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
1159
+ dp_scatter(residual, hidden_states, forward_batch)
1160
+ hidden_states = self.post_attention_layernorm(hidden_states)
1091
1161
  else:
1092
1162
  hidden_states = tensor_model_parallel_all_reduce(hidden_states)
1093
1163
  hidden_states, residual = self.post_attention_layernorm(
@@ -1101,6 +1171,7 @@ class DeepseekV2DecoderLayer(nn.Module):
1101
1171
  # Fully Connected
1102
1172
  hidden_states = self.mlp(hidden_states)
1103
1173
 
1174
+ # TODO(ch-wan): ues reduce-scatter in MLP to avoid this scatter
1104
1175
  # Scatter
1105
1176
  if self.dp_size != 1:
1106
1177
  # important: forward batch.gathered_buffer is used both after scatter and after gather.
@@ -1113,9 +1184,79 @@ class DeepseekV2DecoderLayer(nn.Module):
1113
1184
 
1114
1185
  return hidden_states, residual
1115
1186
 
1187
+ def forward_deepep(
1188
+ self,
1189
+ positions: torch.Tensor,
1190
+ hidden_states: torch.Tensor,
1191
+ forward_batch: ForwardBatch,
1192
+ residual: Optional[torch.Tensor],
1193
+ ) -> torch.Tensor:
1116
1194
 
1117
- class DeepseekV2Model(nn.Module):
1195
+ if hidden_states.shape[0] == 0:
1196
+ residual = hidden_states
1197
+ else:
1198
+ if residual is None:
1199
+ residual = hidden_states
1200
+ hidden_states = self.input_layernorm(hidden_states)
1201
+ else:
1202
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
1118
1203
 
1204
+ if self.attn_tp_size != 1 and self.input_is_scattered:
1205
+ hidden_states, local_hidden_states = (
1206
+ forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
1207
+ hidden_states,
1208
+ )
1209
+ tp_all_gather(
1210
+ list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
1211
+ )
1212
+
1213
+ # Self Attention
1214
+ hidden_states = self.self_attn(
1215
+ positions=positions,
1216
+ hidden_states=hidden_states,
1217
+ forward_batch=forward_batch,
1218
+ )
1219
+
1220
+ if self.attn_tp_size != 1:
1221
+ if self.input_is_scattered:
1222
+ tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
1223
+ hidden_states = tensor_list[self.attn_tp_rank]
1224
+ tp_reduce_scatter(hidden_states, tensor_list)
1225
+ if hidden_states.shape[0] != 0:
1226
+ hidden_states, residual = self.post_attention_layernorm(
1227
+ hidden_states, residual
1228
+ )
1229
+ else:
1230
+ if self.attn_tp_rank == 0:
1231
+ hidden_states += residual
1232
+ tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
1233
+ hidden_states = tensor_list[self.attn_tp_rank]
1234
+ tp_reduce_scatter(hidden_states, tensor_list)
1235
+ residual = hidden_states
1236
+ if hidden_states.shape[0] != 0:
1237
+ hidden_states = self.post_attention_layernorm(hidden_states)
1238
+ else:
1239
+ if hidden_states.shape[0] != 0:
1240
+ hidden_states, residual = self.post_attention_layernorm(
1241
+ hidden_states, residual
1242
+ )
1243
+ hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
1244
+
1245
+ if self.is_last_layer and self.attn_tp_size != 1:
1246
+ hidden_states += residual
1247
+ residual = None
1248
+ hidden_states, local_hidden_states = (
1249
+ forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
1250
+ hidden_states,
1251
+ )
1252
+ tp_all_gather(
1253
+ list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
1254
+ )
1255
+
1256
+ return hidden_states, residual
1257
+
1258
+
1259
+ class DeepseekV2Model(nn.Module):
1119
1260
  fall_back_to_pt_during_load = False
1120
1261
 
1121
1262
  def __init__(
@@ -1169,7 +1310,10 @@ class DeepseekV2Model(nn.Module):
1169
1310
  positions, hidden_states, forward_batch, residual
1170
1311
  )
1171
1312
  if not forward_batch.forward_mode.is_idle():
1172
- hidden_states, _ = self.norm(hidden_states, residual)
1313
+ if residual is None:
1314
+ hidden_states = self.norm(hidden_states)
1315
+ else:
1316
+ hidden_states, _ = self.norm(hidden_states, residual)
1173
1317
  return hidden_states
1174
1318
 
1175
1319
 
@@ -1183,7 +1327,28 @@ class DeepseekV2ForCausalLM(nn.Module):
1183
1327
  ) -> None:
1184
1328
  super().__init__()
1185
1329
  self.config = config
1330
+ self.tp_size = get_tensor_model_parallel_world_size()
1186
1331
  self.quant_config = quant_config
1332
+ self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
1333
+ # Only Deepseek V3/R1 can use shared experts fusion optimization now.
1334
+ if (
1335
+ global_server_args_dict.get("disable_shared_experts_fusion", False)
1336
+ or self.config.architectures[0] != "DeepseekV3ForCausalLM"
1337
+ or self.config.n_routed_experts != 256
1338
+ or self.config.routed_scaling_factor != 2.5
1339
+ ):
1340
+ self.n_share_experts_fusion = None
1341
+ global_server_args_dict["n_share_experts_fusion"] = None
1342
+ logger.info(
1343
+ "Only Deepseek V3/R1 can use shared experts fusion optimization. Shared experts fusion optimization is disabled."
1344
+ )
1345
+ elif self.n_share_experts_fusion is None:
1346
+ global_server_args_dict["n_share_experts_fusion"] = self.tp_size
1347
+ self.n_share_experts_fusion = self.tp_size
1348
+ logger.info(
1349
+ f"Shared experts fusion optimization is default enabled in DeepSeek V3/R1, and n_share_experts_fusion is set to {self.tp_size}. You can tune it by setting --n_share_experts_fusion or disable it by setting --disable_shared_experts_fusion."
1350
+ )
1351
+
1187
1352
  self.model = DeepseekV2Model(
1188
1353
  config, quant_config, prefix=add_prefix("model", prefix)
1189
1354
  )
@@ -1196,6 +1361,9 @@ class DeepseekV2ForCausalLM(nn.Module):
1196
1361
  self.logits_processor = LogitsProcessor(config)
1197
1362
  self.dp_size = get_attention_dp_size()
1198
1363
 
1364
+ def get_input_embeddings(self) -> nn.Embedding:
1365
+ return self.model.embed_tokens
1366
+
1199
1367
  @torch.no_grad()
1200
1368
  def forward(
1201
1369
  self,
@@ -1211,12 +1379,127 @@ class DeepseekV2ForCausalLM(nn.Module):
1211
1379
  input_ids, hidden_states, self.lm_head, forward_batch
1212
1380
  )
1213
1381
 
1382
+ def post_load_weights(self):
1383
+
1384
+ # Perform post-processing after loading weights
1385
+
1386
+ if not global_server_args_dict["disable_mla"]:
1387
+ for layer_id in range(self.config.num_hidden_layers):
1388
+ self_attn = self.model.layers[layer_id].self_attn
1389
+ if hasattr(self_attn.kv_b_proj, "qweight"):
1390
+ # AWQ compatible
1391
+ if _is_cuda:
1392
+ w = awq_dequantize(
1393
+ self_attn.kv_b_proj.qweight,
1394
+ self_attn.kv_b_proj.scales,
1395
+ self_attn.kv_b_proj.qzeros,
1396
+ ).T
1397
+ else:
1398
+ w = ops.awq_dequantize(
1399
+ self_attn.kv_b_proj.qweight,
1400
+ self_attn.kv_b_proj.scales,
1401
+ self_attn.kv_b_proj.qzeros,
1402
+ 0,
1403
+ 0,
1404
+ 0,
1405
+ ).T
1406
+ else:
1407
+ w = self_attn.kv_b_proj.weight
1408
+ # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
1409
+ # This may affect the accuracy of fp8 model.
1410
+ if hasattr(self.quant_config, "weight_block_size") and w.dtype in (
1411
+ torch.float8_e4m3fn,
1412
+ torch.float8_e4m3fnuz,
1413
+ ):
1414
+ weight_block_size = self.quant_config.weight_block_size
1415
+ if weight_block_size is not None:
1416
+ assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
1417
+ if _is_hip:
1418
+ weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
1419
+ weight=w,
1420
+ weight_scale=self_attn.kv_b_proj.weight_scale_inv,
1421
+ input_scale=None,
1422
+ )
1423
+ else:
1424
+ weight = w
1425
+ weight_scale = self_attn.kv_b_proj.weight_scale_inv
1426
+
1427
+ w, scale = block_quant_to_tensor_quant(
1428
+ weight, weight_scale, weight_block_size
1429
+ )
1430
+ self_attn.w_scale = scale
1431
+ if w.dtype == torch.int8:
1432
+ if hasattr(self.quant_config, "weight_block_size"):
1433
+ # block-wise int8 need it
1434
+ weight_block_size = self.quant_config.weight_block_size
1435
+ if weight_block_size is not None:
1436
+ assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
1437
+ weight = w
1438
+ weight_scale = self_attn.kv_b_proj.weight_scale_inv
1439
+ w = int8_block_dequant(
1440
+ weight, weight_scale, weight_block_size
1441
+ ).to(torch.bfloat16)
1442
+ else:
1443
+ # channel-wise int8 need it
1444
+ w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
1445
+ torch.bfloat16
1446
+ )
1447
+ w_kc, w_vc = w.unflatten(
1448
+ 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
1449
+ ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
1450
+ self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
1451
+ self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
1452
+ if (
1453
+ hasattr(self_attn.kv_b_proj, "weight_scale")
1454
+ and self_attn.w_scale is None
1455
+ ):
1456
+ self_attn.w_scale = self_attn.kv_b_proj.weight_scale
1457
+ if _is_hip:
1458
+ self_attn.w_scale *= 2.0
1459
+
1214
1460
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
1215
1461
  stacked_params_mapping = [
1216
1462
  # (param_name, shard_name, shard_id)
1217
1463
  ("gate_up_proj", "gate_proj", 0),
1218
1464
  ("gate_up_proj", "up_proj", 1),
1219
1465
  ]
1466
+ if self.n_share_experts_fusion is not None and self.n_share_experts_fusion > 0:
1467
+ weights_list = list(weights)
1468
+ weights_dict = dict(weights_list)
1469
+ suffix_list = [
1470
+ "down_proj.weight",
1471
+ "down_proj.weight_scale_inv",
1472
+ "gate_proj.weight",
1473
+ "gate_proj.weight_scale_inv",
1474
+ "up_proj.weight",
1475
+ "up_proj.weight_scale_inv",
1476
+ ]
1477
+ names_to_remove = []
1478
+ for moe_layer in tqdm(
1479
+ range(
1480
+ self.config.first_k_dense_replace,
1481
+ self.config.num_hidden_layers,
1482
+ self.config.moe_layer_freq,
1483
+ ),
1484
+ desc=f"Cloning {self.n_share_experts_fusion} "
1485
+ "replicas of the shared expert into MoE",
1486
+ ):
1487
+ for num_repeat in range(self.n_share_experts_fusion):
1488
+ for suffix in suffix_list:
1489
+ shared_expert_weight_name = (
1490
+ f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
1491
+ )
1492
+ weights_list.append(
1493
+ (
1494
+ f"model.layers.{moe_layer}."
1495
+ f"mlp.experts."
1496
+ f"{self.config.n_routed_experts + num_repeat}"
1497
+ f".{suffix}",
1498
+ weights_dict[shared_expert_weight_name].clone(),
1499
+ )
1500
+ )
1501
+ names_to_remove += [shared_expert_weight_name]
1502
+ weights = [w for w in weights_list if w[0] not in names_to_remove]
1220
1503
 
1221
1504
  # Params for weights, fp8 weight scales, fp8 activation scales
1222
1505
  # (param_name, weight_name, expert_id, shard_id)
@@ -1229,7 +1512,12 @@ class DeepseekV2ForCausalLM(nn.Module):
1229
1512
  ckpt_gate_proj_name="gate_proj",
1230
1513
  ckpt_down_proj_name="down_proj",
1231
1514
  ckpt_up_proj_name="up_proj",
1232
- num_experts=self.config.n_routed_experts,
1515
+ num_experts=self.config.n_routed_experts
1516
+ + (
1517
+ self.n_share_experts_fusion
1518
+ if self.n_share_experts_fusion is not None
1519
+ else 0
1520
+ ),
1233
1521
  )
1234
1522
 
1235
1523
  params_dict = dict(self.named_parameters())
@@ -1293,79 +1581,7 @@ class DeepseekV2ForCausalLM(nn.Module):
1293
1581
  )
1294
1582
  weight_loader(param, loaded_weight)
1295
1583
 
1296
- if not global_server_args_dict["disable_mla"]:
1297
- for layer_id in range(self.config.num_hidden_layers):
1298
- self_attn = self.model.layers[layer_id].self_attn
1299
- if hasattr(self_attn.kv_b_proj, "qweight"):
1300
- # AWQ compatible
1301
- if _is_cuda:
1302
- w = awq_dequantize(
1303
- self_attn.kv_b_proj.qweight,
1304
- self_attn.kv_b_proj.scales,
1305
- self_attn.kv_b_proj.qzeros,
1306
- ).T
1307
- else:
1308
- w = ops.awq_dequantize(
1309
- self_attn.kv_b_proj.qweight,
1310
- self_attn.kv_b_proj.scales,
1311
- self_attn.kv_b_proj.qzeros,
1312
- 0,
1313
- 0,
1314
- 0,
1315
- ).T
1316
- else:
1317
- w = self_attn.kv_b_proj.weight
1318
- # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
1319
- # This may affect the accuracy of fp8 model.
1320
- if hasattr(self.quant_config, "weight_block_size") and w.dtype in (
1321
- torch.float8_e4m3fn,
1322
- torch.float8_e4m3fnuz,
1323
- ):
1324
- weight_block_size = self.quant_config.weight_block_size
1325
- if weight_block_size is not None:
1326
- assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
1327
- if _is_hip:
1328
- weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
1329
- weight=w,
1330
- weight_scale=self_attn.kv_b_proj.weight_scale_inv,
1331
- input_scale=None,
1332
- )
1333
- else:
1334
- weight = w
1335
- weight_scale = self_attn.kv_b_proj.weight_scale_inv
1336
-
1337
- w, scale = block_quant_to_tensor_quant(
1338
- weight, weight_scale, weight_block_size
1339
- )
1340
- self_attn.w_scale = scale
1341
- if w.dtype == torch.int8:
1342
- if hasattr(self.quant_config, "weight_block_size"):
1343
- # block-wise int8 need it
1344
- weight_block_size = self.quant_config.weight_block_size
1345
- if weight_block_size is not None:
1346
- assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
1347
- weight = w
1348
- weight_scale = self_attn.kv_b_proj.weight_scale_inv
1349
- w = int8_block_dequant(
1350
- weight, weight_scale, weight_block_size
1351
- ).to(torch.bfloat16)
1352
- else:
1353
- # channel-wise int8 need it
1354
- w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
1355
- torch.bfloat16
1356
- )
1357
- w_kc, w_vc = w.unflatten(
1358
- 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
1359
- ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
1360
- self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
1361
- self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
1362
- if (
1363
- hasattr(self_attn.kv_b_proj, "weight_scale")
1364
- and self_attn.w_scale is None
1365
- ):
1366
- self_attn.w_scale = self_attn.kv_b_proj.weight_scale
1367
- if _is_hip:
1368
- self_attn.w_scale *= 2.0
1584
+ self.post_load_weights()
1369
1585
 
1370
1586
  def get_embed_and_head(self):
1371
1587
  return self.model.embed_tokens.weight, self.lm_head.weight