sglang 0.4.6.post4__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. sglang/bench_offline_throughput.py +6 -6
  2. sglang/bench_one_batch.py +5 -4
  3. sglang/bench_one_batch_server.py +23 -15
  4. sglang/bench_serving.py +133 -57
  5. sglang/compile_deep_gemm.py +4 -4
  6. sglang/srt/configs/model_config.py +39 -28
  7. sglang/srt/conversation.py +1 -1
  8. sglang/srt/disaggregation/decode.py +122 -133
  9. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
  10. sglang/srt/disaggregation/fake/conn.py +3 -13
  11. sglang/srt/disaggregation/kv_events.py +357 -0
  12. sglang/srt/disaggregation/mini_lb.py +57 -24
  13. sglang/srt/disaggregation/mooncake/conn.py +11 -2
  14. sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
  15. sglang/srt/disaggregation/nixl/conn.py +9 -19
  16. sglang/srt/disaggregation/prefill.py +126 -44
  17. sglang/srt/disaggregation/utils.py +116 -5
  18. sglang/srt/distributed/utils.py +3 -3
  19. sglang/srt/entrypoints/EngineBase.py +5 -0
  20. sglang/srt/entrypoints/engine.py +28 -8
  21. sglang/srt/entrypoints/http_server.py +6 -4
  22. sglang/srt/entrypoints/http_server_engine.py +5 -2
  23. sglang/srt/function_call/base_format_detector.py +250 -0
  24. sglang/srt/function_call/core_types.py +34 -0
  25. sglang/srt/function_call/deepseekv3_detector.py +157 -0
  26. sglang/srt/function_call/ebnf_composer.py +234 -0
  27. sglang/srt/function_call/function_call_parser.py +175 -0
  28. sglang/srt/function_call/llama32_detector.py +74 -0
  29. sglang/srt/function_call/mistral_detector.py +84 -0
  30. sglang/srt/function_call/pythonic_detector.py +163 -0
  31. sglang/srt/function_call/qwen25_detector.py +67 -0
  32. sglang/srt/function_call/utils.py +35 -0
  33. sglang/srt/hf_transformers_utils.py +46 -7
  34. sglang/srt/layers/attention/aiter_backend.py +513 -0
  35. sglang/srt/layers/attention/flashattention_backend.py +63 -17
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
  37. sglang/srt/layers/attention/flashmla_backend.py +340 -78
  38. sglang/srt/layers/attention/triton_backend.py +3 -0
  39. sglang/srt/layers/attention/utils.py +2 -2
  40. sglang/srt/layers/attention/vision.py +1 -1
  41. sglang/srt/layers/communicator.py +451 -0
  42. sglang/srt/layers/dp_attention.py +0 -10
  43. sglang/srt/layers/moe/cutlass_moe.py +207 -0
  44. sglang/srt/layers/moe/ep_moe/kernels.py +33 -11
  45. sglang/srt/layers/moe/ep_moe/layer.py +104 -50
  46. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
  47. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
  48. sglang/srt/layers/moe/topk.py +66 -9
  49. sglang/srt/layers/multimodal.py +70 -0
  50. sglang/srt/layers/quantization/__init__.py +7 -2
  51. sglang/srt/layers/quantization/deep_gemm.py +5 -3
  52. sglang/srt/layers/quantization/fp8.py +90 -0
  53. sglang/srt/layers/quantization/fp8_utils.py +6 -0
  54. sglang/srt/layers/quantization/gptq.py +298 -6
  55. sglang/srt/layers/quantization/int8_kernel.py +18 -5
  56. sglang/srt/layers/quantization/qoq.py +244 -0
  57. sglang/srt/lora/lora_manager.py +1 -3
  58. sglang/srt/managers/deepseek_eplb.py +278 -0
  59. sglang/srt/managers/eplb_manager.py +55 -0
  60. sglang/srt/managers/expert_distribution.py +704 -56
  61. sglang/srt/managers/expert_location.py +394 -0
  62. sglang/srt/managers/expert_location_dispatch.py +91 -0
  63. sglang/srt/managers/io_struct.py +16 -3
  64. sglang/srt/managers/mm_utils.py +293 -139
  65. sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
  66. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
  67. sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
  68. sglang/srt/managers/multimodal_processors/internvl.py +14 -5
  69. sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
  70. sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
  71. sglang/srt/managers/multimodal_processors/llava.py +3 -3
  72. sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
  73. sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
  74. sglang/srt/managers/multimodal_processors/pixtral.py +9 -9
  75. sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
  76. sglang/srt/managers/schedule_batch.py +49 -21
  77. sglang/srt/managers/schedule_policy.py +4 -5
  78. sglang/srt/managers/scheduler.py +92 -50
  79. sglang/srt/managers/session_controller.py +1 -1
  80. sglang/srt/managers/tokenizer_manager.py +99 -24
  81. sglang/srt/mem_cache/base_prefix_cache.py +3 -0
  82. sglang/srt/mem_cache/chunk_cache.py +3 -1
  83. sglang/srt/mem_cache/hiradix_cache.py +4 -4
  84. sglang/srt/mem_cache/memory_pool.py +74 -52
  85. sglang/srt/mem_cache/multimodal_cache.py +45 -0
  86. sglang/srt/mem_cache/radix_cache.py +58 -5
  87. sglang/srt/metrics/collector.py +2 -2
  88. sglang/srt/mm_utils.py +10 -0
  89. sglang/srt/model_executor/cuda_graph_runner.py +20 -9
  90. sglang/srt/model_executor/expert_location_updater.py +422 -0
  91. sglang/srt/model_executor/forward_batch_info.py +4 -0
  92. sglang/srt/model_executor/model_runner.py +144 -54
  93. sglang/srt/model_loader/loader.py +10 -6
  94. sglang/srt/models/clip.py +5 -1
  95. sglang/srt/models/deepseek_v2.py +297 -343
  96. sglang/srt/models/exaone.py +8 -3
  97. sglang/srt/models/gemma3_mm.py +70 -33
  98. sglang/srt/models/llama4.py +10 -2
  99. sglang/srt/models/llava.py +26 -18
  100. sglang/srt/models/mimo_mtp.py +220 -0
  101. sglang/srt/models/minicpmo.py +5 -12
  102. sglang/srt/models/mistral.py +71 -1
  103. sglang/srt/models/mllama.py +3 -3
  104. sglang/srt/models/qwen2.py +95 -26
  105. sglang/srt/models/qwen2_5_vl.py +8 -0
  106. sglang/srt/models/qwen2_moe.py +330 -60
  107. sglang/srt/models/qwen2_vl.py +6 -0
  108. sglang/srt/models/qwen3.py +52 -10
  109. sglang/srt/models/qwen3_moe.py +411 -48
  110. sglang/srt/models/siglip.py +294 -0
  111. sglang/srt/openai_api/adapter.py +28 -16
  112. sglang/srt/openai_api/protocol.py +6 -0
  113. sglang/srt/operations.py +154 -0
  114. sglang/srt/operations_strategy.py +31 -0
  115. sglang/srt/server_args.py +134 -24
  116. sglang/srt/speculative/eagle_utils.py +131 -0
  117. sglang/srt/speculative/eagle_worker.py +47 -2
  118. sglang/srt/utils.py +68 -12
  119. sglang/test/test_cutlass_moe.py +278 -0
  120. sglang/test/test_utils.py +2 -36
  121. sglang/utils.py +2 -2
  122. sglang/version.py +1 -1
  123. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +20 -11
  124. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +128 -102
  125. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
  126. sglang/srt/function_call_parser.py +0 -858
  127. sglang/srt/platforms/interface.py +0 -371
  128. /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
  129. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
  130. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -18,8 +18,7 @@
