sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_offline_throughput.py +10 -8
  2. sglang/bench_one_batch.py +7 -6
  3. sglang/bench_one_batch_server.py +157 -21
  4. sglang/bench_serving.py +137 -59
  5. sglang/compile_deep_gemm.py +5 -5
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +78 -78
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +2 -2
  11. sglang/srt/configs/model_config.py +40 -28
  12. sglang/srt/constrained/base_grammar_backend.py +55 -72
  13. sglang/srt/constrained/llguidance_backend.py +25 -21
  14. sglang/srt/constrained/outlines_backend.py +27 -26
  15. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  16. sglang/srt/constrained/xgrammar_backend.py +69 -43
  17. sglang/srt/conversation.py +49 -44
  18. sglang/srt/disaggregation/base/conn.py +1 -0
  19. sglang/srt/disaggregation/decode.py +129 -135
  20. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
  21. sglang/srt/disaggregation/fake/conn.py +3 -13
  22. sglang/srt/disaggregation/kv_events.py +357 -0
  23. sglang/srt/disaggregation/mini_lb.py +57 -24
  24. sglang/srt/disaggregation/mooncake/conn.py +238 -122
  25. sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
  26. sglang/srt/disaggregation/nixl/conn.py +10 -19
  27. sglang/srt/disaggregation/prefill.py +132 -47
  28. sglang/srt/disaggregation/utils.py +123 -6
  29. sglang/srt/distributed/utils.py +3 -3
  30. sglang/srt/entrypoints/EngineBase.py +5 -0
  31. sglang/srt/entrypoints/engine.py +44 -9
  32. sglang/srt/entrypoints/http_server.py +23 -6
  33. sglang/srt/entrypoints/http_server_engine.py +5 -2
  34. sglang/srt/function_call/base_format_detector.py +250 -0
  35. sglang/srt/function_call/core_types.py +34 -0
  36. sglang/srt/function_call/deepseekv3_detector.py +157 -0
  37. sglang/srt/function_call/ebnf_composer.py +234 -0
  38. sglang/srt/function_call/function_call_parser.py +175 -0
  39. sglang/srt/function_call/llama32_detector.py +74 -0
  40. sglang/srt/function_call/mistral_detector.py +84 -0
  41. sglang/srt/function_call/pythonic_detector.py +163 -0
  42. sglang/srt/function_call/qwen25_detector.py +67 -0
  43. sglang/srt/function_call/utils.py +35 -0
  44. sglang/srt/hf_transformers_utils.py +46 -7
  45. sglang/srt/layers/attention/aiter_backend.py +513 -0
  46. sglang/srt/layers/attention/flashattention_backend.py +64 -18
  47. sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
  48. sglang/srt/layers/attention/flashmla_backend.py +340 -78
  49. sglang/srt/layers/attention/triton_backend.py +3 -0
  50. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  51. sglang/srt/layers/attention/utils.py +6 -4
  52. sglang/srt/layers/attention/vision.py +1 -1
  53. sglang/srt/layers/communicator.py +451 -0
  54. sglang/srt/layers/dp_attention.py +61 -21
  55. sglang/srt/layers/layernorm.py +1 -1
  56. sglang/srt/layers/logits_processor.py +46 -11
  57. sglang/srt/layers/moe/cutlass_moe.py +207 -0
  58. sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
  59. sglang/srt/layers/moe/ep_moe/layer.py +105 -51
  60. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
  63. sglang/srt/layers/moe/topk.py +67 -10
  64. sglang/srt/layers/multimodal.py +70 -0
  65. sglang/srt/layers/quantization/__init__.py +8 -3
  66. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  67. sglang/srt/layers/quantization/deep_gemm.py +77 -74
  68. sglang/srt/layers/quantization/fp8.py +92 -2
  69. sglang/srt/layers/quantization/fp8_kernel.py +3 -3
  70. sglang/srt/layers/quantization/fp8_utils.py +6 -0
  71. sglang/srt/layers/quantization/gptq.py +298 -6
  72. sglang/srt/layers/quantization/int8_kernel.py +20 -7
  73. sglang/srt/layers/quantization/qoq.py +244 -0
  74. sglang/srt/layers/sampler.py +0 -4
  75. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  76. sglang/srt/lora/lora_manager.py +2 -4
  77. sglang/srt/lora/mem_pool.py +4 -4
  78. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  79. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  80. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  81. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  82. sglang/srt/lora/utils.py +1 -1
  83. sglang/srt/managers/data_parallel_controller.py +3 -3
  84. sglang/srt/managers/deepseek_eplb.py +278 -0
  85. sglang/srt/managers/detokenizer_manager.py +21 -8
  86. sglang/srt/managers/eplb_manager.py +55 -0
  87. sglang/srt/managers/expert_distribution.py +704 -56
  88. sglang/srt/managers/expert_location.py +394 -0
  89. sglang/srt/managers/expert_location_dispatch.py +91 -0
  90. sglang/srt/managers/io_struct.py +19 -4
  91. sglang/srt/managers/mm_utils.py +294 -140
  92. sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
  93. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
  94. sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
  95. sglang/srt/managers/multimodal_processors/internvl.py +14 -5
  96. sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
  97. sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
  98. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  99. sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
  100. sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
  101. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  102. sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
  103. sglang/srt/managers/schedule_batch.py +122 -42
  104. sglang/srt/managers/schedule_policy.py +1 -5
  105. sglang/srt/managers/scheduler.py +205 -138
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/tokenizer_manager.py +232 -58
  109. sglang/srt/managers/tp_worker.py +12 -9
  110. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  111. sglang/srt/mem_cache/base_prefix_cache.py +3 -0
  112. sglang/srt/mem_cache/chunk_cache.py +3 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +4 -4
  114. sglang/srt/mem_cache/memory_pool.py +76 -52
  115. sglang/srt/mem_cache/multimodal_cache.py +45 -0
  116. sglang/srt/mem_cache/radix_cache.py +58 -5
  117. sglang/srt/metrics/collector.py +314 -39
  118. sglang/srt/mm_utils.py +10 -0
  119. sglang/srt/model_executor/cuda_graph_runner.py +29 -19
  120. sglang/srt/model_executor/expert_location_updater.py +422 -0
  121. sglang/srt/model_executor/forward_batch_info.py +5 -1
  122. sglang/srt/model_executor/model_runner.py +163 -68
  123. sglang/srt/model_loader/loader.py +10 -6
  124. sglang/srt/models/clip.py +5 -1
  125. sglang/srt/models/deepseek_janus_pro.py +2 -2
  126. sglang/srt/models/deepseek_v2.py +308 -351
  127. sglang/srt/models/exaone.py +8 -3
  128. sglang/srt/models/gemma3_mm.py +70 -33
  129. sglang/srt/models/llama.py +2 -0
  130. sglang/srt/models/llama4.py +15 -8
  131. sglang/srt/models/llava.py +258 -7
  132. sglang/srt/models/mimo_mtp.py +220 -0
  133. sglang/srt/models/minicpmo.py +5 -12
  134. sglang/srt/models/mistral.py +71 -1
  135. sglang/srt/models/mixtral.py +98 -34
  136. sglang/srt/models/mllama.py +3 -3
  137. sglang/srt/models/pixtral.py +467 -0
  138. sglang/srt/models/qwen2.py +95 -26
  139. sglang/srt/models/qwen2_5_vl.py +8 -0
  140. sglang/srt/models/qwen2_moe.py +330 -60
  141. sglang/srt/models/qwen2_vl.py +6 -0
  142. sglang/srt/models/qwen3.py +52 -10
  143. sglang/srt/models/qwen3_moe.py +411 -48
  144. sglang/srt/models/roberta.py +1 -1
  145. sglang/srt/models/siglip.py +294 -0
  146. sglang/srt/models/torch_native_llama.py +1 -1
  147. sglang/srt/openai_api/adapter.py +58 -20
  148. sglang/srt/openai_api/protocol.py +6 -8
  149. sglang/srt/operations.py +154 -0
  150. sglang/srt/operations_strategy.py +31 -0
  151. sglang/srt/reasoning_parser.py +3 -3
  152. sglang/srt/sampling/custom_logit_processor.py +18 -3
  153. sglang/srt/sampling/sampling_batch_info.py +4 -56
  154. sglang/srt/sampling/sampling_params.py +2 -2
  155. sglang/srt/server_args.py +162 -22
  156. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  157. sglang/srt/speculative/eagle_utils.py +138 -7
  158. sglang/srt/speculative/eagle_worker.py +69 -21
  159. sglang/srt/utils.py +74 -17
  160. sglang/test/few_shot_gsm8k.py +2 -2
  161. sglang/test/few_shot_gsm8k_engine.py +2 -2
  162. sglang/test/run_eval.py +2 -2
  163. sglang/test/runners.py +8 -1
  164. sglang/test/send_one.py +13 -3
  165. sglang/test/simple_eval_common.py +1 -1
  166. sglang/test/simple_eval_humaneval.py +1 -1
  167. sglang/test/test_cutlass_moe.py +278 -0
  168. sglang/test/test_programs.py +5 -5
  169. sglang/test/test_utils.py +55 -14
  170. sglang/utils.py +3 -3
  171. sglang/version.py +1 -1
  172. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
  173. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
  174. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
  175. sglang/srt/function_call_parser.py +0 -858
  176. sglang/srt/platforms/interface.py +0 -371
  177. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  178. /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
  179. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -17,21 +17,36 @@
