sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +7 -0
  3. sglang/bench_one_batch.py +8 -6
  4. sglang/bench_serving.py +1 -1
  5. sglang/lang/interpreter.py +40 -1
  6. sglang/lang/ir.py +27 -0
  7. sglang/math_utils.py +8 -0
  8. sglang/srt/_custom_ops.py +2 -2
  9. sglang/srt/code_completion_parser.py +2 -44
  10. sglang/srt/configs/model_config.py +6 -0
  11. sglang/srt/constants.py +3 -0
  12. sglang/srt/conversation.py +19 -3
  13. sglang/srt/custom_op.py +5 -1
  14. sglang/srt/disaggregation/base/__init__.py +1 -1
  15. sglang/srt/disaggregation/base/conn.py +25 -11
  16. sglang/srt/disaggregation/common/__init__.py +5 -1
  17. sglang/srt/disaggregation/common/utils.py +42 -0
  18. sglang/srt/disaggregation/decode.py +211 -72
  19. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  20. sglang/srt/disaggregation/fake/__init__.py +1 -1
  21. sglang/srt/disaggregation/fake/conn.py +15 -9
  22. sglang/srt/disaggregation/mini_lb.py +34 -4
  23. sglang/srt/disaggregation/mooncake/__init__.py +1 -1
  24. sglang/srt/disaggregation/mooncake/conn.py +30 -29
  25. sglang/srt/disaggregation/nixl/__init__.py +6 -1
  26. sglang/srt/disaggregation/nixl/conn.py +17 -12
  27. sglang/srt/disaggregation/prefill.py +144 -55
  28. sglang/srt/disaggregation/utils.py +155 -123
  29. sglang/srt/distributed/parallel_state.py +12 -4
  30. sglang/srt/entrypoints/engine.py +37 -29
  31. sglang/srt/entrypoints/http_server.py +153 -72
  32. sglang/srt/entrypoints/http_server_engine.py +0 -3
  33. sglang/srt/entrypoints/openai/__init__.py +0 -0
  34. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
  35. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  36. sglang/srt/entrypoints/openai/serving_chat.py +921 -0
  37. sglang/srt/entrypoints/openai/serving_completions.py +424 -0
  38. sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
  39. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  40. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  41. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  42. sglang/srt/entrypoints/openai/utils.py +72 -0
  43. sglang/srt/eplb_simulator/__init__.py +1 -0
  44. sglang/srt/eplb_simulator/reader.py +51 -0
  45. sglang/srt/function_call/base_format_detector.py +7 -4
  46. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  47. sglang/srt/function_call/ebnf_composer.py +64 -10
  48. sglang/srt/function_call/function_call_parser.py +6 -6
  49. sglang/srt/function_call/llama32_detector.py +1 -1
  50. sglang/srt/function_call/mistral_detector.py +1 -1
  51. sglang/srt/function_call/pythonic_detector.py +1 -1
  52. sglang/srt/function_call/qwen25_detector.py +1 -1
  53. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  54. sglang/srt/layers/activation.py +40 -3
  55. sglang/srt/layers/attention/aiter_backend.py +20 -4
  56. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  57. sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
  58. sglang/srt/layers/attention/flashattention_backend.py +71 -72
  59. sglang/srt/layers/attention/flashinfer_backend.py +10 -8
  60. sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
  61. sglang/srt/layers/attention/flashmla_backend.py +7 -12
  62. sglang/srt/layers/attention/tbo_backend.py +3 -3
  63. sglang/srt/layers/attention/triton_backend.py +138 -130
  64. sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
  65. sglang/srt/layers/attention/vision.py +51 -24
  66. sglang/srt/layers/communicator.py +28 -10
  67. sglang/srt/layers/dp_attention.py +11 -2
  68. sglang/srt/layers/layernorm.py +29 -2
  69. sglang/srt/layers/linear.py +0 -4
  70. sglang/srt/layers/logits_processor.py +2 -14
  71. sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
  72. sglang/srt/layers/moe/ep_moe/layer.py +249 -33
  73. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
  74. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
  76. sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
  77. sglang/srt/layers/moe/topk.py +107 -12
  78. sglang/srt/layers/pooler.py +56 -0
  79. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  80. sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
  81. sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
  82. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
  83. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
  84. sglang/srt/layers/quantization/fp8.py +25 -17
  85. sglang/srt/layers/quantization/fp8_kernel.py +44 -15
  86. sglang/srt/layers/quantization/fp8_utils.py +87 -22
  87. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  88. sglang/srt/layers/quantization/utils.py +5 -2
  89. sglang/srt/layers/radix_attention.py +2 -3
  90. sglang/srt/layers/rotary_embedding.py +42 -2
  91. sglang/srt/layers/sampler.py +1 -1
  92. sglang/srt/lora/lora_manager.py +249 -105
  93. sglang/srt/lora/mem_pool.py +53 -50
  94. sglang/srt/lora/utils.py +1 -1
  95. sglang/srt/managers/cache_controller.py +33 -14
  96. sglang/srt/managers/io_struct.py +31 -10
  97. sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
  98. sglang/srt/managers/multimodal_processors/vila.py +85 -0
  99. sglang/srt/managers/schedule_batch.py +79 -37
  100. sglang/srt/managers/schedule_policy.py +70 -56
  101. sglang/srt/managers/scheduler.py +220 -79
  102. sglang/srt/managers/template_manager.py +226 -0
  103. sglang/srt/managers/tokenizer_manager.py +40 -10
  104. sglang/srt/managers/tp_worker.py +12 -2
  105. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  106. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  107. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  108. sglang/srt/mem_cache/chunk_cache.py +11 -15
  109. sglang/srt/mem_cache/hiradix_cache.py +38 -25
  110. sglang/srt/mem_cache/memory_pool.py +213 -505
  111. sglang/srt/mem_cache/memory_pool_host.py +380 -0
  112. sglang/srt/mem_cache/radix_cache.py +56 -28
  113. sglang/srt/model_executor/cuda_graph_runner.py +198 -100
  114. sglang/srt/model_executor/forward_batch_info.py +32 -10
  115. sglang/srt/model_executor/model_runner.py +28 -12
  116. sglang/srt/model_loader/loader.py +16 -2
  117. sglang/srt/model_loader/weight_utils.py +11 -2
  118. sglang/srt/models/bert.py +113 -13
  119. sglang/srt/models/deepseek_nextn.py +29 -27
  120. sglang/srt/models/deepseek_v2.py +213 -173
  121. sglang/srt/models/glm4.py +312 -0
  122. sglang/srt/models/internvl.py +46 -102
  123. sglang/srt/models/mimo_mtp.py +2 -18
  124. sglang/srt/models/roberta.py +117 -9
  125. sglang/srt/models/vila.py +305 -0
  126. sglang/srt/reasoning_parser.py +21 -11
  127. sglang/srt/sampling/sampling_batch_info.py +24 -0
  128. sglang/srt/sampling/sampling_params.py +2 -0
  129. sglang/srt/server_args.py +351 -238
  130. sglang/srt/speculative/build_eagle_tree.py +1 -1
  131. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
  132. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
  133. sglang/srt/speculative/eagle_utils.py +468 -116
  134. sglang/srt/speculative/eagle_worker.py +258 -84
  135. sglang/srt/torch_memory_saver_adapter.py +19 -15
  136. sglang/srt/two_batch_overlap.py +4 -2
  137. sglang/srt/utils.py +235 -11
  138. sglang/test/attention/test_prefix_chunk_info.py +2 -0
  139. sglang/test/runners.py +38 -3
  140. sglang/test/test_block_fp8.py +1 -0
  141. sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
  142. sglang/test/test_block_fp8_ep.py +2 -0
  143. sglang/test/test_utils.py +4 -1
  144. sglang/utils.py +9 -0
  145. sglang/version.py +1 -1
  146. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
  147. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
  148. sglang/srt/entrypoints/verl_engine.py +0 -179
  149. sglang/srt/openai_api/adapter.py +0 -1990
  150. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
  151. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
  152. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -51,11 +51,11 @@ from sglang.srt.layers.linear import (
51
51
  RowParallelLinear,
52
52
  )
