sglang 0.4.4__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (176) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +3 -1
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +667 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +63 -11
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/parallel_state.py +10 -3
  34. sglang/srt/entrypoints/engine.py +55 -5
  35. sglang/srt/entrypoints/http_server.py +71 -12
  36. sglang/srt/function_call_parser.py +164 -54
  37. sglang/srt/hf_transformers_utils.py +28 -3
  38. sglang/srt/layers/activation.py +4 -2
  39. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +295 -0
  41. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  42. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  43. sglang/srt/layers/attention/triton_backend.py +171 -38
  44. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  45. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  46. sglang/srt/layers/attention/utils.py +53 -0
  47. sglang/srt/layers/attention/vision.py +9 -28
  48. sglang/srt/layers/dp_attention.py +62 -23
  49. sglang/srt/layers/elementwise.py +411 -0
  50. sglang/srt/layers/layernorm.py +24 -2
  51. sglang/srt/layers/linear.py +17 -5
  52. sglang/srt/layers/logits_processor.py +26 -7
  53. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  54. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  55. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  56. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  63. sglang/srt/layers/moe/router.py +342 -0
  64. sglang/srt/layers/moe/topk.py +31 -18
  65. sglang/srt/layers/parameter.py +1 -1
  66. sglang/srt/layers/quantization/__init__.py +184 -126
  67. sglang/srt/layers/quantization/base_config.py +5 -0
  68. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  69. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  70. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  71. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  72. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  73. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  74. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  75. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  76. sglang/srt/layers/quantization/fp8.py +76 -34
  77. sglang/srt/layers/quantization/fp8_kernel.py +24 -8
  78. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  79. sglang/srt/layers/quantization/gptq.py +36 -9
  80. sglang/srt/layers/quantization/kv_cache.py +98 -0
  81. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  82. sglang/srt/layers/quantization/utils.py +153 -0
  83. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  84. sglang/srt/layers/rotary_embedding.py +66 -87
  85. sglang/srt/layers/sampler.py +1 -1
  86. sglang/srt/lora/layers.py +68 -0
  87. sglang/srt/lora/lora.py +2 -22
  88. sglang/srt/lora/lora_manager.py +47 -23
  89. sglang/srt/lora/mem_pool.py +110 -51
  90. sglang/srt/lora/utils.py +12 -1
  91. sglang/srt/managers/cache_controller.py +4 -5
  92. sglang/srt/managers/data_parallel_controller.py +31 -9
  93. sglang/srt/managers/expert_distribution.py +81 -0
  94. sglang/srt/managers/io_struct.py +39 -3
  95. sglang/srt/managers/mm_utils.py +373 -0
  96. sglang/srt/managers/multimodal_processor.py +68 -0
  97. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  98. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  99. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  100. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  101. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  102. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  103. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  104. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  105. sglang/srt/managers/schedule_batch.py +134 -31
  106. sglang/srt/managers/scheduler.py +325 -38
  107. sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
  108. sglang/srt/managers/session_controller.py +1 -1
  109. sglang/srt/managers/tokenizer_manager.py +59 -23
  110. sglang/srt/managers/tp_worker.py +1 -1
  111. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  112. sglang/srt/managers/utils.py +6 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +27 -8
  114. sglang/srt/mem_cache/memory_pool.py +258 -98
  115. sglang/srt/mem_cache/paged_allocator.py +2 -2
  116. sglang/srt/mem_cache/radix_cache.py +4 -4
  117. sglang/srt/model_executor/cuda_graph_runner.py +85 -28
  118. sglang/srt/model_executor/forward_batch_info.py +81 -15
  119. sglang/srt/model_executor/model_runner.py +70 -6
  120. sglang/srt/model_loader/loader.py +160 -2
  121. sglang/srt/model_loader/weight_utils.py +45 -0
  122. sglang/srt/models/deepseek_janus_pro.py +29 -86
  123. sglang/srt/models/deepseek_nextn.py +22 -10
  124. sglang/srt/models/deepseek_v2.py +326 -192
  125. sglang/srt/models/deepseek_vl2.py +358 -0
  126. sglang/srt/models/gemma3_causal.py +684 -0
  127. sglang/srt/models/gemma3_mm.py +462 -0
  128. sglang/srt/models/grok.py +374 -119
  129. sglang/srt/models/llama.py +47 -7
  130. sglang/srt/models/llama_eagle.py +1 -0
  131. sglang/srt/models/llama_eagle3.py +196 -0
  132. sglang/srt/models/llava.py +3 -3
  133. sglang/srt/models/llavavid.py +3 -3
  134. sglang/srt/models/minicpmo.py +1995 -0
  135. sglang/srt/models/minicpmv.py +62 -137
  136. sglang/srt/models/mllama.py +4 -4
  137. sglang/srt/models/phi3_small.py +1 -1
  138. sglang/srt/models/qwen2.py +3 -0
  139. sglang/srt/models/qwen2_5_vl.py +68 -146
  140. sglang/srt/models/qwen2_classification.py +75 -0
  141. sglang/srt/models/qwen2_moe.py +9 -1
  142. sglang/srt/models/qwen2_vl.py +25 -63
  143. sglang/srt/openai_api/adapter.py +145 -47
  144. sglang/srt/openai_api/protocol.py +23 -2
  145. sglang/srt/sampling/sampling_batch_info.py +1 -1
  146. sglang/srt/sampling/sampling_params.py +6 -6
  147. sglang/srt/server_args.py +104 -14
  148. sglang/srt/speculative/build_eagle_tree.py +7 -347
  149. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  150. sglang/srt/speculative/eagle_utils.py +208 -252
  151. sglang/srt/speculative/eagle_worker.py +139 -53
  152. sglang/srt/speculative/spec_info.py +6 -1
  153. sglang/srt/torch_memory_saver_adapter.py +22 -0
  154. sglang/srt/utils.py +182 -21
  155. sglang/test/__init__.py +0 -0
  156. sglang/test/attention/__init__.py +0 -0
  157. sglang/test/attention/test_flashattn_backend.py +312 -0
  158. sglang/test/runners.py +2 -0
  159. sglang/test/test_activation.py +2 -1
  160. sglang/test/test_block_fp8.py +5 -4
  161. sglang/test/test_block_fp8_ep.py +2 -1
  162. sglang/test/test_dynamic_grad_mode.py +58 -0
  163. sglang/test/test_layernorm.py +3 -2
  164. sglang/test/test_utils.py +55 -4
  165. sglang/utils.py +31 -0
  166. sglang/version.py +1 -1
  167. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
  168. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +171 -125
  169. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
  170. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  171. sglang/srt/managers/image_processor.py +0 -55
  172. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  173. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  174. sglang/srt/managers/multi_modality_padding.py +0 -134
  175. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
  176. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -23,18 +23,23 @@ import torch
