sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +10 -8
- sglang/bench_one_batch.py +7 -6
- sglang/bench_one_batch_server.py +157 -21
- sglang/bench_serving.py +137 -59
- sglang/compile_deep_gemm.py +5 -5
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +78 -78
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +2 -2
- sglang/srt/configs/model_config.py +40 -28
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -43
- sglang/srt/conversation.py +49 -44
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +129 -135
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
- sglang/srt/disaggregation/fake/conn.py +3 -13
- sglang/srt/disaggregation/kv_events.py +357 -0
- sglang/srt/disaggregation/mini_lb.py +57 -24
- sglang/srt/disaggregation/mooncake/conn.py +238 -122
- sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
- sglang/srt/disaggregation/nixl/conn.py +10 -19
- sglang/srt/disaggregation/prefill.py +132 -47
- sglang/srt/disaggregation/utils.py +123 -6
- sglang/srt/distributed/utils.py +3 -3
- sglang/srt/entrypoints/EngineBase.py +5 -0
- sglang/srt/entrypoints/engine.py +44 -9
- sglang/srt/entrypoints/http_server.py +23 -6
- sglang/srt/entrypoints/http_server_engine.py +5 -2
- sglang/srt/function_call/base_format_detector.py +250 -0
- sglang/srt/function_call/core_types.py +34 -0
- sglang/srt/function_call/deepseekv3_detector.py +157 -0
- sglang/srt/function_call/ebnf_composer.py +234 -0
- sglang/srt/function_call/function_call_parser.py +175 -0
- sglang/srt/function_call/llama32_detector.py +74 -0
- sglang/srt/function_call/mistral_detector.py +84 -0
- sglang/srt/function_call/pythonic_detector.py +163 -0
- sglang/srt/function_call/qwen25_detector.py +67 -0
- sglang/srt/function_call/utils.py +35 -0
- sglang/srt/hf_transformers_utils.py +46 -7
- sglang/srt/layers/attention/aiter_backend.py +513 -0
- sglang/srt/layers/attention/flashattention_backend.py +64 -18
- sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
- sglang/srt/layers/attention/flashmla_backend.py +340 -78
- sglang/srt/layers/attention/triton_backend.py +3 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/utils.py +6 -4
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/communicator.py +451 -0
- sglang/srt/layers/dp_attention.py +61 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/cutlass_moe.py +207 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
- sglang/srt/layers/moe/ep_moe/layer.py +105 -51
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
- sglang/srt/layers/moe/topk.py +67 -10
- sglang/srt/layers/multimodal.py +70 -0
- sglang/srt/layers/quantization/__init__.py +8 -3
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/deep_gemm.py +77 -74
- sglang/srt/layers/quantization/fp8.py +92 -2
- sglang/srt/layers/quantization/fp8_kernel.py +3 -3
- sglang/srt/layers/quantization/fp8_utils.py +6 -0
- sglang/srt/layers/quantization/gptq.py +298 -6
- sglang/srt/layers/quantization/int8_kernel.py +20 -7
- sglang/srt/layers/quantization/qoq.py +244 -0
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +2 -4
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/deepseek_eplb.py +278 -0
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/eplb_manager.py +55 -0
- sglang/srt/managers/expert_distribution.py +704 -56
- sglang/srt/managers/expert_location.py +394 -0
- sglang/srt/managers/expert_location_dispatch.py +91 -0
- sglang/srt/managers/io_struct.py +19 -4
- sglang/srt/managers/mm_utils.py +294 -140
- sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
- sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
- sglang/srt/managers/multimodal_processors/internvl.py +14 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
- sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
- sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
- sglang/srt/managers/schedule_batch.py +122 -42
- sglang/srt/managers/schedule_policy.py +1 -5
- sglang/srt/managers/scheduler.py +205 -138
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +232 -58
- sglang/srt/managers/tp_worker.py +12 -9
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/base_prefix_cache.py +3 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -1
- sglang/srt/mem_cache/hiradix_cache.py +4 -4
- sglang/srt/mem_cache/memory_pool.py +76 -52
- sglang/srt/mem_cache/multimodal_cache.py +45 -0
- sglang/srt/mem_cache/radix_cache.py +58 -5
- sglang/srt/metrics/collector.py +314 -39
- sglang/srt/mm_utils.py +10 -0
- sglang/srt/model_executor/cuda_graph_runner.py +29 -19
- sglang/srt/model_executor/expert_location_updater.py +422 -0
- sglang/srt/model_executor/forward_batch_info.py +5 -1
- sglang/srt/model_executor/model_runner.py +163 -68
- sglang/srt/model_loader/loader.py +10 -6
- sglang/srt/models/clip.py +5 -1
- sglang/srt/models/deepseek_janus_pro.py +2 -2
- sglang/srt/models/deepseek_v2.py +308 -351
- sglang/srt/models/exaone.py +8 -3
- sglang/srt/models/gemma3_mm.py +70 -33
- sglang/srt/models/llama.py +2 -0
- sglang/srt/models/llama4.py +15 -8
- sglang/srt/models/llava.py +258 -7
- sglang/srt/models/mimo_mtp.py +220 -0
- sglang/srt/models/minicpmo.py +5 -12
- sglang/srt/models/mistral.py +71 -1
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/mllama.py +3 -3
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/qwen2.py +95 -26
- sglang/srt/models/qwen2_5_vl.py +8 -0
- sglang/srt/models/qwen2_moe.py +330 -60
- sglang/srt/models/qwen2_vl.py +6 -0
- sglang/srt/models/qwen3.py +52 -10
- sglang/srt/models/qwen3_moe.py +411 -48
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/siglip.py +294 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/openai_api/adapter.py +58 -20
- sglang/srt/openai_api/protocol.py +6 -8
- sglang/srt/operations.py +154 -0
- sglang/srt/operations_strategy.py +31 -0
- sglang/srt/reasoning_parser.py +3 -3
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +4 -56
- sglang/srt/sampling/sampling_params.py +2 -2
- sglang/srt/server_args.py +162 -22
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +138 -7
- sglang/srt/speculative/eagle_worker.py +69 -21
- sglang/srt/utils.py +74 -17
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_cutlass_moe.py +278 -0
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +55 -14
- sglang/utils.py +3 -3
- sglang/version.py +1 -1
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
- sglang/srt/function_call_parser.py +0 -858
- sglang/srt/platforms/interface.py +0 -371
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -18,8 +18,7 @@
|
|
18
18
|
|
19
19
|
import logging
|
20
20
|
import os
|
21
|
-
from
|
22
|
-
from enum import Enum, IntEnum, auto
|
21
|
+
from enum import IntEnum, auto
|
23
22
|
from typing import Any, Dict, Iterable, Optional, Tuple
|
24
23
|
|
25
24
|
import torch
|
@@ -29,20 +28,20 @@ from tqdm import tqdm
|
|
29
28
|
from transformers import PretrainedConfig
|
30
29
|
|
31
30
|
from sglang.srt.distributed import (
|
32
|
-
get_tensor_model_parallel_rank,
|
33
31
|
get_tensor_model_parallel_world_size,
|
34
32
|
parallel_state,
|
35
33
|
tensor_model_parallel_all_reduce,
|
36
34
|
)
|
37
35
|
from sglang.srt.layers.activation import SiluAndMul
|
36
|
+
from sglang.srt.layers.communicator import (
|
37
|
+
LayerCommunicator,
|
38
|
+
LayerScatterModes,
|
39
|
+
enable_moe_dense_fully_dp,
|
40
|
+
)
|
38
41
|
from sglang.srt.layers.dp_attention import (
|
39
|
-
dp_gather_partial,
|
40
|
-
dp_scatter,
|
41
|
-
get_attention_dp_size,
|
42
42
|
get_attention_tp_rank,
|
43
43
|
get_attention_tp_size,
|
44
|
-
|
45
|
-
tp_reduce_scatter,
|
44
|
+
get_local_attention_dp_size,
|
46
45
|
)
|
47
46
|
from sglang.srt.layers.layernorm import RMSNorm
|
48
47
|
from sglang.srt.layers.linear import (
|
@@ -52,9 +51,8 @@ from sglang.srt.layers.linear import (
|
|
52
51
|
RowParallelLinear,
|
53
52
|
)
|
54
53
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
55
|
-
from sglang.srt.layers.moe.ep_moe.layer import
|
54
|
+
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
56
55
|
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
57
|
-
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
58
56
|
from sglang.srt.layers.moe.topk import select_experts
|
59
57
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
60
58
|
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
@@ -72,15 +70,21 @@ from sglang.srt.layers.quantization.int8_utils import (
|
|
72
70
|
block_dequant as int8_block_dequant,
|
73
71
|
)
|
74
72
|
from sglang.srt.layers.radix_attention import RadixAttention
|
75
|
-
from sglang.srt.layers.rotary_embedding import get_rope
|
73
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
76
74
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
77
75
|
ParallelLMHead,
|
78
76
|
VocabParallelEmbedding,
|
79
77
|
)
|
80
|
-
from sglang.srt.managers.expert_distribution import
|
78
|
+
from sglang.srt.managers.expert_distribution import (
|
79
|
+
get_global_expert_distribution_recorder,
|
80
|
+
)
|
81
|
+
from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
|
82
|
+
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
|
81
83
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
82
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
84
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
83
85
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
86
|
+
from sglang.srt.operations import execute_operations
|
87
|
+
from sglang.srt.operations_strategy import compute_layer_operations
|
84
88
|
from sglang.srt.utils import (
|
85
89
|
BumpAllocator,
|
86
90
|
DeepEPMode,
|
@@ -109,8 +113,6 @@ if _is_hip:
|
|
109
113
|
decode_attention_fwd_grouped_rope,
|
110
114
|
)
|
111
115
|
|
112
|
-
expert_distribution_recorder = ExpertDistributionRecorder()
|
113
|
-
|
114
116
|
logger = logging.getLogger(__name__)
|
115
117
|
|
116
118
|
|
@@ -125,6 +127,9 @@ class AttnForwardMethod(IntEnum):
|
|
125
127
|
# This method can avoid OOM when prefix lengths are long.
|
126
128
|
MHA_CHUNKED_KV = auto()
|
127
129
|
|
130
|
+
# Use MLA but with fused RoPE
|
131
|
+
MLA_FUSED_ROPE = auto()
|
132
|
+
|
128
133
|
|
129
134
|
class DeepseekV2MLP(nn.Module):
|
130
135
|
def __init__(
|
@@ -139,6 +144,8 @@ class DeepseekV2MLP(nn.Module):
|
|
139
144
|
tp_size: Optional[int] = None,
|
140
145
|
) -> None:
|
141
146
|
super().__init__()
|
147
|
+
self.tp_size = tp_size
|
148
|
+
|
142
149
|
self.gate_up_proj = MergedColumnParallelLinear(
|
143
150
|
hidden_size,
|
144
151
|
[intermediate_size] * 2,
|
@@ -165,7 +172,10 @@ class DeepseekV2MLP(nn.Module):
|
|
165
172
|
)
|
166
173
|
self.act_fn = SiluAndMul()
|
167
174
|
|
168
|
-
def forward(self, x,
|
175
|
+
def forward(self, x, forward_batch=None):
|
176
|
+
if (self.tp_size == 1) and x.shape[0] == 0:
|
177
|
+
return x
|
178
|
+
|
169
179
|
gate_up, _ = self.gate_up_proj(x)
|
170
180
|
x = self.act_fn(gate_up)
|
171
181
|
x, _ = self.down_proj(x)
|
@@ -194,11 +204,20 @@ class MoEGate(nn.Module):
|
|
194
204
|
return logits
|
195
205
|
|
196
206
|
|
207
|
+
def is_non_idle_and_non_empty(forward_mode, hidden_states):
|
208
|
+
return (
|
209
|
+
(forward_mode is not None)
|
210
|
+
and not forward_mode.is_idle()
|
211
|
+
and hidden_states.shape[0] > 0
|
212
|
+
)
|
213
|
+
|
214
|
+
|
197
215
|
class DeepseekV2MoE(nn.Module):
|
198
216
|
|
199
217
|
def __init__(
|
200
218
|
self,
|
201
219
|
config: PretrainedConfig,
|
220
|
+
layer_id: int,
|
202
221
|
quant_config: Optional[QuantizationConfig] = None,
|
203
222
|
prefix: str = "",
|
204
223
|
):
|
@@ -207,6 +226,7 @@ class DeepseekV2MoE(nn.Module):
|
|
207
226
|
self.routed_scaling_factor = config.routed_scaling_factor
|
208
227
|
self.n_shared_experts = config.n_shared_experts
|
209
228
|
self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
|
229
|
+
self.layer_id = layer_id
|
210
230
|
|
211
231
|
if self.tp_size > config.n_routed_experts:
|
212
232
|
raise ValueError(
|
@@ -222,17 +242,14 @@ class DeepseekV2MoE(nn.Module):
|
|
222
242
|
|
223
243
|
self.gate = MoEGate(config=config, prefix=add_prefix("gate", prefix))
|
224
244
|
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
)
|
230
|
-
|
231
|
-
self.experts = MoEImpl(
|
232
|
-
num_experts=config.n_routed_experts + self.n_share_experts_fusion,
|
245
|
+
self.experts = get_moe_impl_class()(
|
246
|
+
num_experts=config.n_routed_experts
|
247
|
+
+ self.n_share_experts_fusion
|
248
|
+
+ global_server_args_dict["ep_num_redundant_experts"],
|
233
249
|
top_k=config.num_experts_per_tok + min(self.n_share_experts_fusion, 1),
|
234
250
|
hidden_size=config.hidden_size,
|
235
251
|
intermediate_size=config.moe_intermediate_size,
|
252
|
+
layer_id=self.layer_id,
|
236
253
|
renormalize=config.norm_topk_prob,
|
237
254
|
quant_config=quant_config,
|
238
255
|
use_grouped_topk=True,
|
@@ -251,32 +268,29 @@ class DeepseekV2MoE(nn.Module):
|
|
251
268
|
if config.n_shared_experts is not None and self.n_share_experts_fusion == 0:
|
252
269
|
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
253
270
|
# disable tp for shared experts when enable deepep moe
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
reduce_results=False,
|
270
|
-
prefix=add_prefix("shared_experts", prefix),
|
271
|
-
tp_rank=0,
|
272
|
-
tp_size=1,
|
273
|
-
)
|
271
|
+
self.shared_experts = DeepseekV2MLP(
|
272
|
+
hidden_size=config.hidden_size,
|
273
|
+
intermediate_size=intermediate_size,
|
274
|
+
hidden_act=config.hidden_act,
|
275
|
+
quant_config=quant_config,
|
276
|
+
reduce_results=False,
|
277
|
+
prefix=add_prefix("shared_experts", prefix),
|
278
|
+
**(
|
279
|
+
dict(tp_rank=0, tp_size=1)
|
280
|
+
if global_server_args_dict["enable_deepep_moe"]
|
281
|
+
else {}
|
282
|
+
),
|
283
|
+
)
|
284
|
+
|
285
|
+
self.top_k = config.num_experts_per_tok
|
274
286
|
|
275
287
|
if global_server_args_dict["enable_deepep_moe"]:
|
276
288
|
# TODO: we will support tp < ep in the future
|
277
289
|
self.ep_size = get_tensor_model_parallel_world_size()
|
278
|
-
self.num_experts =
|
279
|
-
|
290
|
+
self.num_experts = (
|
291
|
+
config.n_routed_experts
|
292
|
+
+ global_server_args_dict["ep_num_redundant_experts"]
|
293
|
+
)
|
280
294
|
self.renormalize = config.norm_topk_prob
|
281
295
|
self.topk_group = config.topk_group
|
282
296
|
self.num_expert_group = config.n_group
|
@@ -290,7 +304,7 @@ class DeepseekV2MoE(nn.Module):
|
|
290
304
|
group=parallel_state.get_tp_group().device_group,
|
291
305
|
router_topk=self.top_k,
|
292
306
|
permute_fusion=True,
|
293
|
-
num_experts=
|
307
|
+
num_experts=self.num_experts,
|
294
308
|
num_local_experts=config.n_routed_experts // self.tp_size,
|
295
309
|
hidden_size=config.hidden_size,
|
296
310
|
params_dtype=config.torch_dtype,
|
@@ -299,105 +313,137 @@ class DeepseekV2MoE(nn.Module):
|
|
299
313
|
return_recv_hook=True,
|
300
314
|
)
|
301
315
|
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
if not global_server_args_dict["enable_deepep_moe"]:
|
306
|
-
return self.forward_normal(hidden_states)
|
307
|
-
else:
|
308
|
-
return self.forward_deepep(hidden_states, forward_mode)
|
316
|
+
@property
|
317
|
+
def _enable_deepep_moe(self):
|
318
|
+
return global_server_args_dict["enable_deepep_moe"]
|
309
319
|
|
310
|
-
def
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
* self.routed_scaling_factor
|
317
|
-
)
|
318
|
-
if shared_output is not None:
|
319
|
-
final_hidden_states = final_hidden_states + shared_output
|
320
|
-
if self.tp_size > 1:
|
321
|
-
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
322
|
-
return final_hidden_states
|
320
|
+
def get_moe_weights(self):
|
321
|
+
return [
|
322
|
+
x.data
|
323
|
+
for name, x in self.experts.named_parameters()
|
324
|
+
if name not in ["correction_bias"]
|
325
|
+
]
|
323
326
|
|
324
|
-
def
|
325
|
-
|
326
|
-
|
327
|
-
shared_output = None
|
328
|
-
if (
|
329
|
-
forward_mode is not None
|
330
|
-
and not forward_mode.is_idle()
|
331
|
-
and hidden_states.shape[0] > 0
|
327
|
+
def op_gate(self, state):
|
328
|
+
if (not self._enable_deepep_moe) or is_non_idle_and_non_empty(
|
329
|
+
state.forward_batch.forward_mode, state.hidden_states_mlp_input
|
332
330
|
):
|
333
331
|
# router_logits: (num_tokens, n_experts)
|
334
|
-
router_logits = self.gate(
|
335
|
-
shared_output = self._forward_shared_experts(hidden_states)
|
336
|
-
topk_weights, topk_idx = select_experts(
|
337
|
-
hidden_states=hidden_states,
|
338
|
-
router_logits=router_logits,
|
339
|
-
top_k=self.top_k,
|
340
|
-
use_grouped_topk=True,
|
341
|
-
renormalize=self.renormalize,
|
342
|
-
topk_group=self.topk_group,
|
343
|
-
num_expert_group=self.num_expert_group,
|
344
|
-
correction_bias=self.correction_bias,
|
345
|
-
routed_scaling_factor=self.routed_scaling_factor,
|
346
|
-
)
|
332
|
+
state.router_logits = self.gate(state.hidden_states_mlp_input)
|
347
333
|
else:
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
334
|
+
state.router_logits = None
|
335
|
+
|
336
|
+
def op_shared_experts(self, state):
|
337
|
+
if (self.n_share_experts_fusion == 0) and (
|
338
|
+
(not self._enable_deepep_moe)
|
339
|
+
or is_non_idle_and_non_empty(
|
340
|
+
state.forward_batch.forward_mode, state.hidden_states_mlp_input
|
353
341
|
)
|
354
|
-
|
342
|
+
):
|
343
|
+
state.shared_output = self.shared_experts(state.hidden_states_mlp_input)
|
344
|
+
else:
|
345
|
+
state.shared_output = None
|
346
|
+
|
347
|
+
def op_select_experts(self, state):
|
348
|
+
router_logits = state.router_logits
|
349
|
+
hidden_states = state.hidden_states_mlp_input
|
350
|
+
|
351
|
+
if self._enable_deepep_moe:
|
352
|
+
if router_logits is not None:
|
353
|
+
state.topk_weights_local, state.topk_idx_local = select_experts(
|
354
|
+
hidden_states=hidden_states,
|
355
|
+
router_logits=router_logits,
|
356
|
+
top_k=self.top_k,
|
357
|
+
use_grouped_topk=True,
|
358
|
+
renormalize=self.renormalize,
|
359
|
+
topk_group=self.topk_group,
|
360
|
+
num_expert_group=self.num_expert_group,
|
361
|
+
correction_bias=self.correction_bias,
|
362
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
363
|
+
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
364
|
+
layer_id=self.layer_id,
|
365
|
+
),
|
366
|
+
)
|
367
|
+
else:
|
368
|
+
state.topk_idx_local = torch.full(
|
369
|
+
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
|
370
|
+
)
|
371
|
+
state.topk_weights_local = torch.empty(
|
372
|
+
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
|
373
|
+
)
|
374
|
+
|
375
|
+
def op_dispatch_a(self, state):
|
376
|
+
if self._enable_deepep_moe and (self.ep_size > 1):
|
355
377
|
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
|
378
|
+
self.deepep_dispatcher.dispatch_a(
|
379
|
+
hidden_states=state.pop("hidden_states_mlp_input"),
|
380
|
+
topk_idx=state.pop("topk_idx_local"),
|
381
|
+
topk_weights=state.pop("topk_weights_local"),
|
382
|
+
forward_mode=state.forward_batch.forward_mode,
|
383
|
+
)
|
384
|
+
|
385
|
+
def op_dispatch_b(self, state):
|
386
|
+
if self._enable_deepep_moe and (self.ep_size > 1):
|
356
387
|
(
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
reorder_topk_ids,
|
361
|
-
num_recv_tokens_per_expert,
|
362
|
-
seg_indptr,
|
363
|
-
masked_m,
|
364
|
-
expected_m,
|
365
|
-
) = self.deepep_dispatcher.
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
388
|
+
state.hidden_states_experts_input,
|
389
|
+
state.topk_idx_dispatched,
|
390
|
+
state.topk_weights_dispatched,
|
391
|
+
state.reorder_topk_ids,
|
392
|
+
state.num_recv_tokens_per_expert,
|
393
|
+
state.seg_indptr,
|
394
|
+
state.masked_m,
|
395
|
+
state.expected_m,
|
396
|
+
) = self.deepep_dispatcher.dispatch_b()
|
397
|
+
|
398
|
+
def op_experts(self, state):
|
399
|
+
if self._enable_deepep_moe:
|
400
|
+
state.pop("router_logits")
|
401
|
+
state.hidden_states_experts_output = self.experts(
|
402
|
+
hidden_states=state.pop("hidden_states_experts_input"),
|
403
|
+
topk_idx=state.topk_idx_dispatched,
|
404
|
+
topk_weights=state.topk_weights_dispatched,
|
405
|
+
reorder_topk_ids=state.pop("reorder_topk_ids"),
|
406
|
+
seg_indptr=state.pop("seg_indptr"),
|
407
|
+
masked_m=state.pop("masked_m"),
|
408
|
+
expected_m=state.pop("expected_m"),
|
409
|
+
num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
|
410
|
+
forward_mode=state.forward_batch.forward_mode,
|
370
411
|
)
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
reorder_topk_ids=reorder_topk_ids,
|
376
|
-
seg_indptr=seg_indptr,
|
377
|
-
masked_m=masked_m,
|
378
|
-
expected_m=expected_m,
|
379
|
-
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
|
380
|
-
forward_mode=forward_mode,
|
381
|
-
)
|
382
|
-
if self.ep_size > 1:
|
383
|
-
final_hidden_states = self.deepep_dispatcher.combine(
|
384
|
-
final_hidden_states,
|
385
|
-
topk_idx,
|
386
|
-
topk_weights,
|
387
|
-
forward_mode,
|
412
|
+
else:
|
413
|
+
state.hidden_states_experts_output = self.experts(
|
414
|
+
hidden_states=state.pop("hidden_states_mlp_input"),
|
415
|
+
router_logits=state.pop("router_logits"),
|
388
416
|
)
|
417
|
+
|
418
|
+
def op_combine_a(self, state):
|
419
|
+
if self._enable_deepep_moe and (self.ep_size > 1):
|
420
|
+
self.deepep_dispatcher.combine_a(
|
421
|
+
state.pop("hidden_states_experts_output"),
|
422
|
+
topk_idx=state.pop("topk_idx_dispatched"),
|
423
|
+
topk_weights=state.pop("topk_weights_dispatched"),
|
424
|
+
forward_mode=state.forward_batch.forward_mode,
|
425
|
+
)
|
426
|
+
|
427
|
+
def op_combine_b(self, state):
|
428
|
+
if self._enable_deepep_moe and (self.ep_size > 1):
|
429
|
+
state.hidden_states_after_combine = self.deepep_dispatcher.combine_b()
|
430
|
+
|
431
|
+
def op_output(self, state):
|
432
|
+
final_hidden_states = (
|
433
|
+
state.pop("hidden_states_after_combine")
|
434
|
+
if self._enable_deepep_moe
|
435
|
+
else state.pop("hidden_states_experts_output")
|
436
|
+
)
|
437
|
+
|
389
438
|
final_hidden_states *= self.routed_scaling_factor
|
390
439
|
|
391
|
-
if shared_output is not None:
|
392
|
-
final_hidden_states = final_hidden_states +
|
440
|
+
if (s := state.pop("shared_output")) is not None:
|
441
|
+
final_hidden_states = final_hidden_states + s
|
393
442
|
|
394
|
-
|
443
|
+
if (not self._enable_deepep_moe) and (self.tp_size > 1):
|
444
|
+
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
395
445
|
|
396
|
-
|
397
|
-
if self.n_share_experts_fusion == 0:
|
398
|
-
return self.shared_experts(hidden_states)
|
399
|
-
else:
|
400
|
-
return None
|
446
|
+
state.hidden_states_mlp_output = final_hidden_states
|
401
447
|
|
402
448
|
|
403
449
|
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
@@ -438,7 +484,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
438
484
|
self.v_head_dim = v_head_dim
|
439
485
|
self.q_lora_rank = q_lora_rank
|
440
486
|
self.kv_lora_rank = kv_lora_rank
|
441
|
-
self.dp_size = get_attention_dp_size()
|
442
487
|
attn_tp_rank = get_attention_tp_rank()
|
443
488
|
attn_tp_size = get_attention_tp_size()
|
444
489
|
|
@@ -579,6 +624,18 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
579
624
|
def dispatch_attn_forward_method(
|
580
625
|
self, forward_batch: ForwardBatch
|
581
626
|
) -> AttnForwardMethod:
|
627
|
+
def _dispatch_mla_subtype():
|
628
|
+
if _is_hip:
|
629
|
+
if (
|
630
|
+
self.rocm_fused_decode_mla
|
631
|
+
and forward_batch.forward_mode.is_decode()
|
632
|
+
):
|
633
|
+
return AttnForwardMethod.MLA_FUSED_ROPE
|
634
|
+
else:
|
635
|
+
return AttnForwardMethod.MLA
|
636
|
+
else:
|
637
|
+
return AttnForwardMethod.MLA
|
638
|
+
|
582
639
|
if self.attention_backend == "flashinfer":
|
583
640
|
# Flashinfer MLA: Do not absorb when enabling ragged prefill
|
584
641
|
if (
|
@@ -590,7 +647,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
590
647
|
):
|
591
648
|
return AttnForwardMethod.MHA
|
592
649
|
else:
|
593
|
-
return
|
650
|
+
return _dispatch_mla_subtype()
|
594
651
|
elif self.attention_backend == "fa3":
|
595
652
|
# Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences.
|
596
653
|
if forward_batch.extend_prefix_lens_cpu is not None:
|
@@ -607,7 +664,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
607
664
|
):
|
608
665
|
return AttnForwardMethod.MHA_CHUNKED_KV
|
609
666
|
else:
|
610
|
-
return
|
667
|
+
return _dispatch_mla_subtype()
|
611
668
|
else:
|
612
669
|
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
|
613
670
|
if (
|
@@ -618,7 +675,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
618
675
|
):
|
619
676
|
return AttnForwardMethod.MHA
|
620
677
|
else:
|
621
|
-
return
|
678
|
+
return _dispatch_mla_subtype()
|
622
679
|
|
623
680
|
def forward(
|
624
681
|
self,
|
@@ -641,23 +698,16 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
641
698
|
return self.forward_normal_chunked_kv(
|
642
699
|
positions, hidden_states, forward_batch
|
643
700
|
)
|
701
|
+
elif attn_forward_method == AttnForwardMethod.MLA:
|
702
|
+
return self.forward_absorb(
|
703
|
+
positions, hidden_states, forward_batch, zero_allocator
|
704
|
+
)
|
705
|
+
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
|
706
|
+
return self.forward_absorb_fused_mla_rope(
|
707
|
+
positions, hidden_states, forward_batch
|
708
|
+
)
|
644
709
|
else:
|
645
|
-
|
646
|
-
if (
|
647
|
-
self.rocm_fused_decode_mla
|
648
|
-
and forward_batch.forward_mode.is_decode()
|
649
|
-
):
|
650
|
-
return self.forward_absorb_fused_mla_rope(
|
651
|
-
positions, hidden_states, forward_batch
|
652
|
-
)
|
653
|
-
else:
|
654
|
-
return self.forward_absorb(
|
655
|
-
positions, hidden_states, forward_batch, zero_allocator
|
656
|
-
)
|
657
|
-
else:
|
658
|
-
return self.forward_absorb(
|
659
|
-
positions, hidden_states, forward_batch, zero_allocator
|
660
|
-
)
|
710
|
+
raise NotImplementedError
|
661
711
|
|
662
712
|
def forward_normal(
|
663
713
|
self,
|
@@ -711,6 +761,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
711
761
|
forward_batch: ForwardBatch,
|
712
762
|
zero_allocator: BumpAllocator,
|
713
763
|
) -> torch.Tensor:
|
764
|
+
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
765
|
+
|
714
766
|
if self.q_lora_rank is not None:
|
715
767
|
q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
|
716
768
|
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
|
@@ -718,7 +770,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
718
770
|
k_nope = latent_cache[..., : self.kv_lora_rank]
|
719
771
|
|
720
772
|
# overlap qk norm
|
721
|
-
if self.alt_stream is not None and
|
773
|
+
if self.alt_stream is not None and get_is_capture_mode():
|
722
774
|
current_stream = torch.cuda.current_stream()
|
723
775
|
self.alt_stream.wait_stream(current_stream)
|
724
776
|
q = self.q_a_layernorm(q)
|
@@ -1102,19 +1154,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1102
1154
|
return output
|
1103
1155
|
|
1104
1156
|
|
1105
|
-
class _FFNInputMode(Enum):
|
1106
|
-
# The MLP sublayer requires 1/tp_size tokens as input
|
1107
|
-
SCATTERED = auto()
|
1108
|
-
# The MLP sublayer requires all tokens as input
|
1109
|
-
FULL = auto()
|
1110
|
-
|
1111
|
-
|
1112
|
-
@dataclass
|
1113
|
-
class _DecoderLayerInfo:
|
1114
|
-
is_sparse: bool
|
1115
|
-
ffn_input_mode: _FFNInputMode
|
1116
|
-
|
1117
|
-
|
1118
1157
|
class DeepseekV2DecoderLayer(nn.Module):
|
1119
1158
|
|
1120
1159
|
def __init__(
|
@@ -1128,14 +1167,12 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1128
1167
|
) -> None:
|
1129
1168
|
super().__init__()
|
1130
1169
|
self.hidden_size = config.hidden_size
|
1170
|
+
self.config = config
|
1131
1171
|
rope_theta = getattr(config, "rope_theta", 10000)
|
1132
1172
|
rope_scaling = getattr(config, "rope_scaling", None)
|
1133
1173
|
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
1134
1174
|
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
|
1135
1175
|
self.layer_id = layer_id
|
1136
|
-
self.dp_size = get_attention_dp_size()
|
1137
|
-
self.attn_tp_size = get_attention_tp_size()
|
1138
|
-
self.attn_tp_rank = get_attention_tp_rank()
|
1139
1176
|
self.self_attn = DeepseekV2AttentionMLA(
|
1140
1177
|
config=config,
|
1141
1178
|
hidden_size=self.hidden_size,
|
@@ -1157,19 +1194,25 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1157
1194
|
alt_stream=alt_stream,
|
1158
1195
|
)
|
1159
1196
|
|
1160
|
-
self.
|
1161
|
-
|
1162
|
-
|
1197
|
+
self.is_layer_sparse = self._is_layer_sparse(layer_id, is_nextn=is_nextn)
|
1198
|
+
is_previous_layer_sparse = self._is_layer_sparse(layer_id - 1, is_nextn=False)
|
1199
|
+
|
1200
|
+
self.layer_scatter_modes = LayerScatterModes.init_new(
|
1201
|
+
layer_id=layer_id,
|
1202
|
+
num_layers=config.num_hidden_layers,
|
1203
|
+
is_layer_sparse=self.is_layer_sparse,
|
1204
|
+
is_previous_layer_sparse=is_previous_layer_sparse,
|
1163
1205
|
)
|
1164
1206
|
|
1165
|
-
if self.
|
1207
|
+
if self.is_layer_sparse:
|
1166
1208
|
self.mlp = DeepseekV2MoE(
|
1167
1209
|
config=config,
|
1168
1210
|
quant_config=quant_config,
|
1169
1211
|
prefix=add_prefix("mlp", prefix),
|
1212
|
+
layer_id=self.layer_id,
|
1170
1213
|
)
|
1171
1214
|
else:
|
1172
|
-
if
|
1215
|
+
if enable_moe_dense_fully_dp():
|
1173
1216
|
mlp_tp_rank, mlp_tp_size = 0, 1
|
1174
1217
|
else:
|
1175
1218
|
mlp_tp_rank, mlp_tp_size = None, None
|
@@ -1183,34 +1226,23 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1183
1226
|
tp_size=mlp_tp_size,
|
1184
1227
|
)
|
1185
1228
|
|
1186
|
-
self.input_is_scattered = (
|
1187
|
-
previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
|
1188
|
-
)
|
1189
|
-
self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
|
1190
|
-
|
1191
1229
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
1192
1230
|
self.post_attention_layernorm = RMSNorm(
|
1193
1231
|
config.hidden_size, eps=config.rms_norm_eps
|
1194
1232
|
)
|
1195
1233
|
|
1196
|
-
|
1197
|
-
|
1198
|
-
|
1199
|
-
|
1200
|
-
@staticmethod
|
1201
|
-
def _compute_info(config: PretrainedConfig, layer_id: int, is_nextn: bool):
|
1202
|
-
is_sparse = is_nextn or (
|
1203
|
-
config.n_routed_experts is not None
|
1204
|
-
and layer_id >= config.first_k_dense_replace
|
1205
|
-
and layer_id % config.moe_layer_freq == 0
|
1234
|
+
self.layer_communicator = LayerCommunicator(
|
1235
|
+
layer_scatter_modes=self.layer_scatter_modes,
|
1236
|
+
input_layernorm=self.input_layernorm,
|
1237
|
+
post_attention_layernorm=self.post_attention_layernorm,
|
1206
1238
|
)
|
1207
|
-
|
1208
|
-
|
1209
|
-
|
1210
|
-
|
1211
|
-
|
1239
|
+
|
1240
|
+
def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool:
|
1241
|
+
return is_nextn or (
|
1242
|
+
self.config.n_routed_experts is not None
|
1243
|
+
and layer_id >= self.config.first_k_dense_replace
|
1244
|
+
and layer_id % self.config.moe_layer_freq == 0
|
1212
1245
|
)
|
1213
|
-
return _DecoderLayerInfo(is_sparse=is_sparse, ffn_input_mode=ffn_input_mode)
|
1214
1246
|
|
1215
1247
|
def forward(
|
1216
1248
|
self,
|
@@ -1220,163 +1252,75 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1220
1252
|
residual: Optional[torch.Tensor],
|
1221
1253
|
zero_allocator: BumpAllocator,
|
1222
1254
|
) -> torch.Tensor:
|
1223
|
-
|
1224
|
-
|
1225
|
-
positions, hidden_states, forward_batch, residual, zero_allocator
|
1226
|
-
)
|
1227
|
-
elif self.info.ffn_input_mode == _FFNInputMode.FULL:
|
1228
|
-
return self.forward_ffn_with_full_input(
|
1229
|
-
positions, hidden_states, forward_batch, residual, zero_allocator
|
1230
|
-
)
|
1231
|
-
else:
|
1232
|
-
raise NotImplementedError
|
1233
|
-
|
1234
|
-
def forward_ffn_with_full_input(
|
1235
|
-
self,
|
1236
|
-
positions: torch.Tensor,
|
1237
|
-
hidden_states: torch.Tensor,
|
1238
|
-
forward_batch: ForwardBatch,
|
1239
|
-
residual: Optional[torch.Tensor],
|
1240
|
-
zero_allocator: BumpAllocator,
|
1241
|
-
) -> torch.Tensor:
|
1242
|
-
|
1243
|
-
if hidden_states.shape[0] == 0:
|
1244
|
-
residual = hidden_states
|
1245
|
-
else:
|
1246
|
-
if residual is None:
|
1247
|
-
residual = hidden_states
|
1248
|
-
hidden_states = self.input_layernorm(hidden_states)
|
1249
|
-
else:
|
1250
|
-
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
1251
|
-
|
1252
|
-
assert not (
|
1253
|
-
self.attn_tp_size != 1 and self.input_is_scattered
|
1254
|
-
), "moe_layer_freq > 1 is not supported when attn_tp_size > 1"
|
1255
|
-
|
1256
|
-
# Self Attention
|
1257
|
-
hidden_states = self.self_attn(
|
1255
|
+
return execute_operations(
|
1256
|
+
inputs=dict(
|
1258
1257
|
positions=positions,
|
1259
1258
|
hidden_states=hidden_states,
|
1260
1259
|
forward_batch=forward_batch,
|
1260
|
+
residual=residual,
|
1261
1261
|
zero_allocator=zero_allocator,
|
1262
|
-
)
|
1263
|
-
|
1264
|
-
|
1265
|
-
if get_tensor_model_parallel_world_size() > 1:
|
1266
|
-
# all gather and all reduce
|
1267
|
-
if self.dp_size != 1:
|
1268
|
-
if self.attn_tp_rank == 0:
|
1269
|
-
hidden_states += residual
|
1270
|
-
hidden_states, local_hidden_states = (
|
1271
|
-
forward_batch.gathered_buffer,
|
1272
|
-
hidden_states,
|
1273
|
-
)
|
1274
|
-
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
|
1275
|
-
dp_scatter(residual, hidden_states, forward_batch)
|
1276
|
-
hidden_states = self.post_attention_layernorm(hidden_states)
|
1277
|
-
else:
|
1278
|
-
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
1279
|
-
hidden_states, residual = self.post_attention_layernorm(
|
1280
|
-
hidden_states, residual
|
1281
|
-
)
|
1282
|
-
else:
|
1283
|
-
hidden_states, residual = self.post_attention_layernorm(
|
1284
|
-
hidden_states, residual
|
1285
|
-
)
|
1286
|
-
|
1287
|
-
# Fully Connected
|
1288
|
-
hidden_states = self.mlp(hidden_states)
|
1289
|
-
|
1290
|
-
# TODO(ch-wan): ues reduce-scatter in MLP to avoid this scatter
|
1291
|
-
# Scatter
|
1292
|
-
if self.dp_size != 1:
|
1293
|
-
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
1294
|
-
# be careful about this!
|
1295
|
-
hidden_states, global_hidden_states = (
|
1296
|
-
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
1297
|
-
hidden_states,
|
1298
|
-
)
|
1299
|
-
dp_scatter(hidden_states, global_hidden_states, forward_batch)
|
1300
|
-
|
1301
|
-
return hidden_states, residual
|
1262
|
+
),
|
1263
|
+
operations=compute_layer_operations(self),
|
1264
|
+
)
|
1302
1265
|
|
1303
|
-
def
|
1266
|
+
def op_comm_prepare_attn(
|
1304
1267
|
self,
|
1268
|
+
state,
|
1305
1269
|
positions: torch.Tensor,
|
1306
1270
|
hidden_states: torch.Tensor,
|
1307
1271
|
forward_batch: ForwardBatch,
|
1308
1272
|
residual: Optional[torch.Tensor],
|
1309
1273
|
zero_allocator: BumpAllocator,
|
1310
|
-
)
|
1311
|
-
|
1312
|
-
|
1313
|
-
|
1314
|
-
|
1315
|
-
|
1316
|
-
|
1317
|
-
|
1318
|
-
|
1319
|
-
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
1320
|
-
|
1321
|
-
if self.attn_tp_size != 1 and self.input_is_scattered:
|
1322
|
-
hidden_states, local_hidden_states = (
|
1323
|
-
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
1324
|
-
hidden_states,
|
1325
|
-
)
|
1326
|
-
tp_all_gather(
|
1327
|
-
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
|
1274
|
+
):
|
1275
|
+
state.hidden_states_after_comm_pre_attn, state.residual_after_input_ln = (
|
1276
|
+
self.layer_communicator.prepare_attn(hidden_states, residual, forward_batch)
|
1277
|
+
)
|
1278
|
+
state.update(
|
1279
|
+
dict(
|
1280
|
+
forward_batch=forward_batch,
|
1281
|
+
positions=positions,
|
1282
|
+
zero_allocator=zero_allocator,
|
1328
1283
|
)
|
1284
|
+
)
|
1329
1285
|
|
1330
|
-
|
1331
|
-
|
1332
|
-
positions=positions,
|
1333
|
-
hidden_states=
|
1334
|
-
forward_batch=forward_batch,
|
1335
|
-
zero_allocator=zero_allocator,
|
1286
|
+
def op_attn(self, state):
|
1287
|
+
state.hidden_states_after_attn = self.self_attn(
|
1288
|
+
positions=state.positions,
|
1289
|
+
hidden_states=state.pop("hidden_states_after_comm_pre_attn"),
|
1290
|
+
forward_batch=state.forward_batch,
|
1291
|
+
zero_allocator=state.zero_allocator,
|
1336
1292
|
)
|
1337
1293
|
|
1338
|
-
|
1339
|
-
|
1340
|
-
|
1341
|
-
|
1342
|
-
|
1343
|
-
|
1344
|
-
|
1345
|
-
|
1346
|
-
)
|
1347
|
-
else:
|
1348
|
-
if self.attn_tp_rank == 0:
|
1349
|
-
hidden_states += residual
|
1350
|
-
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
|
1351
|
-
hidden_states = tensor_list[self.attn_tp_rank]
|
1352
|
-
tp_reduce_scatter(hidden_states, tensor_list)
|
1353
|
-
residual = hidden_states
|
1354
|
-
if hidden_states.shape[0] != 0:
|
1355
|
-
hidden_states = self.post_attention_layernorm(hidden_states)
|
1356
|
-
else:
|
1357
|
-
if hidden_states.shape[0] != 0:
|
1358
|
-
hidden_states, residual = self.post_attention_layernorm(
|
1359
|
-
hidden_states, residual
|
1360
|
-
)
|
1294
|
+
def op_comm_prepare_mlp(self, state):
|
1295
|
+
state.hidden_states_mlp_input, state.residual_after_comm_pre_mlp = (
|
1296
|
+
self.layer_communicator.prepare_mlp(
|
1297
|
+
state.pop("hidden_states_after_attn"),
|
1298
|
+
state.pop("residual_after_input_ln"),
|
1299
|
+
state.forward_batch,
|
1300
|
+
)
|
1301
|
+
)
|
1361
1302
|
|
1303
|
+
def op_mlp(self, state):
|
1304
|
+
hidden_states = state.pop("hidden_states_mlp_input")
|
1362
1305
|
if not (
|
1363
|
-
|
1364
|
-
and (not self.
|
1306
|
+
enable_moe_dense_fully_dp()
|
1307
|
+
and (not self.is_layer_sparse)
|
1365
1308
|
and hidden_states.shape[0] == 0
|
1366
1309
|
):
|
1367
|
-
|
1368
|
-
|
1369
|
-
if self.is_last_layer and self.attn_tp_size != 1:
|
1370
|
-
hidden_states += residual
|
1371
|
-
residual = None
|
1372
|
-
hidden_states, local_hidden_states = (
|
1373
|
-
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
1374
|
-
hidden_states,
|
1375
|
-
)
|
1376
|
-
tp_all_gather(
|
1377
|
-
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
|
1310
|
+
state.hidden_states_mlp_output = self.mlp(
|
1311
|
+
hidden_states, state.forward_batch.forward_mode
|
1378
1312
|
)
|
1313
|
+
else:
|
1314
|
+
state.hidden_states_mlp_output = hidden_states
|
1315
|
+
|
1316
|
+
def op_comm_postprocess_layer(self, state):
|
1317
|
+
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
1318
|
+
state.pop("hidden_states_mlp_output"),
|
1319
|
+
state.pop("residual_after_comm_pre_mlp"),
|
1320
|
+
state.forward_batch,
|
1321
|
+
)
|
1379
1322
|
|
1323
|
+
state.clear(expect_keys={"positions", "forward_batch", "zero_allocator"})
|
1380
1324
|
return hidden_states, residual
|
1381
1325
|
|
1382
1326
|
|
@@ -1398,7 +1342,7 @@ class DeepseekV2Model(nn.Module):
|
|
1398
1342
|
config.hidden_size,
|
1399
1343
|
enable_tp=not global_server_args_dict["enable_dp_attention"],
|
1400
1344
|
)
|
1401
|
-
self.alt_stream = torch.cuda.Stream()
|
1345
|
+
self.alt_stream = torch.cuda.Stream() if _is_cuda else None
|
1402
1346
|
self.layers = nn.ModuleList(
|
1403
1347
|
[
|
1404
1348
|
DeepseekV2DecoderLayer(
|
@@ -1413,7 +1357,7 @@ class DeepseekV2Model(nn.Module):
|
|
1413
1357
|
)
|
1414
1358
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
1415
1359
|
|
1416
|
-
self.dp_size =
|
1360
|
+
self.dp_size = get_local_attention_dp_size()
|
1417
1361
|
|
1418
1362
|
def get_input_embeddings(self) -> torch.Tensor:
|
1419
1363
|
return self.embed_tokens
|
@@ -1441,11 +1385,11 @@ class DeepseekV2Model(nn.Module):
|
|
1441
1385
|
|
1442
1386
|
residual = None
|
1443
1387
|
for i in range(len(self.layers)):
|
1444
|
-
|
1445
|
-
|
1446
|
-
|
1447
|
-
|
1448
|
-
|
1388
|
+
with get_global_expert_distribution_recorder().with_current_layer(i):
|
1389
|
+
layer = self.layers[i]
|
1390
|
+
hidden_states, residual = layer(
|
1391
|
+
positions, hidden_states, forward_batch, residual, zero_allocator
|
1392
|
+
)
|
1449
1393
|
if not forward_batch.forward_mode.is_idle():
|
1450
1394
|
if residual is None:
|
1451
1395
|
hidden_states = self.norm(hidden_states)
|
@@ -1475,9 +1419,10 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1475
1419
|
config.hidden_size,
|
1476
1420
|
quant_config=quant_config,
|
1477
1421
|
prefix=add_prefix("lm_head", prefix),
|
1422
|
+
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
1478
1423
|
)
|
1479
1424
|
self.logits_processor = LogitsProcessor(config)
|
1480
|
-
self.dp_size =
|
1425
|
+
self.dp_size = get_local_attention_dp_size()
|
1481
1426
|
|
1482
1427
|
def determine_n_share_experts_fusion(
|
1483
1428
|
self, architecture: str = "DeepseekV3ForCausalLM"
|
@@ -1486,22 +1431,24 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1486
1431
|
if self.n_share_experts_fusion > 0:
|
1487
1432
|
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
|
1488
1433
|
if (
|
1489
|
-
|
1434
|
+
not _is_cuda
|
1435
|
+
or self.config.architectures[0] != architecture
|
1490
1436
|
or self.config.n_routed_experts != 256
|
1491
1437
|
):
|
1492
1438
|
self.n_share_experts_fusion = 0
|
1493
1439
|
global_server_args_dict["n_share_experts_fusion"] = 0
|
1494
1440
|
log_info_on_rank0(
|
1495
1441
|
logger,
|
1496
|
-
"Only Deepseek V3/R1 can use shared experts fusion optimization. Shared experts fusion optimization is disabled.",
|
1442
|
+
"Only Deepseek V3/R1 on NV-platform can use shared experts fusion optimization. Shared experts fusion optimization is disabled.",
|
1497
1443
|
)
|
1498
1444
|
else:
|
1499
1445
|
assert (
|
1500
1446
|
self.n_share_experts_fusion == self.tp_size
|
1501
|
-
), f"Shared experts fusion optimization is enabled in DeepSeek V3/R1, set it to {self.tp_size} can get best optimized
|
1447
|
+
), f"Shared experts fusion optimization is enabled in DeepSeek V3/R1, set it to {self.tp_size} can get best optimized performance."
|
1502
1448
|
elif self.n_share_experts_fusion == 0:
|
1503
1449
|
if (
|
1504
|
-
|
1450
|
+
_is_cuda
|
1451
|
+
and torch.cuda.get_device_capability("cuda") >= (9, 0)
|
1505
1452
|
and self.config.architectures[0] == architecture
|
1506
1453
|
and self.config.n_routed_experts == 256
|
1507
1454
|
and (not global_server_args_dict["enable_deepep_moe"])
|
@@ -1659,11 +1606,19 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1659
1606
|
self_attn.w_vc = w_vc.contiguous()
|
1660
1607
|
self_attn.use_deep_gemm_bmm = True
|
1661
1608
|
|
1609
|
+
# TODO support nextn later
|
1610
|
+
if not is_nextn:
|
1611
|
+
self.routed_experts_weights_of_layer = {
|
1612
|
+
layer_id: layer.mlp.get_moe_weights()
|
1613
|
+
for layer_id, layer in enumerate(self.model.layers)
|
1614
|
+
if isinstance(layer.mlp, DeepseekV2MoE)
|
1615
|
+
}
|
1616
|
+
|
1662
1617
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
|
1663
1618
|
if is_nextn:
|
1664
1619
|
if hasattr(self.config, "num_nextn_predict_layers"):
|
1665
1620
|
num_nextn_layers = self.config.num_nextn_predict_layers
|
1666
|
-
assert num_nextn_layers == 1, "Only 1 nextn layer is
|
1621
|
+
assert num_nextn_layers == 1, "Only 1 nextn layer is supported"
|
1667
1622
|
# compatible with old design
|
1668
1623
|
nextn_layer_id = (
|
1669
1624
|
0
|
@@ -1735,12 +1690,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1735
1690
|
|
1736
1691
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
1737
1692
|
# (param_name, weight_name, expert_id, shard_id)
|
1738
|
-
|
1739
|
-
DeepEPMoE
|
1740
|
-
if global_server_args_dict["enable_deepep_moe"]
|
1741
|
-
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
|
1742
|
-
)
|
1743
|
-
expert_params_mapping = MoEImpl.make_expert_params_mapping(
|
1693
|
+
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
|
1744
1694
|
ckpt_gate_proj_name="gate_proj",
|
1745
1695
|
ckpt_down_proj_name="down_proj",
|
1746
1696
|
ckpt_up_proj_name="up_proj",
|
@@ -1856,7 +1806,6 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1856
1806
|
q_a_proj_name in cached_a_proj
|
1857
1807
|
and kv_a_proj_name in cached_a_proj
|
1858
1808
|
):
|
1859
|
-
|
1860
1809
|
q_a_proj_weight = cached_a_proj[q_a_proj_name]
|
1861
1810
|
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
|
1862
1811
|
fused_weight = torch.cat(
|
@@ -1894,6 +1843,14 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1894
1843
|
torch.cuda.empty_cache()
|
1895
1844
|
torch.cuda.synchronize()
|
1896
1845
|
|
1846
|
+
@classmethod
|
1847
|
+
def get_model_config_for_expert_location(cls, config):
|
1848
|
+
return ModelConfigForExpertLocation(
|
1849
|
+
num_layers=config.num_hidden_layers,
|
1850
|
+
num_logical_experts=config.n_routed_experts,
|
1851
|
+
num_groups=config.n_group,
|
1852
|
+
)
|
1853
|
+
|
1897
1854
|
|
1898
1855
|
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
|
1899
1856
|
pass
|