sglang 0.4.7.post1__py3-none-any.whl → 0.4.8.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. sglang/bench_one_batch.py +8 -6
  2. sglang/srt/_custom_ops.py +2 -2
  3. sglang/srt/code_completion_parser.py +2 -44
  4. sglang/srt/configs/model_config.py +1 -0
  5. sglang/srt/constants.py +3 -0
  6. sglang/srt/conversation.py +14 -3
  7. sglang/srt/custom_op.py +11 -1
  8. sglang/srt/disaggregation/base/conn.py +2 -0
  9. sglang/srt/disaggregation/decode.py +22 -28
  10. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  11. sglang/srt/disaggregation/mini_lb.py +34 -4
  12. sglang/srt/disaggregation/mooncake/conn.py +301 -64
  13. sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
  14. sglang/srt/disaggregation/nixl/conn.py +94 -46
  15. sglang/srt/disaggregation/prefill.py +20 -15
  16. sglang/srt/disaggregation/utils.py +47 -18
  17. sglang/srt/distributed/parallel_state.py +12 -4
  18. sglang/srt/entrypoints/engine.py +27 -31
  19. sglang/srt/entrypoints/http_server.py +149 -79
  20. sglang/srt/entrypoints/http_server_engine.py +0 -3
  21. sglang/srt/entrypoints/openai/__init__.py +0 -0
  22. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +115 -34
  23. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  24. sglang/srt/entrypoints/openai/serving_chat.py +897 -0
  25. sglang/srt/entrypoints/openai/serving_completions.py +425 -0
  26. sglang/srt/entrypoints/openai/serving_embedding.py +170 -0
  27. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  28. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  29. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  30. sglang/srt/entrypoints/openai/utils.py +72 -0
  31. sglang/srt/function_call/base_format_detector.py +7 -4
  32. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  33. sglang/srt/function_call/ebnf_composer.py +64 -10
  34. sglang/srt/function_call/function_call_parser.py +6 -6
  35. sglang/srt/function_call/llama32_detector.py +1 -1
  36. sglang/srt/function_call/mistral_detector.py +1 -1
  37. sglang/srt/function_call/pythonic_detector.py +1 -1
  38. sglang/srt/function_call/qwen25_detector.py +1 -1
  39. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  40. sglang/srt/layers/activation.py +28 -3
  41. sglang/srt/layers/attention/aiter_backend.py +5 -2
  42. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  43. sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
  44. sglang/srt/layers/attention/flashattention_backend.py +43 -23
  45. sglang/srt/layers/attention/flashinfer_backend.py +9 -6
  46. sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
  47. sglang/srt/layers/attention/flashmla_backend.py +5 -2
  48. sglang/srt/layers/attention/tbo_backend.py +3 -3
  49. sglang/srt/layers/attention/triton_backend.py +19 -11
  50. sglang/srt/layers/communicator.py +5 -5
  51. sglang/srt/layers/dp_attention.py +11 -2
  52. sglang/srt/layers/layernorm.py +44 -2
  53. sglang/srt/layers/linear.py +18 -1
  54. sglang/srt/layers/logits_processor.py +14 -5
  55. sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
  56. sglang/srt/layers/moe/ep_moe/layer.py +286 -13
  57. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +19 -2
  58. sglang/srt/layers/moe/fused_moe_native.py +7 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +13 -2
  61. sglang/srt/layers/moe/fused_moe_triton/layer.py +148 -26
  62. sglang/srt/layers/moe/topk.py +117 -4
  63. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  64. sglang/srt/layers/quantization/fp8.py +25 -17
  65. sglang/srt/layers/quantization/fp8_utils.py +5 -4
  66. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  67. sglang/srt/layers/quantization/utils.py +5 -2
  68. sglang/srt/layers/rotary_embedding.py +144 -12
  69. sglang/srt/layers/sampler.py +1 -1
  70. sglang/srt/layers/vocab_parallel_embedding.py +14 -1
  71. sglang/srt/lora/lora_manager.py +173 -74
  72. sglang/srt/lora/mem_pool.py +49 -45
  73. sglang/srt/lora/utils.py +1 -1
  74. sglang/srt/managers/cache_controller.py +33 -15
  75. sglang/srt/managers/expert_distribution.py +21 -0
  76. sglang/srt/managers/io_struct.py +19 -14
  77. sglang/srt/managers/multimodal_processors/base_processor.py +44 -9
  78. sglang/srt/managers/multimodal_processors/gemma3n.py +97 -0
  79. sglang/srt/managers/schedule_batch.py +49 -32
  80. sglang/srt/managers/schedule_policy.py +70 -56
  81. sglang/srt/managers/scheduler.py +189 -68
  82. sglang/srt/managers/template_manager.py +226 -0
  83. sglang/srt/managers/tokenizer_manager.py +11 -8
  84. sglang/srt/managers/tp_worker.py +12 -2
  85. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  86. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  87. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  88. sglang/srt/mem_cache/chunk_cache.py +11 -16
  89. sglang/srt/mem_cache/hiradix_cache.py +34 -23
  90. sglang/srt/mem_cache/memory_pool.py +118 -114
  91. sglang/srt/mem_cache/radix_cache.py +20 -16
  92. sglang/srt/model_executor/cuda_graph_runner.py +77 -46
  93. sglang/srt/model_executor/forward_batch_info.py +18 -5
  94. sglang/srt/model_executor/model_runner.py +27 -8
  95. sglang/srt/model_loader/loader.py +50 -8
  96. sglang/srt/model_loader/weight_utils.py +100 -2
  97. sglang/srt/models/deepseek_nextn.py +35 -30
  98. sglang/srt/models/deepseek_v2.py +255 -30
  99. sglang/srt/models/gemma3n_audio.py +949 -0
  100. sglang/srt/models/gemma3n_causal.py +1009 -0
  101. sglang/srt/models/gemma3n_mm.py +511 -0
  102. sglang/srt/models/glm4.py +312 -0
  103. sglang/srt/models/hunyuan.py +771 -0
  104. sglang/srt/models/mimo_mtp.py +2 -18
  105. sglang/srt/reasoning_parser.py +21 -11
  106. sglang/srt/server_args.py +51 -9
  107. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
  108. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
  109. sglang/srt/speculative/eagle_utils.py +80 -8
  110. sglang/srt/speculative/eagle_worker.py +124 -41
  111. sglang/srt/torch_memory_saver_adapter.py +19 -15
  112. sglang/srt/two_batch_overlap.py +4 -1
  113. sglang/srt/utils.py +248 -11
  114. sglang/test/test_block_fp8_ep.py +1 -0
  115. sglang/test/test_utils.py +1 -0
  116. sglang/version.py +1 -1
  117. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/METADATA +4 -10
  118. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/RECORD +121 -105
  119. sglang/srt/entrypoints/verl_engine.py +0 -179
  120. sglang/srt/openai_api/adapter.py +0 -2148
  121. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/WHEEL +0 -0
  122. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/licenses/LICENSE +0 -0
  123. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/top_level.txt +0 -0
