sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +26 -4
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +676 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +49 -8
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/parallel_state.py +42 -8
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +78 -13
- sglang/srt/entrypoints/verl_engine.py +2 -0
- sglang/srt/function_call_parser.py +133 -55
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +434 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +41 -19
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +25 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/topk.py +60 -20
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +80 -53
- sglang/srt/layers/quantization/awq.py +200 -0
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +25 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -19
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +78 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/backend/base_backend.py +4 -4
- sglang/srt/lora/backend/flashinfer_backend.py +12 -9
- sglang/srt/lora/backend/triton_backend.py +5 -8
- sglang/srt/lora/layers.py +87 -33
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +67 -30
- sglang/srt/lora/mem_pool.py +117 -52
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
- sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
- sglang/srt/lora/utils.py +18 -1
- sglang/srt/managers/cache_controller.py +2 -5
- sglang/srt/managers/data_parallel_controller.py +30 -8
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +43 -5
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/clip.py +63 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +134 -30
- sglang/srt/managers/scheduler.py +290 -31
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -24
- sglang/srt/managers/tp_worker.py +4 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +18 -7
- sglang/srt/mem_cache/memory_pool.py +255 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +36 -21
- sglang/srt/model_executor/forward_batch_info.py +68 -11
- sglang/srt/model_executor/model_runner.py +75 -8
- sglang/srt/model_loader/loader.py +171 -3
- sglang/srt/model_loader/weight_utils.py +51 -3
- sglang/srt/models/clip.py +563 -0
- sglang/srt/models/deepseek_janus_pro.py +31 -88
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +329 -73
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +694 -0
- sglang/srt/models/gemma3_mm.py +468 -0
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +201 -104
- sglang/srt/openai_api/protocol.py +33 -7
- sglang/srt/patch_torch.py +71 -0
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +114 -14
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +140 -54
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +215 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +29 -2
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +56 -5
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +16 -8
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +180 -132
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -23,10 +23,10 @@ import torch
|
|
23
23
|
import torch.nn.functional as F
|
24
24
|
from torch import nn
|
25
25
|
from transformers import PretrainedConfig
|
26
|
-
from vllm import _custom_ops as ops
|
27
26
|
|
28
27
|
from sglang.srt.distributed import (
|
29
28
|
get_tensor_model_parallel_world_size,
|
29
|
+
parallel_state,
|
30
30
|
tensor_model_parallel_all_reduce,
|
31
31
|
)
|
32
32
|
from sglang.srt.layers.activation import SiluAndMul
|
@@ -34,11 +34,13 @@ from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
|
|
34
34
|
decode_attention_fwd_grouped_rope,
|
35
35
|
)
|
36
36
|
from sglang.srt.layers.dp_attention import (
|
37
|
-
|
37
|
+
dp_gather_partial,
|
38
38
|
dp_scatter,
|
39
39
|
get_attention_dp_size,
|
40
40
|
get_attention_tp_rank,
|
41
41
|
get_attention_tp_size,
|
42
|
+
tp_all_gather,
|
43
|
+
tp_reduce_scatter,
|
42
44
|
)
|
43
45
|
from sglang.srt.layers.layernorm import RMSNorm
|
44
46
|
from sglang.srt.layers.linear import (
|
@@ -48,8 +50,10 @@ from sglang.srt.layers.linear import (
|
|
48
50
|
RowParallelLinear,
|
49
51
|
)
|
50
52
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
51
|
-
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
53
|
+
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE
|
54
|
+
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
52
55
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
56
|
+
from sglang.srt.layers.moe.topk import select_experts
|
53
57
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
54
58
|
from sglang.srt.layers.quantization.fp8_utils import (
|
55
59
|
block_quant_to_tensor_quant,
|
@@ -65,15 +69,21 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
65
69
|
ParallelLMHead,
|
66
70
|
VocabParallelEmbedding,
|
67
71
|
)
|
72
|
+
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
68
73
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
69
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
74
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
70
75
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
71
|
-
from sglang.srt.utils import add_prefix,
|
76
|
+
from sglang.srt.utils import add_prefix, is_cuda, is_hip
|
72
77
|
|
73
78
|
_is_hip = is_hip()
|
79
|
+
_is_cuda = is_cuda()
|
74
80
|
|
75
|
-
if
|
76
|
-
from sgl_kernel import bmm_fp8
|
81
|
+
if _is_cuda:
|
82
|
+
from sgl_kernel import awq_dequantize, bmm_fp8
|
83
|
+
else:
|
84
|
+
from vllm import _custom_ops as ops
|
85
|
+
|
86
|
+
expert_distribution_recorder = ExpertDistributionRecorder()
|
77
87
|
|
78
88
|
|
79
89
|
class DeepseekV2MLP(nn.Module):
|
@@ -85,6 +95,8 @@ class DeepseekV2MLP(nn.Module):
|
|
85
95
|
quant_config: Optional[QuantizationConfig] = None,
|
86
96
|
reduce_results: bool = True,
|
87
97
|
prefix: str = "",
|
98
|
+
tp_rank: Optional[int] = None,
|
99
|
+
tp_size: Optional[int] = None,
|
88
100
|
) -> None:
|
89
101
|
super().__init__()
|
90
102
|
self.gate_up_proj = MergedColumnParallelLinear(
|
@@ -93,6 +105,8 @@ class DeepseekV2MLP(nn.Module):
|
|
93
105
|
bias=False,
|
94
106
|
quant_config=quant_config,
|
95
107
|
prefix=add_prefix("gate_up_proj", prefix),
|
108
|
+
tp_rank=tp_rank,
|
109
|
+
tp_size=tp_size,
|
96
110
|
)
|
97
111
|
self.down_proj = RowParallelLinear(
|
98
112
|
intermediate_size,
|
@@ -101,6 +115,8 @@ class DeepseekV2MLP(nn.Module):
|
|
101
115
|
quant_config=quant_config,
|
102
116
|
reduce_results=reduce_results,
|
103
117
|
prefix=add_prefix("down_proj", prefix),
|
118
|
+
tp_rank=tp_rank,
|
119
|
+
tp_size=tp_size,
|
104
120
|
)
|
105
121
|
if hidden_act != "silu":
|
106
122
|
raise ValueError(
|
@@ -165,7 +181,11 @@ class DeepseekV2MoE(nn.Module):
|
|
165
181
|
|
166
182
|
self.gate = MoEGate(config=config, prefix=add_prefix("gate", prefix))
|
167
183
|
|
168
|
-
MoEImpl =
|
184
|
+
MoEImpl = (
|
185
|
+
DeepEPMoE
|
186
|
+
if global_server_args_dict["enable_deepep_moe"]
|
187
|
+
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
|
188
|
+
)
|
169
189
|
self.experts = MoEImpl(
|
170
190
|
num_experts=config.n_routed_experts,
|
171
191
|
top_k=config.num_experts_per_tok,
|
@@ -182,18 +202,60 @@ class DeepseekV2MoE(nn.Module):
|
|
182
202
|
|
183
203
|
if config.n_shared_experts is not None:
|
184
204
|
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
185
|
-
|
205
|
+
# disable tp for shared experts when enable deepep moe
|
206
|
+
if not global_server_args_dict["enable_deepep_moe"]:
|
207
|
+
self.shared_experts = DeepseekV2MLP(
|
208
|
+
hidden_size=config.hidden_size,
|
209
|
+
intermediate_size=intermediate_size,
|
210
|
+
hidden_act=config.hidden_act,
|
211
|
+
quant_config=quant_config,
|
212
|
+
reduce_results=False,
|
213
|
+
prefix=add_prefix("shared_experts", prefix),
|
214
|
+
)
|
215
|
+
else:
|
216
|
+
self.shared_experts = DeepseekV2MLP(
|
217
|
+
hidden_size=config.hidden_size,
|
218
|
+
intermediate_size=intermediate_size,
|
219
|
+
hidden_act=config.hidden_act,
|
220
|
+
quant_config=quant_config,
|
221
|
+
reduce_results=False,
|
222
|
+
prefix=add_prefix("shared_experts", prefix),
|
223
|
+
tp_rank=0,
|
224
|
+
tp_size=1,
|
225
|
+
)
|
226
|
+
|
227
|
+
if global_server_args_dict["enable_deepep_moe"]:
|
228
|
+
self.num_experts = config.n_routed_experts
|
229
|
+
self.top_k = config.num_experts_per_tok
|
230
|
+
self.renormalize = config.norm_topk_prob
|
231
|
+
self.topk_group = config.topk_group
|
232
|
+
self.num_expert_group = config.n_group
|
233
|
+
self.correction_bias = (
|
234
|
+
self.gate.e_score_correction_bias.data
|
235
|
+
if self.gate.e_score_correction_bias is not None
|
236
|
+
else None
|
237
|
+
)
|
238
|
+
|
239
|
+
self.deepep_dispatcher = DeepEPDispatcher(
|
240
|
+
group=parallel_state.get_tp_group().device_group,
|
241
|
+
router_topk=self.top_k,
|
242
|
+
permute_fusion=True,
|
243
|
+
num_experts=config.n_routed_experts,
|
244
|
+
num_local_experts=config.n_routed_experts // self.tp_size,
|
186
245
|
hidden_size=config.hidden_size,
|
187
|
-
|
188
|
-
|
189
|
-
quant_config=quant_config,
|
190
|
-
reduce_results=False,
|
191
|
-
prefix=add_prefix("shared_experts", prefix),
|
246
|
+
params_dtype=config.torch_dtype,
|
247
|
+
async_finish=True, # TODO
|
192
248
|
)
|
193
249
|
|
194
|
-
def forward(
|
195
|
-
|
196
|
-
|
250
|
+
def forward(
|
251
|
+
self, hidden_states: torch.Tensor, forward_mode: Optional[ForwardMode] = None
|
252
|
+
) -> torch.Tensor:
|
253
|
+
if not global_server_args_dict["enable_deepep_moe"]:
|
254
|
+
return self.forward_normal(hidden_states)
|
255
|
+
else:
|
256
|
+
return self.forward_deepep(hidden_states, forward_mode)
|
257
|
+
|
258
|
+
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
197
259
|
if self.n_shared_experts is not None:
|
198
260
|
shared_output = self.shared_experts(hidden_states)
|
199
261
|
# router_logits: (num_tokens, n_experts)
|
@@ -206,8 +268,64 @@ class DeepseekV2MoE(nn.Module):
|
|
206
268
|
final_hidden_states = final_hidden_states + shared_output
|
207
269
|
if self.tp_size > 1:
|
208
270
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
271
|
+
return final_hidden_states
|
272
|
+
|
273
|
+
def forward_deepep(
|
274
|
+
self, hidden_states: torch.Tensor, forward_mode: ForwardMode
|
275
|
+
) -> torch.Tensor:
|
276
|
+
shared_output = None
|
277
|
+
topk_idx = torch.full(
|
278
|
+
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
|
279
|
+
)
|
280
|
+
topk_weights = torch.empty(
|
281
|
+
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
|
282
|
+
)
|
283
|
+
if (
|
284
|
+
forward_mode is not None
|
285
|
+
and not forward_mode.is_idle()
|
286
|
+
and hidden_states.shape[0] > 0
|
287
|
+
):
|
288
|
+
# router_logits: (num_tokens, n_experts)
|
289
|
+
router_logits = self.gate(hidden_states)
|
290
|
+
if self.n_shared_experts is not None:
|
291
|
+
shared_output = self.shared_experts(hidden_states)
|
292
|
+
topk_weights, topk_idx = select_experts(
|
293
|
+
hidden_states=hidden_states,
|
294
|
+
router_logits=router_logits,
|
295
|
+
top_k=self.top_k,
|
296
|
+
use_grouped_topk=True,
|
297
|
+
renormalize=self.renormalize,
|
298
|
+
topk_group=self.topk_group,
|
299
|
+
num_expert_group=self.num_expert_group,
|
300
|
+
correction_bias=self.correction_bias,
|
301
|
+
)
|
302
|
+
if self.tp_size > 1:
|
303
|
+
recv_hidden_states, reorder_topk_ids, seg_indptr = (
|
304
|
+
self.deepep_dispatcher.dispatch(
|
305
|
+
hidden_states,
|
306
|
+
topk_idx,
|
307
|
+
topk_weights,
|
308
|
+
self.num_experts,
|
309
|
+
forward_mode,
|
310
|
+
)
|
311
|
+
)
|
312
|
+
final_hidden_states = (
|
313
|
+
self.experts(
|
314
|
+
hidden_states=recv_hidden_states,
|
315
|
+
reorder_topk_ids=reorder_topk_ids,
|
316
|
+
seg_indptr=seg_indptr,
|
317
|
+
forward_mode=forward_mode,
|
318
|
+
)
|
319
|
+
* self.routed_scaling_factor
|
320
|
+
)
|
321
|
+
if self.tp_size > 1:
|
322
|
+
final_hidden_states = self.deepep_dispatcher.combine(
|
323
|
+
final_hidden_states, forward_mode
|
324
|
+
)
|
325
|
+
if shared_output is not None:
|
326
|
+
final_hidden_states = final_hidden_states + shared_output
|
209
327
|
|
210
|
-
return final_hidden_states
|
328
|
+
return final_hidden_states
|
211
329
|
|
212
330
|
|
213
331
|
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
@@ -537,6 +655,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
537
655
|
self.flashinfer_mla_disable_ragged = global_server_args_dict[
|
538
656
|
"flashinfer_mla_disable_ragged"
|
539
657
|
]
|
658
|
+
self.attention_backend = global_server_args_dict["attention_backend"]
|
540
659
|
self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
|
541
660
|
|
542
661
|
def no_absorb(self, forward_batch: ForwardBatch) -> bool:
|
@@ -547,15 +666,18 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
547
666
|
and forward_batch.forward_mode.is_extend()
|
548
667
|
and not forward_batch.forward_mode.is_target_verify()
|
549
668
|
and not forward_batch.forward_mode.is_draft_extend()
|
550
|
-
and forward_batch.
|
669
|
+
and sum(forward_batch.extend_prefix_lens_cpu) == 0
|
551
670
|
)
|
671
|
+
elif self.attention_backend == "fa3":
|
672
|
+
# Flash Attention: Keep absorbing for all extend/decode
|
673
|
+
return False
|
552
674
|
else:
|
553
675
|
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
|
554
676
|
return (
|
555
677
|
forward_batch.forward_mode.is_extend()
|
556
678
|
and not forward_batch.forward_mode.is_target_verify()
|
557
679
|
and not forward_batch.forward_mode.is_draft_extend()
|
558
|
-
and forward_batch.
|
680
|
+
and sum(forward_batch.extend_prefix_lens_cpu) == 0
|
559
681
|
)
|
560
682
|
|
561
683
|
def forward(
|
@@ -857,6 +979,14 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
857
979
|
is_nextn: bool = False,
|
858
980
|
prefix: str = "",
|
859
981
|
) -> None:
|
982
|
+
|
983
|
+
def is_sparse_layer(l: int):
|
984
|
+
return (
|
985
|
+
config.n_routed_experts is not None
|
986
|
+
and l >= config.first_k_dense_replace
|
987
|
+
and l % config.moe_layer_freq == 0
|
988
|
+
)
|
989
|
+
|
860
990
|
super().__init__()
|
861
991
|
self.hidden_size = config.hidden_size
|
862
992
|
rope_theta = getattr(config, "rope_theta", 10000)
|
@@ -865,6 +995,8 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
865
995
|
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
|
866
996
|
self.layer_id = layer_id
|
867
997
|
self.dp_size = get_attention_dp_size()
|
998
|
+
self.attn_tp_size = get_attention_tp_size()
|
999
|
+
self.attn_tp_rank = get_attention_tp_rank()
|
868
1000
|
|
869
1001
|
if not global_server_args_dict["disable_mla"]:
|
870
1002
|
self.self_attn = DeepseekV2AttentionMLA(
|
@@ -907,16 +1039,13 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
907
1039
|
prefix=add_prefix("self_attn", prefix),
|
908
1040
|
)
|
909
1041
|
|
910
|
-
if is_nextn or (
|
911
|
-
config.n_routed_experts is not None
|
912
|
-
and layer_id >= config.first_k_dense_replace
|
913
|
-
and layer_id % config.moe_layer_freq == 0
|
914
|
-
):
|
1042
|
+
if is_nextn or is_sparse_layer(layer_id):
|
915
1043
|
self.mlp = DeepseekV2MoE(
|
916
1044
|
config=config,
|
917
1045
|
quant_config=quant_config,
|
918
1046
|
prefix=add_prefix("mlp", prefix),
|
919
1047
|
)
|
1048
|
+
self.is_sparse = True
|
920
1049
|
else:
|
921
1050
|
self.mlp = DeepseekV2MLP(
|
922
1051
|
hidden_size=config.hidden_size,
|
@@ -925,6 +1054,14 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
925
1054
|
quant_config=quant_config,
|
926
1055
|
prefix=add_prefix("mlp", prefix),
|
927
1056
|
)
|
1057
|
+
self.is_sparse = False
|
1058
|
+
|
1059
|
+
self.input_is_scattered = (
|
1060
|
+
is_sparse_layer(layer_id - 1)
|
1061
|
+
and global_server_args_dict["enable_deepep_moe"]
|
1062
|
+
)
|
1063
|
+
self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
|
1064
|
+
|
928
1065
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
929
1066
|
self.post_attention_layernorm = RMSNorm(
|
930
1067
|
config.hidden_size, eps=config.rms_norm_eps
|
@@ -937,12 +1074,82 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
937
1074
|
forward_batch: ForwardBatch,
|
938
1075
|
residual: Optional[torch.Tensor],
|
939
1076
|
) -> torch.Tensor:
|
940
|
-
if
|
1077
|
+
if global_server_args_dict["enable_deepep_moe"] and self.is_sparse:
|
1078
|
+
return self.forward_deepep(
|
1079
|
+
positions, hidden_states, forward_batch, residual
|
1080
|
+
)
|
1081
|
+
else:
|
1082
|
+
return self.forward_normal(
|
1083
|
+
positions, hidden_states, forward_batch, residual
|
1084
|
+
)
|
1085
|
+
|
1086
|
+
def forward_normal(
|
1087
|
+
self,
|
1088
|
+
positions: torch.Tensor,
|
1089
|
+
hidden_states: torch.Tensor,
|
1090
|
+
forward_batch: ForwardBatch,
|
1091
|
+
residual: Optional[torch.Tensor],
|
1092
|
+
) -> torch.Tensor:
|
1093
|
+
|
1094
|
+
if hidden_states.shape[0] == 0:
|
941
1095
|
residual = hidden_states
|
942
|
-
hidden_states = self.input_layernorm(hidden_states)
|
943
1096
|
else:
|
944
|
-
|
1097
|
+
if residual is None:
|
1098
|
+
residual = hidden_states
|
1099
|
+
hidden_states = self.input_layernorm(hidden_states)
|
1100
|
+
else:
|
1101
|
+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
1102
|
+
|
1103
|
+
# Self Attention
|
1104
|
+
hidden_states = self.self_attn(
|
1105
|
+
positions=positions,
|
1106
|
+
hidden_states=hidden_states,
|
1107
|
+
forward_batch=forward_batch,
|
1108
|
+
)
|
1109
|
+
|
1110
|
+
if self.attn_tp_size != 1 and self.input_is_scattered:
|
1111
|
+
hidden_states, local_hidden_states = (
|
1112
|
+
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
1113
|
+
hidden_states,
|
1114
|
+
)
|
1115
|
+
tp_all_gather(
|
1116
|
+
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
|
1117
|
+
)
|
1118
|
+
residual, local_residual = (
|
1119
|
+
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
1120
|
+
residual,
|
1121
|
+
)
|
1122
|
+
tp_all_gather(
|
1123
|
+
list(residual.tensor_split(self.attn_tp_size)), local_residual
|
1124
|
+
)
|
1125
|
+
|
1126
|
+
# Gather
|
1127
|
+
if get_tensor_model_parallel_world_size() > 1:
|
1128
|
+
# all gather and all reduce
|
1129
|
+
if self.dp_size != 1:
|
1130
|
+
if self.attn_tp_rank == 0:
|
1131
|
+
hidden_states += residual
|
1132
|
+
hidden_states, local_hidden_states = (
|
1133
|
+
forward_batch.gathered_buffer,
|
1134
|
+
hidden_states,
|
1135
|
+
)
|
1136
|
+
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
|
1137
|
+
dp_scatter(residual, hidden_states, forward_batch)
|
1138
|
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
1139
|
+
else:
|
1140
|
+
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
1141
|
+
hidden_states, residual = self.post_attention_layernorm(
|
1142
|
+
hidden_states, residual
|
1143
|
+
)
|
1144
|
+
else:
|
1145
|
+
hidden_states, residual = self.post_attention_layernorm(
|
1146
|
+
hidden_states, residual
|
1147
|
+
)
|
945
1148
|
|
1149
|
+
# Fully Connected
|
1150
|
+
hidden_states = self.mlp(hidden_states)
|
1151
|
+
|
1152
|
+
# TODO(ch-wan): ues reduce-scatter in MLP to avoid this scatter
|
946
1153
|
# Scatter
|
947
1154
|
if self.dp_size != 1:
|
948
1155
|
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
@@ -953,6 +1160,34 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
953
1160
|
)
|
954
1161
|
dp_scatter(hidden_states, global_hidden_states, forward_batch)
|
955
1162
|
|
1163
|
+
return hidden_states, residual
|
1164
|
+
|
1165
|
+
def forward_deepep(
|
1166
|
+
self,
|
1167
|
+
positions: torch.Tensor,
|
1168
|
+
hidden_states: torch.Tensor,
|
1169
|
+
forward_batch: ForwardBatch,
|
1170
|
+
residual: Optional[torch.Tensor],
|
1171
|
+
) -> torch.Tensor:
|
1172
|
+
|
1173
|
+
if hidden_states.shape[0] == 0:
|
1174
|
+
residual = hidden_states
|
1175
|
+
else:
|
1176
|
+
if residual is None:
|
1177
|
+
residual = hidden_states
|
1178
|
+
hidden_states = self.input_layernorm(hidden_states)
|
1179
|
+
else:
|
1180
|
+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
1181
|
+
|
1182
|
+
if self.attn_tp_size != 1 and self.input_is_scattered:
|
1183
|
+
hidden_states, local_hidden_states = (
|
1184
|
+
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
1185
|
+
hidden_states,
|
1186
|
+
)
|
1187
|
+
tp_all_gather(
|
1188
|
+
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
|
1189
|
+
)
|
1190
|
+
|
956
1191
|
# Self Attention
|
957
1192
|
hidden_states = self.self_attn(
|
958
1193
|
positions=positions,
|
@@ -960,24 +1195,47 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
960
1195
|
forward_batch=forward_batch,
|
961
1196
|
)
|
962
1197
|
|
963
|
-
|
964
|
-
|
965
|
-
|
966
|
-
|
967
|
-
hidden_states,
|
968
|
-
|
969
|
-
hidden_states,
|
970
|
-
|
971
|
-
|
972
|
-
hidden_states, local_hidden_states, forward_batch, self.layer_id
|
973
|
-
)
|
1198
|
+
if self.attn_tp_size != 1:
|
1199
|
+
if self.input_is_scattered:
|
1200
|
+
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
|
1201
|
+
hidden_states = tensor_list[self.attn_tp_rank]
|
1202
|
+
tp_reduce_scatter(hidden_states, tensor_list)
|
1203
|
+
if hidden_states.shape[0] != 0:
|
1204
|
+
hidden_states, residual = self.post_attention_layernorm(
|
1205
|
+
hidden_states, residual
|
1206
|
+
)
|
974
1207
|
else:
|
975
|
-
|
1208
|
+
if self.attn_tp_rank == 0:
|
1209
|
+
hidden_states += residual
|
1210
|
+
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
|
1211
|
+
hidden_states = tensor_list[self.attn_tp_rank]
|
1212
|
+
tp_reduce_scatter(hidden_states, tensor_list)
|
1213
|
+
residual = hidden_states
|
1214
|
+
if hidden_states.shape[0] != 0:
|
1215
|
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
1216
|
+
else:
|
1217
|
+
if hidden_states.shape[0] != 0:
|
1218
|
+
hidden_states, residual = self.post_attention_layernorm(
|
1219
|
+
hidden_states, residual
|
1220
|
+
)
|
1221
|
+
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
|
976
1222
|
|
977
|
-
|
1223
|
+
if self.is_last_layer and self.attn_tp_size != 1:
|
1224
|
+
hidden_states, local_hidden_states = (
|
1225
|
+
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
1226
|
+
hidden_states,
|
1227
|
+
)
|
1228
|
+
tp_all_gather(
|
1229
|
+
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
|
1230
|
+
)
|
1231
|
+
residual, local_residual = (
|
1232
|
+
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
1233
|
+
residual,
|
1234
|
+
)
|
1235
|
+
tp_all_gather(
|
1236
|
+
list(residual.tensor_split(self.attn_tp_size)), local_residual
|
1237
|
+
)
|
978
1238
|
|
979
|
-
# Fully Connected
|
980
|
-
hidden_states = self.mlp(hidden_states)
|
981
1239
|
return hidden_states, residual
|
982
1240
|
|
983
1241
|
|
@@ -1020,23 +1278,17 @@ class DeepseekV2Model(nn.Module):
|
|
1020
1278
|
input_ids: torch.Tensor,
|
1021
1279
|
positions: torch.Tensor,
|
1022
1280
|
forward_batch: ForwardBatch,
|
1281
|
+
input_embeds: torch.Tensor = None,
|
1023
1282
|
) -> torch.Tensor:
|
1024
1283
|
|
1025
|
-
|
1026
|
-
|
1027
|
-
|
1028
|
-
|
1029
|
-
(forward_batch.gathered_buffer.shape[0],),
|
1030
|
-
dtype=input_ids.dtype,
|
1031
|
-
device=input_ids.device,
|
1032
|
-
),
|
1033
|
-
input_ids,
|
1034
|
-
)
|
1035
|
-
dp_gather(input_ids, local_input_ids, forward_batch, "embedding")
|
1284
|
+
if input_embeds is None:
|
1285
|
+
hidden_states = self.embed_tokens(input_ids)
|
1286
|
+
else:
|
1287
|
+
hidden_states = input_embeds
|
1036
1288
|
|
1037
|
-
hidden_states = self.embed_tokens(input_ids)
|
1038
1289
|
residual = None
|
1039
1290
|
for i in range(len(self.layers)):
|
1291
|
+
expert_distribution_recorder.set_current_layer(i)
|
1040
1292
|
layer = self.layers[i]
|
1041
1293
|
hidden_states, residual = layer(
|
1042
1294
|
positions, hidden_states, forward_batch, residual
|
@@ -1075,17 +1327,10 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1075
1327
|
input_ids: torch.Tensor,
|
1076
1328
|
positions: torch.Tensor,
|
1077
1329
|
forward_batch: ForwardBatch,
|
1330
|
+
input_embeds: torch.Tensor = None,
|
1078
1331
|
) -> torch.Tensor:
|
1079
|
-
hidden_states = self.model(input_ids, positions, forward_batch)
|
1080
1332
|
|
1081
|
-
|
1082
|
-
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
1083
|
-
# be careful about this!
|
1084
|
-
hidden_states, global_hidden_states = (
|
1085
|
-
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
1086
|
-
hidden_states,
|
1087
|
-
)
|
1088
|
-
dp_scatter(hidden_states, global_hidden_states, forward_batch)
|
1333
|
+
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
1089
1334
|
|
1090
1335
|
return self.logits_processor(
|
1091
1336
|
input_ids, hidden_states, self.lm_head, forward_batch
|
@@ -1100,7 +1345,11 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1100
1345
|
|
1101
1346
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
1102
1347
|
# (param_name, weight_name, expert_id, shard_id)
|
1103
|
-
MoEImpl =
|
1348
|
+
MoEImpl = (
|
1349
|
+
DeepEPMoE
|
1350
|
+
if global_server_args_dict["enable_deepep_moe"]
|
1351
|
+
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
|
1352
|
+
)
|
1104
1353
|
expert_params_mapping = MoEImpl.make_expert_params_mapping(
|
1105
1354
|
ckpt_gate_proj_name="gate_proj",
|
1106
1355
|
ckpt_down_proj_name="down_proj",
|
@@ -1174,14 +1423,21 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1174
1423
|
self_attn = self.model.layers[layer_id].self_attn
|
1175
1424
|
if hasattr(self_attn.kv_b_proj, "qweight"):
|
1176
1425
|
# AWQ compatible
|
1177
|
-
|
1178
|
-
|
1179
|
-
|
1180
|
-
|
1181
|
-
|
1182
|
-
|
1183
|
-
|
1184
|
-
|
1426
|
+
if _is_cuda:
|
1427
|
+
w = awq_dequantize(
|
1428
|
+
self_attn.kv_b_proj.qweight,
|
1429
|
+
self_attn.kv_b_proj.scales,
|
1430
|
+
self_attn.kv_b_proj.qzeros,
|
1431
|
+
).T
|
1432
|
+
else:
|
1433
|
+
w = ops.awq_dequantize(
|
1434
|
+
self_attn.kv_b_proj.qweight,
|
1435
|
+
self_attn.kv_b_proj.scales,
|
1436
|
+
self_attn.kv_b_proj.qzeros,
|
1437
|
+
0,
|
1438
|
+
0,
|
1439
|
+
0,
|
1440
|
+
).T
|
1185
1441
|
else:
|
1186
1442
|
w = self_attn.kv_b_proj.weight
|
1187
1443
|
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
|