53
53
  from sglang.srt.layers.logits_processor import LogitsProcessor
54
- from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
54
+ from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
55
55
  from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
56
56
  from sglang.srt.layers.moe.topk import select_experts
57
+ from sglang.srt.layers.quantization import deep_gemm_wrapper
57
58
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
58
- from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
59
59
  from sglang.srt.layers.quantization.fp8_kernel import (
60
60
  is_fp8_fnuz,
61
61
  per_tensor_quant_mla_fp8,
@@ -66,12 +66,13 @@ from sglang.srt.layers.quantization.fp8_utils import (
66
66
  block_quant_to_tensor_quant,
67
67
  channel_quant_to_tensor_quant,
68
68
  normalize_e4m3fn_to_e4m3fnuz,
69
+ requant_weight_ue8m0_inplace,
69
70
  )
70
71
  from sglang.srt.layers.quantization.int8_utils import (
71
72
  block_dequant as int8_block_dequant,
72
73
  )
73
74
  from sglang.srt.layers.radix_attention import RadixAttention
74
- from sglang.srt.layers.rotary_embedding import get_rope
75
+ from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
75
76
  from sglang.srt.layers.vocab_parallel_embedding import (
76
77
  ParallelLMHead,
77
78
  VocabParallelEmbedding,
@@ -94,8 +95,10 @@ from sglang.srt.utils import (
94
95
  LazyValue,
95
96
  add_prefix,
96
97
  bind_or_assign,
98
+ cpu_has_amx_support,
97
99
  get_bool_env_var,
98
100
  get_int_env_var,
101
+ is_cpu,
99
102
  is_cuda,
100
103
  is_hip,
101
104
  is_non_idle_and_non_empty,
@@ -106,13 +109,13 @@ _is_hip = is_hip()
106
109
  _is_cuda = is_cuda()
107
110
  _is_fp8_fnuz = is_fp8_fnuz()
108
111
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
112
+ _is_cpu_amx_available = cpu_has_amx_support()
113
+ _is_cpu = is_cpu()
109
114
 
110
115
  if _is_cuda:
111
116
  from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
112
-
113
- from sglang.srt.layers.quantization.deep_gemm import (
114
- grouped_gemm_nt_f8f8bf16_masked as deep_gemm_grouped_gemm_nt_f8f8bf16_masked,
115
- )
117
+ elif _is_cpu and _is_cpu_amx_available:
118
+ pass
116
119
  else:
117
120
  from vllm._custom_ops import awq_dequantize
118
121
 
@@ -223,6 +226,7 @@ class DeepseekV2MoE(nn.Module):
223
226
  layer_id: int,
224
227
  quant_config: Optional[QuantizationConfig] = None,
225
228
  prefix: str = "",
229
+ alt_stream: Optional[torch.cuda.Stream] = None,
226
230
  ):
227
231
  super().__init__()
228
232
  self.tp_size = get_tensor_model_parallel_world_size()
@@ -235,6 +239,7 @@ class DeepseekV2MoE(nn.Module):
235
239
  )
236
240
  self.config = config
237
241
  self.layer_id = layer_id
242
+ self.alt_stream = alt_stream
238
243
 
239
244
  if self.tp_size > config.n_routed_experts:
240
245
  raise ValueError(
@@ -272,6 +277,15 @@ class DeepseekV2MoE(nn.Module):
272
277
  if global_server_args_dict["enable_deepep_moe"]
273
278
  else {}
274
279
  ),
280
+ # Additional args for FusedMoE
281
+ **(
282
+ dict(
283
+ enable_flashinfer_moe=True,
284
+ enable_ep_moe=global_server_args_dict["enable_ep_moe"],
285
+ )
286
+ if global_server_args_dict["enable_flashinfer_moe"]
287
+ else {}
288
+ ),
275
289
  )
