sglang 0.4.6.post4__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +6 -6
- sglang/bench_one_batch.py +5 -4
- sglang/bench_one_batch_server.py +23 -15
- sglang/bench_serving.py +133 -57
- sglang/compile_deep_gemm.py +4 -4
- sglang/srt/configs/model_config.py +39 -28
- sglang/srt/conversation.py +1 -1
- sglang/srt/disaggregation/decode.py +122 -133
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
- sglang/srt/disaggregation/fake/conn.py +3 -13
- sglang/srt/disaggregation/kv_events.py +357 -0
- sglang/srt/disaggregation/mini_lb.py +57 -24
- sglang/srt/disaggregation/mooncake/conn.py +11 -2
- sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
- sglang/srt/disaggregation/nixl/conn.py +9 -19
- sglang/srt/disaggregation/prefill.py +126 -44
- sglang/srt/disaggregation/utils.py +116 -5
- sglang/srt/distributed/utils.py +3 -3
- sglang/srt/entrypoints/EngineBase.py +5 -0
- sglang/srt/entrypoints/engine.py +28 -8
- sglang/srt/entrypoints/http_server.py +6 -4
- sglang/srt/entrypoints/http_server_engine.py +5 -2
- sglang/srt/function_call/base_format_detector.py +250 -0
- sglang/srt/function_call/core_types.py +34 -0
- sglang/srt/function_call/deepseekv3_detector.py +157 -0
- sglang/srt/function_call/ebnf_composer.py +234 -0
- sglang/srt/function_call/function_call_parser.py +175 -0
- sglang/srt/function_call/llama32_detector.py +74 -0
- sglang/srt/function_call/mistral_detector.py +84 -0
- sglang/srt/function_call/pythonic_detector.py +163 -0
- sglang/srt/function_call/qwen25_detector.py +67 -0
- sglang/srt/function_call/utils.py +35 -0
- sglang/srt/hf_transformers_utils.py +46 -7
- sglang/srt/layers/attention/aiter_backend.py +513 -0
- sglang/srt/layers/attention/flashattention_backend.py +63 -17
- sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
- sglang/srt/layers/attention/flashmla_backend.py +340 -78
- sglang/srt/layers/attention/triton_backend.py +3 -0
- sglang/srt/layers/attention/utils.py +2 -2
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/communicator.py +451 -0
- sglang/srt/layers/dp_attention.py +0 -10
- sglang/srt/layers/moe/cutlass_moe.py +207 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +33 -11
- sglang/srt/layers/moe/ep_moe/layer.py +104 -50
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
- sglang/srt/layers/moe/topk.py +66 -9
- sglang/srt/layers/multimodal.py +70 -0
- sglang/srt/layers/quantization/__init__.py +7 -2
- sglang/srt/layers/quantization/deep_gemm.py +5 -3
- sglang/srt/layers/quantization/fp8.py +90 -0
- sglang/srt/layers/quantization/fp8_utils.py +6 -0
- sglang/srt/layers/quantization/gptq.py +298 -6
- sglang/srt/layers/quantization/int8_kernel.py +18 -5
- sglang/srt/layers/quantization/qoq.py +244 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/deepseek_eplb.py +278 -0
- sglang/srt/managers/eplb_manager.py +55 -0
- sglang/srt/managers/expert_distribution.py +704 -56
- sglang/srt/managers/expert_location.py +394 -0
- sglang/srt/managers/expert_location_dispatch.py +91 -0
- sglang/srt/managers/io_struct.py +16 -3
- sglang/srt/managers/mm_utils.py +293 -139
- sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
- sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
- sglang/srt/managers/multimodal_processors/internvl.py +14 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
- sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
- sglang/srt/managers/multimodal_processors/llava.py +3 -3
- sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
- sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +9 -9
- sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
- sglang/srt/managers/schedule_batch.py +49 -21
- sglang/srt/managers/schedule_policy.py +4 -5
- sglang/srt/managers/scheduler.py +92 -50
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +99 -24
- sglang/srt/mem_cache/base_prefix_cache.py +3 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -1
- sglang/srt/mem_cache/hiradix_cache.py +4 -4
- sglang/srt/mem_cache/memory_pool.py +74 -52
- sglang/srt/mem_cache/multimodal_cache.py +45 -0
- sglang/srt/mem_cache/radix_cache.py +58 -5
- sglang/srt/metrics/collector.py +2 -2
- sglang/srt/mm_utils.py +10 -0
- sglang/srt/model_executor/cuda_graph_runner.py +20 -9
- sglang/srt/model_executor/expert_location_updater.py +422 -0
- sglang/srt/model_executor/forward_batch_info.py +4 -0
- sglang/srt/model_executor/model_runner.py +144 -54
- sglang/srt/model_loader/loader.py +10 -6
- sglang/srt/models/clip.py +5 -1
- sglang/srt/models/deepseek_v2.py +297 -343
- sglang/srt/models/exaone.py +8 -3
- sglang/srt/models/gemma3_mm.py +70 -33
- sglang/srt/models/llama4.py +10 -2
- sglang/srt/models/llava.py +26 -18
- sglang/srt/models/mimo_mtp.py +220 -0
- sglang/srt/models/minicpmo.py +5 -12
- sglang/srt/models/mistral.py +71 -1
- sglang/srt/models/mllama.py +3 -3
- sglang/srt/models/qwen2.py +95 -26
- sglang/srt/models/qwen2_5_vl.py +8 -0
- sglang/srt/models/qwen2_moe.py +330 -60
- sglang/srt/models/qwen2_vl.py +6 -0
- sglang/srt/models/qwen3.py +52 -10
- sglang/srt/models/qwen3_moe.py +411 -48
- sglang/srt/models/siglip.py +294 -0
- sglang/srt/openai_api/adapter.py +28 -16
- sglang/srt/openai_api/protocol.py +6 -0
- sglang/srt/operations.py +154 -0
- sglang/srt/operations_strategy.py +31 -0
- sglang/srt/server_args.py +134 -24
- sglang/srt/speculative/eagle_utils.py +131 -0
- sglang/srt/speculative/eagle_worker.py +47 -2
- sglang/srt/utils.py +68 -12
- sglang/test/test_cutlass_moe.py +278 -0
- sglang/test/test_utils.py +2 -36
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +20 -11
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +128 -102
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
- sglang/srt/function_call_parser.py +0 -858
- sglang/srt/platforms/interface.py +0 -371
- /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -18,8 +18,7 @@
|
|
18
18
|
|
19
19
|
import logging
|
20
20
|
import os
|
21
|
-
from
|
22
|
-
from enum import Enum, IntEnum, auto
|
21
|
+
from enum import IntEnum, auto
|
23
22
|
from typing import Any, Dict, Iterable, Optional, Tuple
|
24
23
|
|
25
24
|
import torch
|
@@ -29,17 +28,17 @@ from tqdm import tqdm
|
|
29
28
|
from transformers import PretrainedConfig
|
30
29
|
|
31
30
|
from sglang.srt.distributed import (
|
32
|
-
get_tensor_model_parallel_rank,
|
33
31
|
get_tensor_model_parallel_world_size,
|
34
32
|
parallel_state,
|
35
33
|
tensor_model_parallel_all_reduce,
|
36
34
|
)
|
37
35
|
from sglang.srt.layers.activation import SiluAndMul
|
36
|
+
from sglang.srt.layers.communicator import (
|
37
|
+
LayerCommunicator,
|
38
|
+
LayerScatterModes,
|
39
|
+
enable_moe_dense_fully_dp,
|
40
|
+
)
|
38
41
|
from sglang.srt.layers.dp_attention import (
|
39
|
-
attn_tp_all_gather,
|
40
|
-
attn_tp_reduce_scatter,
|
41
|
-
dp_gather_partial,
|
42
|
-
dp_scatter,
|
43
42
|
get_attention_tp_rank,
|
44
43
|
get_attention_tp_size,
|
45
44
|
get_local_attention_dp_size,
|
@@ -52,9 +51,8 @@ from sglang.srt.layers.linear import (
|
|
52
51
|
RowParallelLinear,
|
53
52
|
)
|
54
53
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
55
|
-
from sglang.srt.layers.moe.ep_moe.layer import
|
54
|
+
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
56
55
|
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
57
|
-
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
58
56
|
from sglang.srt.layers.moe.topk import select_experts
|
59
57
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
60
58
|
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
@@ -72,15 +70,21 @@ from sglang.srt.layers.quantization.int8_utils import (
|
|
72
70
|
block_dequant as int8_block_dequant,
|
73
71
|
)
|
74
72
|
from sglang.srt.layers.radix_attention import RadixAttention
|
75
|
-
from sglang.srt.layers.rotary_embedding import get_rope
|
73
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
76
74
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
77
75
|
ParallelLMHead,
|
78
76
|
VocabParallelEmbedding,
|
79
77
|
)
|
80
|
-
from sglang.srt.managers.expert_distribution import
|
78
|
+
from sglang.srt.managers.expert_distribution import (
|
79
|
+
get_global_expert_distribution_recorder,
|
80
|
+
)
|
81
|
+
from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
|
82
|
+
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
|
81
83
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
82
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
84
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
83
85
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
86
|
+
from sglang.srt.operations import execute_operations
|
87
|
+
from sglang.srt.operations_strategy import compute_layer_operations
|
84
88
|
from sglang.srt.utils import (
|
85
89
|
BumpAllocator,
|
86
90
|
DeepEPMode,
|
@@ -109,8 +113,6 @@ if _is_hip:
|
|
109
113
|
decode_attention_fwd_grouped_rope,
|
110
114
|
)
|
111
115
|
|
112
|
-
expert_distribution_recorder = ExpertDistributionRecorder()
|
113
|
-
|
114
116
|
logger = logging.getLogger(__name__)
|
115
117
|
|
116
118
|
|
@@ -125,6 +127,9 @@ class AttnForwardMethod(IntEnum):
|
|
125
127
|
# This method can avoid OOM when prefix lengths are long.
|
126
128
|
MHA_CHUNKED_KV = auto()
|
127
129
|
|
130
|
+
# Use MLA but with fused RoPE
|
131
|
+
MLA_FUSED_ROPE = auto()
|
132
|
+
|
128
133
|
|
129
134
|
class DeepseekV2MLP(nn.Module):
|
130
135
|
def __init__(
|
@@ -139,6 +144,8 @@ class DeepseekV2MLP(nn.Module):
|
|
139
144
|
tp_size: Optional[int] = None,
|
140
145
|
) -> None:
|
141
146
|
super().__init__()
|
147
|
+
self.tp_size = tp_size
|
148
|
+
|
142
149
|
self.gate_up_proj = MergedColumnParallelLinear(
|
143
150
|
hidden_size,
|
144
151
|
[intermediate_size] * 2,
|
@@ -165,7 +172,10 @@ class DeepseekV2MLP(nn.Module):
|
|
165
172
|
)
|
166
173
|
self.act_fn = SiluAndMul()
|
167
174
|
|
168
|
-
def forward(self, x,
|
175
|
+
def forward(self, x, forward_batch=None):
|
176
|
+
if (self.tp_size == 1) and x.shape[0] == 0:
|
177
|
+
return x
|
178
|
+
|
169
179
|
gate_up, _ = self.gate_up_proj(x)
|
170
180
|
x = self.act_fn(gate_up)
|
171
181
|
x, _ = self.down_proj(x)
|
@@ -194,11 +204,20 @@ class MoEGate(nn.Module):
|
|
194
204
|
return logits
|
195
205
|
|
196
206
|
|
207
|
+
def is_non_idle_and_non_empty(forward_mode, hidden_states):
|
208
|
+
return (
|
209
|
+
(forward_mode is not None)
|
210
|
+
and not forward_mode.is_idle()
|
211
|
+
and hidden_states.shape[0] > 0
|
212
|
+
)
|
213
|
+
|
214
|
+
|
197
215
|
class DeepseekV2MoE(nn.Module):
|
198
216
|
|
199
217
|
def __init__(
|
200
218
|
self,
|
201
219
|
config: PretrainedConfig,
|
220
|
+
layer_id: int,
|
202
221
|
quant_config: Optional[QuantizationConfig] = None,
|
203
222
|
prefix: str = "",
|
204
223
|
):
|
@@ -207,6 +226,7 @@ class DeepseekV2MoE(nn.Module):
|
|
207
226
|
self.routed_scaling_factor = config.routed_scaling_factor
|
208
227
|
self.n_shared_experts = config.n_shared_experts
|
209
228
|
self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
|
229
|
+
self.layer_id = layer_id
|
210
230
|
|
211
231
|
if self.tp_size > config.n_routed_experts:
|
212
232
|
raise ValueError(
|
@@ -222,17 +242,14 @@ class DeepseekV2MoE(nn.Module):
|
|
222
242
|
|
223
243
|
self.gate = MoEGate(config=config, prefix=add_prefix("gate", prefix))
|
224
244
|
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
)
|
230
|
-
|
231
|
-
self.experts = MoEImpl(
|
232
|
-
num_experts=config.n_routed_experts + self.n_share_experts_fusion,
|
245
|
+
self.experts = get_moe_impl_class()(
|
246
|
+
num_experts=config.n_routed_experts
|
247
|
+
+ self.n_share_experts_fusion
|
248
|
+
+ global_server_args_dict["ep_num_redundant_experts"],
|
233
249
|
top_k=config.num_experts_per_tok + min(self.n_share_experts_fusion, 1),
|
234
250
|
hidden_size=config.hidden_size,
|
235
251
|
intermediate_size=config.moe_intermediate_size,
|
252
|
+
layer_id=self.layer_id,
|
236
253
|
renormalize=config.norm_topk_prob,
|
237
254
|
quant_config=quant_config,
|
238
255
|
use_grouped_topk=True,
|
@@ -251,32 +268,29 @@ class DeepseekV2MoE(nn.Module):
|
|
251
268
|
if config.n_shared_experts is not None and self.n_share_experts_fusion == 0:
|
252
269
|
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
253
270
|
# disable tp for shared experts when enable deepep moe
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
reduce_results=False,
|
270
|
-
prefix=add_prefix("shared_experts", prefix),
|
271
|
-
tp_rank=0,
|
272
|
-
tp_size=1,
|
273
|
-
)
|
271
|
+
self.shared_experts = DeepseekV2MLP(
|
272
|
+
hidden_size=config.hidden_size,
|
273
|
+
intermediate_size=intermediate_size,
|
274
|
+
hidden_act=config.hidden_act,
|
275
|
+
quant_config=quant_config,
|
276
|
+
reduce_results=False,
|
277
|
+
prefix=add_prefix("shared_experts", prefix),
|
278
|
+
**(
|
279
|
+
dict(tp_rank=0, tp_size=1)
|
280
|
+
if global_server_args_dict["enable_deepep_moe"]
|
281
|
+
else {}
|
282
|
+
),
|
283
|
+
)
|
284
|
+
|
285
|
+
self.top_k = config.num_experts_per_tok
|
274
286
|
|
275
287
|
if global_server_args_dict["enable_deepep_moe"]:
|
276
288
|
# TODO: we will support tp < ep in the future
|
277
289
|
self.ep_size = get_tensor_model_parallel_world_size()
|
278
|
-
self.num_experts =
|
279
|
-
|
290
|
+
self.num_experts = (
|
291
|
+
config.n_routed_experts
|
292
|
+
+ global_server_args_dict["ep_num_redundant_experts"]
|
293
|
+
)
|
280
294
|
self.renormalize = config.norm_topk_prob
|
281
295
|
self.topk_group = config.topk_group
|
282
296
|
self.num_expert_group = config.n_group
|
@@ -290,7 +304,7 @@ class DeepseekV2MoE(nn.Module):
|
|
290
304
|
group=parallel_state.get_tp_group().device_group,
|
291
305
|
router_topk=self.top_k,
|
292
306
|
permute_fusion=True,
|
293
|
-
num_experts=
|
307
|
+
num_experts=self.num_experts,
|
294
308
|
num_local_experts=config.n_routed_experts // self.tp_size,
|
295
309
|
hidden_size=config.hidden_size,
|
296
310
|
params_dtype=config.torch_dtype,
|
@@ -299,105 +313,137 @@ class DeepseekV2MoE(nn.Module):
|
|
299
313
|
return_recv_hook=True,
|
300
314
|
)
|
301
315
|
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
if not global_server_args_dict["enable_deepep_moe"]:
|
306
|
-
return self.forward_normal(hidden_states)
|
307
|
-
else:
|
308
|
-
return self.forward_deepep(hidden_states, forward_mode)
|
316
|
+
@property
|
317
|
+
def _enable_deepep_moe(self):
|
318
|
+
return global_server_args_dict["enable_deepep_moe"]
|
309
319
|
|
310
|
-
def
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
* self.routed_scaling_factor
|
317
|
-
)
|
318
|
-
if shared_output is not None:
|
319
|
-
final_hidden_states = final_hidden_states + shared_output
|
320
|
-
if self.tp_size > 1:
|
321
|
-
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
322
|
-
return final_hidden_states
|
320
|
+
def get_moe_weights(self):
|
321
|
+
return [
|
322
|
+
x.data
|
323
|
+
for name, x in self.experts.named_parameters()
|
324
|
+
if name not in ["correction_bias"]
|
325
|
+
]
|
323
326
|
|
324
|
-
def
|
325
|
-
|
326
|
-
|
327
|
-
shared_output = None
|
328
|
-
if (
|
329
|
-
forward_mode is not None
|
330
|
-
and not forward_mode.is_idle()
|
331
|
-
and hidden_states.shape[0] > 0
|
327
|
+
def op_gate(self, state):
|
328
|
+
if (not self._enable_deepep_moe) or is_non_idle_and_non_empty(
|
329
|
+
state.forward_batch.forward_mode, state.hidden_states_mlp_input
|
332
330
|
):
|
333
331
|
# router_logits: (num_tokens, n_experts)
|
334
|
-
router_logits = self.gate(
|
335
|
-
shared_output = self._forward_shared_experts(hidden_states)
|
336
|
-
topk_weights, topk_idx = select_experts(
|
337
|
-
hidden_states=hidden_states,
|
338
|
-
router_logits=router_logits,
|
339
|
-
top_k=self.top_k,
|
340
|
-
use_grouped_topk=True,
|
341
|
-
renormalize=self.renormalize,
|
342
|
-
topk_group=self.topk_group,
|
343
|
-
num_expert_group=self.num_expert_group,
|
344
|
-
correction_bias=self.correction_bias,
|
345
|
-
routed_scaling_factor=self.routed_scaling_factor,
|
346
|
-
)
|
332
|
+
state.router_logits = self.gate(state.hidden_states_mlp_input)
|
347
333
|
else:
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
334
|
+
state.router_logits = None
|
335
|
+
|
336
|
+
def op_shared_experts(self, state):
|
337
|
+
if (self.n_share_experts_fusion == 0) and (
|
338
|
+
(not self._enable_deepep_moe)
|
339
|
+
or is_non_idle_and_non_empty(
|
340
|
+
state.forward_batch.forward_mode, state.hidden_states_mlp_input
|
353
341
|
)
|
354
|
-
|
342
|
+
):
|
343
|
+
state.shared_output = self.shared_experts(state.hidden_states_mlp_input)
|
344
|
+
else:
|
345
|
+
state.shared_output = None
|
346
|
+
|
347
|
+
def op_select_experts(self, state):
|
348
|
+
router_logits = state.router_logits
|
349
|
+
hidden_states = state.hidden_states_mlp_input
|
350
|
+
|
351
|
+
if self._enable_deepep_moe:
|
352
|
+
if router_logits is not None:
|
353
|
+
state.topk_weights_local, state.topk_idx_local = select_experts(
|
354
|
+
hidden_states=hidden_states,
|
355
|
+
router_logits=router_logits,
|
356
|
+
top_k=self.top_k,
|
357
|
+
use_grouped_topk=True,
|
358
|
+
renormalize=self.renormalize,
|
359
|
+
topk_group=self.topk_group,
|
360
|
+
num_expert_group=self.num_expert_group,
|
361
|
+
correction_bias=self.correction_bias,
|
362
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
363
|
+
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
364
|
+
layer_id=self.layer_id,
|
365
|
+
),
|
366
|
+
)
|
367
|
+
else:
|
368
|
+
state.topk_idx_local = torch.full(
|
369
|
+
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
|
370
|
+
)
|
371
|
+
state.topk_weights_local = torch.empty(
|
372
|
+
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
|
373
|
+
)
|
374
|
+
|
375
|
+
def op_dispatch_a(self, state):
|
376
|
+
if self._enable_deepep_moe and (self.ep_size > 1):
|
355
377
|
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
|
378
|
+
self.deepep_dispatcher.dispatch_a(
|
379
|
+
hidden_states=state.pop("hidden_states_mlp_input"),
|
380
|
+
topk_idx=state.pop("topk_idx_local"),
|
381
|
+
topk_weights=state.pop("topk_weights_local"),
|
382
|
+
forward_mode=state.forward_batch.forward_mode,
|
383
|
+
)
|
384
|
+
|
385
|
+
def op_dispatch_b(self, state):
|
386
|
+
if self._enable_deepep_moe and (self.ep_size > 1):
|
356
387
|
(
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
reorder_topk_ids,
|
361
|
-
num_recv_tokens_per_expert,
|
362
|
-
seg_indptr,
|
363
|
-
masked_m,
|
364
|
-
expected_m,
|
365
|
-
) = self.deepep_dispatcher.
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
388
|
+
state.hidden_states_experts_input,
|
389
|
+
state.topk_idx_dispatched,
|
390
|
+
state.topk_weights_dispatched,
|
391
|
+
state.reorder_topk_ids,
|
392
|
+
state.num_recv_tokens_per_expert,
|
393
|
+
state.seg_indptr,
|
394
|
+
state.masked_m,
|
395
|
+
state.expected_m,
|
396
|
+
) = self.deepep_dispatcher.dispatch_b()
|
397
|
+
|
398
|
+
def op_experts(self, state):
|
399
|
+
if self._enable_deepep_moe:
|
400
|
+
state.pop("router_logits")
|
401
|
+
state.hidden_states_experts_output = self.experts(
|
402
|
+
hidden_states=state.pop("hidden_states_experts_input"),
|
403
|
+
topk_idx=state.topk_idx_dispatched,
|
404
|
+
topk_weights=state.topk_weights_dispatched,
|
405
|
+
reorder_topk_ids=state.pop("reorder_topk_ids"),
|
406
|
+
seg_indptr=state.pop("seg_indptr"),
|
407
|
+
masked_m=state.pop("masked_m"),
|
408
|
+
expected_m=state.pop("expected_m"),
|
409
|
+
num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
|
410
|
+
forward_mode=state.forward_batch.forward_mode,
|
370
411
|
)
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
reorder_topk_ids=reorder_topk_ids,
|
376
|
-
seg_indptr=seg_indptr,
|
377
|
-
masked_m=masked_m,
|
378
|
-
expected_m=expected_m,
|
379
|
-
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
|
380
|
-
forward_mode=forward_mode,
|
381
|
-
)
|
382
|
-
if self.ep_size > 1:
|
383
|
-
final_hidden_states = self.deepep_dispatcher.combine(
|
384
|
-
final_hidden_states,
|
385
|
-
topk_idx,
|
386
|
-
topk_weights,
|
387
|
-
forward_mode,
|
412
|
+
else:
|
413
|
+
state.hidden_states_experts_output = self.experts(
|
414
|
+
hidden_states=state.pop("hidden_states_mlp_input"),
|
415
|
+
router_logits=state.pop("router_logits"),
|
388
416
|
)
|
417
|
+
|
418
|
+
def op_combine_a(self, state):
|
419
|
+
if self._enable_deepep_moe and (self.ep_size > 1):
|
420
|
+
self.deepep_dispatcher.combine_a(
|
421
|
+
state.pop("hidden_states_experts_output"),
|
422
|
+
topk_idx=state.pop("topk_idx_dispatched"),
|
423
|
+
topk_weights=state.pop("topk_weights_dispatched"),
|
424
|
+
forward_mode=state.forward_batch.forward_mode,
|
425
|
+
)
|
426
|
+
|
427
|
+
def op_combine_b(self, state):
|
428
|
+
if self._enable_deepep_moe and (self.ep_size > 1):
|
429
|
+
state.hidden_states_after_combine = self.deepep_dispatcher.combine_b()
|
430
|
+
|
431
|
+
def op_output(self, state):
|
432
|
+
final_hidden_states = (
|
433
|
+
state.pop("hidden_states_after_combine")
|
434
|
+
if self._enable_deepep_moe
|
435
|
+
else state.pop("hidden_states_experts_output")
|
436
|
+
)
|
437
|
+
|
389
438
|
final_hidden_states *= self.routed_scaling_factor
|
390
439
|
|
391
|
-
if shared_output is not None:
|
392
|
-
final_hidden_states = final_hidden_states +
|
440
|
+
if (s := state.pop("shared_output")) is not None:
|
441
|
+
final_hidden_states = final_hidden_states + s
|
393
442
|
|
394
|
-
|
443
|
+
if (not self._enable_deepep_moe) and (self.tp_size > 1):
|
444
|
+
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
395
445
|
|
396
|
-
|
397
|
-
if self.n_share_experts_fusion == 0:
|
398
|
-
return self.shared_experts(hidden_states)
|
399
|
-
else:
|
400
|
-
return None
|
446
|
+
state.hidden_states_mlp_output = final_hidden_states
|
401
447
|
|
402
448
|
|
403
449
|
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
@@ -578,6 +624,18 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
578
624
|
def dispatch_attn_forward_method(
|
579
625
|
self, forward_batch: ForwardBatch
|
580
626
|
) -> AttnForwardMethod:
|
627
|
+
def _dispatch_mla_subtype():
|
628
|
+
if _is_hip:
|
629
|
+
if (
|
630
|
+
self.rocm_fused_decode_mla
|
631
|
+
and forward_batch.forward_mode.is_decode()
|
632
|
+
):
|
633
|
+
return AttnForwardMethod.MLA_FUSED_ROPE
|
634
|
+
else:
|
635
|
+
return AttnForwardMethod.MLA
|
636
|
+
else:
|
637
|
+
return AttnForwardMethod.MLA
|
638
|
+
|
581
639
|
if self.attention_backend == "flashinfer":
|
582
640
|
# Flashinfer MLA: Do not absorb when enabling ragged prefill
|
583
641
|
if (
|
@@ -589,7 +647,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
589
647
|
):
|
590
648
|
return AttnForwardMethod.MHA
|
591
649
|
else:
|
592
|
-
return
|
650
|
+
return _dispatch_mla_subtype()
|
593
651
|
elif self.attention_backend == "fa3":
|
594
652
|
# Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences.
|
595
653
|
if forward_batch.extend_prefix_lens_cpu is not None:
|
@@ -606,7 +664,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
606
664
|
):
|
607
665
|
return AttnForwardMethod.MHA_CHUNKED_KV
|
608
666
|
else:
|
609
|
-
return
|
667
|
+
return _dispatch_mla_subtype()
|
610
668
|
else:
|
611
669
|
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
|
612
670
|
if (
|
@@ -617,7 +675,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
617
675
|
):
|
618
676
|
return AttnForwardMethod.MHA
|
619
677
|
else:
|
620
|
-
return
|
678
|
+
return _dispatch_mla_subtype()
|
621
679
|
|
622
680
|
def forward(
|
623
681
|
self,
|
@@ -640,23 +698,16 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
640
698
|
return self.forward_normal_chunked_kv(
|
641
699
|
positions, hidden_states, forward_batch
|
642
700
|
)
|
701
|
+
elif attn_forward_method == AttnForwardMethod.MLA:
|
702
|
+
return self.forward_absorb(
|
703
|
+
positions, hidden_states, forward_batch, zero_allocator
|
704
|
+
)
|
705
|
+
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
|
706
|
+
return self.forward_absorb_fused_mla_rope(
|
707
|
+
positions, hidden_states, forward_batch
|
708
|
+
)
|
643
709
|
else:
|
644
|
-
|
645
|
-
if (
|
646
|
-
self.rocm_fused_decode_mla
|
647
|
-
and forward_batch.forward_mode.is_decode()
|
648
|
-
):
|
649
|
-
return self.forward_absorb_fused_mla_rope(
|
650
|
-
positions, hidden_states, forward_batch
|
651
|
-
)
|
652
|
-
else:
|
653
|
-
return self.forward_absorb(
|
654
|
-
positions, hidden_states, forward_batch, zero_allocator
|
655
|
-
)
|
656
|
-
else:
|
657
|
-
return self.forward_absorb(
|
658
|
-
positions, hidden_states, forward_batch, zero_allocator
|
659
|
-
)
|
710
|
+
raise NotImplementedError
|
660
711
|
|
661
712
|
def forward_normal(
|
662
713
|
self,
|
@@ -710,6 +761,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
710
761
|
forward_batch: ForwardBatch,
|
711
762
|
zero_allocator: BumpAllocator,
|
712
763
|
) -> torch.Tensor:
|
764
|
+
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
765
|
+
|
713
766
|
if self.q_lora_rank is not None:
|
714
767
|
q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
|
715
768
|
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
|
@@ -717,7 +770,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
717
770
|
k_nope = latent_cache[..., : self.kv_lora_rank]
|
718
771
|
|
719
772
|
# overlap qk norm
|
720
|
-
if self.alt_stream is not None and
|
773
|
+
if self.alt_stream is not None and get_is_capture_mode():
|
721
774
|
current_stream = torch.cuda.current_stream()
|
722
775
|
self.alt_stream.wait_stream(current_stream)
|
723
776
|
q = self.q_a_layernorm(q)
|
@@ -1101,19 +1154,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1101
1154
|
return output
|
1102
1155
|
|
1103
1156
|
|
1104
|
-
class _FFNInputMode(Enum):
|
1105
|
-
# The MLP sublayer requires 1/tp_size tokens as input
|
1106
|
-
SCATTERED = auto()
|
1107
|
-
# The MLP sublayer requires all tokens as input
|
1108
|
-
FULL = auto()
|
1109
|
-
|
1110
|
-
|
1111
|
-
@dataclass
|
1112
|
-
class _DecoderLayerInfo:
|
1113
|
-
is_sparse: bool
|
1114
|
-
ffn_input_mode: _FFNInputMode
|
1115
|
-
|
1116
|
-
|
1117
1157
|
class DeepseekV2DecoderLayer(nn.Module):
|
1118
1158
|
|
1119
1159
|
def __init__(
|
@@ -1127,14 +1167,12 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1127
1167
|
) -> None:
|
1128
1168
|
super().__init__()
|
1129
1169
|
self.hidden_size = config.hidden_size
|
1170
|
+
self.config = config
|
1130
1171
|
rope_theta = getattr(config, "rope_theta", 10000)
|
1131
1172
|
rope_scaling = getattr(config, "rope_scaling", None)
|
1132
1173
|
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
1133
1174
|
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
|
1134
1175
|
self.layer_id = layer_id
|
1135
|
-
self.local_dp_size = get_local_attention_dp_size()
|
1136
|
-
self.attn_tp_size = get_attention_tp_size()
|
1137
|
-
self.attn_tp_rank = get_attention_tp_rank()
|
1138
1176
|
self.self_attn = DeepseekV2AttentionMLA(
|
1139
1177
|
config=config,
|
1140
1178
|
hidden_size=self.hidden_size,
|
@@ -1156,19 +1194,25 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1156
1194
|
alt_stream=alt_stream,
|
1157
1195
|
)
|
1158
1196
|
|
1159
|
-
self.
|
1160
|
-
|
1161
|
-
|
1197
|
+
self.is_layer_sparse = self._is_layer_sparse(layer_id, is_nextn=is_nextn)
|
1198
|
+
is_previous_layer_sparse = self._is_layer_sparse(layer_id - 1, is_nextn=False)
|
1199
|
+
|
1200
|
+
self.layer_scatter_modes = LayerScatterModes.init_new(
|
1201
|
+
layer_id=layer_id,
|
1202
|
+
num_layers=config.num_hidden_layers,
|
1203
|
+
is_layer_sparse=self.is_layer_sparse,
|
1204
|
+
is_previous_layer_sparse=is_previous_layer_sparse,
|
1162
1205
|
)
|
1163
1206
|
|
1164
|
-
if self.
|
1207
|
+
if self.is_layer_sparse:
|
1165
1208
|
self.mlp = DeepseekV2MoE(
|
1166
1209
|
config=config,
|
1167
1210
|
quant_config=quant_config,
|
1168
1211
|
prefix=add_prefix("mlp", prefix),
|
1212
|
+
layer_id=self.layer_id,
|
1169
1213
|
)
|
1170
1214
|
else:
|
1171
|
-
if
|
1215
|
+
if enable_moe_dense_fully_dp():
|
1172
1216
|
mlp_tp_rank, mlp_tp_size = 0, 1
|
1173
1217
|
else:
|
1174
1218
|
mlp_tp_rank, mlp_tp_size = None, None
|
@@ -1182,35 +1226,23 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1182
1226
|
tp_size=mlp_tp_size,
|
1183
1227
|
)
|
1184
1228
|
|
1185
|
-
self.input_is_scattered = (
|
1186
|
-
layer_id > 0
|
1187
|
-
and previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
|
1188
|
-
)
|
1189
|
-
self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
|
1190
|
-
|
1191
1229
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
1192
1230
|
self.post_attention_layernorm = RMSNorm(
|
1193
1231
|
config.hidden_size, eps=config.rms_norm_eps
|
1194
1232
|
)
|
1195
1233
|
|
1196
|
-
|
1197
|
-
|
1198
|
-
|
1199
|
-
|
1200
|
-
@staticmethod
|
1201
|
-
def _compute_info(config: PretrainedConfig, layer_id: int, is_nextn: bool):
|
1202
|
-
is_sparse = is_nextn or (
|
1203
|
-
config.n_routed_experts is not None
|
1204
|
-
and layer_id >= config.first_k_dense_replace
|
1205
|
-
and layer_id % config.moe_layer_freq == 0
|
1234
|
+
self.layer_communicator = LayerCommunicator(
|
1235
|
+
layer_scatter_modes=self.layer_scatter_modes,
|
1236
|
+
input_layernorm=self.input_layernorm,
|
1237
|
+
post_attention_layernorm=self.post_attention_layernorm,
|
1206
1238
|
)
|
1207
|
-
|
1208
|
-
|
1209
|
-
|
1210
|
-
|
1211
|
-
|
1239
|
+
|
1240
|
+
def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool:
|
1241
|
+
return is_nextn or (
|
1242
|
+
self.config.n_routed_experts is not None
|
1243
|
+
and layer_id >= self.config.first_k_dense_replace
|
1244
|
+
and layer_id % self.config.moe_layer_freq == 0
|
1212
1245
|
)
|
1213
|
-
return _DecoderLayerInfo(is_sparse=is_sparse, ffn_input_mode=ffn_input_mode)
|
1214
1246
|
|
1215
1247
|
def forward(
|
1216
1248
|
self,
|
@@ -1220,163 +1252,75 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1220
1252
|
residual: Optional[torch.Tensor],
|
1221
1253
|
zero_allocator: BumpAllocator,
|
1222
1254
|
) -> torch.Tensor:
|
1223
|
-
|
1224
|
-
|
1225
|
-
positions, hidden_states, forward_batch, residual, zero_allocator
|
1226
|
-
)
|
1227
|
-
elif self.info.ffn_input_mode == _FFNInputMode.FULL:
|
1228
|
-
return self.forward_ffn_with_full_input(
|
1229
|
-
positions, hidden_states, forward_batch, residual, zero_allocator
|
1230
|
-
)
|
1231
|
-
else:
|
1232
|
-
raise NotImplementedError
|
1233
|
-
|
1234
|
-
def forward_ffn_with_full_input(
|
1235
|
-
self,
|
1236
|
-
positions: torch.Tensor,
|
1237
|
-
hidden_states: torch.Tensor,
|
1238
|
-
forward_batch: ForwardBatch,
|
1239
|
-
residual: Optional[torch.Tensor],
|
1240
|
-
zero_allocator: BumpAllocator,
|
1241
|
-
) -> torch.Tensor:
|
1242
|
-
|
1243
|
-
if hidden_states.shape[0] == 0:
|
1244
|
-
residual = hidden_states
|
1245
|
-
else:
|
1246
|
-
if residual is None:
|
1247
|
-
residual = hidden_states
|
1248
|
-
hidden_states = self.input_layernorm(hidden_states)
|
1249
|
-
else:
|
1250
|
-
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
1251
|
-
|
1252
|
-
assert not (
|
1253
|
-
self.attn_tp_size != 1 and self.input_is_scattered
|
1254
|
-
), "moe_layer_freq > 1 is not supported when attn_tp_size > 1"
|
1255
|
-
|
1256
|
-
# Self Attention
|
1257
|
-
hidden_states = self.self_attn(
|
1255
|
+
return execute_operations(
|
1256
|
+
inputs=dict(
|
1258
1257
|
positions=positions,
|
1259
1258
|
hidden_states=hidden_states,
|
1260
1259
|
forward_batch=forward_batch,
|
1260
|
+
residual=residual,
|
1261
1261
|
zero_allocator=zero_allocator,
|
1262
|
-
)
|
1263
|
-
|
1264
|
-
|
1265
|
-
if get_tensor_model_parallel_world_size() > 1:
|
1266
|
-
# all gather and all reduce
|
1267
|
-
if self.local_dp_size != 1:
|
1268
|
-
if self.attn_tp_rank == 0:
|
1269
|
-
hidden_states += residual
|
1270
|
-
hidden_states, local_hidden_states = (
|
1271
|
-
forward_batch.gathered_buffer,
|
1272
|
-
hidden_states,
|
1273
|
-
)
|
1274
|
-
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
|
1275
|
-
dp_scatter(residual, hidden_states, forward_batch)
|
1276
|
-
hidden_states = self.post_attention_layernorm(hidden_states)
|
1277
|
-
else:
|
1278
|
-
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
1279
|
-
hidden_states, residual = self.post_attention_layernorm(
|
1280
|
-
hidden_states, residual
|
1281
|
-
)
|
1282
|
-
else:
|
1283
|
-
hidden_states, residual = self.post_attention_layernorm(
|
1284
|
-
hidden_states, residual
|
1285
|
-
)
|
1286
|
-
|
1287
|
-
# Fully Connected
|
1288
|
-
hidden_states = self.mlp(hidden_states)
|
1289
|
-
|
1290
|
-
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
|
1291
|
-
# Scatter
|
1292
|
-
if self.local_dp_size != 1:
|
1293
|
-
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
1294
|
-
# be careful about this!
|
1295
|
-
hidden_states, global_hidden_states = (
|
1296
|
-
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
1297
|
-
hidden_states,
|
1298
|
-
)
|
1299
|
-
dp_scatter(hidden_states, global_hidden_states, forward_batch)
|
1300
|
-
|
1301
|
-
return hidden_states, residual
|
1262
|
+
),
|
1263
|
+
operations=compute_layer_operations(self),
|
1264
|
+
)
|
1302
1265
|
|
1303
|
-
def
|
1266
|
+
def op_comm_prepare_attn(
|
1304
1267
|
self,
|
1268
|
+
state,
|
1305
1269
|
positions: torch.Tensor,
|
1306
1270
|
hidden_states: torch.Tensor,
|
1307
1271
|
forward_batch: ForwardBatch,
|
1308
1272
|
residual: Optional[torch.Tensor],
|
1309
1273
|
zero_allocator: BumpAllocator,
|
1310
|
-
)
|
1311
|
-
|
1312
|
-
|
1313
|
-
|
1314
|
-
|
1315
|
-
|
1316
|
-
|
1317
|
-
|
1318
|
-
|
1319
|
-
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
1320
|
-
|
1321
|
-
if self.attn_tp_size != 1 and self.input_is_scattered:
|
1322
|
-
hidden_states, local_hidden_states = (
|
1323
|
-
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
1324
|
-
hidden_states,
|
1325
|
-
)
|
1326
|
-
attn_tp_all_gather(
|
1327
|
-
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
|
1274
|
+
):
|
1275
|
+
state.hidden_states_after_comm_pre_attn, state.residual_after_input_ln = (
|
1276
|
+
self.layer_communicator.prepare_attn(hidden_states, residual, forward_batch)
|
1277
|
+
)
|
1278
|
+
state.update(
|
1279
|
+
dict(
|
1280
|
+
forward_batch=forward_batch,
|
1281
|
+
positions=positions,
|
1282
|
+
zero_allocator=zero_allocator,
|
1328
1283
|
)
|
1284
|
+
)
|
1329
1285
|
|
1330
|
-
|
1331
|
-
|
1332
|
-
positions=positions,
|
1333
|
-
hidden_states=
|
1334
|
-
forward_batch=forward_batch,
|
1335
|
-
zero_allocator=zero_allocator,
|
1286
|
+
def op_attn(self, state):
|
1287
|
+
state.hidden_states_after_attn = self.self_attn(
|
1288
|
+
positions=state.positions,
|
1289
|
+
hidden_states=state.pop("hidden_states_after_comm_pre_attn"),
|
1290
|
+
forward_batch=state.forward_batch,
|
1291
|
+
zero_allocator=state.zero_allocator,
|
1336
1292
|
)
|
1337
1293
|
|
1338
|
-
|
1339
|
-
|
1340
|
-
|
1341
|
-
|
1342
|
-
|
1343
|
-
|
1344
|
-
|
1345
|
-
|
1346
|
-
)
|
1347
|
-
else:
|
1348
|
-
if self.attn_tp_rank == 0:
|
1349
|
-
hidden_states += residual
|
1350
|
-
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
|
1351
|
-
hidden_states = tensor_list[self.attn_tp_rank]
|
1352
|
-
attn_tp_reduce_scatter(hidden_states, tensor_list)
|
1353
|
-
residual = hidden_states
|
1354
|
-
if hidden_states.shape[0] != 0:
|
1355
|
-
hidden_states = self.post_attention_layernorm(hidden_states)
|
1356
|
-
else:
|
1357
|
-
if hidden_states.shape[0] != 0:
|
1358
|
-
hidden_states, residual = self.post_attention_layernorm(
|
1359
|
-
hidden_states, residual
|
1360
|
-
)
|
1294
|
+
def op_comm_prepare_mlp(self, state):
|
1295
|
+
state.hidden_states_mlp_input, state.residual_after_comm_pre_mlp = (
|
1296
|
+
self.layer_communicator.prepare_mlp(
|
1297
|
+
state.pop("hidden_states_after_attn"),
|
1298
|
+
state.pop("residual_after_input_ln"),
|
1299
|
+
state.forward_batch,
|
1300
|
+
)
|
1301
|
+
)
|
1361
1302
|
|
1303
|
+
def op_mlp(self, state):
|
1304
|
+
hidden_states = state.pop("hidden_states_mlp_input")
|
1362
1305
|
if not (
|
1363
|
-
|
1364
|
-
and (not self.
|
1306
|
+
enable_moe_dense_fully_dp()
|
1307
|
+
and (not self.is_layer_sparse)
|
1365
1308
|
and hidden_states.shape[0] == 0
|
1366
1309
|
):
|
1367
|
-
|
1368
|
-
|
1369
|
-
if self.is_last_layer and self.attn_tp_size != 1:
|
1370
|
-
hidden_states += residual
|
1371
|
-
residual = None
|
1372
|
-
hidden_states, local_hidden_states = (
|
1373
|
-
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
1374
|
-
hidden_states,
|
1375
|
-
)
|
1376
|
-
attn_tp_all_gather(
|
1377
|
-
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
|
1310
|
+
state.hidden_states_mlp_output = self.mlp(
|
1311
|
+
hidden_states, state.forward_batch.forward_mode
|
1378
1312
|
)
|
1313
|
+
else:
|
1314
|
+
state.hidden_states_mlp_output = hidden_states
|
1315
|
+
|
1316
|
+
def op_comm_postprocess_layer(self, state):
|
1317
|
+
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
1318
|
+
state.pop("hidden_states_mlp_output"),
|
1319
|
+
state.pop("residual_after_comm_pre_mlp"),
|
1320
|
+
state.forward_batch,
|
1321
|
+
)
|
1379
1322
|
|
1323
|
+
state.clear(expect_keys={"positions", "forward_batch", "zero_allocator"})
|
1380
1324
|
return hidden_states, residual
|
1381
1325
|
|
1382
1326
|
|
@@ -1398,7 +1342,7 @@ class DeepseekV2Model(nn.Module):
|
|
1398
1342
|
config.hidden_size,
|
1399
1343
|
enable_tp=not global_server_args_dict["enable_dp_attention"],
|
1400
1344
|
)
|
1401
|
-
self.alt_stream = torch.cuda.Stream()
|
1345
|
+
self.alt_stream = torch.cuda.Stream() if _is_cuda else None
|
1402
1346
|
self.layers = nn.ModuleList(
|
1403
1347
|
[
|
1404
1348
|
DeepseekV2DecoderLayer(
|
@@ -1441,11 +1385,11 @@ class DeepseekV2Model(nn.Module):
|
|
1441
1385
|
|
1442
1386
|
residual = None
|
1443
1387
|
for i in range(len(self.layers)):
|
1444
|
-
|
1445
|
-
|
1446
|
-
|
1447
|
-
|
1448
|
-
|
1388
|
+
with get_global_expert_distribution_recorder().with_current_layer(i):
|
1389
|
+
layer = self.layers[i]
|
1390
|
+
hidden_states, residual = layer(
|
1391
|
+
positions, hidden_states, forward_batch, residual, zero_allocator
|
1392
|
+
)
|
1449
1393
|
if not forward_batch.forward_mode.is_idle():
|
1450
1394
|
if residual is None:
|
1451
1395
|
hidden_states = self.norm(hidden_states)
|
@@ -1662,6 +1606,14 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1662
1606
|
self_attn.w_vc = w_vc.contiguous()
|
1663
1607
|
self_attn.use_deep_gemm_bmm = True
|
1664
1608
|
|
1609
|
+
# TODO support nextn later
|
1610
|
+
if not is_nextn:
|
1611
|
+
self.routed_experts_weights_of_layer = {
|
1612
|
+
layer_id: layer.mlp.get_moe_weights()
|
1613
|
+
for layer_id, layer in enumerate(self.model.layers)
|
1614
|
+
if isinstance(layer.mlp, DeepseekV2MoE)
|
1615
|
+
}
|
1616
|
+
|
1665
1617
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
|
1666
1618
|
if is_nextn:
|
1667
1619
|
if hasattr(self.config, "num_nextn_predict_layers"):
|
@@ -1738,12 +1690,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1738
1690
|
|
1739
1691
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
1740
1692
|
# (param_name, weight_name, expert_id, shard_id)
|
1741
|
-
|
1742
|
-
DeepEPMoE
|
1743
|
-
if global_server_args_dict["enable_deepep_moe"]
|
1744
|
-
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
|
1745
|
-
)
|
1746
|
-
expert_params_mapping = MoEImpl.make_expert_params_mapping(
|
1693
|
+
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
|
1747
1694
|
ckpt_gate_proj_name="gate_proj",
|
1748
1695
|
ckpt_down_proj_name="down_proj",
|
1749
1696
|
ckpt_up_proj_name="up_proj",
|
@@ -1859,7 +1806,6 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1859
1806
|
q_a_proj_name in cached_a_proj
|
1860
1807
|
and kv_a_proj_name in cached_a_proj
|
1861
1808
|
):
|
1862
|
-
|
1863
1809
|
q_a_proj_weight = cached_a_proj[q_a_proj_name]
|
1864
1810
|
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
|
1865
1811
|
fused_weight = torch.cat(
|
@@ -1897,6 +1843,14 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1897
1843
|
torch.cuda.empty_cache()
|
1898
1844
|
torch.cuda.synchronize()
|
1899
1845
|
|
1846
|
+
@classmethod
|
1847
|
+
def get_model_config_for_expert_location(cls, config):
|
1848
|
+
return ModelConfigForExpertLocation(
|
1849
|
+
num_layers=config.num_hidden_layers,
|
1850
|
+
num_logical_experts=config.n_routed_experts,
|
1851
|
+
num_groups=config.n_group,
|
1852
|
+
)
|
1853
|
+
|
1900
1854
|
|
1901
1855
|
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
|
1902
1856
|
pass
|