sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +26 -4
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +676 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +49 -8
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  34. sglang/srt/distributed/parallel_state.py +42 -8
  35. sglang/srt/entrypoints/engine.py +55 -5
  36. sglang/srt/entrypoints/http_server.py +78 -13
  37. sglang/srt/entrypoints/verl_engine.py +2 -0
  38. sglang/srt/function_call_parser.py +133 -55
  39. sglang/srt/hf_transformers_utils.py +28 -3
  40. sglang/srt/layers/activation.py +4 -2
  41. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  42. sglang/srt/layers/attention/flashattention_backend.py +434 -0
  43. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  44. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  45. sglang/srt/layers/attention/triton_backend.py +171 -38
  46. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  47. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  48. sglang/srt/layers/attention/utils.py +53 -0
  49. sglang/srt/layers/attention/vision.py +9 -28
  50. sglang/srt/layers/dp_attention.py +41 -19
  51. sglang/srt/layers/layernorm.py +24 -2
  52. sglang/srt/layers/linear.py +17 -5
  53. sglang/srt/layers/logits_processor.py +25 -7
  54. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  55. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  56. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  57. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  63. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  64. sglang/srt/layers/moe/topk.py +60 -20
  65. sglang/srt/layers/parameter.py +1 -1
  66. sglang/srt/layers/quantization/__init__.py +80 -53
  67. sglang/srt/layers/quantization/awq.py +200 -0
  68. sglang/srt/layers/quantization/base_config.py +5 -0
  69. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  70. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  71. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  72. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  73. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  74. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  75. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  76. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  77. sglang/srt/layers/quantization/fp8.py +76 -34
  78. sglang/srt/layers/quantization/fp8_kernel.py +25 -8
  79. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  80. sglang/srt/layers/quantization/gptq.py +36 -19
  81. sglang/srt/layers/quantization/kv_cache.py +98 -0
  82. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  83. sglang/srt/layers/quantization/utils.py +153 -0
  84. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  85. sglang/srt/layers/rotary_embedding.py +78 -87
  86. sglang/srt/layers/sampler.py +1 -1
  87. sglang/srt/lora/backend/base_backend.py +4 -4
  88. sglang/srt/lora/backend/flashinfer_backend.py +12 -9
  89. sglang/srt/lora/backend/triton_backend.py +5 -8
  90. sglang/srt/lora/layers.py +87 -33
  91. sglang/srt/lora/lora.py +2 -22
  92. sglang/srt/lora/lora_manager.py +67 -30
  93. sglang/srt/lora/mem_pool.py +117 -52
  94. sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
  95. sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
  96. sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
  97. sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
  98. sglang/srt/lora/utils.py +18 -1
  99. sglang/srt/managers/cache_controller.py +2 -5
  100. sglang/srt/managers/data_parallel_controller.py +30 -8
  101. sglang/srt/managers/expert_distribution.py +81 -0
  102. sglang/srt/managers/io_struct.py +43 -5
  103. sglang/srt/managers/mm_utils.py +373 -0
  104. sglang/srt/managers/multimodal_processor.py +68 -0
  105. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  106. sglang/srt/managers/multimodal_processors/clip.py +63 -0
  107. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  108. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  109. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  110. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  111. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  112. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  113. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  114. sglang/srt/managers/schedule_batch.py +134 -30
  115. sglang/srt/managers/scheduler.py +290 -31
  116. sglang/srt/managers/session_controller.py +1 -1
  117. sglang/srt/managers/tokenizer_manager.py +59 -24
  118. sglang/srt/managers/tp_worker.py +4 -1
  119. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  120. sglang/srt/managers/utils.py +6 -1
  121. sglang/srt/mem_cache/hiradix_cache.py +18 -7
  122. sglang/srt/mem_cache/memory_pool.py +255 -98
  123. sglang/srt/mem_cache/paged_allocator.py +2 -2
  124. sglang/srt/mem_cache/radix_cache.py +4 -4
  125. sglang/srt/model_executor/cuda_graph_runner.py +36 -21
  126. sglang/srt/model_executor/forward_batch_info.py +68 -11
  127. sglang/srt/model_executor/model_runner.py +75 -8
  128. sglang/srt/model_loader/loader.py +171 -3
  129. sglang/srt/model_loader/weight_utils.py +51 -3
  130. sglang/srt/models/clip.py +563 -0
  131. sglang/srt/models/deepseek_janus_pro.py +31 -88
  132. sglang/srt/models/deepseek_nextn.py +22 -10
  133. sglang/srt/models/deepseek_v2.py +329 -73
  134. sglang/srt/models/deepseek_vl2.py +358 -0
  135. sglang/srt/models/gemma3_causal.py +694 -0
  136. sglang/srt/models/gemma3_mm.py +468 -0
  137. sglang/srt/models/llama.py +47 -7
  138. sglang/srt/models/llama_eagle.py +1 -0
  139. sglang/srt/models/llama_eagle3.py +196 -0
  140. sglang/srt/models/llava.py +3 -3
  141. sglang/srt/models/llavavid.py +3 -3
  142. sglang/srt/models/minicpmo.py +1995 -0
  143. sglang/srt/models/minicpmv.py +62 -137
  144. sglang/srt/models/mllama.py +4 -4
  145. sglang/srt/models/phi3_small.py +1 -1
  146. sglang/srt/models/qwen2.py +3 -0
  147. sglang/srt/models/qwen2_5_vl.py +68 -146
  148. sglang/srt/models/qwen2_classification.py +75 -0
  149. sglang/srt/models/qwen2_moe.py +9 -1
  150. sglang/srt/models/qwen2_vl.py +25 -63
  151. sglang/srt/openai_api/adapter.py +201 -104
  152. sglang/srt/openai_api/protocol.py +33 -7
  153. sglang/srt/patch_torch.py +71 -0
  154. sglang/srt/sampling/sampling_batch_info.py +1 -1
  155. sglang/srt/sampling/sampling_params.py +6 -6
  156. sglang/srt/server_args.py +114 -14
  157. sglang/srt/speculative/build_eagle_tree.py +7 -347
  158. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  159. sglang/srt/speculative/eagle_utils.py +208 -252
  160. sglang/srt/speculative/eagle_worker.py +140 -54
  161. sglang/srt/speculative/spec_info.py +6 -1
  162. sglang/srt/torch_memory_saver_adapter.py +22 -0
  163. sglang/srt/utils.py +215 -21
  164. sglang/test/__init__.py +0 -0
  165. sglang/test/attention/__init__.py +0 -0
  166. sglang/test/attention/test_flashattn_backend.py +312 -0
  167. sglang/test/runners.py +29 -2
  168. sglang/test/test_activation.py +2 -1
  169. sglang/test/test_block_fp8.py +5 -4
  170. sglang/test/test_block_fp8_ep.py +2 -1
  171. sglang/test/test_dynamic_grad_mode.py +58 -0
  172. sglang/test/test_layernorm.py +3 -2
  173. sglang/test/test_utils.py +56 -5
  174. sglang/utils.py +31 -0
  175. sglang/version.py +1 -1
  176. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +16 -8
  177. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +180 -132
  178. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +1 -1
  179. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  180. sglang/srt/managers/image_processor.py +0 -55
  181. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  182. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  183. sglang/srt/managers/multi_modality_padding.py +0 -134
  184. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info/licenses}/LICENSE +0 -0
  185. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