276
290
 
277
291
  if config.n_shared_experts is not None and self.num_fused_shared_experts == 0:
@@ -335,10 +349,38 @@ class DeepseekV2MoE(nn.Module):
335
349
  self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
336
350
  ) -> torch.Tensor:
337
351
  if not self._enable_deepep_moe:
338
- return self.forward_normal(hidden_states)
352
+ DUAL_STREAM_TOKEN_THRESHOLD = 1024
353
+ if (
354
+ self.alt_stream is not None
355
+ and self.num_fused_shared_experts == 0
356
+ and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
357
+ ):
358
+ return self.forward_normal_dual_stream(hidden_states)
359
+ else:
360
+ return self.forward_normal(hidden_states)
339
361
  else:
340
362
  return self.forward_deepep(hidden_states, forward_batch)
341
363
 
364
+ def forward_normal_dual_stream(self, hidden_states: torch.Tensor) -> torch.Tensor:
365
+ # router_logits: (num_tokens, n_experts)
366
+ router_logits = self.gate(hidden_states)
367
+
368
+ current_stream = torch.cuda.current_stream()
369
+ self.alt_stream.wait_stream(current_stream)
370
+ shared_output = self._forward_shared_experts(hidden_states)
371
+
372
+ with torch.cuda.stream(self.alt_stream):
373
+ final_hidden_states = self.experts(
374
+ hidden_states=hidden_states, router_logits=router_logits
375
+ )
376
+ if not _is_cuda:
377
+ final_hidden_states *= self.routed_scaling_factor
378
+ current_stream.wait_stream(self.alt_stream)
379
+ final_hidden_states = final_hidden_states + shared_output
380
+ if self.tp_size > 1:
381
+ final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
382
+ return final_hidden_states
383
+
342
384
  def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
