sglang 0.4.4__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +3 -1
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +667 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +63 -11
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/parallel_state.py +10 -3
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +71 -12
- sglang/srt/function_call_parser.py +164 -54
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +295 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +62 -23
- sglang/srt/layers/elementwise.py +411 -0
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +26 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/router.py +342 -0
- sglang/srt/layers/moe/topk.py +31 -18
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +184 -126
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +24 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -9
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +66 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/layers.py +68 -0
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +47 -23
- sglang/srt/lora/mem_pool.py +110 -51
- sglang/srt/lora/utils.py +12 -1
- sglang/srt/managers/cache_controller.py +4 -5
- sglang/srt/managers/data_parallel_controller.py +31 -9
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +39 -3
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +134 -31
- sglang/srt/managers/scheduler.py +325 -38
- sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -23
- sglang/srt/managers/tp_worker.py +1 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +27 -8
- sglang/srt/mem_cache/memory_pool.py +258 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +85 -28
- sglang/srt/model_executor/forward_batch_info.py +81 -15
- sglang/srt/model_executor/model_runner.py +70 -6
- sglang/srt/model_loader/loader.py +160 -2
- sglang/srt/model_loader/weight_utils.py +45 -0
- sglang/srt/models/deepseek_janus_pro.py +29 -86
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +326 -192
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +684 -0
- sglang/srt/models/gemma3_mm.py +462 -0
- sglang/srt/models/grok.py +374 -119
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +145 -47
- sglang/srt/openai_api/protocol.py +23 -2
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +104 -14
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +139 -53
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +182 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +2 -0
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +55 -4
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +171 -125
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -23,18 +23,23 @@ import torch
|
|
23
23
|
import torch.nn.functional as F
|
24
24
|
from torch import nn
|
25
25
|
from transformers import PretrainedConfig
|
26
|
-
from vllm import _custom_ops as ops
|
27
26
|
|
28
27
|
from sglang.srt.distributed import (
|
29
|
-
get_tensor_model_parallel_rank,
|
30
28
|
get_tensor_model_parallel_world_size,
|
31
|
-
|
29
|
+
parallel_state,
|
32
30
|
tensor_model_parallel_all_reduce,
|
33
31
|
)
|
34
32
|
from sglang.srt.layers.activation import SiluAndMul
|
35
33
|
from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
|
36
34
|
decode_attention_fwd_grouped_rope,
|
37
35
|
)
|
36
|
+
from sglang.srt.layers.dp_attention import (
|
37
|
+
dp_gather_partial,
|
38
|
+
dp_scatter,
|
39
|
+
get_attention_dp_size,
|
40
|
+
get_attention_tp_rank,
|
41
|
+
get_attention_tp_size,
|
42
|
+
)
|
38
43
|
from sglang.srt.layers.layernorm import RMSNorm
|
39
44
|
from sglang.srt.layers.linear import (
|
40
45
|
ColumnParallelLinear,
|
@@ -43,8 +48,10 @@ from sglang.srt.layers.linear import (
|
|
43
48
|
RowParallelLinear,
|
44
49
|
)
|
45
50
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
46
|
-
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
51
|
+
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE
|
52
|
+
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
47
53
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
54
|
+
from sglang.srt.layers.moe.topk import select_experts
|
48
55
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
49
56
|
from sglang.srt.layers.quantization.fp8_utils import (
|
50
57
|
block_quant_to_tensor_quant,
|
@@ -60,15 +67,21 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
60
67
|
ParallelLMHead,
|
61
68
|
VocabParallelEmbedding,
|
62
69
|
)
|
70
|
+
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
63
71
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
64
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
72
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
65
73
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
66
|
-
from sglang.srt.utils import add_prefix, is_cuda_available, is_hip
|
74
|
+
from sglang.srt.utils import add_prefix, is_cuda, is_cuda_available, is_hip
|
67
75
|
|
68
76
|
_is_hip = is_hip()
|
77
|
+
_is_cuda = is_cuda()
|
78
|
+
|
79
|
+
if _is_cuda:
|
80
|
+
from sgl_kernel import awq_dequantize, bmm_fp8
|
81
|
+
else:
|
82
|
+
from vllm import _custom_ops as ops
|
69
83
|
|
70
|
-
|
71
|
-
from sgl_kernel import bmm_fp8
|
84
|
+
expert_distribution_recorder = ExpertDistributionRecorder()
|
72
85
|
|
73
86
|
|
74
87
|
class DeepseekV2MLP(nn.Module):
|
@@ -80,6 +93,8 @@ class DeepseekV2MLP(nn.Module):
|
|
80
93
|
quant_config: Optional[QuantizationConfig] = None,
|
81
94
|
reduce_results: bool = True,
|
82
95
|
prefix: str = "",
|
96
|
+
tp_rank: Optional[int] = None,
|
97
|
+
tp_size: Optional[int] = None,
|
83
98
|
) -> None:
|
84
99
|
super().__init__()
|
85
100
|
self.gate_up_proj = MergedColumnParallelLinear(
|
@@ -88,6 +103,8 @@ class DeepseekV2MLP(nn.Module):
|
|
88
103
|
bias=False,
|
89
104
|
quant_config=quant_config,
|
90
105
|
prefix=add_prefix("gate_up_proj", prefix),
|
106
|
+
tp_rank=tp_rank,
|
107
|
+
tp_size=tp_size,
|
91
108
|
)
|
92
109
|
self.down_proj = RowParallelLinear(
|
93
110
|
intermediate_size,
|
@@ -96,6 +113,8 @@ class DeepseekV2MLP(nn.Module):
|
|
96
113
|
quant_config=quant_config,
|
97
114
|
reduce_results=reduce_results,
|
98
115
|
prefix=add_prefix("down_proj", prefix),
|
116
|
+
tp_rank=tp_rank,
|
117
|
+
tp_size=tp_size,
|
99
118
|
)
|
100
119
|
if hidden_act != "silu":
|
101
120
|
raise ValueError(
|
@@ -160,7 +179,11 @@ class DeepseekV2MoE(nn.Module):
|
|
160
179
|
|
161
180
|
self.gate = MoEGate(config=config, prefix=add_prefix("gate", prefix))
|
162
181
|
|
163
|
-
MoEImpl =
|
182
|
+
MoEImpl = (
|
183
|
+
DeepEPMoE
|
184
|
+
if global_server_args_dict["enable_deepep_moe"]
|
185
|
+
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
|
186
|
+
)
|
164
187
|
self.experts = MoEImpl(
|
165
188
|
num_experts=config.n_routed_experts,
|
166
189
|
top_k=config.num_experts_per_tok,
|
@@ -177,18 +200,60 @@ class DeepseekV2MoE(nn.Module):
|
|
177
200
|
|
178
201
|
if config.n_shared_experts is not None:
|
179
202
|
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
180
|
-
|
203
|
+
# disable tp for shared experts when enable deepep moe
|
204
|
+
if not global_server_args_dict["enable_deepep_moe"]:
|
205
|
+
self.shared_experts = DeepseekV2MLP(
|
206
|
+
hidden_size=config.hidden_size,
|
207
|
+
intermediate_size=intermediate_size,
|
208
|
+
hidden_act=config.hidden_act,
|
209
|
+
quant_config=quant_config,
|
210
|
+
reduce_results=False,
|
211
|
+
prefix=add_prefix("shared_experts", prefix),
|
212
|
+
)
|
213
|
+
else:
|
214
|
+
self.shared_experts = DeepseekV2MLP(
|
215
|
+
hidden_size=config.hidden_size,
|
216
|
+
intermediate_size=intermediate_size,
|
217
|
+
hidden_act=config.hidden_act,
|
218
|
+
quant_config=quant_config,
|
219
|
+
reduce_results=False,
|
220
|
+
prefix=add_prefix("shared_experts", prefix),
|
221
|
+
tp_rank=0,
|
222
|
+
tp_size=1,
|
223
|
+
)
|
224
|
+
|
225
|
+
if global_server_args_dict["enable_deepep_moe"]:
|
226
|
+
self.num_experts = config.n_routed_experts
|
227
|
+
self.top_k = config.num_experts_per_tok
|
228
|
+
self.renormalize = config.norm_topk_prob
|
229
|
+
self.topk_group = config.topk_group
|
230
|
+
self.num_expert_group = config.n_group
|
231
|
+
self.correction_bias = (
|
232
|
+
self.gate.e_score_correction_bias.data
|
233
|
+
if self.gate.e_score_correction_bias is not None
|
234
|
+
else None
|
235
|
+
)
|
236
|
+
|
237
|
+
self.deepep_dispatcher = DeepEPDispatcher(
|
238
|
+
group=parallel_state.get_tp_group().device_group,
|
239
|
+
router_topk=self.top_k,
|
240
|
+
permute_fusion=True,
|
241
|
+
num_experts=config.n_routed_experts,
|
242
|
+
num_local_experts=config.n_routed_experts // self.tp_size,
|
181
243
|
hidden_size=config.hidden_size,
|
182
|
-
|
183
|
-
|
184
|
-
quant_config=quant_config,
|
185
|
-
reduce_results=False,
|
186
|
-
prefix=add_prefix("shared_experts", prefix),
|
244
|
+
params_dtype=config.torch_dtype,
|
245
|
+
async_finish=True, # TODO
|
187
246
|
)
|
188
247
|
|
189
|
-
def forward(
|
190
|
-
|
191
|
-
|
248
|
+
def forward(
|
249
|
+
self, hidden_states: torch.Tensor, forward_mode: Optional[ForwardMode] = None
|
250
|
+
) -> torch.Tensor:
|
251
|
+
if not global_server_args_dict["enable_deepep_moe"]:
|
252
|
+
return self.forward_normal(hidden_states)
|
253
|
+
else:
|
254
|
+
return self.forward_deepep(hidden_states, forward_mode)
|
255
|
+
|
256
|
+
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
192
257
|
if self.n_shared_experts is not None:
|
193
258
|
shared_output = self.shared_experts(hidden_states)
|
194
259
|
# router_logits: (num_tokens, n_experts)
|
@@ -201,8 +266,60 @@ class DeepseekV2MoE(nn.Module):
|
|
201
266
|
final_hidden_states = final_hidden_states + shared_output
|
202
267
|
if self.tp_size > 1:
|
203
268
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
269
|
+
return final_hidden_states
|
204
270
|
|
205
|
-
|
271
|
+
def forward_deepep(
|
272
|
+
self, hidden_states: torch.Tensor, forward_mode: ForwardMode
|
273
|
+
) -> torch.Tensor:
|
274
|
+
shared_output = None
|
275
|
+
topk_idx = torch.full(
|
276
|
+
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
|
277
|
+
)
|
278
|
+
topk_weights = torch.empty(
|
279
|
+
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
|
280
|
+
)
|
281
|
+
if forward_mode is not None and not forward_mode.is_idle():
|
282
|
+
# router_logits: (num_tokens, n_experts)
|
283
|
+
router_logits = self.gate(hidden_states)
|
284
|
+
if self.n_shared_experts is not None:
|
285
|
+
shared_output = self.shared_experts(hidden_states)
|
286
|
+
topk_weights, topk_idx = select_experts(
|
287
|
+
hidden_states=hidden_states,
|
288
|
+
router_logits=router_logits,
|
289
|
+
top_k=self.top_k,
|
290
|
+
use_grouped_topk=True,
|
291
|
+
renormalize=self.renormalize,
|
292
|
+
topk_group=self.topk_group,
|
293
|
+
num_expert_group=self.num_expert_group,
|
294
|
+
correction_bias=self.correction_bias,
|
295
|
+
)
|
296
|
+
if self.tp_size > 1:
|
297
|
+
recv_hidden_states, reorder_topk_ids, seg_indptr = (
|
298
|
+
self.deepep_dispatcher.dispatch(
|
299
|
+
hidden_states,
|
300
|
+
topk_idx,
|
301
|
+
topk_weights,
|
302
|
+
self.num_experts,
|
303
|
+
forward_mode,
|
304
|
+
)
|
305
|
+
)
|
306
|
+
final_hidden_states = (
|
307
|
+
self.experts(
|
308
|
+
hidden_states=recv_hidden_states,
|
309
|
+
reorder_topk_ids=reorder_topk_ids,
|
310
|
+
seg_indptr=seg_indptr,
|
311
|
+
forward_mode=forward_mode,
|
312
|
+
)
|
313
|
+
* self.routed_scaling_factor
|
314
|
+
)
|
315
|
+
if self.tp_size > 1:
|
316
|
+
final_hidden_states = self.deepep_dispatcher.combine(
|
317
|
+
final_hidden_states, forward_mode
|
318
|
+
)
|
319
|
+
if shared_output is not None:
|
320
|
+
final_hidden_states = final_hidden_states + shared_output
|
321
|
+
|
322
|
+
return final_hidden_states
|
206
323
|
|
207
324
|
|
208
325
|
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
@@ -230,6 +347,7 @@ class DeepseekV2Attention(nn.Module):
|
|
230
347
|
max_position_embeddings: int = 8192,
|
231
348
|
quant_config: Optional[QuantizationConfig] = None,
|
232
349
|
layer_id=None,
|
350
|
+
reduce_results: bool = True,
|
233
351
|
prefix: str = "",
|
234
352
|
) -> None:
|
235
353
|
super().__init__()
|
@@ -241,10 +359,14 @@ class DeepseekV2Attention(nn.Module):
|
|
241
359
|
self.v_head_dim = v_head_dim
|
242
360
|
self.q_lora_rank = q_lora_rank
|
243
361
|
self.kv_lora_rank = kv_lora_rank
|
362
|
+
|
363
|
+
self.dp_size = get_attention_dp_size()
|
364
|
+
attn_tp_rank = get_attention_tp_rank()
|
365
|
+
attn_tp_size = get_attention_tp_size()
|
366
|
+
|
244
367
|
self.num_heads = num_heads
|
245
|
-
|
246
|
-
|
247
|
-
self.num_local_heads = num_heads // tp_size
|
368
|
+
assert num_heads % attn_tp_size == 0
|
369
|
+
self.num_local_heads = num_heads // attn_tp_size
|
248
370
|
self.scaling = self.qk_head_dim**-0.5
|
249
371
|
self.rope_theta = rope_theta
|
250
372
|
self.max_position_embeddings = max_position_embeddings
|
@@ -272,6 +394,8 @@ class DeepseekV2Attention(nn.Module):
|
|
272
394
|
bias=False,
|
273
395
|
quant_config=quant_config,
|
274
396
|
prefix=add_prefix("q_proj", prefix),
|
397
|
+
tp_rank=attn_tp_rank,
|
398
|
+
tp_size=attn_tp_size,
|
275
399
|
)
|
276
400
|
|
277
401
|
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
@@ -296,6 +420,9 @@ class DeepseekV2Attention(nn.Module):
|
|
296
420
|
bias=False,
|
297
421
|
quant_config=quant_config,
|
298
422
|
prefix=add_prefix("o_proj", prefix),
|
423
|
+
reduce_results=reduce_results,
|
424
|
+
tp_rank=attn_tp_rank,
|
425
|
+
tp_size=attn_tp_size,
|
299
426
|
)
|
300
427
|
rope_scaling["rope_type"] = "deepseek_yarn"
|
301
428
|
self.rotary_emb = get_rope_wrapper(
|
@@ -330,6 +457,12 @@ class DeepseekV2Attention(nn.Module):
|
|
330
457
|
hidden_states: torch.Tensor,
|
331
458
|
forward_batch: ForwardBatch,
|
332
459
|
) -> torch.Tensor:
|
460
|
+
if hidden_states.shape[0] == 0:
|
461
|
+
assert (
|
462
|
+
not self.o_proj.reduce_results
|
463
|
+
), "short-circuiting allreduce will lead to hangs"
|
464
|
+
return hidden_states
|
465
|
+
|
333
466
|
if self.q_lora_rank is not None:
|
334
467
|
q = self.q_a_proj(hidden_states)[0]
|
335
468
|
q = self.q_a_layernorm(q)
|
@@ -385,8 +518,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
385
518
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
386
519
|
max_position_embeddings: int = 8192,
|
387
520
|
quant_config: Optional[QuantizationConfig] = None,
|
388
|
-
|
389
|
-
|
521
|
+
reduce_results: bool = True,
|
522
|
+
layer_id: int = None,
|
390
523
|
prefix: str = "",
|
391
524
|
) -> None:
|
392
525
|
super().__init__()
|
@@ -398,96 +531,66 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
398
531
|
self.v_head_dim = v_head_dim
|
399
532
|
self.q_lora_rank = q_lora_rank
|
400
533
|
self.kv_lora_rank = kv_lora_rank
|
534
|
+
self.dp_size = get_attention_dp_size()
|
535
|
+
attn_tp_rank = get_attention_tp_rank()
|
536
|
+
attn_tp_size = get_attention_tp_size()
|
537
|
+
|
401
538
|
self.num_heads = num_heads
|
402
|
-
|
403
|
-
|
404
|
-
self.num_local_heads = num_heads if use_dp else num_heads // tp_size
|
539
|
+
assert num_heads % attn_tp_size == 0
|
540
|
+
self.num_local_heads = num_heads // attn_tp_size
|
405
541
|
self.scaling = self.qk_head_dim**-0.5
|
406
542
|
self.rope_theta = rope_theta
|
407
543
|
self.max_position_embeddings = max_position_embeddings
|
408
544
|
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
self.q_a_proj = ReplicatedLinear(
|
413
|
-
self.hidden_size,
|
414
|
-
self.q_lora_rank,
|
415
|
-
bias=False,
|
416
|
-
quant_config=quant_config,
|
417
|
-
prefix=add_prefix("q_a_proj", prefix),
|
418
|
-
)
|
419
|
-
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
420
|
-
self.q_b_proj = ReplicatedLinear(
|
421
|
-
q_lora_rank,
|
422
|
-
self.num_heads * self.qk_head_dim,
|
423
|
-
bias=False,
|
424
|
-
quant_config=quant_config,
|
425
|
-
prefix=add_prefix("q_b_proj", prefix),
|
426
|
-
)
|
427
|
-
else:
|
428
|
-
self.q_proj = ReplicatedLinear(
|
429
|
-
self.hidden_size,
|
430
|
-
self.num_heads * self.qk_head_dim,
|
431
|
-
bias=False,
|
432
|
-
quant_config=quant_config,
|
433
|
-
prefix=add_prefix("q_proj", prefix),
|
434
|
-
)
|
435
|
-
self.kv_b_proj = ReplicatedLinear(
|
436
|
-
self.kv_lora_rank,
|
437
|
-
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
438
|
-
bias=False,
|
439
|
-
quant_config=quant_config,
|
440
|
-
prefix=add_prefix("kv_b_proj", prefix),
|
441
|
-
)
|
442
|
-
# O projection.
|
443
|
-
self.o_proj = ReplicatedLinear(
|
444
|
-
self.num_heads * self.v_head_dim,
|
545
|
+
# For tensor parallel attention
|
546
|
+
if self.q_lora_rank is not None:
|
547
|
+
self.q_a_proj = ReplicatedLinear(
|
445
548
|
self.hidden_size,
|
549
|
+
self.q_lora_rank,
|
446
550
|
bias=False,
|
447
551
|
quant_config=quant_config,
|
448
|
-
prefix=add_prefix("
|
552
|
+
prefix=add_prefix("q_a_proj", prefix),
|
449
553
|
)
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
self.
|
454
|
-
self.hidden_size,
|
455
|
-
self.q_lora_rank,
|
456
|
-
bias=False,
|
457
|
-
quant_config=quant_config,
|
458
|
-
prefix=add_prefix("q_a_proj", prefix),
|
459
|
-
)
|
460
|
-
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
461
|
-
self.q_b_proj = ColumnParallelLinear(
|
462
|
-
q_lora_rank,
|
463
|
-
self.num_heads * self.qk_head_dim,
|
464
|
-
bias=False,
|
465
|
-
quant_config=quant_config,
|
466
|
-
prefix=add_prefix("q_b_proj", prefix),
|
467
|
-
)
|
468
|
-
else:
|
469
|
-
self.q_proj = ColumnParallelLinear(
|
470
|
-
self.hidden_size,
|
471
|
-
self.num_heads * self.qk_head_dim,
|
472
|
-
bias=False,
|
473
|
-
quant_config=quant_config,
|
474
|
-
prefix=add_prefix("q_proj", prefix),
|
475
|
-
)
|
476
|
-
self.kv_b_proj = ColumnParallelLinear(
|
477
|
-
self.kv_lora_rank,
|
478
|
-
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
554
|
+
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
555
|
+
self.q_b_proj = ColumnParallelLinear(
|
556
|
+
q_lora_rank,
|
557
|
+
self.num_heads * self.qk_head_dim,
|
479
558
|
bias=False,
|
480
559
|
quant_config=quant_config,
|
481
|
-
prefix=add_prefix("
|
560
|
+
prefix=add_prefix("q_b_proj", prefix),
|
561
|
+
tp_rank=attn_tp_rank,
|
562
|
+
tp_size=attn_tp_size,
|
482
563
|
)
|
483
|
-
|
484
|
-
self.
|
485
|
-
self.num_heads * self.v_head_dim,
|
564
|
+
else:
|
565
|
+
self.q_proj = ColumnParallelLinear(
|
486
566
|
self.hidden_size,
|
567
|
+
self.num_heads * self.qk_head_dim,
|
487
568
|
bias=False,
|
488
569
|
quant_config=quant_config,
|
489
|
-
prefix=add_prefix("
|
570
|
+
prefix=add_prefix("q_proj", prefix),
|
571
|
+
tp_rank=attn_tp_rank,
|
572
|
+
tp_size=attn_tp_size,
|
490
573
|
)
|
574
|
+
self.kv_b_proj = ColumnParallelLinear(
|
575
|
+
self.kv_lora_rank,
|
576
|
+
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
577
|
+
bias=False,
|
578
|
+
quant_config=quant_config,
|
579
|
+
prefix=add_prefix("kv_b_proj", prefix),
|
580
|
+
tp_rank=attn_tp_rank,
|
581
|
+
tp_size=attn_tp_size,
|
582
|
+
)
|
583
|
+
# O projection.
|
584
|
+
self.o_proj = RowParallelLinear(
|
585
|
+
self.num_heads * self.v_head_dim,
|
586
|
+
self.hidden_size,
|
587
|
+
bias=False,
|
588
|
+
quant_config=quant_config,
|
589
|
+
reduce_results=reduce_results,
|
590
|
+
prefix=add_prefix("o_proj", prefix),
|
591
|
+
tp_rank=attn_tp_rank,
|
592
|
+
tp_size=attn_tp_size,
|
593
|
+
)
|
491
594
|
|
492
595
|
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
493
596
|
self.hidden_size,
|
@@ -542,38 +645,49 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
542
645
|
self.w_vc = None
|
543
646
|
self.w_scale = None
|
544
647
|
|
648
|
+
self.enable_flashinfer_mla = global_server_args_dict["enable_flashinfer_mla"]
|
649
|
+
self.flashinfer_mla_disable_ragged = global_server_args_dict[
|
650
|
+
"flashinfer_mla_disable_ragged"
|
651
|
+
]
|
652
|
+
self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
|
653
|
+
|
654
|
+
def no_absorb(self, forward_batch: ForwardBatch) -> bool:
|
655
|
+
if self.enable_flashinfer_mla:
|
656
|
+
# Flashinfer MLA: Do not absorb when enabling ragged prefill
|
657
|
+
return (
|
658
|
+
not self.flashinfer_mla_disable_ragged
|
659
|
+
and forward_batch.forward_mode.is_extend()
|
660
|
+
and not forward_batch.forward_mode.is_target_verify()
|
661
|
+
and not forward_batch.forward_mode.is_draft_extend()
|
662
|
+
and sum(forward_batch.extend_prefix_lens_cpu) == 0
|
663
|
+
)
|
664
|
+
else:
|
665
|
+
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
|
666
|
+
return (
|
667
|
+
forward_batch.forward_mode.is_extend()
|
668
|
+
and not forward_batch.forward_mode.is_target_verify()
|
669
|
+
and not forward_batch.forward_mode.is_draft_extend()
|
670
|
+
and sum(forward_batch.extend_prefix_lens_cpu) == 0
|
671
|
+
)
|
672
|
+
|
545
673
|
def forward(
|
546
674
|
self,
|
547
675
|
positions: torch.Tensor,
|
548
676
|
hidden_states: torch.Tensor,
|
549
677
|
forward_batch: ForwardBatch,
|
550
678
|
) -> torch.Tensor:
|
679
|
+
if hidden_states.shape[0] == 0:
|
680
|
+
assert (
|
681
|
+
not self.o_proj.reduce_results
|
682
|
+
), "short-circuiting allreduce will lead to hangs"
|
683
|
+
return hidden_states
|
551
684
|
|
552
|
-
|
553
|
-
if global_server_args_dict["enable_flashinfer_mla"]:
|
554
|
-
# Flashinfer MLA: Do not absorb when enabling ragged prefill
|
555
|
-
return (
|
556
|
-
not global_server_args_dict["flashinfer_mla_disable_ragged"]
|
557
|
-
and forward_batch.forward_mode.is_extend()
|
558
|
-
and not forward_batch.forward_mode.is_target_verify()
|
559
|
-
and not forward_batch.forward_mode.is_draft_extend()
|
560
|
-
and forward_batch.extend_prefix_lens.sum() == 0
|
561
|
-
)
|
562
|
-
else:
|
563
|
-
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
|
564
|
-
return (
|
565
|
-
forward_batch.forward_mode.is_extend()
|
566
|
-
and not forward_batch.forward_mode.is_target_verify()
|
567
|
-
and not forward_batch.forward_mode.is_draft_extend()
|
568
|
-
and forward_batch.extend_prefix_lens.sum() == 0
|
569
|
-
)
|
570
|
-
|
571
|
-
if no_absorb():
|
685
|
+
if self.no_absorb(forward_batch):
|
572
686
|
return self.forward_normal(positions, hidden_states, forward_batch)
|
573
687
|
else:
|
574
688
|
if _is_hip:
|
575
689
|
if (
|
576
|
-
|
690
|
+
self.rocm_fused_decode_mla
|
577
691
|
and forward_batch.forward_mode.is_decode()
|
578
692
|
):
|
579
693
|
return self.forward_absorb_fused_mla_rope(
|
@@ -845,34 +959,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
845
959
|
return output
|
846
960
|
|
847
961
|
|
848
|
-
def all_gather(
|
849
|
-
input_tensor: torch.Tensor, forward_batch: ForwardBatch, rank, world_size, group
|
850
|
-
):
|
851
|
-
all_lens = forward_batch.global_num_tokens_cpu
|
852
|
-
max_len = max(forward_batch.global_num_tokens_cpu)
|
853
|
-
|
854
|
-
if world_size == 1:
|
855
|
-
return input_tensor, 0, all_lens[0]
|
856
|
-
|
857
|
-
padded_tensor = torch.nn.functional.pad(
|
858
|
-
input_tensor, (0, 0, 0, max_len - input_tensor.shape[0])
|
859
|
-
)
|
860
|
-
|
861
|
-
group.all_gather_into_tensor(forward_batch.gathered_buffer, padded_tensor)
|
862
|
-
|
863
|
-
gathered_tensors = torch.concat(
|
864
|
-
[
|
865
|
-
forward_batch.gathered_buffer[i * max_len : i * max_len + all_lens[i]]
|
866
|
-
for i in range(world_size)
|
867
|
-
]
|
868
|
-
)
|
869
|
-
|
870
|
-
start_index = 0 if rank == 0 else sum(all_lens[:rank])
|
871
|
-
end_index = start_index + all_lens[rank]
|
872
|
-
|
873
|
-
return gathered_tensors, start_index, end_index
|
874
|
-
|
875
|
-
|
876
962
|
class DeepseekV2DecoderLayer(nn.Module):
|
877
963
|
|
878
964
|
def __init__(
|
@@ -888,14 +974,10 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
888
974
|
rope_theta = getattr(config, "rope_theta", 10000)
|
889
975
|
rope_scaling = getattr(config, "rope_scaling", None)
|
890
976
|
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
891
|
-
self.enable_dp_attention =
|
892
|
-
|
893
|
-
|
894
|
-
|
895
|
-
if self.enable_dp_attention:
|
896
|
-
self.tp_rank = get_tensor_model_parallel_rank()
|
897
|
-
self.tp_size = get_tensor_model_parallel_world_size()
|
898
|
-
self.tp_group = get_tp_group()
|
977
|
+
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
|
978
|
+
self.layer_id = layer_id
|
979
|
+
self.dp_size = get_attention_dp_size()
|
980
|
+
|
899
981
|
if not global_server_args_dict["disable_mla"]:
|
900
982
|
self.self_attn = DeepseekV2AttentionMLA(
|
901
983
|
config=config,
|
@@ -913,7 +995,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
913
995
|
max_position_embeddings=max_position_embeddings,
|
914
996
|
quant_config=quant_config,
|
915
997
|
layer_id=layer_id,
|
916
|
-
|
998
|
+
reduce_results=False,
|
917
999
|
prefix=add_prefix("self_attn", prefix),
|
918
1000
|
)
|
919
1001
|
else:
|
@@ -933,8 +1015,10 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
933
1015
|
max_position_embeddings=max_position_embeddings,
|
934
1016
|
quant_config=quant_config,
|
935
1017
|
layer_id=layer_id,
|
1018
|
+
reduce_results=False,
|
936
1019
|
prefix=add_prefix("self_attn", prefix),
|
937
1020
|
)
|
1021
|
+
|
938
1022
|
if is_nextn or (
|
939
1023
|
config.n_routed_experts is not None
|
940
1024
|
and layer_id >= config.first_k_dense_replace
|
@@ -965,32 +1049,67 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
965
1049
|
forward_batch: ForwardBatch,
|
966
1050
|
residual: Optional[torch.Tensor],
|
967
1051
|
) -> torch.Tensor:
|
968
|
-
|
969
|
-
|
1052
|
+
if hidden_states.shape[0] == 0:
|
1053
|
+
residual = hidden_states
|
1054
|
+
else:
|
970
1055
|
if residual is None:
|
971
1056
|
residual = hidden_states
|
972
1057
|
hidden_states = self.input_layernorm(hidden_states)
|
973
1058
|
else:
|
974
1059
|
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
975
1060
|
|
1061
|
+
# Self Attention
|
976
1062
|
hidden_states = self.self_attn(
|
977
1063
|
positions=positions,
|
978
1064
|
hidden_states=hidden_states,
|
979
1065
|
forward_batch=forward_batch,
|
980
1066
|
)
|
1067
|
+
|
1068
|
+
# Gather
|
1069
|
+
if get_tensor_model_parallel_world_size() > 1:
|
1070
|
+
# all gather and all reduce
|
1071
|
+
if self.dp_size != 1:
|
1072
|
+
if global_server_args_dict["enable_deepep_moe"] and isinstance(
|
1073
|
+
self.mlp, DeepseekV2MoE
|
1074
|
+
):
|
1075
|
+
if hidden_states.shape[0] != 0:
|
1076
|
+
hidden_states, residual = self.post_attention_layernorm(
|
1077
|
+
hidden_states, residual
|
1078
|
+
)
|
1079
|
+
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
|
1080
|
+
return hidden_states, residual
|
1081
|
+
else:
|
1082
|
+
if get_attention_tp_rank() == 0:
|
1083
|
+
hidden_states += residual
|
1084
|
+
hidden_states, local_hidden_states = (
|
1085
|
+
forward_batch.gathered_buffer,
|
1086
|
+
hidden_states,
|
1087
|
+
)
|
1088
|
+
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
|
1089
|
+
dp_scatter(residual, hidden_states, forward_batch)
|
1090
|
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
1091
|
+
else:
|
1092
|
+
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
1093
|
+
hidden_states, residual = self.post_attention_layernorm(
|
1094
|
+
hidden_states, residual
|
1095
|
+
)
|
1096
|
+
else:
|
981
1097
|
hidden_states, residual = self.post_attention_layernorm(
|
982
1098
|
hidden_states, residual
|
983
1099
|
)
|
984
1100
|
|
985
1101
|
# Fully Connected
|
986
|
-
|
987
|
-
|
988
|
-
|
1102
|
+
hidden_states = self.mlp(hidden_states)
|
1103
|
+
|
1104
|
+
# Scatter
|
1105
|
+
if self.dp_size != 1:
|
1106
|
+
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
1107
|
+
# be careful about this!
|
1108
|
+
hidden_states, global_hidden_states = (
|
1109
|
+
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
1110
|
+
hidden_states,
|
989
1111
|
)
|
990
|
-
hidden_states
|
991
|
-
hidden_states = hidden_states[start_idx:end_idx]
|
992
|
-
else:
|
993
|
-
hidden_states = self.mlp(hidden_states)
|
1112
|
+
dp_scatter(hidden_states, global_hidden_states, forward_batch)
|
994
1113
|
|
995
1114
|
return hidden_states, residual
|
996
1115
|
|
@@ -1027,15 +1146,24 @@ class DeepseekV2Model(nn.Module):
|
|
1027
1146
|
)
|
1028
1147
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
1029
1148
|
|
1149
|
+
self.dp_size = get_attention_dp_size()
|
1150
|
+
|
1030
1151
|
def forward(
|
1031
1152
|
self,
|
1032
1153
|
input_ids: torch.Tensor,
|
1033
1154
|
positions: torch.Tensor,
|
1034
1155
|
forward_batch: ForwardBatch,
|
1156
|
+
input_embeds: torch.Tensor = None,
|
1035
1157
|
) -> torch.Tensor:
|
1036
|
-
|
1158
|
+
|
1159
|
+
if input_embeds is None:
|
1160
|
+
hidden_states = self.embed_tokens(input_ids)
|
1161
|
+
else:
|
1162
|
+
hidden_states = input_embeds
|
1163
|
+
|
1037
1164
|
residual = None
|
1038
1165
|
for i in range(len(self.layers)):
|
1166
|
+
expert_distribution_recorder.set_current_layer(i)
|
1039
1167
|
layer = self.layers[i]
|
1040
1168
|
hidden_states, residual = layer(
|
1041
1169
|
positions, hidden_states, forward_batch, residual
|
@@ -1059,22 +1187,14 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1059
1187
|
self.model = DeepseekV2Model(
|
1060
1188
|
config, quant_config, prefix=add_prefix("model", prefix)
|
1061
1189
|
)
|
1062
|
-
|
1063
|
-
|
1064
|
-
|
1065
|
-
|
1066
|
-
|
1067
|
-
|
1068
|
-
|
1069
|
-
|
1070
|
-
else:
|
1071
|
-
self.lm_head = ParallelLMHead(
|
1072
|
-
config.vocab_size,
|
1073
|
-
config.hidden_size,
|
1074
|
-
quant_config=quant_config,
|
1075
|
-
prefix=add_prefix("lm_head", prefix),
|
1076
|
-
)
|
1077
|
-
self.logits_processor = LogitsProcessor(config)
|
1190
|
+
self.lm_head = ParallelLMHead(
|
1191
|
+
config.vocab_size,
|
1192
|
+
config.hidden_size,
|
1193
|
+
quant_config=quant_config,
|
1194
|
+
prefix=add_prefix("lm_head", prefix),
|
1195
|
+
)
|
1196
|
+
self.logits_processor = LogitsProcessor(config)
|
1197
|
+
self.dp_size = get_attention_dp_size()
|
1078
1198
|
|
1079
1199
|
@torch.no_grad()
|
1080
1200
|
def forward(
|
@@ -1082,8 +1202,11 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1082
1202
|
input_ids: torch.Tensor,
|
1083
1203
|
positions: torch.Tensor,
|
1084
1204
|
forward_batch: ForwardBatch,
|
1205
|
+
input_embeds: torch.Tensor = None,
|
1085
1206
|
) -> torch.Tensor:
|
1086
|
-
|
1207
|
+
|
1208
|
+
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
1209
|
+
|
1087
1210
|
return self.logits_processor(
|
1088
1211
|
input_ids, hidden_states, self.lm_head, forward_batch
|
1089
1212
|
)
|
@@ -1097,7 +1220,11 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1097
1220
|
|
1098
1221
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
1099
1222
|
# (param_name, weight_name, expert_id, shard_id)
|
1100
|
-
MoEImpl =
|
1223
|
+
MoEImpl = (
|
1224
|
+
DeepEPMoE
|
1225
|
+
if global_server_args_dict["enable_deepep_moe"]
|
1226
|
+
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
|
1227
|
+
)
|
1101
1228
|
expert_params_mapping = MoEImpl.make_expert_params_mapping(
|
1102
1229
|
ckpt_gate_proj_name="gate_proj",
|
1103
1230
|
ckpt_down_proj_name="down_proj",
|
@@ -1171,14 +1298,21 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1171
1298
|
self_attn = self.model.layers[layer_id].self_attn
|
1172
1299
|
if hasattr(self_attn.kv_b_proj, "qweight"):
|
1173
1300
|
# AWQ compatible
|
1174
|
-
|
1175
|
-
|
1176
|
-
|
1177
|
-
|
1178
|
-
|
1179
|
-
|
1180
|
-
|
1181
|
-
|
1301
|
+
if _is_cuda:
|
1302
|
+
w = awq_dequantize(
|
1303
|
+
self_attn.kv_b_proj.qweight,
|
1304
|
+
self_attn.kv_b_proj.scales,
|
1305
|
+
self_attn.kv_b_proj.qzeros,
|
1306
|
+
).T
|
1307
|
+
else:
|
1308
|
+
w = ops.awq_dequantize(
|
1309
|
+
self_attn.kv_b_proj.qweight,
|
1310
|
+
self_attn.kv_b_proj.scales,
|
1311
|
+
self_attn.kv_b_proj.qzeros,
|
1312
|
+
0,
|
1313
|
+
0,
|
1314
|
+
0,
|
1315
|
+
).T
|
1182
1316
|
else:
|
1183
1317
|
w = self_attn.kv_b_proj.weight
|
1184
1318
|
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
|