@@ -72,7 +72,7 @@ from sglang.srt.layers.quantization.int8_utils import (
72
72
  block_dequant as int8_block_dequant,
73
73
  )
74
74
  from sglang.srt.layers.radix_attention import RadixAttention
75
- from sglang.srt.layers.rotary_embedding import get_rope
75
+ from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
76
76
  from sglang.srt.layers.vocab_parallel_embedding import (
77
77
  ParallelLMHead,
78
78
  VocabParallelEmbedding,
@@ -93,10 +93,13 @@ from sglang.srt.utils import (
93
93
  BumpAllocator,
94
94
  DeepEPMode,
95
95
  LazyValue,
96
+ PackWeightMethod,
96
97
  add_prefix,
97
98
  bind_or_assign,
99
+ cpu_has_amx_support,
98
100
  get_bool_env_var,
99
101
  get_int_env_var,
102
+ is_cpu,
100
103
  is_cuda,
101
104
  is_hip,
102
105
  is_non_idle_and_non_empty,
@@ -107,9 +110,13 @@ _is_hip = is_hip()
107
110
  _is_cuda = is_cuda()
108
111
  _is_fp8_fnuz = is_fp8_fnuz()
109
112
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
113
+ _is_cpu_amx_available = cpu_has_amx_support()
114
+ _is_cpu = is_cpu()
110
115
 
111
116
  if _is_cuda:
112
117
  from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
118
+ elif _is_cpu and _is_cpu_amx_available:
119
+ pass
113
120
  else:
114
121
  from vllm._custom_ops import awq_dequantize
115
122
 
@@ -118,8 +125,6 @@ if _is_hip:
118
125
  decode_attention_fwd_grouped_rope,
119
126
  )
120
127
 
121
- if _use_aiter:
122
- from aiter.rotary_embedding import get_rope
123
128
 
124
129
  logger = logging.getLogger(__name__)
125
130
 
@@ -138,6 +143,9 @@ class AttnForwardMethod(IntEnum):
138
143
  # Use MLA but with fused RoPE
139
144
  MLA_FUSED_ROPE = auto()
140
145
 
146
+ # Use MLA with fused RoPE kernel for CPU
147
+ MLA_FUSED_ROPE_CPU = auto()
148
+
141
149
 
142
150
  class DeepseekV2MLP(nn.Module):
143
151
  def __init__(
@@ -206,8 +214,18 @@ class MoEGate(nn.Module):
206
214
  )
207
215
  else:
208
216
  self.e_score_correction_bias = None
217
+ if _is_cpu and _is_cpu_amx_available:
218
+ self.quant_method = PackWeightMethod(weight_names=["weight"])
209
219
 
210
220
  def forward(self, hidden_states):
221
+ if getattr(self, "use_intel_amx_backend", False):
222
+ return torch.ops.sgl_kernel.weight_packed_linear(
223
+ hidden_states,
224
+ self.weight,
225
+ None, # bias
226
+ True, # is_vnni
227
+ )
228
+
211
229
  logits = F.linear(hidden_states, self.weight, None)
212
230
  return logits
213
231
 
@@ -220,6 +238,7 @@ class DeepseekV2MoE(nn.Module):
220
238
  layer_id: int,
221
239
  quant_config: Optional[QuantizationConfig] = None,
222
240
  prefix: str = "",
241
+ alt_stream: Optional[torch.cuda.Stream] = None,
223
242
  ):