343
385
  shared_output = self._forward_shared_experts(hidden_states)
344
386
  # router_logits: (num_tokens, n_experts)
@@ -668,13 +710,14 @@ class DeepseekV2AttentionMLA(nn.Module):
668
710
  if rope_scaling:
669
711
  rope_scaling["rope_type"] = "deepseek_yarn"
670
712
 
671
- self.rotary_emb = get_rope(
713
+ self.rotary_emb = get_rope_wrapper(
672
714
  qk_rope_head_dim,
673
715
  rotary_dim=qk_rope_head_dim,
674
716
  max_position=max_position_embeddings,
675
717
  base=rope_theta,
676
718
  rope_scaling=rope_scaling,
677
719
  is_neox_style=False,
720
+ device=global_server_args_dict["device"],
678
721
  )
679
722
 
680
723
  if rope_scaling:
@@ -980,7 +1023,7 @@ class DeepseekV2AttentionMLA(nn.Module):
980
1023
  q_nope_out = q_nope.new_empty(
981
1024
  (self.num_local_heads, aligned_m, self.kv_lora_rank)
982
1025
  )
983
- deep_gemm_grouped_gemm_nt_f8f8bf16_masked(
1026
+ deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
984
1027
  (q_nope_val, q_nope_scale),
985
1028
  (self.w_kc, self.w_scale_k),
986
1029
  q_nope_out,
@@ -1013,7 +1056,11 @@ class DeepseekV2AttentionMLA(nn.Module):
1013
1056
  def forward_absorb_core(
1014
1057
  self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
1015
1058
  ):
1016
- if self.attention_backend == "fa3" or self.attention_backend == "flashinfer":
1059
+ if (
1060
+ self.attention_backend == "fa3"
1061
+ or self.attention_backend == "flashinfer"
1062
+ or self.attention_backend == "cutlass_mla"
1063
+ ):
1017
1064
  attn_output = self.attn_mqa(
1018
1065
  q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
1019
1066
  )
@@ -1032,20 +1079,23 @@ class DeepseekV2AttentionMLA(nn.Module):
1032
1079
  attn_bmm_output = attn_output.new_empty(
1033
1080
  (self.num_local_heads, aligned_m, self.v_head_dim)
1034
1081
  )
1035
- deep_gemm_grouped_gemm_nt_f8f8bf16_masked(
1082
+ deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
1036
1083
  (attn_output_val, attn_output_scale),
1037
1084
  (self.w_vc, self.w_scale_v),
1038
1085
  attn_bmm_output,
1039
1086
  masked_m,
1040
1087
  expected_m,
1041
1088
  )
1042
- attn_bmm_output = attn_bmm_output[:, :expected_m, :]
1089
+ attn_bmm_output = (
1090
+ attn_bmm_output[:, :expected_m, :].transpose(0, 1).flatten(1, 2)
1091
+ )
1043
1092
  elif _is_hip:
1044
1093
  # TODO(haishaw): add bmm_fp8 to ROCm
1045
1094
  attn_bmm_output = torch.bmm(
1046
1095
  attn_output.to(torch.bfloat16).transpose(0, 1),
1047
1096
  self.w_vc.to(torch.bfloat16) * self.w_scale,
1048
1097
  )
1098
+ attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
1049
1099
  elif self.w_vc.dtype == torch.float8_e4m3fn:
1050
1100
  attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
1051
1101
  attn_output.transpose(0, 1),
@@ -1058,10 +1108,21 @@ class DeepseekV2AttentionMLA(nn.Module):
1058
1108
  self.w_scale,
1059
1109
  torch.bfloat16,
1060
1110
  )
1111
+ attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
1061
1112
  else:
1062
- attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc)
1063
- attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
1064
- output, _ = self.o_proj(attn_output)
1113
+ attn_bmm_output = torch.empty(
1114
+ (attn_output.shape[0], self.num_local_heads * self.v_head_dim),
1115
+ dtype=attn_output.dtype,
1116
+ device=attn_output.device,
1117
+ )
1118
+ torch.bmm(
1119
+ attn_output.transpose(0, 1),
1120
+ self.w_vc,
1121
+ out=attn_bmm_output.view(
1122
+ -1, self.num_local_heads, self.v_head_dim
1123
+ ).transpose(0, 1),
1124
+ )
1125
+ output, _ = self.o_proj(attn_bmm_output)
1065
1126
 
1066
1127
  return output
1067
1128
 
@@ -1398,7 +1459,9 @@ class DeepseekV2DecoderLayer(nn.Module):
1398
1459
  rope_scaling = getattr(config, "rope_scaling", None)