18
18
 
19
19
  import logging
20
20
  import os
21
- from dataclasses import dataclass
22
- from enum import Enum, IntEnum, auto
21
+ from enum import IntEnum, auto
23
22
  from typing import Any, Dict, Iterable, Optional, Tuple
24
23
 
25
24
  import torch
@@ -29,17 +28,17 @@ from tqdm import tqdm
29
28
  from transformers import PretrainedConfig
30
29
 
31
30
  from sglang.srt.distributed import (
32
- get_tensor_model_parallel_rank,
33
31
  get_tensor_model_parallel_world_size,
34
32
  parallel_state,
35
33
  tensor_model_parallel_all_reduce,
36
34
  )
37
35
  from sglang.srt.layers.activation import SiluAndMul
36
+ from sglang.srt.layers.communicator import (
37
+ LayerCommunicator,
38
+ LayerScatterModes,
39
+ enable_moe_dense_fully_dp,
40
+ )
38
41
  from sglang.srt.layers.dp_attention import (
39
- attn_tp_all_gather,
40
- attn_tp_reduce_scatter,
41
- dp_gather_partial,
42
- dp_scatter,
43
42
  get_attention_tp_rank,
44
43
  get_attention_tp_size,
45
44
  get_local_attention_dp_size,
@@ -52,9 +51,8 @@ from sglang.srt.layers.linear import (
52
51
  RowParallelLinear,
53
52
  )
54
53
  from sglang.srt.layers.logits_processor import LogitsProcessor
55
- from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE
54
+ from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
56
55
  from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
57
- from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
58
56
  from sglang.srt.layers.moe.topk import select_experts
59
57
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
60
58
  from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
@@ -72,15 +70,21 @@ from sglang.srt.layers.quantization.int8_utils import (
72
70
  block_dequant as int8_block_dequant,
73
71
  )
74
72
  from sglang.srt.layers.radix_attention import RadixAttention
75
- from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
73
+ from sglang.srt.layers.rotary_embedding import get_rope
76
74
  from sglang.srt.layers.vocab_parallel_embedding import (
77
75
  ParallelLMHead,
78
76
  VocabParallelEmbedding,
79
77
  )
80
- from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
78
+ from sglang.srt.managers.expert_distribution import (
79
+ get_global_expert_distribution_recorder,
80
+ )
81
+ from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
82
+ from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
81
83
  from sglang.srt.managers.schedule_batch import global_server_args_dict
82
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
84
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
83
85
  from sglang.srt.model_loader.weight_utils import default_weight_loader
86
+ from sglang.srt.operations import execute_operations
87
+ from sglang.srt.operations_strategy import compute_layer_operations
84
88
  from sglang.srt.utils import (
85
89
  BumpAllocator,
86
90
  DeepEPMode,
@@ -109,8 +113,6 @@ if _is_hip:
109
113
  decode_attention_fwd_grouped_rope,
110
114
  )
111
115
 
112
- expert_distribution_recorder = ExpertDistributionRecorder()
113
-
114
116
  logger = logging.getLogger(__name__)
115
117
 
116
118
 
@@ -125,6 +127,9 @@ class AttnForwardMethod(IntEnum):
125
127
  # This method can avoid OOM when prefix lengths are long.
126
128
  MHA_CHUNKED_KV = auto()
127
129
 
130
+ # Use MLA but with fused RoPE
131
+ MLA_FUSED_ROPE = auto()
132
+
128
133
 
129
134
  class DeepseekV2MLP(nn.Module):
130
135
  def __init__(
@@ -139,6 +144,8 @@ class DeepseekV2MLP(nn.Module):
139
144
  tp_size: Optional[int] = None,
140
145
  ) -> None:
141
146
  super().__init__()
147
+ self.tp_size = tp_size
148
+
142
149
  self.gate_up_proj = MergedColumnParallelLinear(
143
150
  hidden_size,
144
151
  [intermediate_size] * 2,
@@ -165,7 +172,10 @@ class DeepseekV2MLP(nn.Module):
165
172
  )
166
173
  self.act_fn = SiluAndMul()
167
174
 
168
- def forward(self, x, forward_mode: Optional[ForwardMode] = None):
175
+ def forward(self, x, forward_batch=None):
176
+ if (self.tp_size == 1) and x.shape[0] == 0:
177
+ return x
178
+
169
179
  gate_up, _ = self.gate_up_proj(x)
170
180
  x = self.act_fn(gate_up)
171
181
  x, _ = self.down_proj(x)
@@ -194,11 +204,20 @@ class MoEGate(nn.Module):
194
204
  return logits
195
205
 
196
206
 
207
+ def is_non_idle_and_non_empty(forward_mode, hidden_states):
208
+ return (
209
+ (forward_mode is not None)
210
+ and not forward_mode.is_idle()
211
+ and hidden_states.shape[0] > 0
212
+ )
213
+
214
+
197
215
  class DeepseekV2MoE(nn.Module):
198
216
 
199
217
  def __init__(
200
218
  self,
201
219
  config: PretrainedConfig,
220
+ layer_id: int,
202
221
  quant_config: Optional[QuantizationConfig] = None,
203
222
  prefix: str = "",
204
223
  ):
@@ -207,6 +226,7 @@ class DeepseekV2MoE(nn.Module):
207
226
  self.routed_scaling_factor = config.routed_scaling_factor
208
227
  self.n_shared_experts = config.n_shared_experts
209
228
  self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
229
+ self.layer_id = layer_id
210
230
 
211
231
  if self.tp_size > config.n_routed_experts:
212
232
  raise ValueError(
@@ -222,17 +242,14 @@ class DeepseekV2MoE(nn.Module):
222
242
 
223
243
  self.gate = MoEGate(config=config, prefix=add_prefix("gate", prefix))
224
244
 
225
- MoEImpl = (
226
- DeepEPMoE
227
- if global_server_args_dict["enable_deepep_moe"]
228
- else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
229
- )
230
-
231
- self.experts = MoEImpl(
232
- num_experts=config.n_routed_experts + self.n_share_experts_fusion,
245
+ self.experts = get_moe_impl_class()(
246
+ num_experts=config.n_routed_experts
247
+ + self.n_share_experts_fusion
248
+ + global_server_args_dict["ep_num_redundant_experts"],
233
249
  top_k=config.num_experts_per_tok + min(self.n_share_experts_fusion, 1),
234
250
  hidden_size=config.hidden_size,
235
251
  intermediate_size=config.moe_intermediate_size,
252
+ layer_id=self.layer_id,
236
253
  renormalize=config.norm_topk_prob,
237
254
  quant_config=quant_config,
238
255
  use_grouped_topk=True,
@@ -251,32 +268,29 @@ class DeepseekV2MoE(nn.Module):
251
268
  if config.n_shared_experts is not None and self.n_share_experts_fusion == 0:
252
269
  intermediate_size = config.moe_intermediate_size * config.n_shared_experts
253
270
  # disable tp for shared experts when enable deepep moe
254
- if not global_server_args_dict["enable_deepep_moe"]:
255
- self.shared_experts = DeepseekV2MLP(
256
- hidden_size=config.hidden_size,
257
- intermediate_size=intermediate_size,
258
- hidden_act=config.hidden_act,
259
- quant_config=quant_config,
260
- reduce_results=False,
261
- prefix=add_prefix("shared_experts", prefix),
262
- )
263
- else:
264
- self.shared_experts = DeepseekV2MLP(
265
- hidden_size=config.hidden_size,
266
- intermediate_size=intermediate_size,
267
- hidden_act=config.hidden_act,
268
- quant_config=quant_config,
269
- reduce_results=False,
270
- prefix=add_prefix("shared_experts", prefix),
271
- tp_rank=0,
272
- tp_size=1,
273
- )
271
+ self.shared_experts = DeepseekV2MLP(
272
+ hidden_size=config.hidden_size,
273
+ intermediate_size=intermediate_size,
274
+ hidden_act=config.hidden_act,
275
+ quant_config=quant_config,
276
+ reduce_results=False,
277
+ prefix=add_prefix("shared_experts", prefix),
278
+ **(
279
+ dict(tp_rank=0, tp_size=1)
280
+ if global_server_args_dict["enable_deepep_moe"]
281
+ else {}
282
+ ),
283
+ )
284
+
285
+ self.top_k = config.num_experts_per_tok
274
286
 
275
287
  if global_server_args_dict["enable_deepep_moe"]:
276
288
  # TODO: we will support tp < ep in the future
277
289
  self.ep_size = get_tensor_model_parallel_world_size()
278
- self.num_experts = config.n_routed_experts
279
- self.top_k = config.num_experts_per_tok
290
+ self.num_experts = (
291
+ config.n_routed_experts
292
+ + global_server_args_dict["ep_num_redundant_experts"]
293
+ )
280
294
  self.renormalize = config.norm_topk_prob
281
295
  self.topk_group = config.topk_group
282
296
  self.num_expert_group = config.n_group
@@ -290,7 +304,7 @@ class DeepseekV2MoE(nn.Module):
290
304
  group=parallel_state.get_tp_group().device_group,
291
305
  router_topk=self.top_k,
292
306
  permute_fusion=True,
293
- num_experts=config.n_routed_experts,
307
+ num_experts=self.num_experts,
294
308
  num_local_experts=config.n_routed_experts // self.tp_size,
295
309
  hidden_size=config.hidden_size,
296
310
  params_dtype=config.torch_dtype,
@@ -299,105 +313,137 @@ class DeepseekV2MoE(nn.Module):
299
313
  return_recv_hook=True,
300
314
  )
301
315
 
302
- def forward(
303
- self, hidden_states: torch.Tensor, forward_mode: Optional[ForwardMode] = None
304
- ) -> torch.Tensor:
305
- if not global_server_args_dict["enable_deepep_moe"]:
306
- return self.forward_normal(hidden_states)
307
- else:
308
- return self.forward_deepep(hidden_states, forward_mode)
316
+ @property
317
+ def _enable_deepep_moe(self):
318
+ return global_server_args_dict["enable_deepep_moe"]
309
319
 
310
- def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
311
- shared_output = self._forward_shared_experts(hidden_states)
312
- # router_logits: (num_tokens, n_experts)
313
- router_logits = self.gate(hidden_states)
314
- final_hidden_states = (
315
- self.experts(hidden_states=hidden_states, router_logits=router_logits)
316
- * self.routed_scaling_factor
317
- )
318
- if shared_output is not None:
319
- final_hidden_states = final_hidden_states + shared_output
320
- if self.tp_size > 1:
321
- final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
322
- return final_hidden_states
320
+ def get_moe_weights(self):
321
+ return [
322
+ x.data
323
+ for name, x in self.experts.named_parameters()
324
+ if name not in ["correction_bias"]
325
+ ]
323
326
 
324
- def forward_deepep(
325
- self, hidden_states: torch.Tensor, forward_mode: ForwardMode
326
- ) -> torch.Tensor:
327
- shared_output = None
328
- if (
329
- forward_mode is not None
330
- and not forward_mode.is_idle()
331
- and hidden_states.shape[0] > 0
327
+ def op_gate(self, state):
328
+ if (not self._enable_deepep_moe) or is_non_idle_and_non_empty(
329
+ state.forward_batch.forward_mode, state.hidden_states_mlp_input
332
330
  ):
333
331
  # router_logits: (num_tokens, n_experts)
334
- router_logits = self.gate(hidden_states)
335
- shared_output = self._forward_shared_experts(hidden_states)
336
- topk_weights, topk_idx = select_experts(
337
- hidden_states=hidden_states,
338
- router_logits=router_logits,
339
- top_k=self.top_k,
340
- use_grouped_topk=True,
341
- renormalize=self.renormalize,
342
- topk_group=self.topk_group,
343
- num_expert_group=self.num_expert_group,
344
- correction_bias=self.correction_bias,
345
- routed_scaling_factor=self.routed_scaling_factor,
346
- )
332
+ state.router_logits = self.gate(state.hidden_states_mlp_input)
347
333
  else:
348
- topk_idx = torch.full(
349
- (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
350
- )
351
- topk_weights = torch.empty(
352
- (0, self.top_k), dtype=torch.float32, device=hidden_states.device
334
+ state.router_logits = None
335
+
336
+ def op_shared_experts(self, state):
337
+ if (self.n_share_experts_fusion == 0) and (
338
+ (not self._enable_deepep_moe)
339
+ or is_non_idle_and_non_empty(
340
+ state.forward_batch.forward_mode, state.hidden_states_mlp_input
353
341
  )
354
- if self.ep_size > 1:
342
+ ):
343
+ state.shared_output = self.shared_experts(state.hidden_states_mlp_input)
344
+ else:
345
+ state.shared_output = None
346
+
347
+ def op_select_experts(self, state):
348
+ router_logits = state.router_logits
349
+ hidden_states = state.hidden_states_mlp_input
350
+
351
+ if self._enable_deepep_moe:
352
+ if router_logits is not None:
353
+ state.topk_weights_local, state.topk_idx_local = select_experts(
354
+ hidden_states=hidden_states,
355
+ router_logits=router_logits,
356
+ top_k=self.top_k,
357
+ use_grouped_topk=True,
358
+ renormalize=self.renormalize,
359
+ topk_group=self.topk_group,
360
+ num_expert_group=self.num_expert_group,
361
+ correction_bias=self.correction_bias,
362
+ routed_scaling_factor=self.routed_scaling_factor,
363
+ expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
364
+ layer_id=self.layer_id,
365
+ ),
366
+ )
367
+ else:
368
+ state.topk_idx_local = torch.full(
369
+ (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
370
+ )
371
+ state.topk_weights_local = torch.empty(
372
+ (0, self.top_k), dtype=torch.float32, device=hidden_states.device
373
+ )
374
+
375
+ def op_dispatch_a(self, state):
376
+ if self._enable_deepep_moe and (self.ep_size > 1):
355
377
  # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
378
+ self.deepep_dispatcher.dispatch_a(
379
+ hidden_states=state.pop("hidden_states_mlp_input"),
380
+ topk_idx=state.pop("topk_idx_local"),
381
+ topk_weights=state.pop("topk_weights_local"),
382
+ forward_mode=state.forward_batch.forward_mode,
383
+ )
384
+
385
+ def op_dispatch_b(self, state):
386
+ if self._enable_deepep_moe and (self.ep_size > 1):
356
387
  (
357
- hidden_states,
358
- topk_idx,
359
- topk_weights,
360
- reorder_topk_ids,
361
- num_recv_tokens_per_expert,
362
- seg_indptr,
363
- masked_m,
364
- expected_m,
365
- ) = self.deepep_dispatcher.dispatch(
366
- hidden_states,
367
- topk_idx,
368
- topk_weights,
369
- forward_mode=forward_mode,
388
+ state.hidden_states_experts_input,
389
+ state.topk_idx_dispatched,
390
+ state.topk_weights_dispatched,
391
+ state.reorder_topk_ids,
392
+ state.num_recv_tokens_per_expert,
393
+ state.seg_indptr,
394
+ state.masked_m,
395
+ state.expected_m,
396
+ ) = self.deepep_dispatcher.dispatch_b()
397
+
398
+ def op_experts(self, state):
399
+ if self._enable_deepep_moe:
400
+ state.pop("router_logits")
401
+ state.hidden_states_experts_output = self.experts(
402
+ hidden_states=state.pop("hidden_states_experts_input"),
403
+ topk_idx=state.topk_idx_dispatched,
404
+ topk_weights=state.topk_weights_dispatched,
405
+ reorder_topk_ids=state.pop("reorder_topk_ids"),
406
+ seg_indptr=state.pop("seg_indptr"),
407
+ masked_m=state.pop("masked_m"),
408
+ expected_m=state.pop("expected_m"),
409
+ num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
410
+ forward_mode=state.forward_batch.forward_mode,
370
411
  )
371
- final_hidden_states = self.experts(
372
- hidden_states=hidden_states,
373
- topk_idx=topk_idx,
374
- topk_weights=topk_weights,
375
- reorder_topk_ids=reorder_topk_ids,
376
- seg_indptr=seg_indptr,
377
- masked_m=masked_m,
378
- expected_m=expected_m,
379
- num_recv_tokens_per_expert=num_recv_tokens_per_expert,
380
- forward_mode=forward_mode,
381
- )
382
- if self.ep_size > 1:
383
- final_hidden_states = self.deepep_dispatcher.combine(
384
- final_hidden_states,
385
- topk_idx,
386
- topk_weights,
387
- forward_mode,
412
+ else:
413
+ state.hidden_states_experts_output = self.experts(
414
+ hidden_states=state.pop("hidden_states_mlp_input"),
415
+ router_logits=state.pop("router_logits"),
388
416
  )
417
+
418
+ def op_combine_a(self, state):
419
+ if self._enable_deepep_moe and (self.ep_size > 1):
420
+ self.deepep_dispatcher.combine_a(
421
+ state.pop("hidden_states_experts_output"),
422
+ topk_idx=state.pop("topk_idx_dispatched"),
423
+ topk_weights=state.pop("topk_weights_dispatched"),
424
+ forward_mode=state.forward_batch.forward_mode,
425
+ )
426
+
427
+ def op_combine_b(self, state):
428
+ if self._enable_deepep_moe and (self.ep_size > 1):
429
+ state.hidden_states_after_combine = self.deepep_dispatcher.combine_b()
430
+
431
+ def op_output(self, state):
432
+ final_hidden_states = (
433
+ state.pop("hidden_states_after_combine")
434
+ if self._enable_deepep_moe
435
+ else state.pop("hidden_states_experts_output")
436
+ )
437
+
389
438
  final_hidden_states *= self.routed_scaling_factor
390
439
 
391
- if shared_output is not None:
392
- final_hidden_states = final_hidden_states + shared_output
440
+ if (s := state.pop("shared_output")) is not None:
441
+ final_hidden_states = final_hidden_states + s
393
442
 
394
- return final_hidden_states
443
+ if (not self._enable_deepep_moe) and (self.tp_size > 1):
444
+ final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
395
445
 
396
- def _forward_shared_experts(self, hidden_states):
397
- if self.n_share_experts_fusion == 0:
398
- return self.shared_experts(hidden_states)
399
- else:
400
- return None
446
+ state.hidden_states_mlp_output = final_hidden_states
401
447
 
402
448
 
403
449
  def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
@@ -578,6 +624,18 @@ class DeepseekV2AttentionMLA(nn.Module):
578
624
  def dispatch_attn_forward_method(
579
625
  self, forward_batch: ForwardBatch
580
626
  ) -> AttnForwardMethod:
627
+ def _dispatch_mla_subtype():
628
+ if _is_hip:
629
+ if (
630
+ self.rocm_fused_decode_mla
631
+ and forward_batch.forward_mode.is_decode()
632
+ ):
633
+ return AttnForwardMethod.MLA_FUSED_ROPE
634
+ else:
635
+ return AttnForwardMethod.MLA
636
+ else:
637
+ return AttnForwardMethod.MLA
638
+
581
639
  if self.attention_backend == "flashinfer":
582
640
  # Flashinfer MLA: Do not absorb when enabling ragged prefill
583
641
  if (
@@ -589,7 +647,7 @@ class DeepseekV2AttentionMLA(nn.Module):
589
647
  ):
590
648
  return AttnForwardMethod.MHA
591
649
  else:
592
- return AttnForwardMethod.MLA
650
+ return _dispatch_mla_subtype()
593
651
  elif self.attention_backend == "fa3":
594
652
  # Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences.
595
653
  if forward_batch.extend_prefix_lens_cpu is not None:
@@ -606,7 +664,7 @@ class DeepseekV2AttentionMLA(nn.Module):
606
664
  ):
607
665
  return AttnForwardMethod.MHA_CHUNKED_KV
608
666
  else:
609
- return AttnForwardMethod.MLA
667
+ return _dispatch_mla_subtype()
610
668
  else:
611
669
  # Triton: Use normal computation for prefill and use weight absorption for extend/decode
612
670
  if (
@@ -617,7 +675,7 @@ class DeepseekV2AttentionMLA(nn.Module):
617
675
  ):
618
676
  return AttnForwardMethod.MHA
619
677
  else:
620
- return AttnForwardMethod.MLA
678
+ return _dispatch_mla_subtype()
621
679
 
622
680
  def forward(
623
681
  self,
@@ -640,23 +698,16 @@ class DeepseekV2AttentionMLA(nn.Module):
640
698
  return self.forward_normal_chunked_kv(
641
699
  positions, hidden_states, forward_batch
642
700
  )
701
+ elif attn_forward_method == AttnForwardMethod.MLA:
702
+ return self.forward_absorb(
703
+ positions, hidden_states, forward_batch, zero_allocator
704
+ )
705
+ elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
706
+ return self.forward_absorb_fused_mla_rope(
707
+ positions, hidden_states, forward_batch
708
+ )
643
709
  else:
644
- if _is_hip:
645
- if (
646
- self.rocm_fused_decode_mla
647
- and forward_batch.forward_mode.is_decode()
648
- ):
649
- return self.forward_absorb_fused_mla_rope(
650
- positions, hidden_states, forward_batch
651
- )
652
- else:
653
- return self.forward_absorb(
654
- positions, hidden_states, forward_batch, zero_allocator
655
- )
656
- else:
657
- return self.forward_absorb(
658
- positions, hidden_states, forward_batch, zero_allocator
659
- )
710
+ raise NotImplementedError
660
711
 
661
712
  def forward_normal(
662
713
  self,
@@ -710,6 +761,8 @@ class DeepseekV2AttentionMLA(nn.Module):
710
761
  forward_batch: ForwardBatch,
711
762
  zero_allocator: BumpAllocator,
712
763
  ) -> torch.Tensor:
764
+ from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
765
+
713
766
  if self.q_lora_rank is not None:
714
767
  q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
715
768
  [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
@@ -717,7 +770,7 @@ class DeepseekV2AttentionMLA(nn.Module):
717
770
  k_nope = latent_cache[..., : self.kv_lora_rank]
718
771
 
719
772
  # overlap qk norm
720
- if self.alt_stream is not None and torch.cuda.is_current_stream_capturing():
773
+ if self.alt_stream is not None and get_is_capture_mode():
721
774
  current_stream = torch.cuda.current_stream()
722
775
  self.alt_stream.wait_stream(current_stream)
723
776
  q = self.q_a_layernorm(q)
@@ -1101,19 +1154,6 @@ class DeepseekV2AttentionMLA(nn.Module):
1101
1154
  return output
1102
1155
 
1103
1156
 
1104
- class _FFNInputMode(Enum):
1105
- # The MLP sublayer requires 1/tp_size tokens as input
1106
- SCATTERED = auto()
1107
- # The MLP sublayer requires all tokens as input
1108
- FULL = auto()
1109
-
1110
-
1111
- @dataclass
1112
- class _DecoderLayerInfo:
1113
- is_sparse: bool
1114
- ffn_input_mode: _FFNInputMode
1115
-
1116
-
1117
1157
  class DeepseekV2DecoderLayer(nn.Module):
1118
1158
 
1119
1159
  def __init__(
@@ -1127,14 +1167,12 @@ class DeepseekV2DecoderLayer(nn.Module):
1127
1167
  ) -> None:
1128
1168
  super().__init__()
1129
1169
  self.hidden_size = config.hidden_size
1170
+ self.config = config
1130
1171
  rope_theta = getattr(config, "rope_theta", 10000)
1131
1172
  rope_scaling = getattr(config, "rope_scaling", None)
1132
1173
  max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
1133
1174
  self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
1134
1175
  self.layer_id = layer_id
1135
- self.local_dp_size = get_local_attention_dp_size()
1136
- self.attn_tp_size = get_attention_tp_size()
1137
- self.attn_tp_rank = get_attention_tp_rank()
1138
1176
  self.self_attn = DeepseekV2AttentionMLA(
1139
1177
  config=config,
1140
1178
  hidden_size=self.hidden_size,
@@ -1156,19 +1194,25 @@ class DeepseekV2DecoderLayer(nn.Module):
1156
1194
  alt_stream=alt_stream,
1157
1195
  )
1158
1196
 
1159
- self.info = self._compute_info(config, layer_id=layer_id, is_nextn=is_nextn)
1160
- previous_layer_info = self._compute_info(
1161
- config, layer_id=layer_id - 1, is_nextn=False
1197
+ self.is_layer_sparse = self._is_layer_sparse(layer_id, is_nextn=is_nextn)
1198
+ is_previous_layer_sparse = self._is_layer_sparse(layer_id - 1, is_nextn=False)
1199
+
1200
+ self.layer_scatter_modes = LayerScatterModes.init_new(
1201
+ layer_id=layer_id,
1202
+ num_layers=config.num_hidden_layers,
1203
+ is_layer_sparse=self.is_layer_sparse,
1204
+ is_previous_layer_sparse=is_previous_layer_sparse,
1162
1205
  )
1163
1206
 
1164
- if self.info.is_sparse:
1207
+ if self.is_layer_sparse:
1165
1208
  self.mlp = DeepseekV2MoE(
1166
1209
  config=config,
1167
1210
  quant_config=quant_config,
1168
1211
  prefix=add_prefix("mlp", prefix),
1212
+ layer_id=self.layer_id,
1169
1213
  )
1170
1214
  else:
1171
- if self._enable_moe_dense_fully_dp():
1215
+ if enable_moe_dense_fully_dp():
1172
1216
  mlp_tp_rank, mlp_tp_size = 0, 1
1173
1217
  else:
1174
1218
  mlp_tp_rank, mlp_tp_size = None, None
@@ -1182,35 +1226,23 @@ class DeepseekV2DecoderLayer(nn.Module):
1182
1226
  tp_size=mlp_tp_size,
1183
1227
  )
1184
1228
 
1185
- self.input_is_scattered = (
1186
- layer_id > 0
1187
- and previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
1188
- )
1189
- self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
1190
-
1191
1229
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1192
1230
  self.post_attention_layernorm = RMSNorm(
1193
1231
  config.hidden_size, eps=config.rms_norm_eps
1194
1232
  )
1195
1233
 
1196
- @staticmethod
1197
- def _enable_moe_dense_fully_dp():
1198
- return global_server_args_dict["moe_dense_tp_size"] == 1
1199
-
1200
- @staticmethod
1201
- def _compute_info(config: PretrainedConfig, layer_id: int, is_nextn: bool):
1202
- is_sparse = is_nextn or (
1203
- config.n_routed_experts is not None
1204
- and layer_id >= config.first_k_dense_replace
1205
- and layer_id % config.moe_layer_freq == 0
1234
+ self.layer_communicator = LayerCommunicator(
1235
+ layer_scatter_modes=self.layer_scatter_modes,
1236
+ input_layernorm=self.input_layernorm,
1237
+ post_attention_layernorm=self.post_attention_layernorm,
1206
1238
  )
1207
- ffn_input_mode = (
1208
- _FFNInputMode.SCATTERED
1209
- if (global_server_args_dict["enable_deepep_moe"] and is_sparse)
1210
- or (DeepseekV2DecoderLayer._enable_moe_dense_fully_dp() and not is_sparse)
1211
- else _FFNInputMode.FULL
1239
+
1240
+ def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool:
1241
+ return is_nextn or (
1242
+ self.config.n_routed_experts is not None
1243
+ and layer_id >= self.config.first_k_dense_replace
1244
+ and layer_id % self.config.moe_layer_freq == 0
1212
1245
  )
1213
- return _DecoderLayerInfo(is_sparse=is_sparse, ffn_input_mode=ffn_input_mode)
1214
1246
 
1215
1247
  def forward(
1216
1248
  self,
@@ -1220,163 +1252,75 @@ class DeepseekV2DecoderLayer(nn.Module):
1220
1252
  residual: Optional[torch.Tensor],
1221
1253
  zero_allocator: BumpAllocator,
1222
1254
  ) -> torch.Tensor:
1223
- if self.info.ffn_input_mode == _FFNInputMode.SCATTERED:
1224
- return self.forward_ffn_with_scattered_input(
1225
- positions, hidden_states, forward_batch, residual, zero_allocator
1226
- )
1227
- elif self.info.ffn_input_mode == _FFNInputMode.FULL:
1228
- return self.forward_ffn_with_full_input(
1229
- positions, hidden_states, forward_batch, residual, zero_allocator
1230
- )
1231
- else:
1232
- raise NotImplementedError
1233
-
1234
- def forward_ffn_with_full_input(
1235
- self,
1236
- positions: torch.Tensor,
1237
- hidden_states: torch.Tensor,
1238
- forward_batch: ForwardBatch,
1239
- residual: Optional[torch.Tensor],
1240
- zero_allocator: BumpAllocator,
1241
- ) -> torch.Tensor:
1242
-
1243
- if hidden_states.shape[0] == 0:
1244
- residual = hidden_states
1245
- else:
1246
- if residual is None:
1247
- residual = hidden_states
1248
- hidden_states = self.input_layernorm(hidden_states)
1249
- else:
1250
- hidden_states, residual = self.input_layernorm(hidden_states, residual)
1251
-
1252
- assert not (
1253
- self.attn_tp_size != 1 and self.input_is_scattered
1254
- ), "moe_layer_freq > 1 is not supported when attn_tp_size > 1"
1255
-
1256
- # Self Attention
1257
- hidden_states = self.self_attn(
1255
+ return execute_operations(
1256
+ inputs=dict(
1258
1257
  positions=positions,
1259
1258
  hidden_states=hidden_states,
1260
1259
  forward_batch=forward_batch,
1260
+ residual=residual,
1261
1261
  zero_allocator=zero_allocator,
1262
- )
1263
-
1264
- # Gather
1265
- if get_tensor_model_parallel_world_size() > 1:
1266
- # all gather and all reduce
1267
- if self.local_dp_size != 1:
1268
- if self.attn_tp_rank == 0:
1269
- hidden_states += residual
1270
- hidden_states, local_hidden_states = (
1271
- forward_batch.gathered_buffer,
1272
- hidden_states,
1273
- )
1274
- dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
1275
- dp_scatter(residual, hidden_states, forward_batch)
1276
- hidden_states = self.post_attention_layernorm(hidden_states)
1277
- else:
1278
- hidden_states = tensor_model_parallel_all_reduce(hidden_states)
1279
- hidden_states, residual = self.post_attention_layernorm(
1280
- hidden_states, residual
1281
- )
1282
- else:
1283
- hidden_states, residual = self.post_attention_layernorm(
1284
- hidden_states, residual
1285
- )
1286
-
1287
- # Fully Connected
1288
- hidden_states = self.mlp(hidden_states)
1289
-
1290
- # TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
1291
- # Scatter
1292
- if self.local_dp_size != 1:
1293
- # important: forward batch.gathered_buffer is used both after scatter and after gather.
1294
- # be careful about this!
1295
- hidden_states, global_hidden_states = (
1296
- forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
1297
- hidden_states,
1298
- )
1299
- dp_scatter(hidden_states, global_hidden_states, forward_batch)
1300
-
1301
- return hidden_states, residual
1262
+ ),
1263
+ operations=compute_layer_operations(self),
1264
+ )
1302
1265
 
1303
- def forward_ffn_with_scattered_input(
1266
+ def op_comm_prepare_attn(
1304
1267
  self,
1268
+ state,
1305
1269
  positions: torch.Tensor,
1306
1270
  hidden_states: torch.Tensor,
1307
1271
  forward_batch: ForwardBatch,
1308
1272
  residual: Optional[torch.Tensor],
1309
1273
  zero_allocator: BumpAllocator,
1310
- ) -> torch.Tensor:
1311
-
1312
- if hidden_states.shape[0] == 0:
1313
- residual = hidden_states
1314
- else:
1315
- if residual is None:
1316
- residual = hidden_states
1317
- hidden_states = self.input_layernorm(hidden_states)
1318
- else:
1319
- hidden_states, residual = self.input_layernorm(hidden_states, residual)
1320
-
1321
- if self.attn_tp_size != 1 and self.input_is_scattered:
1322
- hidden_states, local_hidden_states = (
1323
- forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
1324
- hidden_states,
1325
- )
1326
- attn_tp_all_gather(
1327
- list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
1274
+ ):
1275
+ state.hidden_states_after_comm_pre_attn, state.residual_after_input_ln = (
1276
+ self.layer_communicator.prepare_attn(hidden_states, residual, forward_batch)
1277
+ )
1278
+ state.update(
1279
+ dict(
1280
+ forward_batch=forward_batch,
1281
+ positions=positions,
1282
+ zero_allocator=zero_allocator,
1328
1283
  )
1284
+ )
1329
1285
 
1330
- # Self Attention
1331
- hidden_states = self.self_attn(
1332
- positions=positions,
1333
- hidden_states=hidden_states,
1334
- forward_batch=forward_batch,
1335
- zero_allocator=zero_allocator,
1286
+ def op_attn(self, state):
1287
+ state.hidden_states_after_attn = self.self_attn(
1288
+ positions=state.positions,
1289
+ hidden_states=state.pop("hidden_states_after_comm_pre_attn"),
1290
+ forward_batch=state.forward_batch,
1291
+ zero_allocator=state.zero_allocator,
1336
1292
  )
1337
1293
 
1338
- if self.attn_tp_size != 1:
1339
- if self.input_is_scattered:
1340
- tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
1341
- hidden_states = tensor_list[self.attn_tp_rank]
1342
- attn_tp_reduce_scatter(hidden_states, tensor_list)
1343
- if hidden_states.shape[0] != 0:
1344
- hidden_states, residual = self.post_attention_layernorm(
1345
- hidden_states, residual
1346
- )
1347
- else:
1348
- if self.attn_tp_rank == 0:
1349
- hidden_states += residual
1350
- tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
1351
- hidden_states = tensor_list[self.attn_tp_rank]
1352
- attn_tp_reduce_scatter(hidden_states, tensor_list)
1353
- residual = hidden_states
1354
- if hidden_states.shape[0] != 0:
1355
- hidden_states = self.post_attention_layernorm(hidden_states)
1356
- else:
1357
- if hidden_states.shape[0] != 0:
1358
- hidden_states, residual = self.post_attention_layernorm(
1359
- hidden_states, residual
1360
- )
1294
+ def op_comm_prepare_mlp(self, state):
1295
+ state.hidden_states_mlp_input, state.residual_after_comm_pre_mlp = (
1296
+ self.layer_communicator.prepare_mlp(
1297
+ state.pop("hidden_states_after_attn"),
1298
+ state.pop("residual_after_input_ln"),
1299
+ state.forward_batch,
1300
+ )
1301
+ )
1361
1302
 
1303
+ def op_mlp(self, state):
1304
+ hidden_states = state.pop("hidden_states_mlp_input")
1362
1305
  if not (
1363
- self._enable_moe_dense_fully_dp()
1364
- and (not self.info.is_sparse)
1306
+ enable_moe_dense_fully_dp()
1307
+ and (not self.is_layer_sparse)
1365
1308
  and hidden_states.shape[0] == 0
1366
1309
  ):
1367
- hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
1368
-
1369
- if self.is_last_layer and self.attn_tp_size != 1:
1370
- hidden_states += residual
1371
- residual = None
1372
- hidden_states, local_hidden_states = (
1373
- forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
1374
- hidden_states,
1375
- )
1376
- attn_tp_all_gather(
1377
- list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
1310
+ state.hidden_states_mlp_output = self.mlp(
1311
+ hidden_states, state.forward_batch.forward_mode
1378
1312
  )
1313
+ else:
1314
+ state.hidden_states_mlp_output = hidden_states
1315
+
1316
+ def op_comm_postprocess_layer(self, state):
1317
+ hidden_states, residual = self.layer_communicator.postprocess_layer(
1318
+ state.pop("hidden_states_mlp_output"),
1319
+ state.pop("residual_after_comm_pre_mlp"),
1320
+ state.forward_batch,
1321
+ )
1379
1322
 
1323
+ state.clear(expect_keys={"positions", "forward_batch", "zero_allocator"})
1380
1324
  return hidden_states, residual
1381
1325
 
1382
1326
 
@@ -1398,7 +1342,7 @@ class DeepseekV2Model(nn.Module):
1398
1342
  config.hidden_size,
1399
1343
  enable_tp=not global_server_args_dict["enable_dp_attention"],
1400
1344
  )
1401
- self.alt_stream = torch.cuda.Stream()
1345
+ self.alt_stream = torch.cuda.Stream() if _is_cuda else None
1402
1346
  self.layers = nn.ModuleList(
1403
1347
  [
1404
1348
  DeepseekV2DecoderLayer(
@@ -1441,11 +1385,11 @@ class DeepseekV2Model(nn.Module):
1441
1385
 
1442
1386
  residual = None
1443
1387
  for i in range(len(self.layers)):
1444
- expert_distribution_recorder.set_current_layer(i)
1445
- layer = self.layers[i]
1446
- hidden_states, residual = layer(
1447
- positions, hidden_states, forward_batch, residual, zero_allocator
1448
- )
1388
+ with get_global_expert_distribution_recorder().with_current_layer(i):
1389
+ layer = self.layers[i]
1390
+ hidden_states, residual = layer(
1391
+ positions, hidden_states, forward_batch, residual, zero_allocator
1392
+ )
1449
1393
  if not forward_batch.forward_mode.is_idle():
1450
1394
  if residual is None:
1451
1395
  hidden_states = self.norm(hidden_states)
@@ -1662,6 +1606,14 @@ class DeepseekV2ForCausalLM(nn.Module):
1662
1606
  self_attn.w_vc = w_vc.contiguous()
1663
1607
  self_attn.use_deep_gemm_bmm = True
1664
1608
 
1609
+ # TODO support nextn later
1610
+ if not is_nextn:
1611
+ self.routed_experts_weights_of_layer = {
1612
+ layer_id: layer.mlp.get_moe_weights()
1613
+ for layer_id, layer in enumerate(self.model.layers)
1614
+ if isinstance(layer.mlp, DeepseekV2MoE)
1615
+ }
1616
+
1665
1617
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
1666
1618
  if is_nextn:
1667
1619
  if hasattr(self.config, "num_nextn_predict_layers"):
@@ -1738,12 +1690,7 @@ class DeepseekV2ForCausalLM(nn.Module):
1738
1690
 
1739
1691
  # Params for weights, fp8 weight scales, fp8 activation scales
1740
1692
  # (param_name, weight_name, expert_id, shard_id)
1741
- MoEImpl = (
1742
- DeepEPMoE
1743
- if global_server_args_dict["enable_deepep_moe"]
1744
- else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
1745
- )
1746
- expert_params_mapping = MoEImpl.make_expert_params_mapping(
1693
+ expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
1747
1694
  ckpt_gate_proj_name="gate_proj",
1748
1695
  ckpt_down_proj_name="down_proj",
1749
1696
  ckpt_up_proj_name="up_proj",
@@ -1859,7 +1806,6 @@ class DeepseekV2ForCausalLM(nn.Module):
1859
1806
  q_a_proj_name in cached_a_proj
1860
1807
  and kv_a_proj_name in cached_a_proj
1861
1808
  ):
1862
-
1863
1809
  q_a_proj_weight = cached_a_proj[q_a_proj_name]
1864
1810
  kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
1865
1811
  fused_weight = torch.cat(
@@ -1897,6 +1843,14 @@ class DeepseekV2ForCausalLM(nn.Module):
1897
1843
  torch.cuda.empty_cache()
1898
1844
  torch.cuda.synchronize()
1899
1845
 
1846
+ @classmethod
1847
+ def get_model_config_for_expert_location(cls, config):
1848
+ return ModelConfigForExpertLocation(
1849
+ num_layers=config.num_hidden_layers,
1850
+ num_logical_experts=config.n_routed_experts,
1851
+ num_groups=config.n_group,
1852
+ )
1853
+
1900
1854
 
1901
1855
  class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
1902
1856
  pass