224
243
  super().__init__()
225
244
  self.tp_size = get_tensor_model_parallel_world_size()
@@ -232,6 +251,7 @@ class DeepseekV2MoE(nn.Module):
232
251
  )
233
252
  self.config = config
234
253
  self.layer_id = layer_id
254
+ self.alt_stream = alt_stream
235
255
 
236
256
  if self.tp_size > config.n_routed_experts:
237
257
  raise ValueError(
@@ -269,6 +289,15 @@ class DeepseekV2MoE(nn.Module):
269
289
  if global_server_args_dict["enable_deepep_moe"]
270
290
  else {}
271
291
  ),
292
+ # Additional args for FusedMoE
293
+ **(
294
+ dict(
295
+ enable_flashinfer_moe=True,
296
+ enable_ep_moe=global_server_args_dict["enable_ep_moe"],
297
+ )
298
+ if global_server_args_dict["enable_flashinfer_moe"]
299
+ else {}
300
+ ),
272
301
  )
273
302
 
274
303
  if config.n_shared_experts is not None and self.num_fused_shared_experts == 0:
@@ -332,10 +361,38 @@ class DeepseekV2MoE(nn.Module):
332
361
  self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
333
362
  ) -> torch.Tensor:
334
363
  if not self._enable_deepep_moe:
335
- return self.forward_normal(hidden_states)
364
+ DUAL_STREAM_TOKEN_THRESHOLD = 1024
365
+ if (
366
+ self.alt_stream is not None
367
+ and self.num_fused_shared_experts == 0
368
+ and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
369
+ ):
370
+ return self.forward_normal_dual_stream(hidden_states)
371
+ else:
372
+ return self.forward_normal(hidden_states)
336
373
  else:
337
374
  return self.forward_deepep(hidden_states, forward_batch)
338
375
 