23
23
  import torch.nn.functional as F
24
24
  from torch import nn
25
25
  from transformers import PretrainedConfig
26
- from vllm import _custom_ops as ops
27
26
 
28
27
  from sglang.srt.distributed import (
29
- get_tensor_model_parallel_rank,
30
28
  get_tensor_model_parallel_world_size,
31
- get_tp_group,
29
+ parallel_state,
32
30
  tensor_model_parallel_all_reduce,
33
31
  )
34
32
  from sglang.srt.layers.activation import SiluAndMul
35
33
  from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
36
34
  decode_attention_fwd_grouped_rope,
37
35
  )
36
+ from sglang.srt.layers.dp_attention import (
37
+ dp_gather_partial,
38
+ dp_scatter,
39
+ get_attention_dp_size,
40
+ get_attention_tp_rank,
41
+ get_attention_tp_size,
42
+ )
38
43
  from sglang.srt.layers.layernorm import RMSNorm
39
44
  from sglang.srt.layers.linear import (
40
45
  ColumnParallelLinear,
@@ -43,8 +48,10 @@ from sglang.srt.layers.linear import (
43
48
  RowParallelLinear,
44
49
  )
45
50
  from sglang.srt.layers.logits_processor import LogitsProcessor
46
- from sglang.srt.layers.moe.ep_moe.layer import EPMoE
51
+ from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE
52
+ from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
47
53
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
54
+ from sglang.srt.layers.moe.topk import select_experts
48
55
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
49
56
  from sglang.srt.layers.quantization.fp8_utils import (
50
57
  block_quant_to_tensor_quant,
@@ -60,15 +67,21 @@ from sglang.srt.layers.vocab_parallel_embedding import (
60
67
  ParallelLMHead,
61
68
  VocabParallelEmbedding,
62
69
  )
70
+ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
63
71
  from sglang.srt.managers.schedule_batch import global_server_args_dict
64
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
72
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
65
73
  from sglang.srt.model_loader.weight_utils import default_weight_loader
66
- from sglang.srt.utils import add_prefix, is_cuda_available, is_hip
74
+ from sglang.srt.utils import add_prefix, is_cuda, is_cuda_available, is_hip
67
75
 
68
76
  _is_hip = is_hip()
77
+ _is_cuda = is_cuda()
78
+
79
+ if _is_cuda:
80
+ from sgl_kernel import awq_dequantize, bmm_fp8
81
+ else:
82
+ from vllm import _custom_ops as ops
69
83
 
70
- if is_cuda_available():
71
- from sgl_kernel import bmm_fp8
84
+ expert_distribution_recorder = ExpertDistributionRecorder()
72
85
 
73
86
 
74
87
  class DeepseekV2MLP(nn.Module):
@@ -80,6 +93,8 @@ class DeepseekV2MLP(nn.Module):
80
93
  quant_config: Optional[QuantizationConfig] = None,
81
94
  reduce_results: bool = True,
82
95
  prefix: str = "",
96
+ tp_rank: Optional[int] = None,
97
+ tp_size: Optional[int] = None,
83
98
  ) -> None:
84
99
  super().__init__()
85
100
  self.gate_up_proj = MergedColumnParallelLinear(
@@ -88,6 +103,8 @@ class DeepseekV2MLP(nn.Module):
88
103
  bias=False,
89
104
  quant_config=quant_config,
90
105
  prefix=add_prefix("gate_up_proj", prefix),
106
+ tp_rank=tp_rank,
107
+ tp_size=tp_size,
91
108
  )
92
109
  self.down_proj = RowParallelLinear(
93
110
  intermediate_size,
@@ -96,6 +113,8 @@ class DeepseekV2MLP(nn.Module):
96
113
  quant_config=quant_config,
97
114
  reduce_results=reduce_results,
98
115
  prefix=add_prefix("down_proj", prefix),
116
+ tp_rank=tp_rank,
117
+ tp_size=tp_size,
99
118
  )
100
119
  if hidden_act != "silu":
101
120
  raise ValueError(
@@ -160,7 +179,11 @@ class DeepseekV2MoE(nn.Module):
160
179
 
161
180
  self.gate = MoEGate(config=config, prefix=add_prefix("gate", prefix))
162
181
 
163
- MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
182
+ MoEImpl = (
183
+ DeepEPMoE
184
+ if global_server_args_dict["enable_deepep_moe"]
185
+ else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
186
+ )
164
187
  self.experts = MoEImpl(
165
188
  num_experts=config.n_routed_experts,
166
189
  top_k=config.num_experts_per_tok,
@@ -177,18 +200,60 @@ class DeepseekV2MoE(nn.Module):
177
200
 
178
201
  if config.n_shared_experts is not None:
179
202
  intermediate_size = config.moe_intermediate_size * config.n_shared_experts
180
- self.shared_experts = DeepseekV2MLP(
203
+ # disable tp for shared experts when enable deepep moe
204
+ if not global_server_args_dict["enable_deepep_moe"]:
205
+ self.shared_experts = DeepseekV2MLP(
206
+ hidden_size=config.hidden_size,
207
+ intermediate_size=intermediate_size,
208
+ hidden_act=config.hidden_act,
209
+ quant_config=quant_config,
210
+ reduce_results=False,
211
+ prefix=add_prefix("shared_experts", prefix),
212
+ )
213
+ else:
214
+ self.shared_experts = DeepseekV2MLP(
215
+ hidden_size=config.hidden_size,
216
+ intermediate_size=intermediate_size,
217
+ hidden_act=config.hidden_act,
218
+ quant_config=quant_config,
219
+ reduce_results=False,
220
+ prefix=add_prefix("shared_experts", prefix),
221
+ tp_rank=0,
222
+ tp_size=1,
223
+ )
224
+
225
+ if global_server_args_dict["enable_deepep_moe"]:
226
+ self.num_experts = config.n_routed_experts
227
+ self.top_k = config.num_experts_per_tok
228
+ self.renormalize = config.norm_topk_prob
229
+ self.topk_group = config.topk_group
230
+ self.num_expert_group = config.n_group
231
+ self.correction_bias = (
232
+ self.gate.e_score_correction_bias.data
233
+ if self.gate.e_score_correction_bias is not None
234
+ else None
235
+ )
236
+
237
+ self.deepep_dispatcher = DeepEPDispatcher(
238
+ group=parallel_state.get_tp_group().device_group,
239
+ router_topk=self.top_k,
240
+ permute_fusion=True,
241
+ num_experts=config.n_routed_experts,
242
+ num_local_experts=config.n_routed_experts // self.tp_size,
181
243
  hidden_size=config.hidden_size,
182
- intermediate_size=intermediate_size,
183
- hidden_act=config.hidden_act,
184
- quant_config=quant_config,
185
- reduce_results=False,
186
- prefix=add_prefix("shared_experts", prefix),
244
+ params_dtype=config.torch_dtype,
245
+ async_finish=True, # TODO
187
246
  )
188
247
 
189
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
190
- num_tokens, hidden_dim = hidden_states.shape
191
- hidden_states = hidden_states.view(-1, hidden_dim)
248
+ def forward(
249
+ self, hidden_states: torch.Tensor, forward_mode: Optional[ForwardMode] = None
250
+ ) -> torch.Tensor:
251
+ if not global_server_args_dict["enable_deepep_moe"]:
252
+ return self.forward_normal(hidden_states)
253
+ else:
254
+ return self.forward_deepep(hidden_states, forward_mode)
255
+
256
+ def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
192
257
  if self.n_shared_experts is not None:
193
258
  shared_output = self.shared_experts(hidden_states)
194
259
  # router_logits: (num_tokens, n_experts)
@@ -201,8 +266,60 @@ class DeepseekV2MoE(nn.Module):
201
266
  final_hidden_states = final_hidden_states + shared_output
202
267
  if self.tp_size > 1:
203
268
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
269
+ return final_hidden_states
204
270
 
205
- return final_hidden_states.view(num_tokens, hidden_dim)
271
+ def forward_deepep(
272
+ self, hidden_states: torch.Tensor, forward_mode: ForwardMode
273
+ ) -> torch.Tensor:
274
+ shared_output = None
275
+ topk_idx = torch.full(
276
+ (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
277
+ )
278
+ topk_weights = torch.empty(
279
+ (0, self.top_k), dtype=torch.float32, device=hidden_states.device
280
+ )
281
+ if forward_mode is not None and not forward_mode.is_idle():
282
+ # router_logits: (num_tokens, n_experts)
283
+ router_logits = self.gate(hidden_states)
284
+ if self.n_shared_experts is not None:
285
+ shared_output = self.shared_experts(hidden_states)
286
+ topk_weights, topk_idx = select_experts(
287
+ hidden_states=hidden_states,
288
+ router_logits=router_logits,
289
+ top_k=self.top_k,
290
+ use_grouped_topk=True,
291
+ renormalize=self.renormalize,
292
+ topk_group=self.topk_group,
293
+ num_expert_group=self.num_expert_group,
294
+ correction_bias=self.correction_bias,
295
+ )
296
+ if self.tp_size > 1:
297
+ recv_hidden_states, reorder_topk_ids, seg_indptr = (
298
+ self.deepep_dispatcher.dispatch(
299
+ hidden_states,
300
+ topk_idx,
301
+ topk_weights,
302
+ self.num_experts,
303
+ forward_mode,
304
+ )
305
+ )
306
+ final_hidden_states = (
307
+ self.experts(
308
+ hidden_states=recv_hidden_states,
309
+ reorder_topk_ids=reorder_topk_ids,
310
+ seg_indptr=seg_indptr,
311
+ forward_mode=forward_mode,
312
+ )
313
+ * self.routed_scaling_factor
314
+ )
315
+ if self.tp_size > 1:
316
+ final_hidden_states = self.deepep_dispatcher.combine(
317
+ final_hidden_states, forward_mode
318
+ )
319
+ if shared_output is not None:
320
+ final_hidden_states = final_hidden_states + shared_output
321
+
322
+ return final_hidden_states
206
323
 
207
324
 
208
325
  def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
@@ -230,6 +347,7 @@ class DeepseekV2Attention(nn.Module):
230
347
  max_position_embeddings: int = 8192,
231
348
  quant_config: Optional[QuantizationConfig] = None,
232
349
  layer_id=None,
350
+ reduce_results: bool = True,
233
351
  prefix: str = "",
234
352
  ) -> None:
235
353
  super().__init__()
@@ -241,10 +359,14 @@ class DeepseekV2Attention(nn.Module):
241
359
  self.v_head_dim = v_head_dim
242
360
  self.q_lora_rank = q_lora_rank
243
361
  self.kv_lora_rank = kv_lora_rank
362
+
363
+ self.dp_size = get_attention_dp_size()
364
+ attn_tp_rank = get_attention_tp_rank()
365
+ attn_tp_size = get_attention_tp_size()
366
+
244
367
  self.num_heads = num_heads
245
- tp_size = get_tensor_model_parallel_world_size()
246
- assert num_heads % tp_size == 0
247
- self.num_local_heads = num_heads // tp_size
368
+ assert num_heads % attn_tp_size == 0
369
+ self.num_local_heads = num_heads // attn_tp_size
248
370
  self.scaling = self.qk_head_dim**-0.5
249
371
  self.rope_theta = rope_theta
250
372
  self.max_position_embeddings = max_position_embeddings
@@ -272,6 +394,8 @@ class DeepseekV2Attention(nn.Module):
272
394
  bias=False,
273
395
  quant_config=quant_config,
274
396
  prefix=add_prefix("q_proj", prefix),
397
+ tp_rank=attn_tp_rank,
398
+ tp_size=attn_tp_size,
275
399
  )
276
400
 
277
401
  self.kv_a_proj_with_mqa = ReplicatedLinear(
@@ -296,6 +420,9 @@ class DeepseekV2Attention(nn.Module):
296
420
  bias=False,
297
421
  quant_config=quant_config,
298
422
  prefix=add_prefix("o_proj", prefix),
423
+ reduce_results=reduce_results,
424
+ tp_rank=attn_tp_rank,
425
+ tp_size=attn_tp_size,
299
426
  )
300
427
  rope_scaling["rope_type"] = "deepseek_yarn"
301
428
  self.rotary_emb = get_rope_wrapper(
@@ -330,6 +457,12 @@ class DeepseekV2Attention(nn.Module):
330
457
  hidden_states: torch.Tensor,
331
458
  forward_batch: ForwardBatch,
332
459
  ) -> torch.Tensor:
460
+ if hidden_states.shape[0] == 0:
461
+ assert (
462
+ not self.o_proj.reduce_results
463
+ ), "short-circuiting allreduce will lead to hangs"
464
+ return hidden_states
465
+
333
466
  if self.q_lora_rank is not None:
334
467
  q = self.q_a_proj(hidden_states)[0]
335
468
  q = self.q_a_layernorm(q)
@@ -385,8 +518,8 @@ class DeepseekV2AttentionMLA(nn.Module):
385
518
  rope_scaling: Optional[Dict[str, Any]] = None,
386
519
  max_position_embeddings: int = 8192,
387
520
  quant_config: Optional[QuantizationConfig] = None,
388
- layer_id=None,
389
- use_dp=False,
521
+ reduce_results: bool = True,
522
+ layer_id: int = None,
390
523
  prefix: str = "",
391
524
  ) -> None:
392
525
  super().__init__()
@@ -398,96 +531,66 @@ class DeepseekV2AttentionMLA(nn.Module):
398
531
  self.v_head_dim = v_head_dim
399
532
  self.q_lora_rank = q_lora_rank
400
533
  self.kv_lora_rank = kv_lora_rank
534
+ self.dp_size = get_attention_dp_size()
535
+ attn_tp_rank = get_attention_tp_rank()
536
+ attn_tp_size = get_attention_tp_size()
537
+
401
538
  self.num_heads = num_heads
402
- tp_size = get_tensor_model_parallel_world_size()
403
- assert num_heads % tp_size == 0
404
- self.num_local_heads = num_heads if use_dp else num_heads // tp_size
539
+ assert num_heads % attn_tp_size == 0
540
+ self.num_local_heads = num_heads // attn_tp_size
405
541
  self.scaling = self.qk_head_dim**-0.5
406
542
  self.rope_theta = rope_theta
407
543
  self.max_position_embeddings = max_position_embeddings
408
544
 
409
- if use_dp:
410
- # For data parallel attention
411
- if self.q_lora_rank is not None:
412
- self.q_a_proj = ReplicatedLinear(
413
- self.hidden_size,
414
- self.q_lora_rank,
415
- bias=False,
416
- quant_config=quant_config,
417
- prefix=add_prefix("q_a_proj", prefix),
418
- )
419
- self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
420
- self.q_b_proj = ReplicatedLinear(
421
- q_lora_rank,
422
- self.num_heads * self.qk_head_dim,
423
- bias=False,
424
- quant_config=quant_config,
425
- prefix=add_prefix("q_b_proj", prefix),
426
- )
427
- else:
428
- self.q_proj = ReplicatedLinear(
429
- self.hidden_size,
430
- self.num_heads * self.qk_head_dim,
431
- bias=False,
432
- quant_config=quant_config,
433
- prefix=add_prefix("q_proj", prefix),
434
- )
435
- self.kv_b_proj = ReplicatedLinear(
436
- self.kv_lora_rank,
437
- self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
438
- bias=False,
439
- quant_config=quant_config,
440
- prefix=add_prefix("kv_b_proj", prefix),
441
- )
442
- # O projection.
443
- self.o_proj = ReplicatedLinear(
444
- self.num_heads * self.v_head_dim,
545
+ # For tensor parallel attention
546
+ if self.q_lora_rank is not None:
547
+ self.q_a_proj = ReplicatedLinear(
445
548
  self.hidden_size,
549
+ self.q_lora_rank,
446
550
  bias=False,
447
551
  quant_config=quant_config,
448
- prefix=add_prefix("o_proj", prefix),
552
+ prefix=add_prefix("q_a_proj", prefix),
449
553
  )
450
- else:
451
- # For tensor parallel attention
452
- if self.q_lora_rank is not None:
453
- self.q_a_proj = ReplicatedLinear(
454
- self.hidden_size,
455
- self.q_lora_rank,
456
- bias=False,
457
- quant_config=quant_config,
458
- prefix=add_prefix("q_a_proj", prefix),
459
- )
460
- self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
461
- self.q_b_proj = ColumnParallelLinear(
462
- q_lora_rank,
463
- self.num_heads * self.qk_head_dim,
464
- bias=False,
465
- quant_config=quant_config,
466
- prefix=add_prefix("q_b_proj", prefix),
467
- )
468
- else:
469
- self.q_proj = ColumnParallelLinear(
470
- self.hidden_size,
471
- self.num_heads * self.qk_head_dim,
472
- bias=False,
473
- quant_config=quant_config,
474
- prefix=add_prefix("q_proj", prefix),
475
- )
476
- self.kv_b_proj = ColumnParallelLinear(
477
- self.kv_lora_rank,
478
- self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
554
+ self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
555
+ self.q_b_proj = ColumnParallelLinear(
556
+ q_lora_rank,
557
+ self.num_heads * self.qk_head_dim,
479
558
  bias=False,
480
559
  quant_config=quant_config,
481
- prefix=add_prefix("kv_b_proj", prefix),
560
+ prefix=add_prefix("q_b_proj", prefix),
561
+ tp_rank=attn_tp_rank,
562
+ tp_size=attn_tp_size,
482
563
  )
483
- # O projection.
484
- self.o_proj = RowParallelLinear(
485
- self.num_heads * self.v_head_dim,
564
+ else:
565
+ self.q_proj = ColumnParallelLinear(
486
566
  self.hidden_size,
567
+ self.num_heads * self.qk_head_dim,
487
568
  bias=False,
488
569
  quant_config=quant_config,
489
- prefix=add_prefix("o_proj", prefix),
570
+ prefix=add_prefix("q_proj", prefix),
571
+ tp_rank=attn_tp_rank,
572
+ tp_size=attn_tp_size,
490
573
  )
574
+ self.kv_b_proj = ColumnParallelLinear(
575
+ self.kv_lora_rank,
576
+ self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
577
+ bias=False,
578
+ quant_config=quant_config,
579
+ prefix=add_prefix("kv_b_proj", prefix),
580
+ tp_rank=attn_tp_rank,
581
+ tp_size=attn_tp_size,
582
+ )
583
+ # O projection.
584
+ self.o_proj = RowParallelLinear(
585
+ self.num_heads * self.v_head_dim,
586
+ self.hidden_size,
587
+ bias=False,
588
+ quant_config=quant_config,
589
+ reduce_results=reduce_results,
590
+ prefix=add_prefix("o_proj", prefix),
591
+ tp_rank=attn_tp_rank,
592
+ tp_size=attn_tp_size,
593
+ )
491
594
 
492
595
  self.kv_a_proj_with_mqa = ReplicatedLinear(
493
596
  self.hidden_size,
@@ -542,38 +645,49 @@ class DeepseekV2AttentionMLA(nn.Module):
542
645
  self.w_vc = None
543
646
  self.w_scale = None
544
647
 
648
+ self.enable_flashinfer_mla = global_server_args_dict["enable_flashinfer_mla"]
649
+ self.flashinfer_mla_disable_ragged = global_server_args_dict[
650
+ "flashinfer_mla_disable_ragged"
651
+ ]
652
+ self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
653
+
654
+ def no_absorb(self, forward_batch: ForwardBatch) -> bool:
655
+ if self.enable_flashinfer_mla:
656
+ # Flashinfer MLA: Do not absorb when enabling ragged prefill
657
+ return (
658
+ not self.flashinfer_mla_disable_ragged
659
+ and forward_batch.forward_mode.is_extend()
660
+ and not forward_batch.forward_mode.is_target_verify()
661
+ and not forward_batch.forward_mode.is_draft_extend()
662
+ and sum(forward_batch.extend_prefix_lens_cpu) == 0
663
+ )
664
+ else:
665
+ # Triton: Use normal computation for prefill and use weight absorption for extend/decode
666
+ return (
667
+ forward_batch.forward_mode.is_extend()
668
+ and not forward_batch.forward_mode.is_target_verify()
669
+ and not forward_batch.forward_mode.is_draft_extend()
670
+ and sum(forward_batch.extend_prefix_lens_cpu) == 0
671
+ )
672
+
545
673
  def forward(
546
674
  self,
547
675
  positions: torch.Tensor,
548
676
  hidden_states: torch.Tensor,
549
677
  forward_batch: ForwardBatch,
550
678
  ) -> torch.Tensor:
679
+ if hidden_states.shape[0] == 0:
680
+ assert (
681
+ not self.o_proj.reduce_results
682
+ ), "short-circuiting allreduce will lead to hangs"
683
+ return hidden_states
551
684
 
552
- def no_absorb() -> bool:
553
- if global_server_args_dict["enable_flashinfer_mla"]:
554
- # Flashinfer MLA: Do not absorb when enabling ragged prefill
555
- return (
556
- not global_server_args_dict["flashinfer_mla_disable_ragged"]
557
- and forward_batch.forward_mode.is_extend()
558
- and not forward_batch.forward_mode.is_target_verify()
559
- and not forward_batch.forward_mode.is_draft_extend()
560
- and forward_batch.extend_prefix_lens.sum() == 0
561
- )
562
- else:
563
- # Triton: Use normal computation for prefill and use weight absorption for extend/decode
564
- return (
565
- forward_batch.forward_mode.is_extend()
566
- and not forward_batch.forward_mode.is_target_verify()
567
- and not forward_batch.forward_mode.is_draft_extend()
568
- and forward_batch.extend_prefix_lens.sum() == 0
569
- )
570
-
571
- if no_absorb():
685
+ if self.no_absorb(forward_batch):
572
686
  return self.forward_normal(positions, hidden_states, forward_batch)
573
687
  else:
574
688
  if _is_hip:
575
689
  if (
576
- os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
690
+ self.rocm_fused_decode_mla
577
691
  and forward_batch.forward_mode.is_decode()
578
692
  ):
579
693
  return self.forward_absorb_fused_mla_rope(
@@ -845,34 +959,6 @@ class DeepseekV2AttentionMLA(nn.Module):
845
959
  return output
846
960
 
847
961
 
848
- def all_gather(
849
- input_tensor: torch.Tensor, forward_batch: ForwardBatch, rank, world_size, group
850
- ):
851
- all_lens = forward_batch.global_num_tokens_cpu
852
- max_len = max(forward_batch.global_num_tokens_cpu)
853
-
854
- if world_size == 1:
855
- return input_tensor, 0, all_lens[0]
856
-
857
- padded_tensor = torch.nn.functional.pad(
858
- input_tensor, (0, 0, 0, max_len - input_tensor.shape[0])
859
- )
860
-
861
- group.all_gather_into_tensor(forward_batch.gathered_buffer, padded_tensor)
862
-
863
- gathered_tensors = torch.concat(
864
- [
865
- forward_batch.gathered_buffer[i * max_len : i * max_len + all_lens[i]]
866
- for i in range(world_size)
867
- ]
868
- )
869
-
870
- start_index = 0 if rank == 0 else sum(all_lens[:rank])
871
- end_index = start_index + all_lens[rank]
872
-
873
- return gathered_tensors, start_index, end_index
874
-
875
-
876
962
  class DeepseekV2DecoderLayer(nn.Module):
877
963
 
878
964
  def __init__(
@@ -888,14 +974,10 @@ class DeepseekV2DecoderLayer(nn.Module):
888
974
  rope_theta = getattr(config, "rope_theta", 10000)
889
975
  rope_scaling = getattr(config, "rope_scaling", None)
890
976
  max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
891
- self.enable_dp_attention = (
892
- not global_server_args_dict["disable_mla"]
893
- and global_server_args_dict["enable_dp_attention"]
894
- )
895
- if self.enable_dp_attention:
896
- self.tp_rank = get_tensor_model_parallel_rank()
897
- self.tp_size = get_tensor_model_parallel_world_size()
898
- self.tp_group = get_tp_group()
977
+ self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
978
+ self.layer_id = layer_id
979
+ self.dp_size = get_attention_dp_size()
980
+
899
981
  if not global_server_args_dict["disable_mla"]:
900
982
  self.self_attn = DeepseekV2AttentionMLA(
901
983
  config=config,
@@ -913,7 +995,7 @@ class DeepseekV2DecoderLayer(nn.Module):
913
995
  max_position_embeddings=max_position_embeddings,
914
996
  quant_config=quant_config,
915
997
  layer_id=layer_id,
916
- use_dp=self.enable_dp_attention,
998
+ reduce_results=False,
917
999
  prefix=add_prefix("self_attn", prefix),
918
1000
  )
919
1001
  else:
@@ -933,8 +1015,10 @@ class DeepseekV2DecoderLayer(nn.Module):
933
1015
  max_position_embeddings=max_position_embeddings,
934
1016
  quant_config=quant_config,
935
1017
  layer_id=layer_id,
1018
+ reduce_results=False,
936
1019
  prefix=add_prefix("self_attn", prefix),
937
1020
  )
1021
+
938
1022
  if is_nextn or (
939
1023
  config.n_routed_experts is not None
940
1024
  and layer_id >= config.first_k_dense_replace
@@ -965,32 +1049,67 @@ class DeepseekV2DecoderLayer(nn.Module):
965
1049
  forward_batch: ForwardBatch,
966
1050
  residual: Optional[torch.Tensor],
967
1051
  ) -> torch.Tensor:
968
- # Self Attention
969
- if not forward_batch.forward_mode.is_idle():
1052
+ if hidden_states.shape[0] == 0:
1053
+ residual = hidden_states
1054
+ else:
970
1055
  if residual is None:
971
1056
  residual = hidden_states
972
1057
  hidden_states = self.input_layernorm(hidden_states)
973
1058
  else:
974
1059
  hidden_states, residual = self.input_layernorm(hidden_states, residual)
975
1060
 
1061
+ # Self Attention
976
1062
  hidden_states = self.self_attn(
977
1063
  positions=positions,
978
1064
  hidden_states=hidden_states,
979
1065
  forward_batch=forward_batch,
980
1066
  )
1067
+
1068
+ # Gather
1069
+ if get_tensor_model_parallel_world_size() > 1:
1070
+ # all gather and all reduce
1071
+ if self.dp_size != 1:
1072
+ if global_server_args_dict["enable_deepep_moe"] and isinstance(
1073
+ self.mlp, DeepseekV2MoE
1074
+ ):
1075
+ if hidden_states.shape[0] != 0:
1076
+ hidden_states, residual = self.post_attention_layernorm(
1077
+ hidden_states, residual
1078
+ )
1079
+ hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
1080
+ return hidden_states, residual
1081
+ else:
1082
+ if get_attention_tp_rank() == 0:
1083
+ hidden_states += residual
1084
+ hidden_states, local_hidden_states = (
1085
+ forward_batch.gathered_buffer,
1086
+ hidden_states,
1087
+ )
1088
+ dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
1089
+ dp_scatter(residual, hidden_states, forward_batch)
1090
+ hidden_states = self.post_attention_layernorm(hidden_states)
1091
+ else:
1092
+ hidden_states = tensor_model_parallel_all_reduce(hidden_states)
1093
+ hidden_states, residual = self.post_attention_layernorm(
1094
+ hidden_states, residual
1095
+ )
1096
+ else:
981
1097
  hidden_states, residual = self.post_attention_layernorm(
982
1098
  hidden_states, residual
983
1099
  )
984
1100
 
985
1101
  # Fully Connected
986
- if self.enable_dp_attention:
987
- hidden_states, start_idx, end_idx = all_gather(
988
- hidden_states, forward_batch, self.tp_rank, self.tp_size, self.tp_group
1102
+ hidden_states = self.mlp(hidden_states)
1103
+
1104
+ # Scatter
1105
+ if self.dp_size != 1:
1106
+ # important: forward batch.gathered_buffer is used both after scatter and after gather.
1107
+ # be careful about this!
1108
+ hidden_states, global_hidden_states = (
1109
+ forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
1110
+ hidden_states,
989
1111
  )
990
- hidden_states = self.mlp(hidden_states)
991
- hidden_states = hidden_states[start_idx:end_idx]
992
- else:
993
- hidden_states = self.mlp(hidden_states)
1112
+ dp_scatter(hidden_states, global_hidden_states, forward_batch)
994
1113
 
995
1114
  return hidden_states, residual
996
1115
 
@@ -1027,15 +1146,24 @@ class DeepseekV2Model(nn.Module):
1027
1146
  )
1028
1147
  self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1029
1148
 
1149
+ self.dp_size = get_attention_dp_size()
1150
+
1030
1151
  def forward(
1031
1152
  self,
1032
1153
  input_ids: torch.Tensor,
1033
1154
  positions: torch.Tensor,
1034
1155
  forward_batch: ForwardBatch,
1156
+ input_embeds: torch.Tensor = None,
1035
1157
  ) -> torch.Tensor:
1036
- hidden_states = self.embed_tokens(input_ids)
1158
+
1159
+ if input_embeds is None:
1160
+ hidden_states = self.embed_tokens(input_ids)
1161
+ else:
1162
+ hidden_states = input_embeds
1163
+
1037
1164
  residual = None
1038
1165
  for i in range(len(self.layers)):
1166
+ expert_distribution_recorder.set_current_layer(i)
1039
1167
  layer = self.layers[i]
1040
1168
  hidden_states, residual = layer(
1041
1169
  positions, hidden_states, forward_batch, residual
@@ -1059,22 +1187,14 @@ class DeepseekV2ForCausalLM(nn.Module):
1059
1187
  self.model = DeepseekV2Model(
1060
1188
  config, quant_config, prefix=add_prefix("model", prefix)
1061
1189
  )
1062
- if global_server_args_dict["enable_dp_attention"]:
1063
- self.lm_head = ReplicatedLinear(
1064
- config.hidden_size,
1065
- config.vocab_size,
1066
- bias=False,
1067
- prefix=add_prefix("lm_head", prefix),
1068
- )
1069
- self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
1070
- else:
1071
- self.lm_head = ParallelLMHead(
1072
- config.vocab_size,
1073
- config.hidden_size,
1074
- quant_config=quant_config,
1075
- prefix=add_prefix("lm_head", prefix),
1076
- )
1077
- self.logits_processor = LogitsProcessor(config)
1190
+ self.lm_head = ParallelLMHead(
1191
+ config.vocab_size,
1192
+ config.hidden_size,
1193
+ quant_config=quant_config,
1194
+ prefix=add_prefix("lm_head", prefix),
1195
+ )
1196
+ self.logits_processor = LogitsProcessor(config)
1197
+ self.dp_size = get_attention_dp_size()
1078
1198
 
1079
1199
  @torch.no_grad()
1080
1200
  def forward(
@@ -1082,8 +1202,11 @@ class DeepseekV2ForCausalLM(nn.Module):
1082
1202
  input_ids: torch.Tensor,
1083
1203
  positions: torch.Tensor,
1084
1204
  forward_batch: ForwardBatch,
1205
+ input_embeds: torch.Tensor = None,
1085
1206
  ) -> torch.Tensor:
1086
- hidden_states = self.model(input_ids, positions, forward_batch)
1207
+
1208
+ hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
1209
+
1087
1210
  return self.logits_processor(
1088
1211
  input_ids, hidden_states, self.lm_head, forward_batch
1089
1212
  )
@@ -1097,7 +1220,11 @@ class DeepseekV2ForCausalLM(nn.Module):
1097
1220
 
1098
1221
  # Params for weights, fp8 weight scales, fp8 activation scales
1099
1222
  # (param_name, weight_name, expert_id, shard_id)
1100
- MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
1223
+ MoEImpl = (
1224
+ DeepEPMoE
1225
+ if global_server_args_dict["enable_deepep_moe"]
1226
+ else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
1227
+ )
1101
1228
  expert_params_mapping = MoEImpl.make_expert_params_mapping(
1102
1229
  ckpt_gate_proj_name="gate_proj",
1103
1230
  ckpt_down_proj_name="down_proj",
@@ -1171,14 +1298,21 @@ class DeepseekV2ForCausalLM(nn.Module):
1171
1298
  self_attn = self.model.layers[layer_id].self_attn
1172
1299
  if hasattr(self_attn.kv_b_proj, "qweight"):
1173
1300
  # AWQ compatible
1174
- w = ops.awq_dequantize(
1175
- self_attn.kv_b_proj.qweight,
1176
- self_attn.kv_b_proj.scales,
1177
- self_attn.kv_b_proj.qzeros,
1178
- 0,
1179
- 0,
1180
- 0,
1181
- ).T
1301
+ if _is_cuda:
1302
+ w = awq_dequantize(
1303
+ self_attn.kv_b_proj.qweight,
1304
+ self_attn.kv_b_proj.scales,
1305
+ self_attn.kv_b_proj.qzeros,
1306
+ ).T
1307
+ else:
1308
+ w = ops.awq_dequantize(
1309
+ self_attn.kv_b_proj.qweight,
1310
+ self_attn.kv_b_proj.scales,
1311
+ self_attn.kv_b_proj.qzeros,
1312
+ 0,
1313
+ 0,
1314
+ 0,
1315
+ ).T
1182
1316
  else:
1183
1317
  w = self_attn.kv_b_proj.weight
1184
1318
  # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.