17
17
 
18
18
  """Inference-only Qwen3MoE model compatible with HuggingFace weights."""
19
19
 
20
+ import logging
21
+ from dataclasses import dataclass
22
+ from enum import Enum, auto
20
23
  from functools import partial
21
24
  from typing import Any, Dict, Iterable, Optional, Tuple
22
25
 
23
26
  import torch
24
27
  import torch.nn.functional as F
25
28
  from torch import nn
29
+ from transformers.configuration_utils import PretrainedConfig
26
30
 
27
31
  from sglang.srt.distributed import (
32
+ get_pp_group,
28
33
  get_tensor_model_parallel_rank,
29
34
  get_tensor_model_parallel_world_size,
35
+ parallel_state,
30
36
  split_tensor_along_last_dim,
31
37
  tensor_model_parallel_all_gather,
32
38
  tensor_model_parallel_all_reduce,
33
39
  )
34
40
  from sglang.srt.layers.activation import SiluAndMul
41
+ from sglang.srt.layers.dp_attention import (
42
+ attn_tp_all_gather,
43
+ attn_tp_reduce_scatter,
44
+ dp_gather_partial,
45
+ dp_scatter,
46
+ get_attention_tp_rank,
47
+ get_attention_tp_size,
48
+ get_local_attention_dp_size,
49
+ )
35
50
  from sglang.srt.layers.layernorm import RMSNorm