376
+ def forward_normal_dual_stream(self, hidden_states: torch.Tensor) -> torch.Tensor:
377
+ # router_logits: (num_tokens, n_experts)
378
+ router_logits = self.gate(hidden_states)
379
+
380
+ current_stream = torch.cuda.current_stream()
381
+ self.alt_stream.wait_stream(current_stream)
382
+ shared_output = self._forward_shared_experts(hidden_states)
383
+
384
+ with torch.cuda.stream(self.alt_stream):
385
+ final_hidden_states = self.experts(
386
+ hidden_states=hidden_states, router_logits=router_logits
387
+ )
388
+ if not _is_cuda:
389
+ final_hidden_states *= self.routed_scaling_factor
390
+ current_stream.wait_stream(self.alt_stream)
391
+ final_hidden_states = final_hidden_states + shared_output
392
+ if self.tp_size > 1:
393
+ final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
394
+ return final_hidden_states
395
+
339
396
  def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
340
397
  shared_output = self._forward_shared_experts(hidden_states)
341
398
  # router_logits: (num_tokens, n_experts)
@@ -343,7 +400,8 @@ class DeepseekV2MoE(nn.Module):
343
400
  final_hidden_states = self.experts(
344
401
  hidden_states=hidden_states, router_logits=router_logits
345
402
  )
346
- if not _is_cuda:
403
+ if not _is_cuda and not _use_aiter:
404
+ # fused in biased_grouped_topk so we can skip here
347
405
  final_hidden_states *= self.routed_scaling_factor
348
406
  if shared_output is not None:
349
407
  final_hidden_states = final_hidden_states + shared_output
@@ -665,13 +723,14 @@ class DeepseekV2AttentionMLA(nn.Module):
665
723
  if rope_scaling:
666
724
  rope_scaling["rope_type"] = "deepseek_yarn"
667
725
 
668
- self.rotary_emb = get_rope(
726
+ self.rotary_emb = get_rope_wrapper(
669
727
  qk_rope_head_dim,
670
728
  rotary_dim=qk_rope_head_dim,
671
729
  max_position=max_position_embeddings,
672
730
  base=rope_theta,
673
731
  rope_scaling=rope_scaling,
674
732
  is_neox_style=False,
733
+ device=global_server_args_dict["device"],
675
734
  )
676
735
 
677
736
  if rope_scaling:
@@ -731,6 +790,37 @@ class DeepseekV2AttentionMLA(nn.Module):
731
790
  "SGL_CHUNKED_PREFIX_CACHE_THRESHOLD", 8192
732
791
  )
733
792
 
793
+ # If we have self.fused_qkv_a_proj_with_mqa and we're running on CPU, we will choose the torch.ops.sgl_kernel.qkv_proj_with_rope_fused_weight kernel
794
+ # which requires self.w_kc and self.w_vc to be packed.
795
+ # If not, we will use torch.bmm and weight shouldn't be packed in this case
796
+ if (
797
+ hasattr(self, "fused_qkv_a_proj_with_mqa")
798
+ and _is_cpu
799
+ and _is_cpu_amx_available
800
+ ):
801
+ self.quant_method = PackWeightMethod(
802
+ weight_names=["w_kc", "w_vc"], transpose_dims=[[1, 2], [1, 2]]
803
+ )
804
+
805
+ self.qkv_proj_with_rope_is_int8 = (
806
+ hasattr(self, "fused_qkv_a_proj_with_mqa")
807
+ and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.int8
808
+ )
809
+ self.qkv_proj_with_rope_is_fp8 = (
810
+ hasattr(self, "fused_qkv_a_proj_with_mqa")
811
+ and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.float8_e4m3fn
812
+ )
813
+
814
+ self.weight_block_size = None
815
+ if self.qkv_proj_with_rope_is_fp8:
816
+ assert (
817
+ self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
818
+ == self.q_b_proj.quant_method.quant_config.weight_block_size
819
+ )
820
+ self.weight_block_size = (
821
+ self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
822
+ )
823
+
734
824
  def dispatch_attn_forward_method(
735
825
  self, forward_batch: ForwardBatch
736
826
  ) -> AttnForwardMethod:
@@ -744,7 +834,12 @@ class DeepseekV2AttentionMLA(nn.Module):
744
834
  else:
745
835
  return AttnForwardMethod.MLA
746
836
  else:
747
- return AttnForwardMethod.MLA
837
+ if hasattr(self, "fused_qkv_a_proj_with_mqa") and getattr(
838
+ self, "use_intel_amx_backend", False
839
+ ):
840
+ return AttnForwardMethod.MLA_FUSED_ROPE_CPU
841
+ else:
842
+ return AttnForwardMethod.MLA
748
843
 
