sglang 0.4.7.post1__py3-none-any.whl → 0.4.8.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +8 -6
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +14 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/disaggregation/base/conn.py +2 -0
- sglang/srt/disaggregation/decode.py +22 -28
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/conn.py +301 -64
- sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
- sglang/srt/disaggregation/nixl/conn.py +94 -46
- sglang/srt/disaggregation/prefill.py +20 -15
- sglang/srt/disaggregation/utils.py +47 -18
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +27 -31
- sglang/srt/entrypoints/http_server.py +149 -79
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +115 -34
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +897 -0
- sglang/srt/entrypoints/openai/serving_completions.py +425 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +170 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +28 -3
- sglang/srt/layers/attention/aiter_backend.py +5 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
- sglang/srt/layers/attention/flashattention_backend.py +43 -23
- sglang/srt/layers/attention/flashinfer_backend.py +9 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
- sglang/srt/layers/attention/flashmla_backend.py +5 -2
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +19 -11
- sglang/srt/layers/communicator.py +5 -5
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +44 -2
- sglang/srt/layers/linear.py +18 -1
- sglang/srt/layers/logits_processor.py +14 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
- sglang/srt/layers/moe/ep_moe/layer.py +286 -13
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +19 -2
- sglang/srt/layers/moe/fused_moe_native.py +7 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +13 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +148 -26
- sglang/srt/layers/moe/topk.py +117 -4
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/fp8_utils.py +5 -4
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/rotary_embedding.py +144 -12
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/layers/vocab_parallel_embedding.py +14 -1
- sglang/srt/lora/lora_manager.py +173 -74
- sglang/srt/lora/mem_pool.py +49 -45
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -15
- sglang/srt/managers/expert_distribution.py +21 -0
- sglang/srt/managers/io_struct.py +19 -14
- sglang/srt/managers/multimodal_processors/base_processor.py +44 -9
- sglang/srt/managers/multimodal_processors/gemma3n.py +97 -0
- sglang/srt/managers/schedule_batch.py +49 -32
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +189 -68
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +11 -8
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -16
- sglang/srt/mem_cache/hiradix_cache.py +34 -23
- sglang/srt/mem_cache/memory_pool.py +118 -114
- sglang/srt/mem_cache/radix_cache.py +20 -16
- sglang/srt/model_executor/cuda_graph_runner.py +77 -46
- sglang/srt/model_executor/forward_batch_info.py +18 -5
- sglang/srt/model_executor/model_runner.py +27 -8
- sglang/srt/model_loader/loader.py +50 -8
- sglang/srt/model_loader/weight_utils.py +100 -2
- sglang/srt/models/deepseek_nextn.py +35 -30
- sglang/srt/models/deepseek_v2.py +255 -30
- sglang/srt/models/gemma3n_audio.py +949 -0
- sglang/srt/models/gemma3n_causal.py +1009 -0
- sglang/srt/models/gemma3n_mm.py +511 -0
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/hunyuan.py +771 -0
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/server_args.py +51 -9
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
- sglang/srt/speculative/eagle_utils.py +80 -8
- sglang/srt/speculative/eagle_worker.py +124 -41
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/two_batch_overlap.py +4 -1
- sglang/srt/utils.py +248 -11
- sglang/test/test_block_fp8_ep.py +1 -0
- sglang/test/test_utils.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/METADATA +4 -10
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/RECORD +121 -105
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -2148
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/top_level.txt +0 -0
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -72,7 +72,7 @@ from sglang.srt.layers.quantization.int8_utils import (
|
|
72
72
|
block_dequant as int8_block_dequant,
|
73
73
|
)
|
74
74
|
from sglang.srt.layers.radix_attention import RadixAttention
|
75
|
-
from sglang.srt.layers.rotary_embedding import get_rope
|
75
|
+
from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
|
76
76
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
77
77
|
ParallelLMHead,
|
78
78
|
VocabParallelEmbedding,
|
@@ -93,10 +93,13 @@ from sglang.srt.utils import (
|
|
93
93
|
BumpAllocator,
|
94
94
|
DeepEPMode,
|
95
95
|
LazyValue,
|
96
|
+
PackWeightMethod,
|
96
97
|
add_prefix,
|
97
98
|
bind_or_assign,
|
99
|
+
cpu_has_amx_support,
|
98
100
|
get_bool_env_var,
|
99
101
|
get_int_env_var,
|
102
|
+
is_cpu,
|
100
103
|
is_cuda,
|
101
104
|
is_hip,
|
102
105
|
is_non_idle_and_non_empty,
|
@@ -107,9 +110,13 @@ _is_hip = is_hip()
|
|
107
110
|
_is_cuda = is_cuda()
|
108
111
|
_is_fp8_fnuz = is_fp8_fnuz()
|
109
112
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
113
|
+
_is_cpu_amx_available = cpu_has_amx_support()
|
114
|
+
_is_cpu = is_cpu()
|
110
115
|
|
111
116
|
if _is_cuda:
|
112
117
|
from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
|
118
|
+
elif _is_cpu and _is_cpu_amx_available:
|
119
|
+
pass
|
113
120
|
else:
|
114
121
|
from vllm._custom_ops import awq_dequantize
|
115
122
|
|
@@ -118,8 +125,6 @@ if _is_hip:
|
|
118
125
|
decode_attention_fwd_grouped_rope,
|
119
126
|
)
|
120
127
|
|
121
|
-
if _use_aiter:
|
122
|
-
from aiter.rotary_embedding import get_rope
|
123
128
|
|
124
129
|
logger = logging.getLogger(__name__)
|
125
130
|
|
@@ -138,6 +143,9 @@ class AttnForwardMethod(IntEnum):
|
|
138
143
|
# Use MLA but with fused RoPE
|
139
144
|
MLA_FUSED_ROPE = auto()
|
140
145
|
|
146
|
+
# Use MLA with fused RoPE kernel for CPU
|
147
|
+
MLA_FUSED_ROPE_CPU = auto()
|
148
|
+
|
141
149
|
|
142
150
|
class DeepseekV2MLP(nn.Module):
|
143
151
|
def __init__(
|
@@ -206,8 +214,18 @@ class MoEGate(nn.Module):
|
|
206
214
|
)
|
207
215
|
else:
|
208
216
|
self.e_score_correction_bias = None
|
217
|
+
if _is_cpu and _is_cpu_amx_available:
|
218
|
+
self.quant_method = PackWeightMethod(weight_names=["weight"])
|
209
219
|
|
210
220
|
def forward(self, hidden_states):
|
221
|
+
if getattr(self, "use_intel_amx_backend", False):
|
222
|
+
return torch.ops.sgl_kernel.weight_packed_linear(
|
223
|
+
hidden_states,
|
224
|
+
self.weight,
|
225
|
+
None, # bias
|
226
|
+
True, # is_vnni
|
227
|
+
)
|
228
|
+
|
211
229
|
logits = F.linear(hidden_states, self.weight, None)
|
212
230
|
return logits
|
213
231
|
|
@@ -220,6 +238,7 @@ class DeepseekV2MoE(nn.Module):
|
|
220
238
|
layer_id: int,
|
221
239
|
quant_config: Optional[QuantizationConfig] = None,
|
222
240
|
prefix: str = "",
|
241
|
+
alt_stream: Optional[torch.cuda.Stream] = None,
|
223
242
|
):
|
224
243
|
super().__init__()
|
225
244
|
self.tp_size = get_tensor_model_parallel_world_size()
|
@@ -232,6 +251,7 @@ class DeepseekV2MoE(nn.Module):
|
|
232
251
|
)
|
233
252
|
self.config = config
|
234
253
|
self.layer_id = layer_id
|
254
|
+
self.alt_stream = alt_stream
|
235
255
|
|
236
256
|
if self.tp_size > config.n_routed_experts:
|
237
257
|
raise ValueError(
|
@@ -269,6 +289,15 @@ class DeepseekV2MoE(nn.Module):
|
|
269
289
|
if global_server_args_dict["enable_deepep_moe"]
|
270
290
|
else {}
|
271
291
|
),
|
292
|
+
# Additional args for FusedMoE
|
293
|
+
**(
|
294
|
+
dict(
|
295
|
+
enable_flashinfer_moe=True,
|
296
|
+
enable_ep_moe=global_server_args_dict["enable_ep_moe"],
|
297
|
+
)
|
298
|
+
if global_server_args_dict["enable_flashinfer_moe"]
|
299
|
+
else {}
|
300
|
+
),
|
272
301
|
)
|
273
302
|
|
274
303
|
if config.n_shared_experts is not None and self.num_fused_shared_experts == 0:
|
@@ -332,10 +361,38 @@ class DeepseekV2MoE(nn.Module):
|
|
332
361
|
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
|
333
362
|
) -> torch.Tensor:
|
334
363
|
if not self._enable_deepep_moe:
|
335
|
-
|
364
|
+
DUAL_STREAM_TOKEN_THRESHOLD = 1024
|
365
|
+
if (
|
366
|
+
self.alt_stream is not None
|
367
|
+
and self.num_fused_shared_experts == 0
|
368
|
+
and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
|
369
|
+
):
|
370
|
+
return self.forward_normal_dual_stream(hidden_states)
|
371
|
+
else:
|
372
|
+
return self.forward_normal(hidden_states)
|
336
373
|
else:
|
337
374
|
return self.forward_deepep(hidden_states, forward_batch)
|
338
375
|
|
376
|
+
def forward_normal_dual_stream(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
377
|
+
# router_logits: (num_tokens, n_experts)
|
378
|
+
router_logits = self.gate(hidden_states)
|
379
|
+
|
380
|
+
current_stream = torch.cuda.current_stream()
|
381
|
+
self.alt_stream.wait_stream(current_stream)
|
382
|
+
shared_output = self._forward_shared_experts(hidden_states)
|
383
|
+
|
384
|
+
with torch.cuda.stream(self.alt_stream):
|
385
|
+
final_hidden_states = self.experts(
|
386
|
+
hidden_states=hidden_states, router_logits=router_logits
|
387
|
+
)
|
388
|
+
if not _is_cuda:
|
389
|
+
final_hidden_states *= self.routed_scaling_factor
|
390
|
+
current_stream.wait_stream(self.alt_stream)
|
391
|
+
final_hidden_states = final_hidden_states + shared_output
|
392
|
+
if self.tp_size > 1:
|
393
|
+
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
394
|
+
return final_hidden_states
|
395
|
+
|
339
396
|
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
340
397
|
shared_output = self._forward_shared_experts(hidden_states)
|
341
398
|
# router_logits: (num_tokens, n_experts)
|
@@ -343,7 +400,8 @@ class DeepseekV2MoE(nn.Module):
|
|
343
400
|
final_hidden_states = self.experts(
|
344
401
|
hidden_states=hidden_states, router_logits=router_logits
|
345
402
|
)
|
346
|
-
if not _is_cuda:
|
403
|
+
if not _is_cuda and not _use_aiter:
|
404
|
+
# fused in biased_grouped_topk so we can skip here
|
347
405
|
final_hidden_states *= self.routed_scaling_factor
|
348
406
|
if shared_output is not None:
|
349
407
|
final_hidden_states = final_hidden_states + shared_output
|
@@ -665,13 +723,14 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
665
723
|
if rope_scaling:
|
666
724
|
rope_scaling["rope_type"] = "deepseek_yarn"
|
667
725
|
|
668
|
-
self.rotary_emb =
|
726
|
+
self.rotary_emb = get_rope_wrapper(
|
669
727
|
qk_rope_head_dim,
|
670
728
|
rotary_dim=qk_rope_head_dim,
|
671
729
|
max_position=max_position_embeddings,
|
672
730
|
base=rope_theta,
|
673
731
|
rope_scaling=rope_scaling,
|
674
732
|
is_neox_style=False,
|
733
|
+
device=global_server_args_dict["device"],
|
675
734
|
)
|
676
735
|
|
677
736
|
if rope_scaling:
|
@@ -731,6 +790,37 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
731
790
|
"SGL_CHUNKED_PREFIX_CACHE_THRESHOLD", 8192
|
732
791
|
)
|
733
792
|
|
793
|
+
# If we have self.fused_qkv_a_proj_with_mqa and we're running on CPU, we will choose the torch.ops.sgl_kernel.qkv_proj_with_rope_fused_weight kernel
|
794
|
+
# which requires self.w_kc and self.w_vc to be packed.
|
795
|
+
# If not, we will use torch.bmm and weight shouldn't be packed in this case
|
796
|
+
if (
|
797
|
+
hasattr(self, "fused_qkv_a_proj_with_mqa")
|
798
|
+
and _is_cpu
|
799
|
+
and _is_cpu_amx_available
|
800
|
+
):
|
801
|
+
self.quant_method = PackWeightMethod(
|
802
|
+
weight_names=["w_kc", "w_vc"], transpose_dims=[[1, 2], [1, 2]]
|
803
|
+
)
|
804
|
+
|
805
|
+
self.qkv_proj_with_rope_is_int8 = (
|
806
|
+
hasattr(self, "fused_qkv_a_proj_with_mqa")
|
807
|
+
and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.int8
|
808
|
+
)
|
809
|
+
self.qkv_proj_with_rope_is_fp8 = (
|
810
|
+
hasattr(self, "fused_qkv_a_proj_with_mqa")
|
811
|
+
and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.float8_e4m3fn
|
812
|
+
)
|
813
|
+
|
814
|
+
self.weight_block_size = None
|
815
|
+
if self.qkv_proj_with_rope_is_fp8:
|
816
|
+
assert (
|
817
|
+
self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
|
818
|
+
== self.q_b_proj.quant_method.quant_config.weight_block_size
|
819
|
+
)
|
820
|
+
self.weight_block_size = (
|
821
|
+
self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
|
822
|
+
)
|
823
|
+
|
734
824
|
def dispatch_attn_forward_method(
|
735
825
|
self, forward_batch: ForwardBatch
|
736
826
|
) -> AttnForwardMethod:
|
@@ -744,7 +834,12 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
744
834
|
else:
|
745
835
|
return AttnForwardMethod.MLA
|
746
836
|
else:
|
747
|
-
|
837
|
+
if hasattr(self, "fused_qkv_a_proj_with_mqa") and getattr(
|
838
|
+
self, "use_intel_amx_backend", False
|
839
|
+
):
|
840
|
+
return AttnForwardMethod.MLA_FUSED_ROPE_CPU
|
841
|
+
else:
|
842
|
+
return AttnForwardMethod.MLA
|
748
843
|
|
749
844
|
if self.attention_backend == "flashinfer":
|
750
845
|
# Flashinfer MLA: Do not absorb when enabling ragged prefill
|
@@ -858,6 +953,10 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
858
953
|
inner_state = self.forward_absorb_fused_mla_rope_prepare(
|
859
954
|
positions, hidden_states, forward_batch, zero_allocator
|
860
955
|
)
|
956
|
+
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE_CPU:
|
957
|
+
inner_state = self.forward_absorb_fused_mla_rope_cpu_prepare(
|
958
|
+
positions, hidden_states, forward_batch, zero_allocator
|
959
|
+
)
|
861
960
|
else:
|
862
961
|
raise NotImplementedError
|
863
962
|
return None, attn_forward_method, forward_batch, inner_state
|
@@ -877,6 +976,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
877
976
|
return self.forward_absorb_core(*inner_state)
|
878
977
|
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
|
879
978
|
return self.forward_absorb_fused_mla_rope_core(*inner_state)
|
979
|
+
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE_CPU:
|
980
|
+
return self.forward_absorb_fused_mla_rope_cpu_core(*inner_state)
|
880
981
|
else:
|
881
982
|
raise NotImplementedError
|
882
983
|
|
@@ -1040,13 +1141,16 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1040
1141
|
masked_m,
|
1041
1142
|
expected_m,
|
1042
1143
|
)
|
1043
|
-
attn_bmm_output =
|
1144
|
+
attn_bmm_output = (
|
1145
|
+
attn_bmm_output[:, :expected_m, :].transpose(0, 1).flatten(1, 2)
|
1146
|
+
)
|
1044
1147
|
elif _is_hip:
|
1045
1148
|
# TODO(haishaw): add bmm_fp8 to ROCm
|
1046
1149
|
attn_bmm_output = torch.bmm(
|
1047
1150
|
attn_output.to(torch.bfloat16).transpose(0, 1),
|
1048
1151
|
self.w_vc.to(torch.bfloat16) * self.w_scale,
|
1049
1152
|
)
|
1153
|
+
attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
|
1050
1154
|
elif self.w_vc.dtype == torch.float8_e4m3fn:
|
1051
1155
|
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
|
1052
1156
|
attn_output.transpose(0, 1),
|
@@ -1059,10 +1163,21 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1059
1163
|
self.w_scale,
|
1060
1164
|
torch.bfloat16,
|
1061
1165
|
)
|
1166
|
+
attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
|
1062
1167
|
else:
|
1063
|
-
attn_bmm_output = torch.
|
1064
|
-
|
1065
|
-
|
1168
|
+
attn_bmm_output = torch.empty(
|
1169
|
+
(attn_output.shape[0], self.num_local_heads * self.v_head_dim),
|
1170
|
+
dtype=attn_output.dtype,
|
1171
|
+
device=attn_output.device,
|
1172
|
+
)
|
1173
|
+
torch.bmm(
|
1174
|
+
attn_output.transpose(0, 1),
|
1175
|
+
self.w_vc,
|
1176
|
+
out=attn_bmm_output.view(
|
1177
|
+
-1, self.num_local_heads, self.v_head_dim
|
1178
|
+
).transpose(0, 1),
|
1179
|
+
)
|
1180
|
+
output, _ = self.o_proj(attn_bmm_output)
|
1066
1181
|
|
1067
1182
|
return output
|
1068
1183
|
|
@@ -1180,6 +1295,57 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1180
1295
|
zero_allocator,
|
1181
1296
|
)
|
1182
1297
|
|
1298
|
+
def forward_absorb_fused_mla_rope_cpu_prepare(
|
1299
|
+
self,
|
1300
|
+
positions: torch.Tensor,
|
1301
|
+
hidden_states: torch.Tensor,
|
1302
|
+
forward_batch: ForwardBatch,
|
1303
|
+
zero_allocator: BumpAllocator,
|
1304
|
+
):
|
1305
|
+
assert self.q_lora_rank is not None and getattr(
|
1306
|
+
self, "use_intel_amx_backend", False
|
1307
|
+
), "forward_absorb_fused_mla_rope_cpu_prepare requires q_lora_rank is not None and use_intel_amx_backend"
|
1308
|
+
|
1309
|
+
q_input, k_input, v_input = (
|
1310
|
+
torch.ops.sgl_kernel.qkv_proj_with_rope_fused_weight(
|
1311
|
+
hidden_states,
|
1312
|
+
self.fused_qkv_a_proj_with_mqa.weight,
|
1313
|
+
self.q_b_proj.weight,
|
1314
|
+
self.w_kc,
|
1315
|
+
self.q_a_layernorm.weight,
|
1316
|
+
self.kv_a_layernorm.weight,
|
1317
|
+
positions,
|
1318
|
+
self.rotary_emb.cos_sin_cache,
|
1319
|
+
self.kv_a_layernorm.variance_epsilon,
|
1320
|
+
self.qkv_proj_with_rope_is_int8,
|
1321
|
+
self.qkv_proj_with_rope_is_fp8,
|
1322
|
+
(
|
1323
|
+
self.fused_qkv_a_proj_with_mqa.weight_scale
|
1324
|
+
if self.qkv_proj_with_rope_is_int8
|
1325
|
+
else (
|
1326
|
+
self.fused_qkv_a_proj_with_mqa.weight_scale_inv
|
1327
|
+
if self.qkv_proj_with_rope_is_fp8
|
1328
|
+
else None
|
1329
|
+
)
|
1330
|
+
),
|
1331
|
+
(
|
1332
|
+
self.q_b_proj.weight_scale
|
1333
|
+
if self.qkv_proj_with_rope_is_int8
|
1334
|
+
else (
|
1335
|
+
self.q_b_proj.weight_scale_inv
|
1336
|
+
if self.qkv_proj_with_rope_is_fp8
|
1337
|
+
else None
|
1338
|
+
)
|
1339
|
+
),
|
1340
|
+
True, # is_vnni
|
1341
|
+
self.weight_block_size,
|
1342
|
+
self.q_lora_rank,
|
1343
|
+
self.kv_lora_rank,
|
1344
|
+
self.qk_rope_head_dim,
|
1345
|
+
)
|
1346
|
+
)
|
1347
|
+
return (q_input, k_input, v_input, forward_batch, zero_allocator)
|
1348
|
+
|
1183
1349
|
def forward_absorb_fused_mla_rope_core(
|
1184
1350
|
self,
|
1185
1351
|
q_input,
|
@@ -1253,6 +1419,43 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1253
1419
|
|
1254
1420
|
return output
|
1255
1421
|
|
1422
|
+
def forward_absorb_fused_mla_rope_cpu_core(
|
1423
|
+
self, q_input, k_input, v_input, forward_batch, zero_allocator
|
1424
|
+
):
|
1425
|
+
assert self.q_lora_rank is not None and getattr(
|
1426
|
+
self, "use_intel_amx_backend", False
|
1427
|
+
), "forward_absorb_fused_mla_rope_cpu_core requires q_lora_rank is not None and use_intel_amx_backend"
|
1428
|
+
|
1429
|
+
attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
|
1430
|
+
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
1431
|
+
|
1432
|
+
# [Note] Align shapes of bmm inputs.
|
1433
|
+
# Shapes of inputs:
|
1434
|
+
# q_nope: [M, B, K]
|
1435
|
+
# original self.w_kc: [B, K, N]
|
1436
|
+
# current self.w_kc (which has been converted in PackWeightMethod): [B, N, K]
|
1437
|
+
|
1438
|
+
# Shapes of inputs to sgl_kernel.cpu.bmm:
|
1439
|
+
# out: [B, M, N]
|
1440
|
+
# mat1: [B, M, K]
|
1441
|
+
# mat2: [B, N, K]
|
1442
|
+
B = self.w_vc.size(0)
|
1443
|
+
N = self.w_vc.size(1)
|
1444
|
+
M = attn_output.size(0)
|
1445
|
+
output = torch.empty([M, int(B * N)], dtype=attn_output.dtype)
|
1446
|
+
attn_bmm_output = output.view([M, B, N]).transpose_(0, 1)
|
1447
|
+
torch.ops.sgl_kernel.bmm_cpu(
|
1448
|
+
attn_bmm_output,
|
1449
|
+
attn_output.transpose(0, 1),
|
1450
|
+
self.w_vc,
|
1451
|
+
True, # is_vnni
|
1452
|
+
None, # scale
|
1453
|
+
)
|
1454
|
+
attn_output = output
|
1455
|
+
output, _ = self.o_proj(attn_output)
|
1456
|
+
|
1457
|
+
return output
|
1458
|
+
|
1256
1459
|
def _chunked_prefix_attn_mha(
|
1257
1460
|
self,
|
1258
1461
|
q: torch.Tensor,
|
@@ -1399,7 +1602,9 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1399
1602
|
rope_scaling = getattr(config, "rope_scaling", None)
|
1400
1603
|
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
1401
1604
|
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
|
1605
|
+
self.speculative_algorithm = global_server_args_dict["speculative_algorithm"]
|
1402
1606
|
self.layer_id = layer_id
|
1607
|
+
self.is_nextn = is_nextn
|
1403
1608
|
self.self_attn = DeepseekV2AttentionMLA(
|
1404
1609
|
config=config,
|
1405
1610
|
hidden_size=self.hidden_size,
|
@@ -1426,7 +1631,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1426
1631
|
|
1427
1632
|
self.layer_scatter_modes = LayerScatterModes.init_new(
|
1428
1633
|
layer_id=layer_id,
|
1429
|
-
num_layers=config.num_hidden_layers,
|
1634
|
+
num_layers=1 if is_nextn else config.num_hidden_layers,
|
1430
1635
|
is_layer_sparse=self.is_layer_sparse,
|
1431
1636
|
is_previous_layer_sparse=is_previous_layer_sparse,
|
1432
1637
|
)
|
@@ -1437,6 +1642,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1437
1642
|
quant_config=quant_config,
|
1438
1643
|
prefix=add_prefix("mlp", prefix),
|
1439
1644
|
layer_id=self.layer_id,
|
1645
|
+
alt_stream=alt_stream,
|
1440
1646
|
)
|
1441
1647
|
else:
|
1442
1648
|
if enable_moe_dense_fully_dp():
|
@@ -1479,6 +1685,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1479
1685
|
residual: Optional[torch.Tensor],
|
1480
1686
|
zero_allocator: BumpAllocator,
|
1481
1687
|
) -> torch.Tensor:
|
1688
|
+
|
1482
1689
|
hidden_states, residual = self.layer_communicator.prepare_attn(
|
1483
1690
|
hidden_states, residual, forward_batch
|
1484
1691
|
)
|
@@ -1500,6 +1707,11 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1500
1707
|
hidden_states, residual, forward_batch
|
1501
1708
|
)
|
1502
1709
|
|
1710
|
+
if self.enable_dp_attention and self.speculative_algorithm.is_eagle():
|
1711
|
+
# NOTE: this line resolves the degradation of MTP reception rate for non-zero DP ranks.
|
1712
|
+
# See discussion here (https://github.com/sgl-project/sglang/pull/6081#discussion_r2147452251).
|
1713
|
+
hidden_states = hidden_states.clone()
|
1714
|
+
|
1503
1715
|
return hidden_states, residual
|
1504
1716
|
|
1505
1717
|
def op_comm_prepare_attn(
|
@@ -1607,8 +1819,6 @@ class DeepseekV2Model(nn.Module):
|
|
1607
1819
|
)
|
1608
1820
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
1609
1821
|
|
1610
|
-
self.dp_size = get_local_attention_dp_size()
|
1611
|
-
|
1612
1822
|
def get_input_embeddings(self) -> torch.Tensor:
|
1613
1823
|
return self.embed_tokens
|
1614
1824
|
|
@@ -1692,7 +1902,6 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1692
1902
|
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
1693
1903
|
)
|
1694
1904
|
self.logits_processor = LogitsProcessor(config)
|
1695
|
-
self.dp_size = get_local_attention_dp_size()
|
1696
1905
|
|
1697
1906
|
self._routed_experts_weights_of_layer = LazyValue(
|
1698
1907
|
lambda: {
|
@@ -1717,12 +1926,12 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1717
1926
|
disable_reason = None
|
1718
1927
|
if (
|
1719
1928
|
not _is_cuda
|
1720
|
-
or torch.cuda.get_device_capability("cuda") < (
|
1929
|
+
or torch.cuda.get_device_capability("cuda") < (8, 0)
|
1721
1930
|
or self.config.architectures[0] != architecture
|
1722
1931
|
or self.config.n_routed_experts != 256
|
1723
1932
|
or self.config.n_shared_experts != 1
|
1724
1933
|
):
|
1725
|
-
disable_reason = "Only Deepseek V3/R1 on NV-platform with capability >=
|
1934
|
+
disable_reason = "Only Deepseek V3/R1 on NV-platform with capability >= 80 can use shared experts fusion optimization."
|
1726
1935
|
elif (
|
1727
1936
|
global_server_args_dict["enable_deepep_moe"]
|
1728
1937
|
or global_server_args_dict["enable_ep_moe"]
|
@@ -1919,10 +2128,12 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1919
2128
|
if (
|
1920
2129
|
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
1921
2130
|
and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
2131
|
+
and hasattr(self.quant_config, "weight_block_size")
|
2132
|
+
and self.quant_config.weight_block_size is not None
|
1922
2133
|
):
|
1923
|
-
self._weight_requant_ue8m0()
|
2134
|
+
self._weight_requant_ue8m0(is_nextn)
|
1924
2135
|
|
1925
|
-
def _weight_requant_ue8m0(self):
|
2136
|
+
def _weight_requant_ue8m0(self, is_nextn=False):
|
1926
2137
|
weight_block_size = self.quant_config.weight_block_size
|
1927
2138
|
|
1928
2139
|
moe_layers = list(
|
@@ -1933,8 +2144,12 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1933
2144
|
)
|
1934
2145
|
)
|
1935
2146
|
|
1936
|
-
|
1937
|
-
|
2147
|
+
num_hidden_layers = 1 if is_nextn else self.config.num_hidden_layers
|
2148
|
+
for layer_id in range(num_hidden_layers):
|
2149
|
+
if is_nextn:
|
2150
|
+
layer = self.model.decoder
|
2151
|
+
else:
|
2152
|
+
layer = self.model.layers[layer_id]
|
1938
2153
|
|
1939
2154
|
for module in [
|
1940
2155
|
layer.self_attn.fused_qkv_a_proj_with_mqa,
|
@@ -1946,7 +2161,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1946
2161
|
module.weight, module.weight_scale_inv, weight_block_size
|
1947
2162
|
)
|
1948
2163
|
|
1949
|
-
if layer_id in moe_layers:
|
2164
|
+
if layer_id in moe_layers or is_nextn:
|
1950
2165
|
shared_experts = getattr(layer.mlp, "shared_experts", None)
|
1951
2166
|
if shared_experts is not None:
|
1952
2167
|
for module in [
|
@@ -2022,7 +2237,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2022
2237
|
|
2023
2238
|
if self.num_fused_shared_experts > 0:
|
2024
2239
|
assert self.num_fused_shared_experts == 1
|
2025
|
-
logger
|
2240
|
+
log_info_on_rank0(logger, "Shared experts fusion optimization enabled.")
|
2026
2241
|
|
2027
2242
|
params_dict = dict(self.named_parameters())
|
2028
2243
|
weight_names = []
|
@@ -2128,8 +2343,14 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2128
2343
|
):
|
2129
2344
|
q_a_proj_weight = cached_a_proj[q_a_proj_name]
|
2130
2345
|
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
|
2346
|
+
cat_dim = 0
|
2347
|
+
if self.quant_config is not None and (
|
2348
|
+
self.quant_config.get_name() == "awq"
|
2349
|
+
or self.quant_config.get_name() == "moe_wna16"
|
2350
|
+
):
|
2351
|
+
cat_dim = 1
|
2131
2352
|
fused_weight = torch.cat(
|
2132
|
-
[q_a_proj_weight, kv_a_proj_weight], dim=
|
2353
|
+
[q_a_proj_weight, kv_a_proj_weight], dim=cat_dim
|
2133
2354
|
)
|
2134
2355
|
param_name = (
|
2135
2356
|
name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
|
@@ -2151,12 +2372,16 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2151
2372
|
"k_scale" in name or "v_scale" in name
|
2152
2373
|
) and name not in params_dict:
|
2153
2374
|
# modelopt attn kv scale is named differently
|
2154
|
-
|
2155
|
-
|
2156
|
-
|
2157
|
-
|
2158
|
-
|
2159
|
-
|
2375
|
+
for scale in ["k_scale", "v_scale"]:
|
2376
|
+
if scale in name:
|
2377
|
+
name = name.replace(f"{scale[0]}_proj", "attn_mqa")
|
2378
|
+
break
|
2379
|
+
if name not in params_dict:
|
2380
|
+
# modelopt ckpt contains not needed weights for MTP module:
|
2381
|
+
# model.decoder.self_attn.attn_mqa.v_scale and
|
2382
|
+
# model.decoder.self_attn.attn_mqa.k_scale
|
2383
|
+
logger.warning(f"{name} not found in params_dict.")
|
2384
|
+
continue
|
2160
2385
|
param = params_dict[name]
|
2161
2386
|
weight_loader = getattr(
|
2162
2387
|
param, "weight_loader", default_weight_loader
|