36
51
  from sglang.srt.layers.linear import (
37
52
  MergedColumnParallelLinear,
@@ -39,52 +54,69 @@ from sglang.srt.layers.linear import (
39
54
  ReplicatedLinear,
40
55
  RowParallelLinear,
41
56
  )
42
- from sglang.srt.layers.logits_processor import LogitsProcessor
43
- from sglang.srt.layers.moe.ep_moe.layer import EPMoE
57
+ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
58
+ from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
59
+ from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
44
60
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
61
+ from sglang.srt.layers.moe.topk import select_experts
45
62
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
46
63
  from sglang.srt.layers.radix_attention import RadixAttention
47
64
  from sglang.srt.layers.rotary_embedding import get_rope
65
+ from sglang.srt.layers.utils import get_layer_id
48
66
  from sglang.srt.layers.vocab_parallel_embedding import (
49
67
  ParallelLMHead,
50
68
  VocabParallelEmbedding,
51
69
  )
70
+ from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
71
+ from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
52
72
  from sglang.srt.managers.schedule_batch import global_server_args_dict
53
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
73
+ from sglang.srt.model_executor.forward_batch_info import (
74
+ ForwardBatch,
75
+ ForwardMode,
76
+ PPProxyTensors,
77
+ )
54
78
  from sglang.srt.model_loader.weight_utils import default_weight_loader
55
79
  from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
56
80
  from sglang.srt.models.qwen2_moe import Qwen2MoeModel
57
- from sglang.srt.utils import add_prefix
81
+ from sglang.srt.utils import DeepEPMode, add_prefix
58
82
 
59
83
  Qwen3MoeConfig = None
60
84
 
85
+ logger = logging.getLogger(__name__)
86
+
61
87
 
62
88
  class Qwen3MoeSparseMoeBlock(nn.Module):
63
89
  def __init__(
64
90
  self,
91
+ layer_id: int,
65
92
  config: Qwen3MoeConfig,
66
93
  quant_config: Optional[QuantizationConfig] = None,
67
94
  prefix: str = "",
68
95
  ):
69
96
  super().__init__()
70
97
  self.tp_size = get_tensor_model_parallel_world_size()
71
-
98
+ self.layer_id = layer_id
72
99
  if self.tp_size > config.num_experts:
73
100
  raise ValueError(
74
101
  f"Tensor parallel size {self.tp_size} is greater than "
75
102
  f"the number of experts {config.num_experts}."
76
103
  )
77
104
 
78
- MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
79
-
80
- self.experts = MoEImpl(
81
- num_experts=config.num_experts,
105
+ self.experts = get_moe_impl_class()(
106
+ num_experts=config.num_experts
107
+ + global_server_args_dict["ep_num_redundant_experts"],
82
108
  top_k=config.num_experts_per_tok,
109
+ layer_id=layer_id,
83
110
  hidden_size=config.hidden_size,
84
111
  intermediate_size=config.moe_intermediate_size,
85
112
  renormalize=config.norm_topk_prob,
86
113
  quant_config=quant_config,
87
114
  prefix=add_prefix("experts", prefix),
115
+ **(
116
+ dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
117
+ if global_server_args_dict["enable_deepep_moe"]
118
+ else {}
119
+ ),
88
120
  )
89
121
 
90
122
  self.gate = ReplicatedLinear(
@@ -95,7 +127,45 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
95
127
  prefix=add_prefix("gate", prefix),
96
128
  )
97
129
 
98
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
130
+ if global_server_args_dict["enable_deepep_moe"]:
131
+ # TODO: we will support tp < ep in the future
132
+ self.ep_size = get_tensor_model_parallel_world_size()
133
+ self.num_experts = (
134
+ config.num_experts + global_server_args_dict["ep_num_redundant_experts"]
135
+ )
136
+ self.top_k = config.num_experts_per_tok
137
+ self.renormalize = config.norm_topk_prob
138
+
139
+ self.deepep_dispatcher = DeepEPDispatcher(
140
+ group=parallel_state.get_tp_group().device_group,
141
+ router_topk=self.top_k,
142
+ permute_fusion=True,
143
+ num_experts=self.num_experts,
144
+ num_local_experts=config.num_experts // self.tp_size,
145
+ hidden_size=config.hidden_size,
146
+ params_dtype=config.torch_dtype,
147
+ deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
148
+ async_finish=True, # TODO
149
+ return_recv_hook=True,
150
+ )
151
+
152
+ def forward(
153
+ self, hidden_states: torch.Tensor, forward_mode: Optional[ForwardMode] = None
154
+ ) -> torch.Tensor:
155
+
156
+ if not global_server_args_dict["enable_deepep_moe"]:
157
+ return self.forward_normal(hidden_states)
158
+ else:
159
+ return self.forward_deepep(hidden_states, forward_mode)
160
+
161
+ def get_moe_weights(self):
162
+ return [
163
+ x.data
164
+ for name, x in self.experts.named_parameters()
165
+ if name not in ["correction_bias"]
166
+ ]
167
+
168
+ def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
99
169
  num_tokens, hidden_dim = hidden_states.shape
100
170
  hidden_states = hidden_states.view(-1, hidden_dim)
101
171
 
@@ -109,6 +179,71 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
109
179
 
110
180
  return final_hidden_states.view(num_tokens, hidden_dim)
111
181
 
182
+ def forward_deepep(
183
+ self, hidden_states: torch.Tensor, forward_mode: ForwardMode
184
+ ) -> torch.Tensor:
185
+ if (
186
+ forward_mode is not None
187
+ and not forward_mode.is_idle()
188
+ and hidden_states.shape[0] > 0
189
+ ):
190
+ # router_logits: (num_tokens, n_experts)
191
+ router_logits, _ = self.gate(hidden_states)
192
+
193
+ topk_weights, topk_idx = select_experts(
194
+ hidden_states=hidden_states,
195
+ router_logits=router_logits,
196
+ top_k=self.top_k,
197
+ use_grouped_topk=False,
198
+ renormalize=self.renormalize,
199
+ expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
200
+ layer_id=self.layer_id,
201
+ ),
202
+ )
203
+ else:
204
+ topk_idx = torch.full(
205
+ (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
206
+ )
207
+ topk_weights = torch.empty(
208
+ (0, self.top_k), dtype=torch.float32, device=hidden_states.device
209
+ )
210
+ if self.ep_size > 1:
211
+ # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
212
+ (
213
+ hidden_states,
214
+ topk_idx,
215
+ topk_weights,
216
+ reorder_topk_ids,
217
+ num_recv_tokens_per_expert,
218
+ seg_indptr,
219
+ masked_m,
220
+ expected_m,
221
+ ) = self.deepep_dispatcher.dispatch(
222
+ hidden_states,
223
+ topk_idx,
224
+ topk_weights,
225
+ forward_mode=forward_mode,
226
+ )
227
+ final_hidden_states = self.experts(
228
+ hidden_states=hidden_states,
229
+ topk_idx=topk_idx,
230
+ topk_weights=topk_weights,
231
+ reorder_topk_ids=reorder_topk_ids,
232
+ seg_indptr=seg_indptr,
233
+ masked_m=masked_m,
234
+ expected_m=expected_m,
235
+ num_recv_tokens_per_expert=num_recv_tokens_per_expert,
236
+ forward_mode=forward_mode,
237
+ )
238
+ if self.ep_size > 1:
239
+ final_hidden_states = self.deepep_dispatcher.combine(
240
+ final_hidden_states,
241
+ topk_idx,
242
+ topk_weights,
243
+ forward_mode,
244
+ )
245
+ return final_hidden_states
246
+
112
247
 
113
248
  class Qwen3MoeAttention(nn.Module):
114
249
  def __init__(
@@ -128,20 +263,23 @@ class Qwen3MoeAttention(nn.Module):
128
263
  ) -> None:
129
264
  super().__init__()
130
265
  self.hidden_size = hidden_size
131
- self.tp_size = get_tensor_model_parallel_world_size()
266
+
267
+ attn_tp_rank = get_attention_tp_rank()
268
+ attn_tp_size = get_attention_tp_size()
269
+
132
270
  self.total_num_heads = num_heads
133
- assert self.total_num_heads % self.tp_size == 0
134
- self.num_heads = self.total_num_heads // self.tp_size
271
+ assert self.total_num_heads % attn_tp_size == 0
272
+ self.num_heads = self.total_num_heads // attn_tp_size
135
273
  self.total_num_kv_heads = num_kv_heads
136
- if self.total_num_kv_heads >= self.tp_size:
274
+ if self.total_num_kv_heads >= attn_tp_size:
137
275
  # Number of KV heads is greater than TP size, so we partition
138
276
  # the KV heads across multiple tensor parallel GPUs.
139
- assert self.total_num_kv_heads % self.tp_size == 0
277
+ assert self.total_num_kv_heads % attn_tp_size == 0
140
278
  else:
141
279
  # Number of KV heads is less than TP size, so we replicate
142
280
  # the KV heads across multiple tensor parallel GPUs.
143
- assert self.tp_size % self.total_num_kv_heads == 0
144
- self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
281
+ assert attn_tp_size % self.total_num_kv_heads == 0
282
+ self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size)
145
283
  self.head_dim = head_dim or hidden_size // self.total_num_heads
146
284
  self.q_size = self.num_heads * self.head_dim
147
285
  self.kv_size = self.num_kv_heads * self.head_dim
@@ -157,6 +295,8 @@ class Qwen3MoeAttention(nn.Module):
157
295
  self.total_num_kv_heads,
158
296
  bias=attention_bias,
159
297
  quant_config=quant_config,
298
+ tp_rank=attn_tp_rank,
299
+ tp_size=attn_tp_size,
160
300
  prefix=add_prefix("qkv_proj", prefix),
161
301
  )