749
844
  if self.attention_backend == "flashinfer":
750
845
  # Flashinfer MLA: Do not absorb when enabling ragged prefill
@@ -858,6 +953,10 @@ class DeepseekV2AttentionMLA(nn.Module):
858
953
  inner_state = self.forward_absorb_fused_mla_rope_prepare(
859
954
  positions, hidden_states, forward_batch, zero_allocator
860
955
  )
956
+ elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE_CPU:
957
+ inner_state = self.forward_absorb_fused_mla_rope_cpu_prepare(
958
+ positions, hidden_states, forward_batch, zero_allocator
959
+ )
861
960
  else:
862
961
  raise NotImplementedError
863
962
  return None, attn_forward_method, forward_batch, inner_state
@@ -877,6 +976,8 @@ class DeepseekV2AttentionMLA(nn.Module):
877
976
  return self.forward_absorb_core(*inner_state)
878
977
  elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
879
978
  return self.forward_absorb_fused_mla_rope_core(*inner_state)
979
+ elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE_CPU:
980
+ return self.forward_absorb_fused_mla_rope_cpu_core(*inner_state)
880
981
  else:
881
982
  raise NotImplementedError
882
983
 
@@ -1040,13 +1141,16 @@ class DeepseekV2AttentionMLA(nn.Module):
1040
1141
  masked_m,
1041
1142
  expected_m,
1042
1143
  )
1043
- attn_bmm_output = attn_bmm_output[:, :expected_m, :]
1144
+ attn_bmm_output = (
1145
+ attn_bmm_output[:, :expected_m, :].transpose(0, 1).flatten(1, 2)
1146
+ )
1044
1147
  elif _is_hip:
1045
1148
  # TODO(haishaw): add bmm_fp8 to ROCm
1046
1149
  attn_bmm_output = torch.bmm(
1047
1150
  attn_output.to(torch.bfloat16).transpose(0, 1),
1048
1151
  self.w_vc.to(torch.bfloat16) * self.w_scale,
1049
1152
  )
1153
+ attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
1050
1154
  elif self.w_vc.dtype == torch.float8_e4m3fn:
1051
1155
  attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
1052
1156
  attn_output.transpose(0, 1),
@@ -1059,10 +1163,21 @@ class DeepseekV2AttentionMLA(nn.Module):
1059
1163
  self.w_scale,
1060
1164
  torch.bfloat16,
1061
1165
  )
1166
+ attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
1062
1167
  else:
1063
- attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc)
1064
- attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
1065
- output, _ = self.o_proj(attn_output)
1168
+ attn_bmm_output = torch.empty(
1169
+ (attn_output.shape[0], self.num_local_heads * self.v_head_dim),
1170
+ dtype=attn_output.dtype,
1171
+ device=attn_output.device,
1172
+ )
1173
+ torch.bmm(
1174
+ attn_output.transpose(0, 1),
1175
+ self.w_vc,
1176
+ out=attn_bmm_output.view(
1177
+ -1, self.num_local_heads, self.v_head_dim
1178
+ ).transpose(0, 1),
1179
+ )
1180
+ output, _ = self.o_proj(attn_bmm_output)
1066
1181
 
1067
1182
  return output
1068
1183
 
@@ -1180,6 +1295,57 @@ class DeepseekV2AttentionMLA(nn.Module):
1180
1295
  zero_allocator,
1181
1296
  )
1182
1297
 