1399
1460
  max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
1400
1461
  self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
1462
+ self.speculative_algorithm = global_server_args_dict["speculative_algorithm"]
1401
1463
  self.layer_id = layer_id
1464
+ self.is_nextn = is_nextn
1402
1465
  self.self_attn = DeepseekV2AttentionMLA(
1403
1466
  config=config,
1404
1467
  hidden_size=self.hidden_size,
@@ -1425,7 +1488,7 @@ class DeepseekV2DecoderLayer(nn.Module):
1425
1488
 
1426
1489
  self.layer_scatter_modes = LayerScatterModes.init_new(
1427
1490
  layer_id=layer_id,
1428
- num_layers=config.num_hidden_layers,
1491
+ num_layers=1 if is_nextn else config.num_hidden_layers,
1429
1492
  is_layer_sparse=self.is_layer_sparse,
1430
1493
  is_previous_layer_sparse=is_previous_layer_sparse,
1431
1494
  )
@@ -1436,6 +1499,7 @@ class DeepseekV2DecoderLayer(nn.Module):
1436
1499
  quant_config=quant_config,
1437
1500
  prefix=add_prefix("mlp", prefix),
1438
1501
  layer_id=self.layer_id,
1502
+ alt_stream=alt_stream,
1439
1503
  )
1440
1504
  else:
1441
1505
  if enable_moe_dense_fully_dp():
@@ -1478,6 +1542,7 @@ class DeepseekV2DecoderLayer(nn.Module):
1478
1542
  residual: Optional[torch.Tensor],
1479
1543
  zero_allocator: BumpAllocator,
1480
1544
  ) -> torch.Tensor:
1545
+
1481
1546
  hidden_states, residual = self.layer_communicator.prepare_attn(
1482
1547
  hidden_states, residual, forward_batch
1483
1548
  )
@@ -1499,6 +1564,11 @@ class DeepseekV2DecoderLayer(nn.Module):
1499
1564
  hidden_states, residual, forward_batch
1500
1565
  )
1501
1566
 
1567
+ if self.enable_dp_attention and self.speculative_algorithm.is_eagle():
1568
+ # NOTE: this line resolves the degradation of MTP reception rate for non-zero DP ranks.
1569
+ # See discussion here (https://github.com/sgl-project/sglang/pull/6081#discussion_r2147452251).
1570
+ hidden_states = hidden_states.clone()
1571
+
1502
1572
  return hidden_states, residual
1503
1573
 
1504
1574
  def op_comm_prepare_attn(
@@ -1606,8 +1676,6 @@ class DeepseekV2Model(nn.Module):
1606
1676
  )
1607
1677
  self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1608
1678
 
1609
- self.dp_size = get_local_attention_dp_size()
1610
-
1611
1679
  def get_input_embeddings(self) -> torch.Tensor:
1612
1680
  return self.embed_tokens
1613
1681
 
@@ -1691,7 +1759,6 @@ class DeepseekV2ForCausalLM(nn.Module):
1691
1759
  use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
1692
1760
  )
1693
1761
  self.logits_processor = LogitsProcessor(config)
1694
- self.dp_size = get_local_attention_dp_size()
1695
1762
 