@@ -23,10 +23,10 @@ import torch
23
23
  import torch.nn.functional as F
24
24
  from torch import nn
25
25
  from transformers import PretrainedConfig
26
- from vllm import _custom_ops as ops
27
26
 
28
27
  from sglang.srt.distributed import (
29
28
  get_tensor_model_parallel_world_size,
29
+ parallel_state,
30
30
  tensor_model_parallel_all_reduce,
31
31
  )
32
32
  from sglang.srt.layers.activation import SiluAndMul
@@ -34,11 +34,13 @@ from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
34
34
  decode_attention_fwd_grouped_rope,
35
35
  )
36
36
  from sglang.srt.layers.dp_attention import (
37
- dp_gather,
37
+ dp_gather_partial,
38
38
  dp_scatter,
39
39
  get_attention_dp_size,
40
40
  get_attention_tp_rank,
41
41
  get_attention_tp_size,
42
+ tp_all_gather,
43
+ tp_reduce_scatter,
42
44
  )
43
45
  from sglang.srt.layers.layernorm import RMSNorm
44
46
  from sglang.srt.layers.linear import (
@@ -48,8 +50,10 @@ from sglang.srt.layers.linear import (
48
50
  RowParallelLinear,
49
51
  )
50
52
  from sglang.srt.layers.logits_processor import LogitsProcessor
51
- from sglang.srt.layers.moe.ep_moe.layer import EPMoE
53
+ from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE
54
+ from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
52
55
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
56
+ from sglang.srt.layers.moe.topk import select_experts
53
57
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
54
58
  from sglang.srt.layers.quantization.fp8_utils import (
55
59
  block_quant_to_tensor_quant,
@@ -65,15 +69,21 @@ from sglang.srt.layers.vocab_parallel_embedding import (
65
69
  ParallelLMHead,
66
70
  VocabParallelEmbedding,
67
71
  )
72
+ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
68
73
  from sglang.srt.managers.schedule_batch import global_server_args_dict
69
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
74
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
70
75
  from sglang.srt.model_loader.weight_utils import default_weight_loader
71
- from sglang.srt.utils import add_prefix, is_cuda_available, is_hip
76
+ from sglang.srt.utils import add_prefix, is_cuda, is_hip
72
77
 
73
78
  _is_hip = is_hip()
79
+ _is_cuda = is_cuda()
74
80
 
75
- if is_cuda_available():
76
- from sgl_kernel import bmm_fp8
81
+ if _is_cuda:
82
+ from sgl_kernel import awq_dequantize, bmm_fp8
83
+ else:
84
+ from vllm import _custom_ops as ops
85
+
86
+ expert_distribution_recorder = ExpertDistributionRecorder()
77
87
 
78
88
 
79
89
  class DeepseekV2MLP(nn.Module):
@@ -85,6 +95,8 @@ class DeepseekV2MLP(nn.Module):
85
95
  quant_config: Optional[QuantizationConfig] = None,
86
96
  reduce_results: bool = True,
87
97
  prefix: str = "",
98
+ tp_rank: Optional[int] = None,
99
+ tp_size: Optional[int] = None,
88
100
  ) -> None:
89
101
  super().__init__()
90
102
  self.gate_up_proj = MergedColumnParallelLinear(
@@ -93,6 +105,8 @@ class DeepseekV2MLP(nn.Module):
93
105
  bias=False,
94
106
  quant_config=quant_config,
95
107
  prefix=add_prefix("gate_up_proj", prefix),
108
+ tp_rank=tp_rank,
109
+ tp_size=tp_size,
96
110
  )
97
111
  self.down_proj = RowParallelLinear(
98
112
  intermediate_size,
@@ -101,6 +115,8 @@ class DeepseekV2MLP(nn.Module):
101
115
  quant_config=quant_config,
102
116
  reduce_results=reduce_results,
103
117
  prefix=add_prefix("down_proj", prefix),
118
+ tp_rank=tp_rank,
119
+ tp_size=tp_size,
104
120
  )
105
121
  if hidden_act != "silu":
106
122
  raise ValueError(
@@ -165,7 +181,11 @@ class DeepseekV2MoE(nn.Module):
165
181
 
166
182
  self.gate = MoEGate(config=config, prefix=add_prefix("gate", prefix))
167
183
 
168
- MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
184
+ MoEImpl = (
185
+ DeepEPMoE
186
+ if global_server_args_dict["enable_deepep_moe"]
187
+ else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
188
+ )
169
189
  self.experts = MoEImpl(
170
190
  num_experts=config.n_routed_experts,
171
191
  top_k=config.num_experts_per_tok,
@@ -182,18 +202,60 @@ class DeepseekV2MoE(nn.Module):
182
202
 
183
203
  if config.n_shared_experts is not None:
184
204
  intermediate_size = config.moe_intermediate_size * config.n_shared_experts
185
- self.shared_experts = DeepseekV2MLP(
205
+ # disable tp for shared experts when enable deepep moe
206
+ if not global_server_args_dict["enable_deepep_moe"]:
207
+ self.shared_experts = DeepseekV2MLP(
208
+ hidden_size=config.hidden_size,
209
+ intermediate_size=intermediate_size,
210
+ hidden_act=config.hidden_act,
211
+ quant_config=quant_config,
212
+ reduce_results=False,
213
+ prefix=add_prefix("shared_experts", prefix),
214
+ )
215
+ else:
216
+ self.shared_experts = DeepseekV2MLP(
217
+ hidden_size=config.hidden_size,
218
+ intermediate_size=intermediate_size,
219
+ hidden_act=config.hidden_act,
220
+ quant_config=quant_config,
221
+ reduce_results=False,
222
+ prefix=add_prefix("shared_experts", prefix),
223
+ tp_rank=0,
224
+ tp_size=1,
225
+ )
226
+
227
+ if global_server_args_dict["enable_deepep_moe"]:
228
+ self.num_experts = config.n_routed_experts
229
+ self.top_k = config.num_experts_per_tok
230
+ self.renormalize = config.norm_topk_prob
231
+ self.topk_group = config.topk_group
232
+ self.num_expert_group = config.n_group
233
+ self.correction_bias = (
234
+ self.gate.e_score_correction_bias.data
235
+ if self.gate.e_score_correction_bias is not None
236
+ else None
237
+ )
238
+
239
+ self.deepep_dispatcher = DeepEPDispatcher(
240
+ group=parallel_state.get_tp_group().device_group,
241
+ router_topk=self.top_k,
242
+ permute_fusion=True,
243
+ num_experts=config.n_routed_experts,
244
+ num_local_experts=config.n_routed_experts // self.tp_size,
186
245
  hidden_size=config.hidden_size,
187
- intermediate_size=intermediate_size,
188
- hidden_act=config.hidden_act,
189
- quant_config=quant_config,
190
- reduce_results=False,
191
- prefix=add_prefix("shared_experts", prefix),
246
+ params_dtype=config.torch_dtype,
247
+ async_finish=True, # TODO
192
248
  )
193
249
 
194
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
195
- num_tokens, hidden_dim = hidden_states.shape
196
- hidden_states = hidden_states.view(-1, hidden_dim)
250
+ def forward(
251
+ self, hidden_states: torch.Tensor, forward_mode: Optional[ForwardMode] = None
252
+ ) -> torch.Tensor:
253
+ if not global_server_args_dict["enable_deepep_moe"]:
254
+ return self.forward_normal(hidden_states)
255
+ else:
256
+ return self.forward_deepep(hidden_states, forward_mode)
257
+
258
+ def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
197
259
  if self.n_shared_experts is not None:
198
260
  shared_output = self.shared_experts(hidden_states)
199
261
  # router_logits: (num_tokens, n_experts)
@@ -206,8 +268,64 @@ class DeepseekV2MoE(nn.Module):
206
268
  final_hidden_states = final_hidden_states + shared_output
207
269
  if self.tp_size > 1:
208
270
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
271
+ return final_hidden_states
272
+
273
+ def forward_deepep(
274
+ self, hidden_states: torch.Tensor, forward_mode: ForwardMode
275
+ ) -> torch.Tensor:
276
+ shared_output = None
277
+ topk_idx = torch.full(
278
+ (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
279
+ )
280
+ topk_weights = torch.empty(
281
+ (0, self.top_k), dtype=torch.float32, device=hidden_states.device
282
+ )
283
+ if (
284
+ forward_mode is not None
285
+ and not forward_mode.is_idle()
286
+ and hidden_states.shape[0] > 0
287
+ ):
288
+ # router_logits: (num_tokens, n_experts)
289
+ router_logits = self.gate(hidden_states)
290
+ if self.n_shared_experts is not None:
291
+ shared_output = self.shared_experts(hidden_states)
292
+ topk_weights, topk_idx = select_experts(
293
+ hidden_states=hidden_states,
294
+ router_logits=router_logits,
295
+ top_k=self.top_k,
296
+ use_grouped_topk=True,
297
+ renormalize=self.renormalize,
298
+ topk_group=self.topk_group,
299
+ num_expert_group=self.num_expert_group,
300
+ correction_bias=self.correction_bias,
301
+ )
302
+ if self.tp_size > 1:
303
+ recv_hidden_states, reorder_topk_ids, seg_indptr = (
304
+ self.deepep_dispatcher.dispatch(
305
+ hidden_states,
306
+ topk_idx,
307
+ topk_weights,
308
+ self.num_experts,
309
+ forward_mode,
310
+ )
311
+ )
312
+ final_hidden_states = (
313
+ self.experts(
314
+ hidden_states=recv_hidden_states,
315
+ reorder_topk_ids=reorder_topk_ids,
316
+ seg_indptr=seg_indptr,
317
+ forward_mode=forward_mode,
318
+ )
319
+ * self.routed_scaling_factor
320
+ )
321
+ if self.tp_size > 1:
322
+ final_hidden_states = self.deepep_dispatcher.combine(
323
+ final_hidden_states, forward_mode
324
+ )
325
+ if shared_output is not None:
326
+ final_hidden_states = final_hidden_states + shared_output
209
327
 
210
- return final_hidden_states.view(num_tokens, hidden_dim)
328
+ return final_hidden_states
211
329
 
212
330
 
213
331
  def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
@@ -537,6 +655,7 @@ class DeepseekV2AttentionMLA(nn.Module):
537
655
  self.flashinfer_mla_disable_ragged = global_server_args_dict[
538
656
  "flashinfer_mla_disable_ragged"
539
657
  ]
658
+ self.attention_backend = global_server_args_dict["attention_backend"]
540
659
  self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
541
660
 
542
661
  def no_absorb(self, forward_batch: ForwardBatch) -> bool:
@@ -547,15 +666,18 @@ class DeepseekV2AttentionMLA(nn.Module):
547
666
  and forward_batch.forward_mode.is_extend()
548
667
  and not forward_batch.forward_mode.is_target_verify()
549
668
  and not forward_batch.forward_mode.is_draft_extend()
550
- and forward_batch.extend_prefix_lens.sum() == 0
669
+ and sum(forward_batch.extend_prefix_lens_cpu) == 0
551
670
  )
671
+ elif self.attention_backend == "fa3":
672
+ # Flash Attention: Keep absorbing for all extend/decode
673
+ return False
552
674
  else:
553
675
  # Triton: Use normal computation for prefill and use weight absorption for extend/decode
554
676
  return (
555
677
  forward_batch.forward_mode.is_extend()
556
678
  and not forward_batch.forward_mode.is_target_verify()
557
679
  and not forward_batch.forward_mode.is_draft_extend()
558
- and forward_batch.extend_prefix_lens.sum() == 0
680
+ and sum(forward_batch.extend_prefix_lens_cpu) == 0
559
681
  )
560
682
 
561
683
  def forward(
@@ -857,6 +979,14 @@ class DeepseekV2DecoderLayer(nn.Module):
857
979
  is_nextn: bool = False,
858
980
  prefix: str = "",
859
981
  ) -> None:
982
+
983
+ def is_sparse_layer(l: int):
984
+ return (
985
+ config.n_routed_experts is not None
986
+ and l >= config.first_k_dense_replace
987
+ and l % config.moe_layer_freq == 0
988
+ )
989
+
860
990
  super().__init__()
861
991
  self.hidden_size = config.hidden_size
862
992
  rope_theta = getattr(config, "rope_theta", 10000)
@@ -865,6 +995,8 @@ class DeepseekV2DecoderLayer(nn.Module):
865
995
  self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
866
996
  self.layer_id = layer_id
867
997
  self.dp_size = get_attention_dp_size()
998
+ self.attn_tp_size = get_attention_tp_size()
999
+ self.attn_tp_rank = get_attention_tp_rank()
868
1000
 
869
1001
  if not global_server_args_dict["disable_mla"]:
870
1002
  self.self_attn = DeepseekV2AttentionMLA(
@@ -907,16 +1039,13 @@ class DeepseekV2DecoderLayer(nn.Module):
907
1039
  prefix=add_prefix("self_attn", prefix),
908
1040
  )
909
1041
 
910
- if is_nextn or (
911
- config.n_routed_experts is not None
912
- and layer_id >= config.first_k_dense_replace
913
- and layer_id % config.moe_layer_freq == 0
914
- ):
1042
+ if is_nextn or is_sparse_layer(layer_id):
915
1043
  self.mlp = DeepseekV2MoE(
916
1044
  config=config,
917
1045
  quant_config=quant_config,
918
1046
  prefix=add_prefix("mlp", prefix),
919
1047
  )
1048
+ self.is_sparse = True
920
1049
  else:
921
1050
  self.mlp = DeepseekV2MLP(
922
1051
  hidden_size=config.hidden_size,
@@ -925,6 +1054,14 @@ class DeepseekV2DecoderLayer(nn.Module):
925
1054
  quant_config=quant_config,
926
1055
  prefix=add_prefix("mlp", prefix),
927
1056
  )
1057
+ self.is_sparse = False
1058
+
1059
+ self.input_is_scattered = (
1060
+ is_sparse_layer(layer_id - 1)
1061
+ and global_server_args_dict["enable_deepep_moe"]
1062
+ )
1063
+ self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
1064
+
928
1065
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
929
1066
  self.post_attention_layernorm = RMSNorm(
930
1067
  config.hidden_size, eps=config.rms_norm_eps
@@ -937,12 +1074,82 @@ class DeepseekV2DecoderLayer(nn.Module):
937
1074
  forward_batch: ForwardBatch,
938
1075
  residual: Optional[torch.Tensor],
939
1076
  ) -> torch.Tensor:
940
- if residual is None:
1077
+ if global_server_args_dict["enable_deepep_moe"] and self.is_sparse:
1078
+ return self.forward_deepep(
1079
+ positions, hidden_states, forward_batch, residual
1080
+ )
1081
+ else:
1082
+ return self.forward_normal(
1083
+ positions, hidden_states, forward_batch, residual
1084
+ )
1085
+
1086
+ def forward_normal(
1087
+ self,
1088
+ positions: torch.Tensor,
1089
+ hidden_states: torch.Tensor,
1090
+ forward_batch: ForwardBatch,
1091
+ residual: Optional[torch.Tensor],
1092
+ ) -> torch.Tensor:
1093
+
1094
+ if hidden_states.shape[0] == 0:
941
1095
  residual = hidden_states
942
- hidden_states = self.input_layernorm(hidden_states)
943
1096
  else:
944
- hidden_states, residual = self.input_layernorm(hidden_states, residual)
1097
+ if residual is None:
1098
+ residual = hidden_states
1099
+ hidden_states = self.input_layernorm(hidden_states)
1100
+ else:
1101
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
1102
+
1103
+ # Self Attention
1104
+ hidden_states = self.self_attn(
1105
+ positions=positions,
1106
+ hidden_states=hidden_states,
1107
+ forward_batch=forward_batch,
1108
+ )
1109
+
1110
+ if self.attn_tp_size != 1 and self.input_is_scattered:
1111
+ hidden_states, local_hidden_states = (
1112
+ forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
1113
+ hidden_states,
1114
+ )
1115
+ tp_all_gather(
1116
+ list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
1117
+ )
1118
+ residual, local_residual = (
1119
+ forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
1120
+ residual,
1121
+ )
1122
+ tp_all_gather(
1123
+ list(residual.tensor_split(self.attn_tp_size)), local_residual
1124
+ )
1125
+
1126
+ # Gather
1127
+ if get_tensor_model_parallel_world_size() > 1:
1128
+ # all gather and all reduce
1129
+ if self.dp_size != 1:
1130
+ if self.attn_tp_rank == 0:
1131
+ hidden_states += residual
1132
+ hidden_states, local_hidden_states = (
1133
+ forward_batch.gathered_buffer,
1134
+ hidden_states,
1135
+ )
1136
+ dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
1137
+ dp_scatter(residual, hidden_states, forward_batch)
1138
+ hidden_states = self.post_attention_layernorm(hidden_states)
1139
+ else:
1140
+ hidden_states = tensor_model_parallel_all_reduce(hidden_states)
1141
+ hidden_states, residual = self.post_attention_layernorm(
1142
+ hidden_states, residual
1143
+ )
1144
+ else:
1145
+ hidden_states, residual = self.post_attention_layernorm(
1146
+ hidden_states, residual
1147
+ )
945
1148
 
1149
+ # Fully Connected
1150
+ hidden_states = self.mlp(hidden_states)
1151
+
1152
+ # TODO(ch-wan): ues reduce-scatter in MLP to avoid this scatter
946
1153
  # Scatter
947
1154
  if self.dp_size != 1:
948
1155
  # important: forward batch.gathered_buffer is used both after scatter and after gather.
@@ -953,6 +1160,34 @@ class DeepseekV2DecoderLayer(nn.Module):
953
1160
  )
954
1161
  dp_scatter(hidden_states, global_hidden_states, forward_batch)
955
1162
 
1163
+ return hidden_states, residual
1164
+
1165
+ def forward_deepep(
1166
+ self,
1167
+ positions: torch.Tensor,
1168
+ hidden_states: torch.Tensor,
1169
+ forward_batch: ForwardBatch,
1170
+ residual: Optional[torch.Tensor],
1171
+ ) -> torch.Tensor:
1172
+
1173
+ if hidden_states.shape[0] == 0:
1174
+ residual = hidden_states
1175
+ else:
1176
+ if residual is None:
1177
+ residual = hidden_states
1178
+ hidden_states = self.input_layernorm(hidden_states)
1179
+ else:
1180
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
1181
+
1182
+ if self.attn_tp_size != 1 and self.input_is_scattered:
1183
+ hidden_states, local_hidden_states = (
1184
+ forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
1185
+ hidden_states,
1186
+ )
1187
+ tp_all_gather(
1188
+ list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
1189
+ )
1190
+
956
1191
  # Self Attention
957
1192
  hidden_states = self.self_attn(
958
1193
  positions=positions,
@@ -960,24 +1195,47 @@ class DeepseekV2DecoderLayer(nn.Module):
960
1195
  forward_batch=forward_batch,
961
1196
  )
962
1197
 
963
- # Gather
964
- if get_tensor_model_parallel_world_size() > 1:
965
- # all gather and all reduce
966
- if self.dp_size != 1:
967
- hidden_states, local_hidden_states = (
968
- forward_batch.gathered_buffer,
969
- hidden_states,
970
- )
971
- dp_gather(
972
- hidden_states, local_hidden_states, forward_batch, self.layer_id
973
- )
1198
+ if self.attn_tp_size != 1:
1199
+ if self.input_is_scattered:
1200
+ tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
1201
+ hidden_states = tensor_list[self.attn_tp_rank]
1202
+ tp_reduce_scatter(hidden_states, tensor_list)
1203
+ if hidden_states.shape[0] != 0:
1204
+ hidden_states, residual = self.post_attention_layernorm(
1205
+ hidden_states, residual
1206
+ )
974
1207
  else:
975
- hidden_states = tensor_model_parallel_all_reduce(hidden_states)
1208
+ if self.attn_tp_rank == 0:
1209
+ hidden_states += residual
1210
+ tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
1211
+ hidden_states = tensor_list[self.attn_tp_rank]
1212
+ tp_reduce_scatter(hidden_states, tensor_list)
1213
+ residual = hidden_states
1214
+ if hidden_states.shape[0] != 0:
1215
+ hidden_states = self.post_attention_layernorm(hidden_states)
1216
+ else:
1217
+ if hidden_states.shape[0] != 0:
1218
+ hidden_states, residual = self.post_attention_layernorm(
1219
+ hidden_states, residual
1220
+ )
1221
+ hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
976
1222
 
977
- hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
1223
+ if self.is_last_layer and self.attn_tp_size != 1:
1224
+ hidden_states, local_hidden_states = (
1225
+ forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
1226
+ hidden_states,
1227
+ )
1228
+ tp_all_gather(
1229
+ list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
1230
+ )
1231
+ residual, local_residual = (
1232
+ forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
1233
+ residual,
1234
+ )
1235
+ tp_all_gather(
1236
+ list(residual.tensor_split(self.attn_tp_size)), local_residual
1237
+ )
978
1238
 
979
- # Fully Connected
980
- hidden_states = self.mlp(hidden_states)
981
1239
  return hidden_states, residual
982
1240
 
983
1241
 
@@ -1020,23 +1278,17 @@ class DeepseekV2Model(nn.Module):
1020
1278
  input_ids: torch.Tensor,
1021
1279
  positions: torch.Tensor,
1022
1280
  forward_batch: ForwardBatch,
1281
+ input_embeds: torch.Tensor = None,
1023
1282
  ) -> torch.Tensor:
1024
1283
 
1025
- # Gather
1026
- if self.dp_size != 1:
1027
- input_ids, local_input_ids = (
1028
- torch.empty(
1029
- (forward_batch.gathered_buffer.shape[0],),
1030
- dtype=input_ids.dtype,
1031
- device=input_ids.device,
1032
- ),
1033
- input_ids,
1034
- )
1035
- dp_gather(input_ids, local_input_ids, forward_batch, "embedding")
1284
+ if input_embeds is None:
1285
+ hidden_states = self.embed_tokens(input_ids)
1286
+ else:
1287
+ hidden_states = input_embeds
1036
1288
 
1037
- hidden_states = self.embed_tokens(input_ids)
1038
1289
  residual = None
1039
1290
  for i in range(len(self.layers)):
1291
+ expert_distribution_recorder.set_current_layer(i)
1040
1292
  layer = self.layers[i]
1041
1293
  hidden_states, residual = layer(
1042
1294
  positions, hidden_states, forward_batch, residual
@@ -1075,17 +1327,10 @@ class DeepseekV2ForCausalLM(nn.Module):
1075
1327
  input_ids: torch.Tensor,
1076
1328
  positions: torch.Tensor,
1077
1329
  forward_batch: ForwardBatch,
1330
+ input_embeds: torch.Tensor = None,
1078
1331
  ) -> torch.Tensor:
1079
- hidden_states = self.model(input_ids, positions, forward_batch)
1080
1332
 
1081
- if self.dp_size != 1:
1082
- # important: forward batch.gathered_buffer is used both after scatter and after gather.
1083
- # be careful about this!
1084
- hidden_states, global_hidden_states = (
1085
- forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
1086
- hidden_states,
1087
- )
1088
- dp_scatter(hidden_states, global_hidden_states, forward_batch)
1333
+ hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
1089
1334
 
1090
1335
  return self.logits_processor(
1091
1336
  input_ids, hidden_states, self.lm_head, forward_batch
@@ -1100,7 +1345,11 @@ class DeepseekV2ForCausalLM(nn.Module):
1100
1345
 
1101
1346
  # Params for weights, fp8 weight scales, fp8 activation scales
1102
1347
  # (param_name, weight_name, expert_id, shard_id)
1103
- MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
1348
+ MoEImpl = (
1349
+ DeepEPMoE
1350
+ if global_server_args_dict["enable_deepep_moe"]
1351
+ else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
1352
+ )
1104
1353
  expert_params_mapping = MoEImpl.make_expert_params_mapping(
1105
1354
  ckpt_gate_proj_name="gate_proj",
1106
1355
  ckpt_down_proj_name="down_proj",
@@ -1174,14 +1423,21 @@ class DeepseekV2ForCausalLM(nn.Module):
1174
1423
  self_attn = self.model.layers[layer_id].self_attn
1175
1424
  if hasattr(self_attn.kv_b_proj, "qweight"):
1176
1425
  # AWQ compatible
1177
- w = ops.awq_dequantize(
1178
- self_attn.kv_b_proj.qweight,
1179
- self_attn.kv_b_proj.scales,
1180
- self_attn.kv_b_proj.qzeros,
1181
- 0,
1182
- 0,
1183
- 0,
1184
- ).T
1426
+ if _is_cuda:
1427
+ w = awq_dequantize(
1428
+ self_attn.kv_b_proj.qweight,
1429
+ self_attn.kv_b_proj.scales,
1430
+ self_attn.kv_b_proj.qzeros,
1431
+ ).T
1432
+ else:
1433
+ w = ops.awq_dequantize(
1434
+ self_attn.kv_b_proj.qweight,
1435
+ self_attn.kv_b_proj.scales,
1436
+ self_attn.kv_b_proj.qzeros,
1437
+ 0,
1438
+ 0,
1439
+ 0,
1440
+ ).T
1185
1441
  else:
1186
1442
  w = self_attn.kv_b_proj.weight
1187
1443
  # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.