1298
+ def forward_absorb_fused_mla_rope_cpu_prepare(
1299
+ self,
1300
+ positions: torch.Tensor,
1301
+ hidden_states: torch.Tensor,
1302
+ forward_batch: ForwardBatch,
1303
+ zero_allocator: BumpAllocator,
1304
+ ):
1305
+ assert self.q_lora_rank is not None and getattr(
1306
+ self, "use_intel_amx_backend", False
1307
+ ), "forward_absorb_fused_mla_rope_cpu_prepare requires q_lora_rank is not None and use_intel_amx_backend"
1308
+
1309
+ q_input, k_input, v_input = (
1310
+ torch.ops.sgl_kernel.qkv_proj_with_rope_fused_weight(
1311
+ hidden_states,
1312
+ self.fused_qkv_a_proj_with_mqa.weight,
1313
+ self.q_b_proj.weight,
1314
+ self.w_kc,
1315
+ self.q_a_layernorm.weight,
1316
+ self.kv_a_layernorm.weight,
1317
+ positions,
1318
+ self.rotary_emb.cos_sin_cache,
1319
+ self.kv_a_layernorm.variance_epsilon,
1320
+ self.qkv_proj_with_rope_is_int8,
1321
+ self.qkv_proj_with_rope_is_fp8,
1322
+ (
1323
+ self.fused_qkv_a_proj_with_mqa.weight_scale
1324
+ if self.qkv_proj_with_rope_is_int8
1325
+ else (
1326
+ self.fused_qkv_a_proj_with_mqa.weight_scale_inv
1327
+ if self.qkv_proj_with_rope_is_fp8
1328
+ else None
1329
+ )
1330
+ ),
1331
+ (
1332
+ self.q_b_proj.weight_scale
1333
+ if self.qkv_proj_with_rope_is_int8
1334
+ else (
1335
+ self.q_b_proj.weight_scale_inv
1336
+ if self.qkv_proj_with_rope_is_fp8
1337
+ else None
1338
+ )
1339
+ ),
1340
+ True, # is_vnni
1341
+ self.weight_block_size,
1342
+ self.q_lora_rank,
1343
+ self.kv_lora_rank,
1344
+ self.qk_rope_head_dim,
1345
+ )
1346
+ )
1347
+ return (q_input, k_input, v_input, forward_batch, zero_allocator)
1348
+
1183
1349
  def forward_absorb_fused_mla_rope_core(
1184
1350
  self,
1185
1351
  q_input,
@@ -1253,6 +1419,43 @@ class DeepseekV2AttentionMLA(nn.Module):
1253
1419
 
1254
1420
  return output
1255
1421
 
1422
+ def forward_absorb_fused_mla_rope_cpu_core(
1423
+ self, q_input, k_input, v_input, forward_batch, zero_allocator
1424
+ ):
1425
+ assert self.q_lora_rank is not None and getattr(
1426
+ self, "use_intel_amx_backend", False
1427
+ ), "forward_absorb_fused_mla_rope_cpu_core requires q_lora_rank is not None and use_intel_amx_backend"
1428
+
1429
+ attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
1430
+ attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
1431
+
1432
+ # [Note] Align shapes of bmm inputs.
1433
+ # Shapes of inputs:
1434
+ # q_nope: [M, B, K]
1435
+ # original self.w_kc: [B, K, N]
1436
+ # current self.w_kc (which has been converted in PackWeightMethod): [B, N, K]
1437
+
1438
+ # Shapes of inputs to sgl_kernel.cpu.bmm:
1439
+ # out: [B, M, N]
1440
+ # mat1: [B, M, K]
1441
+ # mat2: [B, N, K]
1442
+ B = self.w_vc.size(0)
1443
+ N = self.w_vc.size(1)
1444
+ M = attn_output.size(0)
1445
+ output = torch.empty([M, int(B * N)], dtype=attn_output.dtype)
1446
+ attn_bmm_output = output.view([M, B, N]).transpose_(0, 1)
1447
+ torch.ops.sgl_kernel.bmm_cpu(
1448
+ attn_bmm_output,
1449
+ attn_output.transpose(0, 1),
1450
+ self.w_vc,
1451
+ True, # is_vnni
1452
+ None, # scale
1453
+ )
1454
+ attn_output = output
1455
+ output, _ = self.o_proj(attn_output)
1456
+
1457
+ return output
1458
+
1256
1459
  def _chunked_prefix_attn_mha(
1257
1460
  self,
1258
1461
  q: torch.Tensor,
@@ -1399,7 +1602,9 @@ class DeepseekV2DecoderLayer(nn.Module):
1399
1602
  rope_scaling = getattr(config, "rope_scaling", None)
1400
1603
  max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
1401
1604
  self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
1605
+ self.speculative_algorithm = global_server_args_dict["speculative_algorithm"]
1402
1606
  self.layer_id = layer_id
1607
+ self.is_nextn = is_nextn
1403
1608
  self.self_attn = DeepseekV2AttentionMLA(
1404
1609
  config=config,
1405
1610
  hidden_size=self.hidden_size,
@@ -1426,7 +1631,7 @@ class DeepseekV2DecoderLayer(nn.Module):
1426
1631
 
1427
1632
  self.layer_scatter_modes = LayerScatterModes.init_new(
1428
1633
  layer_id=layer_id,
1429
- num_layers=config.num_hidden_layers,
1634
+ num_layers=1 if is_nextn else config.num_hidden_layers,
1430
1635
  is_layer_sparse=self.is_layer_sparse,
1431
1636
  is_previous_layer_sparse=is_previous_layer_sparse,
1432
1637
  )
@@ -1437,6 +1642,7 @@ class DeepseekV2DecoderLayer(nn.Module):
1437
1642
  quant_config=quant_config,
1438
1643
  prefix=add_prefix("mlp", prefix),
1439
1644
  layer_id=self.layer_id,
1645
+ alt_stream=alt_stream,
1440
1646
  )
1441
1647
  else:
1442
1648
  if enable_moe_dense_fully_dp():
@@ -1479,6 +1685,7 @@ class DeepseekV2DecoderLayer(nn.Module):
1479
1685
  residual: Optional[torch.Tensor],
1480
1686
  zero_allocator: BumpAllocator,
1481
1687
  ) -> torch.Tensor:
1688
+
1482
1689
  hidden_states, residual = self.layer_communicator.prepare_attn(
1483
1690
  hidden_states, residual, forward_batch
1484
1691
  )
@@ -1500,6 +1707,11 @@ class DeepseekV2DecoderLayer(nn.Module):
1500
1707
  hidden_states, residual, forward_batch
1501
1708
  )