1696
1763
  self._routed_experts_weights_of_layer = LazyValue(
1697
1764
  lambda: {
@@ -1708,53 +1775,35 @@ class DeepseekV2ForCausalLM(nn.Module):
1708
1775
  def determine_num_fused_shared_experts(
1709
1776
  self, architecture: str = "DeepseekV3ForCausalLM"
1710
1777
  ):
1711
- self.num_fused_shared_experts = (
1712
- 0
1713
- if global_server_args_dict["disable_shared_experts_fusion"]
1714
- else self.config.n_shared_experts
1715
- )
1716
- if self.num_fused_shared_experts > 0:
1717
- # Only Deepseek V3/R1 can use shared experts fusion optimization now.
1718
- if (
1719
- not _is_cuda
1720
- or self.config.architectures[0] != architecture
1721
- or self.config.n_routed_experts != 256
1722
- ):
1723
- self.num_fused_shared_experts = 0
1724
- global_server_args_dict["disable_shared_experts_fusion"] = True
1725
- log_info_on_rank0(
1726
- logger,
1727
- "Only Deepseek V3/R1 on NV-platform can use shared experts fusion optimization. Shared experts fusion optimization is disabled.",
1728
- )
1729
- elif (
1730
- global_server_args_dict["enable_deepep_moe"]
1731
- or global_server_args_dict["enable_ep_moe"]
1732
- ):
1733
- self.num_fused_shared_experts = 0
1734
- global_server_args_dict["disable_shared_experts_fusion"] = True
1735
- log_info_on_rank0(
1736
- logger,
1737
- "Deepseek V3/R1 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode. Shared experts fusion optimization is disabled.",
1738
- )
1739
- elif self.num_fused_shared_experts == 0:
1740
- if (
1741
- _is_cuda
1742
- and torch.cuda.get_device_capability("cuda") >= (9, 0)
1743
- and self.config.architectures[0] == architecture
1744
- and self.config.n_routed_experts == 256
1745
- and (
1746
- not (
1747
- global_server_args_dict["enable_deepep_moe"]
1748
- or global_server_args_dict["enable_ep_moe"]
1749
- )
1750
- )
1751
- ):
1752
- self.num_fused_shared_experts = self.config.n_shared_experts
1753
- global_server_args_dict["disable_shared_experts_fusion"] = False
1754
- log_info_on_rank0(
1755
- logger,
1756
- "Deepseek V3/R1 with fp8/fp4 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled.",
1757
- )
1778
+ self.num_fused_shared_experts = 0
1779
+ if global_server_args_dict["disable_shared_experts_fusion"]:
1780
+ return
1781
+
1782
+ # Only Deepseek V3/R1 can use shared experts fusion optimization now.
1783
+ disable_reason = None
1784
+ if (
1785
+ not _is_cuda
1786
+ or torch.cuda.get_device_capability("cuda") < (8, 0)
1787
+ or self.config.architectures[0] != architecture
1788
+ or self.config.n_routed_experts != 256
1789
+ or self.config.n_shared_experts != 1
1790
+ ):
1791
+ disable_reason = "Only Deepseek V3/R1 on NV-platform with capability >= 80 can use shared experts fusion optimization."
1792
+ elif (
1793
+ global_server_args_dict["enable_deepep_moe"]
1794
+ or global_server_args_dict["enable_ep_moe"]
1795
+ ):
1796
+ disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
1797
+
1798
+ if disable_reason is not None:
1799
+ global_server_args_dict["disable_shared_experts_fusion"] = True
1800
+ log_info_on_rank0(
1801
+ logger,
1802
+ f"{disable_reason} Shared experts fusion optimization is disabled.",
1803
+ )
1804
+ return
1805
+
1806
+ self.num_fused_shared_experts = self.config.n_shared_experts
1758
1807
 
1759
1808
  def get_input_embeddings(self) -> nn.Embedding:
1760
1809
  return self.model.embed_tokens
@@ -1786,8 +1835,7 @@ class DeepseekV2ForCausalLM(nn.Module):
1786
1835
  for name in weight_names:
1787
1836
  if "kv_b_proj" in name:
1788
1837
  layer_id = int(name.split(".")[2])
1789
- # filter the nextn layer.
1790
- if layer_id != self.config.num_hidden_layers:
1838
+ if layer_id < self.config.num_hidden_layers:
1791
1839
  layer_ids.add(layer_id)
1792
1840
 
1793
1841
  for layer_id in layer_ids:
@@ -1847,8 +1895,10 @@ class DeepseekV2ForCausalLM(nn.Module):
1847
1895
  and weight_block_size[1] == 128
1848
1896
  and model_dtype == torch.bfloat16
1849
1897
  ):
1850
- if _ENABLE_JIT_DEEPGEMM and get_bool_env_var(
1851
- "SGL_USE_DEEPGEMM_BMM", "false"
1898
+ if (
1899
+ deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
1900
+ and not deep_gemm_wrapper.DEEPGEMM_BLACKWELL
1901
+ and get_bool_env_var("SGL_USE_DEEPGEMM_BMM", "false")
1852
1902
  ):
1853
1903
  block_scale = weight_scale
1854
1904
  use_deep_gemm_bmm = True
@@ -1932,6 +1982,71 @@ class DeepseekV2ForCausalLM(nn.Module):
1932
1982
  self_attn.w_vc = bind_or_assign(self_attn.w_vc, w_vc.contiguous())
1933
1983
  self_attn.use_deep_gemm_bmm = True
1934
1984
 
1985
+ if (
1986
+ deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
1987
+ and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
1988
+ and hasattr(self.quant_config, "weight_block_size")
1989
+ and self.quant_config.weight_block_size is not None
1990
+ ):
1991
+ self._weight_requant_ue8m0(is_nextn)
1992
+
1993
+ def _weight_requant_ue8m0(self, is_nextn=False):
1994
+ weight_block_size = self.quant_config.weight_block_size
1995
+
1996
+ moe_layers = list(
1997
+ range(
1998
+ self.config.first_k_dense_replace,
1999
+ self.config.num_hidden_layers,
2000
+ self.config.moe_layer_freq,
2001
+ )
2002
+ )
2003
+
2004
+ num_hidden_layers = 1 if is_nextn else self.config.num_hidden_layers
2005
+ for layer_id in range(num_hidden_layers):
2006
+ if is_nextn:
2007
+ layer = self.model.decoder
2008
+ else:
2009
+ layer = self.model.layers[layer_id]
2010
+
2011
+ for module in [
2012
+ layer.self_attn.fused_qkv_a_proj_with_mqa,
2013
+ layer.self_attn.q_b_proj,
2014
+ layer.self_attn.kv_b_proj,
2015
+ layer.self_attn.o_proj,
2016
+ ]:
2017
+ requant_weight_ue8m0_inplace(
2018
+ module.weight, module.weight_scale_inv, weight_block_size
2019
+ )
2020
+
2021
+ if layer_id in moe_layers or is_nextn:
2022
+ shared_experts = getattr(layer.mlp, "shared_experts", None)
2023
+ if shared_experts is not None:
2024
+ for module in [
2025
+ shared_experts.gate_up_proj,
2026
+ shared_experts.down_proj,
2027
+ ]:
2028
+ requant_weight_ue8m0_inplace(
2029
+ module.weight, module.weight_scale_inv, weight_block_size
2030
+ )
2031
+
2032
+ experts = layer.mlp.experts
2033
+ if isinstance(experts, DeepEPMoE):
2034
+ for w in [
2035
+ experts.w13_weight_fp8,
2036
+ experts.w2_weight_fp8,
2037
+ ]:
2038
+ requant_weight_ue8m0_inplace(w[0], w[1], weight_block_size)
2039
+ else:
2040
+ mlp = layer.mlp
2041
+ assert isinstance(mlp, DeepseekV2MLP)
2042
+ for module in [
2043
+ mlp.gate_up_proj,
2044
+ mlp.down_proj,
2045
+ ]:
2046
+ requant_weight_ue8m0_inplace(
2047
+ module.weight, module.weight_scale_inv, weight_block_size
2048
+ )
2049
+
1935
2050
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
1936
2051
 
1937
2052
  if is_nextn:
@@ -1952,101 +2067,6 @@ class DeepseekV2ForCausalLM(nn.Module):
1952
2067
  ("gate_up_proj", "gate_proj", 0),
1953
2068
  ("gate_up_proj", "up_proj", 1),
1954
2069
  ]
1955
- if self.num_fused_shared_experts > 0:
1956
- assert self.num_fused_shared_experts == 1
1957
- weights_list = list(weights)
1958
- weights_dict = dict(weights_list)
1959
- if self.quant_config is not None:
1960
- if self.quant_config.get_name() == "w8a8_int8":
1961
- suffix_list = [
1962
- "down_proj.weight",
1963
- "down_proj.weight_scale",
1964
- "gate_proj.weight",
1965
- "gate_proj.weight_scale",
1966
- "up_proj.weight",
1967
- "up_proj.weight_scale",
1968
- ]
1969
- elif (
1970
- self.quant_config.get_name() == "fp8"
1971
- or self.quant_config.get_name() == "blockwise_int8"
1972
- ):
1973
- suffix_list = [
1974
- "down_proj.weight",
1975
- "down_proj.weight_scale_inv",
1976
- "gate_proj.weight",
1977
- "gate_proj.weight_scale_inv",
1978
- "up_proj.weight",
1979
- "up_proj.weight_scale_inv",
1980
- ]
1981
- elif self.quant_config.get_name() == "awq":
1982
- suffix_list = [
1983
- "down_proj.qweight",
1984
- "down_proj.qzeros",
1985
- "down_proj.scales",
1986
- "gate_proj.qweight",
1987
- "gate_proj.qzeros",
1988
- "gate_proj.scales",
1989
- "up_proj.qweight",
1990
- "up_proj.qzeros",
1991
- "up_proj.scales",
1992
- ]
1993
- elif self.quant_config.get_name() == "modelopt_fp4":
1994
- suffix_list = [
1995
- "down_proj.weight",
1996
- "down_proj.weight_scale",
1997
- "down_proj.weight_scale_2",
1998
- "down_proj.input_scale",
1999
- "gate_proj.weight",
2000
- "gate_proj.weight_scale",
2001
- "gate_proj.weight_scale_2",
2002
- "gate_proj.input_scale",
2003
- "up_proj.weight",
2004
- "up_proj.weight_scale",
2005
- "up_proj.weight_scale_2",
2006
- "up_proj.input_scale",
2007
- ]
2008
- else:
2009
- raise ValueError(
2010
- f"Unsupported shared expert fusion for quantization: {self.quant_config.get_name()}."
2011
- )
2012
- else:
2013
- suffix_list = [
2014
- "down_proj.weight",
2015
- "gate_proj.weight",
2016
- "up_proj.weight",
2017
- ]
2018
- names_to_remove = []
2019
-
2020
- moe_layers = (
2021
- range(
2022
- self.config.first_k_dense_replace,
2023
- self.config.num_hidden_layers,
2024
- self.config.moe_layer_freq,
2025
- )
2026
- if not is_nextn
2027
- else [nextn_layer_id]
2028
- )
2029
-
2030
- for moe_layer in tqdm(
2031
- moe_layers,
2032
- desc=f"Cloning {self.num_fused_shared_experts} "
2033
- "shared expert into MoE",
2034
- ):
2035
- for suffix in suffix_list:
2036
- shared_expert_weight_name = (
2037
- f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
2038
- )
2039
- weights_list.append(
2040
- (
2041
- f"model.layers.{moe_layer}."
2042
- f"mlp.experts."
2043
- f"{self.config.n_routed_experts + 0}"
2044
- f".{suffix}",
2045
- weights_dict[shared_expert_weight_name],
2046
- )
2047
- )
2048
- names_to_remove += [shared_expert_weight_name]
2049
- weights = [w for w in weights_list if w[0] not in names_to_remove]
2050
2070
 
2051
2071
  # Params for weights, fp8 weight scales, fp8 activation scales
2052
2072
  # (param_name, weight_name, expert_id, shard_id)
@@ -2072,9 +2092,19 @@ class DeepseekV2ForCausalLM(nn.Module):
2072
2092
  "hnorm",
2073
2093
  ]