162
302
 
@@ -165,6 +305,9 @@ class Qwen3MoeAttention(nn.Module):
165
305
  hidden_size,
166
306
  bias=attention_bias,
167
307
  quant_config=quant_config,
308
+ tp_rank=attn_tp_rank,
309
+ tp_size=attn_tp_size,
310
+ reduce_results=False,
168
311
  prefix=add_prefix("o_proj", prefix),
169
312
  )
170
313
 
@@ -213,6 +356,19 @@ class Qwen3MoeAttention(nn.Module):
213
356
  return output
214
357
 
215
358
 
359
+ class _FFNInputMode(Enum):
360
+ # The MLP sublayer requires 1/tp_size tokens as input
361
+ SCATTERED = auto()
362
+ # The MLP sublayer requires all tokens as input
363
+ FULL = auto()
364
+
365
+
366
+ @dataclass
367
+ class _DecoderLayerInfo:
368
+ is_sparse: bool
369
+ ffn_input_mode: _FFNInputMode
370
+
371
+
216
372
  class Qwen3MoeDecoderLayer(nn.Module):
217
373
  def __init__(
218
374
  self,
@@ -246,15 +402,23 @@ class Qwen3MoeDecoderLayer(nn.Module):
246
402
  prefix=add_prefix("self_attn", prefix),
247
403
  )