1502
1709
 
1710
+ if self.enable_dp_attention and self.speculative_algorithm.is_eagle():
1711
+ # NOTE: this line resolves the degradation of MTP reception rate for non-zero DP ranks.
1712
+ # See discussion here (https://github.com/sgl-project/sglang/pull/6081#discussion_r2147452251).
1713
+ hidden_states = hidden_states.clone()
1714
+
1503
1715
  return hidden_states, residual
1504
1716
 
1505
1717
  def op_comm_prepare_attn(
@@ -1607,8 +1819,6 @@ class DeepseekV2Model(nn.Module):
1607
1819
  )
1608
1820
  self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1609
1821
 
1610
- self.dp_size = get_local_attention_dp_size()
1611
-
1612
1822
  def get_input_embeddings(self) -> torch.Tensor:
1613
1823
  return self.embed_tokens
1614
1824
 
@@ -1692,7 +1902,6 @@ class DeepseekV2ForCausalLM(nn.Module):
1692
1902
  use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
1693
1903
  )
1694
1904
  self.logits_processor = LogitsProcessor(config)
1695
- self.dp_size = get_local_attention_dp_size()
1696
1905
 
1697
1906
  self._routed_experts_weights_of_layer = LazyValue(
1698
1907
  lambda: {
@@ -1717,12 +1926,12 @@ class DeepseekV2ForCausalLM(nn.Module):
1717
1926
  disable_reason = None
1718
1927
  if (
1719
1928
  not _is_cuda
1720
- or torch.cuda.get_device_capability("cuda") < (9, 0)
1929
+ or torch.cuda.get_device_capability("cuda") < (8, 0)
1721
1930
  or self.config.architectures[0] != architecture
1722
1931
  or self.config.n_routed_experts != 256
1723
1932
  or self.config.n_shared_experts != 1
1724
1933
  ):
1725
- disable_reason = "Only Deepseek V3/R1 on NV-platform with capability >= 90 can use shared experts fusion optimization."
1934
+ disable_reason = "Only Deepseek V3/R1 on NV-platform with capability >= 80 can use shared experts fusion optimization."
1726
1935
  elif (
1727
1936
  global_server_args_dict["enable_deepep_moe"]
1728
1937
  or global_server_args_dict["enable_ep_moe"]
@@ -1919,10 +2128,12 @@ class DeepseekV2ForCausalLM(nn.Module):
1919
2128
  if (
1920
2129
  deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
1921
2130
  and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
2131
+ and hasattr(self.quant_config, "weight_block_size")
2132
+ and self.quant_config.weight_block_size is not None
1922
2133
  ):
1923
- self._weight_requant_ue8m0()
2134
+ self._weight_requant_ue8m0(is_nextn)
1924
2135
 
1925
- def _weight_requant_ue8m0(self):
2136
+ def _weight_requant_ue8m0(self, is_nextn=False):
1926
2137
  weight_block_size = self.quant_config.weight_block_size
1927
2138
 
1928
2139
  moe_layers = list(
@@ -1933,8 +2144,12 @@ class DeepseekV2ForCausalLM(nn.Module):
1933
2144
  )
1934
2145
  )
1935
2146
 
1936
- for layer_id in range(self.config.num_hidden_layers):
1937
- layer = self.model.layers[layer_id]
2147
+ num_hidden_layers = 1 if is_nextn else self.config.num_hidden_layers
2148
+ for layer_id in range(num_hidden_layers):
2149
+ if is_nextn:
2150
+ layer = self.model.decoder
2151
+ else:
2152
+ layer = self.model.layers[layer_id]
1938
2153
 
1939
2154
  for module in [
1940
2155
  layer.self_attn.fused_qkv_a_proj_with_mqa,
@@ -1946,7 +2161,7 @@ class DeepseekV2ForCausalLM(nn.Module):
1946
2161
  module.weight, module.weight_scale_inv, weight_block_size
1947
2162
  )
1948
2163
 
1949
- if layer_id in moe_layers:
2164
+ if layer_id in moe_layers or is_nextn:
1950
2165
  shared_experts = getattr(layer.mlp, "shared_experts", None)
1951
2166
  if shared_experts is not None:
1952
2167
  for module in [
@@ -2022,7 +2237,7 @@ class DeepseekV2ForCausalLM(nn.Module):
2022
2237
 
2023
2238
  if self.num_fused_shared_experts > 0:
2024
2239
  assert self.num_fused_shared_experts == 1
2025
- logger.info("Shared experts fusion optimization enabled.")
2240
+ log_info_on_rank0(logger, "Shared experts fusion optimization enabled.")
2026
2241
 
2027
2242
  params_dict = dict(self.named_parameters())
2028
2243
  weight_names = []
@@ -2128,8 +2343,14 @@ class DeepseekV2ForCausalLM(nn.Module):
2128
2343
  ):
2129
2344
  q_a_proj_weight = cached_a_proj[q_a_proj_name]
2130
2345
  kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
2346
+ cat_dim = 0
2347
+ if self.quant_config is not None and (
2348
+ self.quant_config.get_name() == "awq"
2349
+ or self.quant_config.get_name() == "moe_wna16"
2350
+ ):
2351
+ cat_dim = 1
2131
2352
  fused_weight = torch.cat(
2132
- [q_a_proj_weight, kv_a_proj_weight], dim=0
2353
+ [q_a_proj_weight, kv_a_proj_weight], dim=cat_dim
2133
2354
  )
2134
2355
  param_name = (
2135
2356
  name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
@@ -2151,12 +2372,16 @@ class DeepseekV2ForCausalLM(nn.Module):
2151
2372
  "k_scale" in name or "v_scale" in name
2152
2373
  ) and name not in params_dict:
2153
2374
  # modelopt attn kv scale is named differently
2154
- if any(scale in name for scale in ["k_scale", "v_scale"]):
2155
- name = name.replace("_proj", "attn_mqa")
2156
- else:
2157
- logger.warning(
2158
- f"Unknown scale found in checkpoint: {name}"
2159
- )
2375
+ for scale in ["k_scale", "v_scale"]:
2376
+ if scale in name:
2377
+ name = name.replace(f"{scale[0]}_proj", "attn_mqa")
2378
+ break
2379
+ if name not in params_dict:
2380
+ # modelopt ckpt contains not needed weights for MTP module:
2381
+ # model.decoder.self_attn.attn_mqa.v_scale and
2382
+ # model.decoder.self_attn.attn_mqa.k_scale
2383
+ logger.warning(f"{name} not found in params_dict.")
2384
+ continue
2160
2385
  param = params_dict[name]
2161
2386
  weight_loader = getattr(
2162
2387
  param, "weight_loader", default_weight_loader