sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +7 -0
- sglang/bench_one_batch.py +8 -6
- sglang/bench_serving.py +1 -1
- sglang/lang/interpreter.py +40 -1
- sglang/lang/ir.py +27 -0
- sglang/math_utils.py +8 -0
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/configs/model_config.py +6 -0
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +19 -3
- sglang/srt/custom_op.py +5 -1
- sglang/srt/disaggregation/base/__init__.py +1 -1
- sglang/srt/disaggregation/base/conn.py +25 -11
- sglang/srt/disaggregation/common/__init__.py +5 -1
- sglang/srt/disaggregation/common/utils.py +42 -0
- sglang/srt/disaggregation/decode.py +211 -72
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/fake/__init__.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +15 -9
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/__init__.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +30 -29
- sglang/srt/disaggregation/nixl/__init__.py +6 -1
- sglang/srt/disaggregation/nixl/conn.py +17 -12
- sglang/srt/disaggregation/prefill.py +144 -55
- sglang/srt/disaggregation/utils.py +155 -123
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +37 -29
- sglang/srt/entrypoints/http_server.py +153 -72
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +921 -0
- sglang/srt/entrypoints/openai/serving_completions.py +424 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/eplb_simulator/__init__.py +1 -0
- sglang/srt/eplb_simulator/reader.py +51 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +40 -3
- sglang/srt/layers/attention/aiter_backend.py +20 -4
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
- sglang/srt/layers/attention/flashattention_backend.py +71 -72
- sglang/srt/layers/attention/flashinfer_backend.py +10 -8
- sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
- sglang/srt/layers/attention/flashmla_backend.py +7 -12
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +138 -130
- sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
- sglang/srt/layers/attention/vision.py +51 -24
- sglang/srt/layers/communicator.py +28 -10
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +29 -2
- sglang/srt/layers/linear.py +0 -4
- sglang/srt/layers/logits_processor.py +2 -14
- sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
- sglang/srt/layers/moe/ep_moe/layer.py +249 -33
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
- sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
- sglang/srt/layers/moe/topk.py +107 -12
- sglang/srt/layers/pooler.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
- sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/fp8_kernel.py +44 -15
- sglang/srt/layers/quantization/fp8_utils.py +87 -22
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/radix_attention.py +2 -3
- sglang/srt/layers/rotary_embedding.py +42 -2
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/lora_manager.py +249 -105
- sglang/srt/lora/mem_pool.py +53 -50
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -14
- sglang/srt/managers/io_struct.py +31 -10
- sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
- sglang/srt/managers/multimodal_processors/vila.py +85 -0
- sglang/srt/managers/schedule_batch.py +79 -37
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +220 -79
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +40 -10
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -15
- sglang/srt/mem_cache/hiradix_cache.py +38 -25
- sglang/srt/mem_cache/memory_pool.py +213 -505
- sglang/srt/mem_cache/memory_pool_host.py +380 -0
- sglang/srt/mem_cache/radix_cache.py +56 -28
- sglang/srt/model_executor/cuda_graph_runner.py +198 -100
- sglang/srt/model_executor/forward_batch_info.py +32 -10
- sglang/srt/model_executor/model_runner.py +28 -12
- sglang/srt/model_loader/loader.py +16 -2
- sglang/srt/model_loader/weight_utils.py +11 -2
- sglang/srt/models/bert.py +113 -13
- sglang/srt/models/deepseek_nextn.py +29 -27
- sglang/srt/models/deepseek_v2.py +213 -173
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/internvl.py +46 -102
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/models/roberta.py +117 -9
- sglang/srt/models/vila.py +305 -0
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/sampling/sampling_batch_info.py +24 -0
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +351 -238
- sglang/srt/speculative/build_eagle_tree.py +1 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
- sglang/srt/speculative/eagle_utils.py +468 -116
- sglang/srt/speculative/eagle_worker.py +258 -84
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/two_batch_overlap.py +4 -2
- sglang/srt/utils.py +235 -11
- sglang/test/attention/test_prefix_chunk_info.py +2 -0
- sglang/test/runners.py +38 -3
- sglang/test/test_block_fp8.py +1 -0
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
- sglang/test/test_block_fp8_ep.py +2 -0
- sglang/test/test_utils.py +4 -1
- sglang/utils.py +9 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -1990
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -51,11 +51,11 @@ from sglang.srt.layers.linear import (
|
|
51
51
|
RowParallelLinear,
|
52
52
|
)
|
53
53
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
54
|
-
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
54
|
+
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
|
55
55
|
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
56
56
|
from sglang.srt.layers.moe.topk import select_experts
|
57
|
+
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
57
58
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
58
|
-
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
59
59
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
60
60
|
is_fp8_fnuz,
|
61
61
|
per_tensor_quant_mla_fp8,
|
@@ -66,12 +66,13 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
|
66
66
|
block_quant_to_tensor_quant,
|
67
67
|
channel_quant_to_tensor_quant,
|
68
68
|
normalize_e4m3fn_to_e4m3fnuz,
|
69
|
+
requant_weight_ue8m0_inplace,
|
69
70
|
)
|
70
71
|
from sglang.srt.layers.quantization.int8_utils import (
|
71
72
|
block_dequant as int8_block_dequant,
|
72
73
|
)
|
73
74
|
from sglang.srt.layers.radix_attention import RadixAttention
|
74
|
-
from sglang.srt.layers.rotary_embedding import get_rope
|
75
|
+
from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
|
75
76
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
76
77
|
ParallelLMHead,
|
77
78
|
VocabParallelEmbedding,
|
@@ -94,8 +95,10 @@ from sglang.srt.utils import (
|
|
94
95
|
LazyValue,
|
95
96
|
add_prefix,
|
96
97
|
bind_or_assign,
|
98
|
+
cpu_has_amx_support,
|
97
99
|
get_bool_env_var,
|
98
100
|
get_int_env_var,
|
101
|
+
is_cpu,
|
99
102
|
is_cuda,
|
100
103
|
is_hip,
|
101
104
|
is_non_idle_and_non_empty,
|
@@ -106,13 +109,13 @@ _is_hip = is_hip()
|
|
106
109
|
_is_cuda = is_cuda()
|
107
110
|
_is_fp8_fnuz = is_fp8_fnuz()
|
108
111
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
112
|
+
_is_cpu_amx_available = cpu_has_amx_support()
|
113
|
+
_is_cpu = is_cpu()
|
109
114
|
|
110
115
|
if _is_cuda:
|
111
116
|
from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
|
112
|
-
|
113
|
-
|
114
|
-
grouped_gemm_nt_f8f8bf16_masked as deep_gemm_grouped_gemm_nt_f8f8bf16_masked,
|
115
|
-
)
|
117
|
+
elif _is_cpu and _is_cpu_amx_available:
|
118
|
+
pass
|
116
119
|
else:
|
117
120
|
from vllm._custom_ops import awq_dequantize
|
118
121
|
|
@@ -223,6 +226,7 @@ class DeepseekV2MoE(nn.Module):
|
|
223
226
|
layer_id: int,
|
224
227
|
quant_config: Optional[QuantizationConfig] = None,
|
225
228
|
prefix: str = "",
|
229
|
+
alt_stream: Optional[torch.cuda.Stream] = None,
|
226
230
|
):
|
227
231
|
super().__init__()
|
228
232
|
self.tp_size = get_tensor_model_parallel_world_size()
|
@@ -235,6 +239,7 @@ class DeepseekV2MoE(nn.Module):
|
|
235
239
|
)
|
236
240
|
self.config = config
|
237
241
|
self.layer_id = layer_id
|
242
|
+
self.alt_stream = alt_stream
|
238
243
|
|
239
244
|
if self.tp_size > config.n_routed_experts:
|
240
245
|
raise ValueError(
|
@@ -272,6 +277,15 @@ class DeepseekV2MoE(nn.Module):
|
|
272
277
|
if global_server_args_dict["enable_deepep_moe"]
|
273
278
|
else {}
|
274
279
|
),
|
280
|
+
# Additional args for FusedMoE
|
281
|
+
**(
|
282
|
+
dict(
|
283
|
+
enable_flashinfer_moe=True,
|
284
|
+
enable_ep_moe=global_server_args_dict["enable_ep_moe"],
|
285
|
+
)
|
286
|
+
if global_server_args_dict["enable_flashinfer_moe"]
|
287
|
+
else {}
|
288
|
+
),
|
275
289
|
)
|
276
290
|
|
277
291
|
if config.n_shared_experts is not None and self.num_fused_shared_experts == 0:
|
@@ -335,10 +349,38 @@ class DeepseekV2MoE(nn.Module):
|
|
335
349
|
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
|
336
350
|
) -> torch.Tensor:
|
337
351
|
if not self._enable_deepep_moe:
|
338
|
-
|
352
|
+
DUAL_STREAM_TOKEN_THRESHOLD = 1024
|
353
|
+
if (
|
354
|
+
self.alt_stream is not None
|
355
|
+
and self.num_fused_shared_experts == 0
|
356
|
+
and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
|
357
|
+
):
|
358
|
+
return self.forward_normal_dual_stream(hidden_states)
|
359
|
+
else:
|
360
|
+
return self.forward_normal(hidden_states)
|
339
361
|
else:
|
340
362
|
return self.forward_deepep(hidden_states, forward_batch)
|
341
363
|
|
364
|
+
def forward_normal_dual_stream(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
365
|
+
# router_logits: (num_tokens, n_experts)
|
366
|
+
router_logits = self.gate(hidden_states)
|
367
|
+
|
368
|
+
current_stream = torch.cuda.current_stream()
|
369
|
+
self.alt_stream.wait_stream(current_stream)
|
370
|
+
shared_output = self._forward_shared_experts(hidden_states)
|
371
|
+
|
372
|
+
with torch.cuda.stream(self.alt_stream):
|
373
|
+
final_hidden_states = self.experts(
|
374
|
+
hidden_states=hidden_states, router_logits=router_logits
|
375
|
+
)
|
376
|
+
if not _is_cuda:
|
377
|
+
final_hidden_states *= self.routed_scaling_factor
|
378
|
+
current_stream.wait_stream(self.alt_stream)
|
379
|
+
final_hidden_states = final_hidden_states + shared_output
|
380
|
+
if self.tp_size > 1:
|
381
|
+
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
382
|
+
return final_hidden_states
|
383
|
+
|
342
384
|
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
343
385
|
shared_output = self._forward_shared_experts(hidden_states)
|
344
386
|
# router_logits: (num_tokens, n_experts)
|
@@ -668,13 +710,14 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
668
710
|
if rope_scaling:
|
669
711
|
rope_scaling["rope_type"] = "deepseek_yarn"
|
670
712
|
|
671
|
-
self.rotary_emb =
|
713
|
+
self.rotary_emb = get_rope_wrapper(
|
672
714
|
qk_rope_head_dim,
|
673
715
|
rotary_dim=qk_rope_head_dim,
|
674
716
|
max_position=max_position_embeddings,
|
675
717
|
base=rope_theta,
|
676
718
|
rope_scaling=rope_scaling,
|
677
719
|
is_neox_style=False,
|
720
|
+
device=global_server_args_dict["device"],
|
678
721
|
)
|
679
722
|
|
680
723
|
if rope_scaling:
|
@@ -980,7 +1023,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
980
1023
|
q_nope_out = q_nope.new_empty(
|
981
1024
|
(self.num_local_heads, aligned_m, self.kv_lora_rank)
|
982
1025
|
)
|
983
|
-
|
1026
|
+
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
984
1027
|
(q_nope_val, q_nope_scale),
|
985
1028
|
(self.w_kc, self.w_scale_k),
|
986
1029
|
q_nope_out,
|
@@ -1013,7 +1056,11 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1013
1056
|
def forward_absorb_core(
|
1014
1057
|
self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
|
1015
1058
|
):
|
1016
|
-
if
|
1059
|
+
if (
|
1060
|
+
self.attention_backend == "fa3"
|
1061
|
+
or self.attention_backend == "flashinfer"
|
1062
|
+
or self.attention_backend == "cutlass_mla"
|
1063
|
+
):
|
1017
1064
|
attn_output = self.attn_mqa(
|
1018
1065
|
q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
|
1019
1066
|
)
|
@@ -1032,20 +1079,23 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1032
1079
|
attn_bmm_output = attn_output.new_empty(
|
1033
1080
|
(self.num_local_heads, aligned_m, self.v_head_dim)
|
1034
1081
|
)
|
1035
|
-
|
1082
|
+
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
1036
1083
|
(attn_output_val, attn_output_scale),
|
1037
1084
|
(self.w_vc, self.w_scale_v),
|
1038
1085
|
attn_bmm_output,
|
1039
1086
|
masked_m,
|
1040
1087
|
expected_m,
|
1041
1088
|
)
|
1042
|
-
attn_bmm_output =
|
1089
|
+
attn_bmm_output = (
|
1090
|
+
attn_bmm_output[:, :expected_m, :].transpose(0, 1).flatten(1, 2)
|
1091
|
+
)
|
1043
1092
|
elif _is_hip:
|
1044
1093
|
# TODO(haishaw): add bmm_fp8 to ROCm
|
1045
1094
|
attn_bmm_output = torch.bmm(
|
1046
1095
|
attn_output.to(torch.bfloat16).transpose(0, 1),
|
1047
1096
|
self.w_vc.to(torch.bfloat16) * self.w_scale,
|
1048
1097
|
)
|
1098
|
+
attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
|
1049
1099
|
elif self.w_vc.dtype == torch.float8_e4m3fn:
|
1050
1100
|
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
|
1051
1101
|
attn_output.transpose(0, 1),
|
@@ -1058,10 +1108,21 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1058
1108
|
self.w_scale,
|
1059
1109
|
torch.bfloat16,
|
1060
1110
|
)
|
1111
|
+
attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
|
1061
1112
|
else:
|
1062
|
-
attn_bmm_output = torch.
|
1063
|
-
|
1064
|
-
|
1113
|
+
attn_bmm_output = torch.empty(
|
1114
|
+
(attn_output.shape[0], self.num_local_heads * self.v_head_dim),
|
1115
|
+
dtype=attn_output.dtype,
|
1116
|
+
device=attn_output.device,
|
1117
|
+
)
|
1118
|
+
torch.bmm(
|
1119
|
+
attn_output.transpose(0, 1),
|
1120
|
+
self.w_vc,
|
1121
|
+
out=attn_bmm_output.view(
|
1122
|
+
-1, self.num_local_heads, self.v_head_dim
|
1123
|
+
).transpose(0, 1),
|
1124
|
+
)
|
1125
|
+
output, _ = self.o_proj(attn_bmm_output)
|
1065
1126
|
|
1066
1127
|
return output
|
1067
1128
|
|
@@ -1398,7 +1459,9 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1398
1459
|
rope_scaling = getattr(config, "rope_scaling", None)
|
1399
1460
|
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
1400
1461
|
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
|
1462
|
+
self.speculative_algorithm = global_server_args_dict["speculative_algorithm"]
|
1401
1463
|
self.layer_id = layer_id
|
1464
|
+
self.is_nextn = is_nextn
|
1402
1465
|
self.self_attn = DeepseekV2AttentionMLA(
|
1403
1466
|
config=config,
|
1404
1467
|
hidden_size=self.hidden_size,
|
@@ -1425,7 +1488,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1425
1488
|
|
1426
1489
|
self.layer_scatter_modes = LayerScatterModes.init_new(
|
1427
1490
|
layer_id=layer_id,
|
1428
|
-
num_layers=config.num_hidden_layers,
|
1491
|
+
num_layers=1 if is_nextn else config.num_hidden_layers,
|
1429
1492
|
is_layer_sparse=self.is_layer_sparse,
|
1430
1493
|
is_previous_layer_sparse=is_previous_layer_sparse,
|
1431
1494
|
)
|
@@ -1436,6 +1499,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1436
1499
|
quant_config=quant_config,
|
1437
1500
|
prefix=add_prefix("mlp", prefix),
|
1438
1501
|
layer_id=self.layer_id,
|
1502
|
+
alt_stream=alt_stream,
|
1439
1503
|
)
|
1440
1504
|
else:
|
1441
1505
|
if enable_moe_dense_fully_dp():
|
@@ -1478,6 +1542,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1478
1542
|
residual: Optional[torch.Tensor],
|
1479
1543
|
zero_allocator: BumpAllocator,
|
1480
1544
|
) -> torch.Tensor:
|
1545
|
+
|
1481
1546
|
hidden_states, residual = self.layer_communicator.prepare_attn(
|
1482
1547
|
hidden_states, residual, forward_batch
|
1483
1548
|
)
|
@@ -1499,6 +1564,11 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1499
1564
|
hidden_states, residual, forward_batch
|
1500
1565
|
)
|
1501
1566
|
|
1567
|
+
if self.enable_dp_attention and self.speculative_algorithm.is_eagle():
|
1568
|
+
# NOTE: this line resolves the degradation of MTP reception rate for non-zero DP ranks.
|
1569
|
+
# See discussion here (https://github.com/sgl-project/sglang/pull/6081#discussion_r2147452251).
|
1570
|
+
hidden_states = hidden_states.clone()
|
1571
|
+
|
1502
1572
|
return hidden_states, residual
|
1503
1573
|
|
1504
1574
|
def op_comm_prepare_attn(
|
@@ -1606,8 +1676,6 @@ class DeepseekV2Model(nn.Module):
|
|
1606
1676
|
)
|
1607
1677
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
1608
1678
|
|
1609
|
-
self.dp_size = get_local_attention_dp_size()
|
1610
|
-
|
1611
1679
|
def get_input_embeddings(self) -> torch.Tensor:
|
1612
1680
|
return self.embed_tokens
|
1613
1681
|
|
@@ -1691,7 +1759,6 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1691
1759
|
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
1692
1760
|
)
|
1693
1761
|
self.logits_processor = LogitsProcessor(config)
|
1694
|
-
self.dp_size = get_local_attention_dp_size()
|
1695
1762
|
|
1696
1763
|
self._routed_experts_weights_of_layer = LazyValue(
|
1697
1764
|
lambda: {
|
@@ -1708,53 +1775,35 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1708
1775
|
def determine_num_fused_shared_experts(
|
1709
1776
|
self, architecture: str = "DeepseekV3ForCausalLM"
|
1710
1777
|
):
|
1711
|
-
self.num_fused_shared_experts =
|
1712
|
-
|
1713
|
-
|
1714
|
-
|
1715
|
-
|
1716
|
-
|
1717
|
-
|
1718
|
-
|
1719
|
-
|
1720
|
-
|
1721
|
-
|
1722
|
-
|
1723
|
-
|
1724
|
-
|
1725
|
-
|
1726
|
-
|
1727
|
-
|
1728
|
-
|
1729
|
-
|
1730
|
-
|
1731
|
-
|
1732
|
-
|
1733
|
-
|
1734
|
-
|
1735
|
-
|
1736
|
-
|
1737
|
-
|
1738
|
-
|
1739
|
-
|
1740
|
-
if (
|
1741
|
-
_is_cuda
|
1742
|
-
and torch.cuda.get_device_capability("cuda") >= (9, 0)
|
1743
|
-
and self.config.architectures[0] == architecture
|
1744
|
-
and self.config.n_routed_experts == 256
|
1745
|
-
and (
|
1746
|
-
not (
|
1747
|
-
global_server_args_dict["enable_deepep_moe"]
|
1748
|
-
or global_server_args_dict["enable_ep_moe"]
|
1749
|
-
)
|
1750
|
-
)
|
1751
|
-
):
|
1752
|
-
self.num_fused_shared_experts = self.config.n_shared_experts
|
1753
|
-
global_server_args_dict["disable_shared_experts_fusion"] = False
|
1754
|
-
log_info_on_rank0(
|
1755
|
-
logger,
|
1756
|
-
"Deepseek V3/R1 with fp8/fp4 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled.",
|
1757
|
-
)
|
1778
|
+
self.num_fused_shared_experts = 0
|
1779
|
+
if global_server_args_dict["disable_shared_experts_fusion"]:
|
1780
|
+
return
|
1781
|
+
|
1782
|
+
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
|
1783
|
+
disable_reason = None
|
1784
|
+
if (
|
1785
|
+
not _is_cuda
|
1786
|
+
or torch.cuda.get_device_capability("cuda") < (8, 0)
|
1787
|
+
or self.config.architectures[0] != architecture
|
1788
|
+
or self.config.n_routed_experts != 256
|
1789
|
+
or self.config.n_shared_experts != 1
|
1790
|
+
):
|
1791
|
+
disable_reason = "Only Deepseek V3/R1 on NV-platform with capability >= 80 can use shared experts fusion optimization."
|
1792
|
+
elif (
|
1793
|
+
global_server_args_dict["enable_deepep_moe"]
|
1794
|
+
or global_server_args_dict["enable_ep_moe"]
|
1795
|
+
):
|
1796
|
+
disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
|
1797
|
+
|
1798
|
+
if disable_reason is not None:
|
1799
|
+
global_server_args_dict["disable_shared_experts_fusion"] = True
|
1800
|
+
log_info_on_rank0(
|
1801
|
+
logger,
|
1802
|
+
f"{disable_reason} Shared experts fusion optimization is disabled.",
|
1803
|
+
)
|
1804
|
+
return
|
1805
|
+
|
1806
|
+
self.num_fused_shared_experts = self.config.n_shared_experts
|
1758
1807
|
|
1759
1808
|
def get_input_embeddings(self) -> nn.Embedding:
|
1760
1809
|
return self.model.embed_tokens
|
@@ -1786,8 +1835,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1786
1835
|
for name in weight_names:
|
1787
1836
|
if "kv_b_proj" in name:
|
1788
1837
|
layer_id = int(name.split(".")[2])
|
1789
|
-
|
1790
|
-
if layer_id != self.config.num_hidden_layers:
|
1838
|
+
if layer_id < self.config.num_hidden_layers:
|
1791
1839
|
layer_ids.add(layer_id)
|
1792
1840
|
|
1793
1841
|
for layer_id in layer_ids:
|
@@ -1847,8 +1895,10 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1847
1895
|
and weight_block_size[1] == 128
|
1848
1896
|
and model_dtype == torch.bfloat16
|
1849
1897
|
):
|
1850
|
-
if
|
1851
|
-
|
1898
|
+
if (
|
1899
|
+
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
1900
|
+
and not deep_gemm_wrapper.DEEPGEMM_BLACKWELL
|
1901
|
+
and get_bool_env_var("SGL_USE_DEEPGEMM_BMM", "false")
|
1852
1902
|
):
|
1853
1903
|
block_scale = weight_scale
|
1854
1904
|
use_deep_gemm_bmm = True
|
@@ -1932,6 +1982,71 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1932
1982
|
self_attn.w_vc = bind_or_assign(self_attn.w_vc, w_vc.contiguous())
|
1933
1983
|
self_attn.use_deep_gemm_bmm = True
|
1934
1984
|
|
1985
|
+
if (
|
1986
|
+
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
1987
|
+
and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
1988
|
+
and hasattr(self.quant_config, "weight_block_size")
|
1989
|
+
and self.quant_config.weight_block_size is not None
|
1990
|
+
):
|
1991
|
+
self._weight_requant_ue8m0(is_nextn)
|
1992
|
+
|
1993
|
+
def _weight_requant_ue8m0(self, is_nextn=False):
|
1994
|
+
weight_block_size = self.quant_config.weight_block_size
|
1995
|
+
|
1996
|
+
moe_layers = list(
|
1997
|
+
range(
|
1998
|
+
self.config.first_k_dense_replace,
|
1999
|
+
self.config.num_hidden_layers,
|
2000
|
+
self.config.moe_layer_freq,
|
2001
|
+
)
|
2002
|
+
)
|
2003
|
+
|
2004
|
+
num_hidden_layers = 1 if is_nextn else self.config.num_hidden_layers
|
2005
|
+
for layer_id in range(num_hidden_layers):
|
2006
|
+
if is_nextn:
|
2007
|
+
layer = self.model.decoder
|
2008
|
+
else:
|
2009
|
+
layer = self.model.layers[layer_id]
|
2010
|
+
|
2011
|
+
for module in [
|
2012
|
+
layer.self_attn.fused_qkv_a_proj_with_mqa,
|
2013
|
+
layer.self_attn.q_b_proj,
|
2014
|
+
layer.self_attn.kv_b_proj,
|
2015
|
+
layer.self_attn.o_proj,
|
2016
|
+
]:
|
2017
|
+
requant_weight_ue8m0_inplace(
|
2018
|
+
module.weight, module.weight_scale_inv, weight_block_size
|
2019
|
+
)
|
2020
|
+
|
2021
|
+
if layer_id in moe_layers or is_nextn:
|
2022
|
+
shared_experts = getattr(layer.mlp, "shared_experts", None)
|
2023
|
+
if shared_experts is not None:
|
2024
|
+
for module in [
|
2025
|
+
shared_experts.gate_up_proj,
|
2026
|
+
shared_experts.down_proj,
|
2027
|
+
]:
|
2028
|
+
requant_weight_ue8m0_inplace(
|
2029
|
+
module.weight, module.weight_scale_inv, weight_block_size
|
2030
|
+
)
|
2031
|
+
|
2032
|
+
experts = layer.mlp.experts
|
2033
|
+
if isinstance(experts, DeepEPMoE):
|
2034
|
+
for w in [
|
2035
|
+
experts.w13_weight_fp8,
|
2036
|
+
experts.w2_weight_fp8,
|
2037
|
+
]:
|
2038
|
+
requant_weight_ue8m0_inplace(w[0], w[1], weight_block_size)
|
2039
|
+
else:
|
2040
|
+
mlp = layer.mlp
|
2041
|
+
assert isinstance(mlp, DeepseekV2MLP)
|
2042
|
+
for module in [
|
2043
|
+
mlp.gate_up_proj,
|
2044
|
+
mlp.down_proj,
|
2045
|
+
]:
|
2046
|
+
requant_weight_ue8m0_inplace(
|
2047
|
+
module.weight, module.weight_scale_inv, weight_block_size
|
2048
|
+
)
|
2049
|
+
|
1935
2050
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
|
1936
2051
|
|
1937
2052
|
if is_nextn:
|
@@ -1952,101 +2067,6 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1952
2067
|
("gate_up_proj", "gate_proj", 0),
|
1953
2068
|
("gate_up_proj", "up_proj", 1),
|
1954
2069
|
]
|
1955
|
-
if self.num_fused_shared_experts > 0:
|
1956
|
-
assert self.num_fused_shared_experts == 1
|
1957
|
-
weights_list = list(weights)
|
1958
|
-
weights_dict = dict(weights_list)
|
1959
|
-
if self.quant_config is not None:
|
1960
|
-
if self.quant_config.get_name() == "w8a8_int8":
|
1961
|
-
suffix_list = [
|
1962
|
-
"down_proj.weight",
|
1963
|
-
"down_proj.weight_scale",
|
1964
|
-
"gate_proj.weight",
|
1965
|
-
"gate_proj.weight_scale",
|
1966
|
-
"up_proj.weight",
|
1967
|
-
"up_proj.weight_scale",
|
1968
|
-
]
|
1969
|
-
elif (
|
1970
|
-
self.quant_config.get_name() == "fp8"
|
1971
|
-
or self.quant_config.get_name() == "blockwise_int8"
|
1972
|
-
):
|
1973
|
-
suffix_list = [
|
1974
|
-
"down_proj.weight",
|
1975
|
-
"down_proj.weight_scale_inv",
|
1976
|
-
"gate_proj.weight",
|
1977
|
-
"gate_proj.weight_scale_inv",
|
1978
|
-
"up_proj.weight",
|
1979
|
-
"up_proj.weight_scale_inv",
|
1980
|
-
]
|
1981
|
-
elif self.quant_config.get_name() == "awq":
|
1982
|
-
suffix_list = [
|
1983
|
-
"down_proj.qweight",
|
1984
|
-
"down_proj.qzeros",
|
1985
|
-
"down_proj.scales",
|
1986
|
-
"gate_proj.qweight",
|
1987
|
-
"gate_proj.qzeros",
|
1988
|
-
"gate_proj.scales",
|
1989
|
-
"up_proj.qweight",
|
1990
|
-
"up_proj.qzeros",
|
1991
|
-
"up_proj.scales",
|
1992
|
-
]
|
1993
|
-
elif self.quant_config.get_name() == "modelopt_fp4":
|
1994
|
-
suffix_list = [
|
1995
|
-
"down_proj.weight",
|
1996
|
-
"down_proj.weight_scale",
|
1997
|
-
"down_proj.weight_scale_2",
|
1998
|
-
"down_proj.input_scale",
|
1999
|
-
"gate_proj.weight",
|
2000
|
-
"gate_proj.weight_scale",
|
2001
|
-
"gate_proj.weight_scale_2",
|
2002
|
-
"gate_proj.input_scale",
|
2003
|
-
"up_proj.weight",
|
2004
|
-
"up_proj.weight_scale",
|
2005
|
-
"up_proj.weight_scale_2",
|
2006
|
-
"up_proj.input_scale",
|
2007
|
-
]
|
2008
|
-
else:
|
2009
|
-
raise ValueError(
|
2010
|
-
f"Unsupported shared expert fusion for quantization: {self.quant_config.get_name()}."
|
2011
|
-
)
|
2012
|
-
else:
|
2013
|
-
suffix_list = [
|
2014
|
-
"down_proj.weight",
|
2015
|
-
"gate_proj.weight",
|
2016
|
-
"up_proj.weight",
|
2017
|
-
]
|
2018
|
-
names_to_remove = []
|
2019
|
-
|
2020
|
-
moe_layers = (
|
2021
|
-
range(
|
2022
|
-
self.config.first_k_dense_replace,
|
2023
|
-
self.config.num_hidden_layers,
|
2024
|
-
self.config.moe_layer_freq,
|
2025
|
-
)
|
2026
|
-
if not is_nextn
|
2027
|
-
else [nextn_layer_id]
|
2028
|
-
)
|
2029
|
-
|
2030
|
-
for moe_layer in tqdm(
|
2031
|
-
moe_layers,
|
2032
|
-
desc=f"Cloning {self.num_fused_shared_experts} "
|
2033
|
-
"shared expert into MoE",
|
2034
|
-
):
|
2035
|
-
for suffix in suffix_list:
|
2036
|
-
shared_expert_weight_name = (
|
2037
|
-
f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
|
2038
|
-
)
|
2039
|
-
weights_list.append(
|
2040
|
-
(
|
2041
|
-
f"model.layers.{moe_layer}."
|
2042
|
-
f"mlp.experts."
|
2043
|
-
f"{self.config.n_routed_experts + 0}"
|
2044
|
-
f".{suffix}",
|
2045
|
-
weights_dict[shared_expert_weight_name],
|
2046
|
-
)
|
2047
|
-
)
|
2048
|
-
names_to_remove += [shared_expert_weight_name]
|
2049
|
-
weights = [w for w in weights_list if w[0] not in names_to_remove]
|
2050
2070
|
|
2051
2071
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
2052
2072
|
# (param_name, weight_name, expert_id, shard_id)
|
@@ -2072,9 +2092,19 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2072
2092
|
"hnorm",
|
2073
2093
|
]
|
2074
2094
|
|
2095
|
+
if self.num_fused_shared_experts > 0:
|
2096
|
+
assert self.num_fused_shared_experts == 1
|
2097
|
+
log_info_on_rank0(logger, "Shared experts fusion optimization enabled.")
|
2098
|
+
|
2075
2099
|
params_dict = dict(self.named_parameters())
|
2076
2100
|
weight_names = []
|
2077
2101
|
for name, loaded_weight in weights:
|
2102
|
+
if self.num_fused_shared_experts > 0 and "mlp.shared_experts" in name:
|
2103
|
+
name = name.replace(
|
2104
|
+
"mlp.shared_experts",
|
2105
|
+
f"mlp.experts.{self.config.n_routed_experts}",
|
2106
|
+
)
|
2107
|
+
|
2078
2108
|
weight_names.append(name)
|
2079
2109
|
|
2080
2110
|
if not is_nextn:
|
@@ -2170,8 +2200,14 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2170
2200
|
):
|
2171
2201
|
q_a_proj_weight = cached_a_proj[q_a_proj_name]
|
2172
2202
|
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
|
2203
|
+
cat_dim = 0
|
2204
|
+
if self.quant_config is not None and (
|
2205
|
+
self.quant_config.get_name() == "awq"
|
2206
|
+
or self.quant_config.get_name() == "moe_wna16"
|
2207
|
+
):
|
2208
|
+
cat_dim = 1
|
2173
2209
|
fused_weight = torch.cat(
|
2174
|
-
[q_a_proj_weight, kv_a_proj_weight], dim=
|
2210
|
+
[q_a_proj_weight, kv_a_proj_weight], dim=cat_dim
|
2175
2211
|
)
|
2176
2212
|
param_name = (
|
2177
2213
|
name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
|
@@ -2193,12 +2229,16 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2193
2229
|
"k_scale" in name or "v_scale" in name
|
2194
2230
|
) and name not in params_dict:
|
2195
2231
|
# modelopt attn kv scale is named differently
|
2196
|
-
|
2197
|
-
|
2198
|
-
|
2199
|
-
|
2200
|
-
|
2201
|
-
|
2232
|
+
for scale in ["k_scale", "v_scale"]:
|
2233
|
+
if scale in name:
|
2234
|
+
name = name.replace(f"{scale[0]}_proj", "attn_mqa")
|
2235
|
+
break
|
2236
|
+
if name not in params_dict:
|
2237
|
+
# modelopt ckpt contains not needed weights for MTP module:
|
2238
|
+
# model.decoder.self_attn.attn_mqa.v_scale and
|
2239
|
+
# model.decoder.self_attn.attn_mqa.k_scale
|
2240
|
+
logger.warning(f"{name} not found in params_dict.")
|
2241
|
+
continue
|
2202
2242
|
param = params_dict[name]
|
2203
2243
|
weight_loader = getattr(
|
2204
2244
|
param, "weight_loader", default_weight_loader
|