248
404
 
249
- # Note: Qwen/Qwen2-57B-A14B-Instruct does not have
250
- # `mlp_only_layers` in the config.
251
- mlp_only_layers = (
252
- [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers
405
+ self.layer_id = layer_id
406
+
407
+ self.attn_tp_size = get_attention_tp_size()
408
+ self.attn_tp_rank = get_attention_tp_rank()
409
+ self.local_dp_size = get_local_attention_dp_size()
410
+
411
+ self.info = self._compute_info(config, layer_id=layer_id)
412
+ previous_layer_info = self._compute_info(config, layer_id=layer_id - 1)
413
+ self.input_is_scattered = (
414
+ layer_id > 0
415
+ and previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
253
416
  )
254
- if (layer_id not in mlp_only_layers) and (
255
- config.num_experts > 0 and (layer_id + 1) % config.decoder_sparse_step == 0
256
- ):
417
+ self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
418
+
419
+ if self.info.is_sparse:
257
420
  self.mlp = Qwen3MoeSparseMoeBlock(
421
+ layer_id=self.layer_id,
258
422
  config=config,
259
423
  quant_config=quant_config,
260
424
  prefix=add_prefix("mlp", prefix),
@@ -272,28 +436,182 @@ class Qwen3MoeDecoderLayer(nn.Module):
272
436
  config.hidden_size, eps=config.rms_norm_eps
273
437
  )
274
438
 
439
+ @staticmethod
440
+ def _enable_moe_dense_fully_dp():
441
+ return global_server_args_dict["moe_dense_tp_size"] == 1
442
+
443
+ @staticmethod
444
+ def _compute_info(config: PretrainedConfig, layer_id: int):
445
+ # WARN: Qwen3MOE has no dense_layer, it is only for compatibility.
446
+ mlp_only_layers = (
447
+ [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers
448
+ )
449
+ is_sparse = (layer_id not in mlp_only_layers) and (
450
+ config.num_experts > 0 and (layer_id + 1) % config.decoder_sparse_step == 0
451
+ )
452
+ ffn_input_mode = (
453
+ _FFNInputMode.SCATTERED
454
+ if (global_server_args_dict["enable_deepep_moe"] and is_sparse)
455
+ or (Qwen3MoeDecoderLayer._enable_moe_dense_fully_dp() and not is_sparse)
456
+ else _FFNInputMode.FULL
457
+ )
458
+ return _DecoderLayerInfo(is_sparse=is_sparse, ffn_input_mode=ffn_input_mode)
459
+
275
460
  def forward(
276
461
  self,
277
462
  positions: torch.Tensor,
278
463
  hidden_states: torch.Tensor,
279
464
  forward_batch: ForwardBatch,
280
465
  residual: Optional[torch.Tensor],
281
- ) -> torch.Tensor:
282
- # Self Attention
283
- if residual is None:
466
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
467
+ if self.info.ffn_input_mode == _FFNInputMode.SCATTERED:
468
+ return self.forward_ffn_with_scattered_input(
469
+ positions, hidden_states, forward_batch, residual
470
+ )
471
+ elif self.info.ffn_input_mode == _FFNInputMode.FULL:
472
+ return self.forward_ffn_with_full_input(
473
+ positions, hidden_states, forward_batch, residual
474
+ )
475
+ else:
476
+ raise NotImplementedError
477
+
478
+ def forward_ffn_with_full_input(
479
+ self,
480
+ positions: torch.Tensor,
481
+ hidden_states: torch.Tensor,
482
+ forward_batch: ForwardBatch,
483
+ residual: Optional[torch.Tensor],
484
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
485
+ if hidden_states.shape[0] == 0:
284
486
  residual = hidden_states
285
- hidden_states = self.input_layernorm(hidden_states)
286
487
  else:
287
- hidden_states, residual = self.input_layernorm(hidden_states, residual)
288
- hidden_states = self.self_attn(
289
- positions=positions,
290
- hidden_states=hidden_states,
291
- forward_batch=forward_batch,
292
- )
488
+ if residual is None:
489
+ residual = hidden_states
490
+ hidden_states = self.input_layernorm(hidden_states)
491
+ else:
492
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
493
+
494
+ # Self Attention
495
+ hidden_states = self.self_attn(
496
+ positions=positions,
497
+ hidden_states=hidden_states,
498
+ forward_batch=forward_batch,
499
+ )
500
+ # Gather
501
+ if get_tensor_model_parallel_world_size() > 1:
502
+ if self.local_dp_size != 1:
503
+ if self.attn_tp_rank == 0:
504
+ hidden_states += residual
505
+ hidden_states, local_hidden_states = (
506
+ forward_batch.gathered_buffer,
507
+ hidden_states,
508
+ )
509
+ dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
510
+ dp_scatter(residual, hidden_states, forward_batch)
511
+ hidden_states = self.post_attention_layernorm(hidden_states)
512
+ else:
513
+ hidden_states = tensor_model_parallel_all_reduce(hidden_states)
514
+ # TODO extract this bugfix
515
+ if hidden_states.shape[0] != 0:
516
+ hidden_states, residual = self.post_attention_layernorm(
517
+ hidden_states, residual
518
+ )
519
+ elif hidden_states.shape[0] != 0:
520
+ hidden_states, residual = self.post_attention_layernorm(
521
+ hidden_states, residual
522
+ )
293
523
 
294
524
  # Fully Connected
295
- hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
296
- hidden_states = self.mlp(hidden_states)
525
+ hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
526
+
527
+ # TODO: use reduce-scatter in MLP to avoid this scatter
528
+ # Scatter
529
+ if self.local_dp_size != 1:
530
+ # important: forward batch.gathered_buffer is used both after scatter and after gather.
531
+ # be careful about this!
532
+ hidden_states, global_hidden_states = (
533
+ forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
534
+ hidden_states,
535
+ )
536
+ dp_scatter(hidden_states, global_hidden_states, forward_batch)
537
+
538
+ return hidden_states, residual
539
+
540
+ def forward_ffn_with_scattered_input(
541
+ self,
542
+ positions: torch.Tensor,
543
+ hidden_states: torch.Tensor,
544
+ forward_batch: ForwardBatch,
545
+ residual: Optional[torch.Tensor],
546
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
547
+ if hidden_states.shape[0] == 0:
548
+ residual = hidden_states
549
+ else:
550
+ if residual is None:
551
+ residual = hidden_states
552
+ hidden_states = self.input_layernorm(hidden_states)
553
+ else:
554
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
555
+
556
+ if self.attn_tp_size != 1 and self.input_is_scattered:
557
+ hidden_states, local_hidden_states = (
558
+ forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
559
+ hidden_states,
560
+ )
561
+ attn_tp_all_gather(
562
+ list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
563
+ )
564
+
565
+ # Self Attention
566
+ if hidden_states.shape[0] != 0:
567
+ hidden_states = self.self_attn(
568
+ positions=positions,
569
+ hidden_states=hidden_states,
570
+ forward_batch=forward_batch,
571
+ )
572
+
573
+ if self.attn_tp_size != 1:
574
+ if self.input_is_scattered:
575
+ tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
576
+ hidden_states = tensor_list[self.attn_tp_rank]
577
+ attn_tp_reduce_scatter(hidden_states, tensor_list)
578
+ if hidden_states.shape[0] != 0:
579
+ hidden_states, residual = self.post_attention_layernorm(
580
+ hidden_states, residual
581
+ )
582
+ else:
583
+ if self.attn_tp_rank == 0:
584
+ hidden_states += residual
585
+ tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
586
+ hidden_states = tensor_list[self.attn_tp_rank]
587
+ attn_tp_reduce_scatter(hidden_states, tensor_list)
588
+ residual = hidden_states
589
+ if hidden_states.shape[0] != 0:
590
+ hidden_states = self.post_attention_layernorm(hidden_states)
591
+ else:
592
+ if hidden_states.shape[0] != 0:
593
+ hidden_states, residual = self.post_attention_layernorm(
594
+ hidden_states, residual
595
+ )
596
+
597
+ if not (
598
+ self._enable_moe_dense_fully_dp()
599
+ and (not self.info.is_sparse)
600
+ and hidden_states.shape[0] == 0
601
+ ):
602
+ hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
603
+
604
+ if self.is_last_layer and self.attn_tp_size != 1:
605
+ hidden_states += residual
606
+ residual = None
607
+ hidden_states, local_hidden_states = (
608
+ forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
609
+ hidden_states,
610
+ )
611
+ attn_tp_all_gather(
612
+ list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
613
+ )
614
+
297
615
  return hidden_states, residual
298
616
 
299
617
 
@@ -313,7 +631,6 @@ class Qwen3MoeModel(Qwen2MoeModel):
313
631
 
314
632
 
315
633
  class Qwen3MoeForCausalLM(nn.Module):
316
-
317
634
  fall_back_to_pt_during_load = False
318
635
 
319
636
  def __init__(
@@ -323,6 +640,7 @@ class Qwen3MoeForCausalLM(nn.Module):
323
640
  prefix: str = "",
324
641
  ) -> None:
325
642
  super().__init__()
643
+ self.pp_group = get_pp_group()
326
644
  self.config = config
327
645
  self.quant_config = quant_config
328
646
  self.model = Qwen3MoeModel(
@@ -343,12 +661,31 @@ class Qwen3MoeForCausalLM(nn.Module):
343
661
  positions: torch.Tensor,
344
662
  forward_batch: ForwardBatch,
345
663
  input_embeds: torch.Tensor = None,
664
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
346
665
  ) -> torch.Tensor:
347
- hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
348
- return self.logits_processor(
349
- input_ids, hidden_states, self.lm_head, forward_batch
666
+ hidden_states = self.model(
667
+ input_ids,
668
+ positions,
669
+ forward_batch,
670
+ input_embeds,
671
+ pp_proxy_tensors=pp_proxy_tensors,
350
672
  )
351
673
 
674
+ if self.pp_group.is_last_rank:
675
+ return self.logits_processor(
676
+ input_ids, hidden_states, self.lm_head, forward_batch
677
+ )
678
+ else:
679
+ return hidden_states
680
+
681
+ @property
682
+ def start_layer(self):
683
+ return self.model.start_layer
684
+
685
+ @property
686
+ def end_layer(self):
687
+ return self.model.end_layer
688
+
352
689
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
353
690
  stacked_params_mapping = [
354
691
  # (param_name, shard_name, shard_id)
@@ -359,9 +696,7 @@ class Qwen3MoeForCausalLM(nn.Module):
359
696
  ("gate_up_proj", "up_proj", 1),
360
697
  ]
361
698
 
362
- MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
363
-
364
- expert_params_mapping = MoEImpl.make_expert_params_mapping(
699
+ expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
365
700
  ckpt_gate_proj_name="gate_proj",
366
701
  ckpt_down_proj_name="down_proj",
367
702
  ckpt_up_proj_name="up_proj",
@@ -370,6 +705,17 @@ class Qwen3MoeForCausalLM(nn.Module):
370
705
 
371
706
  params_dict = dict(self.named_parameters())
372
707
  for name, loaded_weight in weights:
708
+ layer_id = get_layer_id(name)
709
+ if (
710
+ layer_id is not None
711
+ and hasattr(self.model, "start_layer")
712
+ and (
713
+ layer_id < self.model.start_layer
714
+ or layer_id >= self.model.end_layer
715
+ )
716
+ ):
717
+ continue
718
+
373
719
  if "rotary_emb.inv_freq" in name:
374
720
  continue
375
721
  for param_name, weight_name, shard_id in stacked_params_mapping:
@@ -418,11 +764,28 @@ class Qwen3MoeForCausalLM(nn.Module):
418
764
  if name not in params_dict:
419
765
  continue
420
766
 
421
- param = params_dict[name]
422
- weight_loader = getattr(
423
- param, "weight_loader", default_weight_loader
424
- )
425
- weight_loader(param, loaded_weight)
767
+ if name in params_dict.keys():
768
+ param = params_dict[name]
769
+ weight_loader = getattr(
770
+ param, "weight_loader", default_weight_loader
771
+ )
772
+ weight_loader(param, loaded_weight)
773
+ else:
774
+ logger.warning(f"Parameter {name} not found in params_dict")
775
+
776
+ self.routed_experts_weights_of_layer = {
777
+ layer_id: layer.mlp.get_moe_weights()
778
+ for layer_id, layer in enumerate(self.model.layers)
779
+ if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock)
780
+ }
781
+
782
+ @classmethod
783
+ def get_model_config_for_expert_location(cls, config):
784
+ return ModelConfigForExpertLocation(
785
+ num_layers=config.num_hidden_layers,
786
+ num_logical_experts=config.num_experts,
787
+ num_groups=None,
788
+ )
426
789
 
427
790
 
428
791
  EntryClass = Qwen3MoeForCausalLM
@@ -57,7 +57,7 @@ class RobertaEmbedding(nn.Module):
57
57
  input_shape = input_ids.size()
58
58
  inputs_embeds = self.word_embeddings(input_ids)
59
59
 
60
- # adpated from vllm: https://github.com/vllm-project/vllm/commit/4a18fd14ba4a349291c798a16bf62fa8a9af0b6b/vllm/model_executor/models/roberta.py
60
+ # Adapted from vllm: https://github.com/vllm-project/vllm/commit/4a18fd14ba4a349291c798a16bf62fa8a9af0b6b/vllm/model_executor/models/roberta.py
61
61
 
62
62
  pos_list = []
63
63
  token_list = []