2074
2094
 
2095
+ if self.num_fused_shared_experts > 0:
2096
+ assert self.num_fused_shared_experts == 1
2097
+ log_info_on_rank0(logger, "Shared experts fusion optimization enabled.")
2098
+
2075
2099
  params_dict = dict(self.named_parameters())
2076
2100
  weight_names = []
2077
2101
  for name, loaded_weight in weights:
2102
+ if self.num_fused_shared_experts > 0 and "mlp.shared_experts" in name:
2103
+ name = name.replace(
2104
+ "mlp.shared_experts",
2105
+ f"mlp.experts.{self.config.n_routed_experts}",
2106
+ )
2107
+
2078
2108
  weight_names.append(name)
2079
2109
 
2080
2110
  if not is_nextn:
@@ -2170,8 +2200,14 @@ class DeepseekV2ForCausalLM(nn.Module):
2170
2200
  ):
2171
2201
  q_a_proj_weight = cached_a_proj[q_a_proj_name]
2172
2202
  kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
2203
+ cat_dim = 0
2204
+ if self.quant_config is not None and (
2205
+ self.quant_config.get_name() == "awq"
2206
+ or self.quant_config.get_name() == "moe_wna16"
2207
+ ):
2208
+ cat_dim = 1
2173
2209
  fused_weight = torch.cat(
2174
- [q_a_proj_weight, kv_a_proj_weight], dim=0
2210
+ [q_a_proj_weight, kv_a_proj_weight], dim=cat_dim
2175
2211
  )
2176
2212
  param_name = (
2177
2213
  name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
@@ -2193,12 +2229,16 @@ class DeepseekV2ForCausalLM(nn.Module):
2193
2229
  "k_scale" in name or "v_scale" in name
2194
2230
  ) and name not in params_dict:
2195
2231
  # modelopt attn kv scale is named differently
2196
- if any(scale in name for scale in ["k_scale", "v_scale"]):
2197
- name = name.replace("_proj", "attn_mqa")
2198
- else:
2199
- logger.warning(
2200
- f"Unknown scale found in checkpoint: {name}"
2201
- )
2232
+ for scale in ["k_scale", "v_scale"]:
2233
+ if scale in name:
2234
+ name = name.replace(f"{scale[0]}_proj", "attn_mqa")
2235
+ break
2236
+ if name not in params_dict:
2237
+ # modelopt ckpt contains not needed weights for MTP module:
2238
+ # model.decoder.self_attn.attn_mqa.v_scale and
2239
+ # model.decoder.self_attn.attn_mqa.k_scale
2240
+ logger.warning(f"{name} not found in params_dict.")
2241
+ continue
2202
2242
  param = params_dict[name]
2203
2243
  weight_loader = getattr(
2204
2244
  param, "weight_loader", default_weight_loader