sglang 0.4.5.post1__py3-none-any.whl → 0.4.5.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -4
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +3 -6
- sglang/compile_deep_gemm.py +136 -0
- sglang/lang/backend/anthropic.py +0 -4
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/openai.py +6 -2
- sglang/lang/backend/runtime_endpoint.py +5 -1
- sglang/lang/backend/vertexai.py +0 -1
- sglang/lang/compiler.py +1 -7
- sglang/lang/tracer.py +3 -7
- sglang/srt/_custom_ops.py +0 -2
- sglang/srt/configs/model_config.py +4 -1
- sglang/srt/constrained/outlines_jump_forward.py +14 -1
- sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
- sglang/srt/constrained/xgrammar_backend.py +27 -4
- sglang/srt/custom_op.py +0 -62
- sglang/srt/disaggregation/decode.py +105 -6
- sglang/srt/disaggregation/mini_lb.py +74 -9
- sglang/srt/disaggregation/mooncake/conn.py +33 -63
- sglang/srt/disaggregation/mooncake/transfer_engine.py +30 -61
- sglang/srt/disaggregation/nixl/__init__.py +1 -0
- sglang/srt/disaggregation/nixl/conn.py +622 -0
- sglang/srt/disaggregation/prefill.py +137 -17
- sglang/srt/disaggregation/utils.py +32 -0
- sglang/srt/entrypoints/engine.py +4 -0
- sglang/srt/entrypoints/http_server.py +3 -7
- sglang/srt/entrypoints/verl_engine.py +7 -5
- sglang/srt/function_call_parser.py +60 -0
- sglang/srt/layers/activation.py +6 -8
- sglang/srt/layers/attention/flashattention_backend.py +883 -209
- sglang/srt/layers/attention/flashinfer_backend.py +5 -2
- sglang/srt/layers/attention/torch_native_backend.py +6 -1
- sglang/srt/layers/attention/triton_backend.py +6 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
- sglang/srt/layers/attention/triton_ops/extend_attention.py +18 -7
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
- sglang/srt/layers/dp_attention.py +1 -1
- sglang/srt/layers/layernorm.py +20 -5
- sglang/srt/layers/linear.py +17 -3
- sglang/srt/layers/moe/ep_moe/layer.py +17 -29
- sglang/srt/layers/moe/fused_moe_native.py +4 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +14 -19
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/topk.py +27 -30
- sglang/srt/layers/parameter.py +0 -2
- sglang/srt/layers/quantization/__init__.py +1 -0
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +9 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +16 -44
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
- sglang/srt/layers/quantization/deep_gemm.py +378 -0
- sglang/srt/layers/quantization/fp8.py +115 -132
- sglang/srt/layers/quantization/fp8_kernel.py +213 -88
- sglang/srt/layers/quantization/fp8_utils.py +189 -264
- sglang/srt/layers/quantization/gptq.py +13 -7
- sglang/srt/layers/quantization/modelopt_quant.py +2 -2
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/utils.py +5 -11
- sglang/srt/layers/quantization/w8a8_fp8.py +2 -0
- sglang/srt/layers/quantization/w8a8_int8.py +7 -7
- sglang/srt/layers/radix_attention.py +15 -0
- sglang/srt/layers/rotary_embedding.py +9 -8
- sglang/srt/layers/sampler.py +7 -12
- sglang/srt/lora/backend/base_backend.py +18 -2
- sglang/srt/lora/backend/flashinfer_backend.py +1 -1
- sglang/srt/lora/backend/triton_backend.py +1 -1
- sglang/srt/lora/layers.py +1 -1
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/lora/lora_manager.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +7 -1
- sglang/srt/managers/detokenizer_manager.py +0 -1
- sglang/srt/managers/io_struct.py +15 -3
- sglang/srt/managers/mm_utils.py +4 -3
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +3 -2
- sglang/srt/managers/schedule_batch.py +15 -4
- sglang/srt/managers/scheduler.py +28 -77
- sglang/srt/managers/tokenizer_manager.py +116 -29
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +41 -29
- sglang/srt/mem_cache/memory_pool.py +38 -15
- sglang/srt/model_executor/cuda_graph_runner.py +15 -10
- sglang/srt/model_executor/model_runner.py +39 -31
- sglang/srt/models/bert.py +398 -0
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_nextn.py +74 -70
- sglang/srt/models/deepseek_v2.py +292 -348
- sglang/srt/models/llama.py +5 -5
- sglang/srt/models/minicpm3.py +31 -203
- sglang/srt/models/minicpmo.py +17 -6
- sglang/srt/models/qwen2.py +4 -1
- sglang/srt/models/qwen2_moe.py +14 -13
- sglang/srt/models/qwen3.py +335 -0
- sglang/srt/models/qwen3_moe.py +423 -0
- sglang/srt/openai_api/adapter.py +71 -4
- sglang/srt/openai_api/protocol.py +6 -1
- sglang/srt/reasoning_parser.py +0 -1
- sglang/srt/sampling/sampling_batch_info.py +2 -3
- sglang/srt/server_args.py +86 -72
- sglang/srt/speculative/build_eagle_tree.py +2 -2
- sglang/srt/speculative/eagle_utils.py +2 -2
- sglang/srt/speculative/eagle_worker.py +6 -14
- sglang/srt/utils.py +62 -6
- sglang/test/runners.py +5 -1
- sglang/test/test_block_fp8.py +167 -0
- sglang/test/test_custom_ops.py +1 -1
- sglang/test/test_utils.py +3 -1
- sglang/version.py +1 -1
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/METADATA +5 -5
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/RECORD +116 -110
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/WHEEL +1 -1
- sglang/lang/__init__.py +0 -0
- sglang/srt/lora/backend/__init__.py +0 -25
- sglang/srt/server.py +0 -18
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/top_level.txt +0 -0
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -18,7 +18,8 @@
|
|
18
18
|
|
19
19
|
import logging
|
20
20
|
import os
|
21
|
-
from
|
21
|
+
from dataclasses import dataclass
|
22
|
+
from enum import Enum, IntEnum, auto
|
22
23
|
from typing import Any, Dict, Iterable, Optional, Tuple
|
23
24
|
|
24
25
|
import torch
|
@@ -28,6 +29,7 @@ from tqdm import tqdm
|
|
28
29
|
from transformers import PretrainedConfig
|
29
30
|
|
30
31
|
from sglang.srt.distributed import (
|
32
|
+
get_tensor_model_parallel_rank,
|
31
33
|
get_tensor_model_parallel_world_size,
|
32
34
|
parallel_state,
|
33
35
|
tensor_model_parallel_all_reduce,
|
@@ -51,10 +53,15 @@ from sglang.srt.layers.linear import (
|
|
51
53
|
)
|
52
54
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
53
55
|
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE
|
56
|
+
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
54
57
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
55
58
|
from sglang.srt.layers.moe.topk import select_experts
|
56
59
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
57
|
-
from sglang.srt.layers.quantization.
|
60
|
+
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
61
|
+
from sglang.srt.layers.quantization.fp8_kernel import (
|
62
|
+
per_tensor_quant_mla_deep_gemm_masked_fp8,
|
63
|
+
per_tensor_quant_mla_fp8,
|
64
|
+
)
|
58
65
|
from sglang.srt.layers.quantization.fp8_utils import (
|
59
66
|
block_quant_to_tensor_quant,
|
60
67
|
channel_quant_to_tensor_quant,
|
@@ -73,7 +80,7 @@ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
|
73
80
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
74
81
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
75
82
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
76
|
-
from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_hip
|
83
|
+
from sglang.srt.utils import BumpAllocator, DeepEPMode, add_prefix, is_cuda, is_hip
|
77
84
|
|
78
85
|
_is_hip = is_hip()
|
79
86
|
_is_cuda = is_cuda()
|
@@ -81,9 +88,11 @@ _is_cuda = is_cuda()
|
|
81
88
|
if _is_cuda:
|
82
89
|
from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
|
83
90
|
|
84
|
-
from sglang.srt.layers.
|
91
|
+
from sglang.srt.layers.quantization.deep_gemm import (
|
92
|
+
grouped_gemm_nt_f8f8bf16_masked as deep_gemm_grouped_gemm_nt_f8f8bf16_masked,
|
93
|
+
)
|
85
94
|
else:
|
86
|
-
from vllm import
|
95
|
+
from vllm._custom_ops import awq_dequantize
|
87
96
|
|
88
97
|
if _is_hip:
|
89
98
|
from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
|
@@ -96,7 +105,6 @@ logger = logging.getLogger(__name__)
|
|
96
105
|
|
97
106
|
|
98
107
|
class AttnForwardMethod(IntEnum):
|
99
|
-
|
100
108
|
# Use multi-head attention
|
101
109
|
MHA = auto()
|
102
110
|
|
@@ -147,7 +155,7 @@ class DeepseekV2MLP(nn.Module):
|
|
147
155
|
)
|
148
156
|
self.act_fn = SiluAndMul()
|
149
157
|
|
150
|
-
def forward(self, x):
|
158
|
+
def forward(self, x, forward_mode: Optional[ForwardMode] = None):
|
151
159
|
gate_up, _ = self.gate_up_proj(x)
|
152
160
|
x = self.act_fn(gate_up)
|
153
161
|
x, _ = self.down_proj(x)
|
@@ -188,11 +196,7 @@ class DeepseekV2MoE(nn.Module):
|
|
188
196
|
self.tp_size = get_tensor_model_parallel_world_size()
|
189
197
|
self.routed_scaling_factor = config.routed_scaling_factor
|
190
198
|
self.n_shared_experts = config.n_shared_experts
|
191
|
-
self.n_share_experts_fusion =
|
192
|
-
global_server_args_dict["n_share_experts_fusion"]
|
193
|
-
if global_server_args_dict["n_share_experts_fusion"] is not None
|
194
|
-
else 0
|
195
|
-
)
|
199
|
+
self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
|
196
200
|
|
197
201
|
if self.tp_size > config.n_routed_experts:
|
198
202
|
raise ValueError(
|
@@ -225,6 +229,7 @@ class DeepseekV2MoE(nn.Module):
|
|
225
229
|
num_expert_group=config.n_group,
|
226
230
|
topk_group=config.topk_group,
|
227
231
|
correction_bias=self.gate.e_score_correction_bias,
|
232
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
228
233
|
prefix=add_prefix("experts", prefix),
|
229
234
|
**(
|
230
235
|
dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
|
@@ -333,6 +338,7 @@ class DeepseekV2MoE(nn.Module):
|
|
333
338
|
topk_group=self.topk_group,
|
334
339
|
num_expert_group=self.num_expert_group,
|
335
340
|
correction_bias=self.correction_bias,
|
341
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
336
342
|
)
|
337
343
|
if self.ep_size > 1:
|
338
344
|
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
|
@@ -373,7 +379,7 @@ class DeepseekV2MoE(nn.Module):
|
|
373
379
|
return final_hidden_states
|
374
380
|
|
375
381
|
def _forward_shared_experts(self, hidden_states):
|
376
|
-
if self.
|
382
|
+
if self.n_share_experts_fusion == 0:
|
377
383
|
return self.shared_experts(hidden_states)
|
378
384
|
else:
|
379
385
|
return None
|
@@ -387,179 +393,6 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
|
387
393
|
return 0.1 * mscale * math.log(scale) + 1.0
|
388
394
|
|
389
395
|
|
390
|
-
class DeepseekV2Attention(nn.Module):
|
391
|
-
|
392
|
-
def __init__(
|
393
|
-
self,
|
394
|
-
config: PretrainedConfig,
|
395
|
-
hidden_size: int,
|
396
|
-
num_heads: int,
|
397
|
-
qk_nope_head_dim: int,
|
398
|
-
qk_rope_head_dim: int,
|
399
|
-
v_head_dim: int,
|
400
|
-
q_lora_rank: int,
|
401
|
-
kv_lora_rank: int,
|
402
|
-
rope_theta: float = 10000,
|
403
|
-
rope_scaling: Optional[Dict[str, Any]] = None,
|
404
|
-
max_position_embeddings: int = 8192,
|
405
|
-
quant_config: Optional[QuantizationConfig] = None,
|
406
|
-
layer_id=None,
|
407
|
-
reduce_results: bool = True,
|
408
|
-
prefix: str = "",
|
409
|
-
) -> None:
|
410
|
-
super().__init__()
|
411
|
-
self.layer_id = layer_id
|
412
|
-
self.hidden_size = hidden_size
|
413
|
-
self.qk_nope_head_dim = qk_nope_head_dim
|
414
|
-
self.qk_rope_head_dim = qk_rope_head_dim
|
415
|
-
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
416
|
-
self.v_head_dim = v_head_dim
|
417
|
-
self.q_lora_rank = q_lora_rank
|
418
|
-
self.kv_lora_rank = kv_lora_rank
|
419
|
-
|
420
|
-
self.dp_size = get_attention_dp_size()
|
421
|
-
attn_tp_rank = get_attention_tp_rank()
|
422
|
-
attn_tp_size = get_attention_tp_size()
|
423
|
-
|
424
|
-
self.num_heads = num_heads
|
425
|
-
assert num_heads % attn_tp_size == 0
|
426
|
-
self.num_local_heads = num_heads // attn_tp_size
|
427
|
-
self.scaling = self.qk_head_dim**-0.5
|
428
|
-
self.rope_theta = rope_theta
|
429
|
-
self.max_position_embeddings = max_position_embeddings
|
430
|
-
|
431
|
-
if self.q_lora_rank is not None:
|
432
|
-
self.q_a_proj = ReplicatedLinear(
|
433
|
-
self.hidden_size,
|
434
|
-
self.q_lora_rank,
|
435
|
-
bias=False,
|
436
|
-
quant_config=quant_config,
|
437
|
-
prefix=add_prefix("q_a_proj", prefix),
|
438
|
-
)
|
439
|
-
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
440
|
-
self.q_b_proj = ColumnParallelLinear(
|
441
|
-
q_lora_rank,
|
442
|
-
self.num_heads * self.qk_head_dim,
|
443
|
-
bias=False,
|
444
|
-
quant_config=quant_config,
|
445
|
-
prefix=add_prefix("q_b_proj", prefix),
|
446
|
-
)
|
447
|
-
else:
|
448
|
-
self.q_proj = ColumnParallelLinear(
|
449
|
-
self.hidden_size,
|
450
|
-
self.num_heads * self.qk_head_dim,
|
451
|
-
bias=False,
|
452
|
-
quant_config=quant_config,
|
453
|
-
prefix=add_prefix("q_proj", prefix),
|
454
|
-
tp_rank=attn_tp_rank,
|
455
|
-
tp_size=attn_tp_size,
|
456
|
-
)
|
457
|
-
|
458
|
-
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
459
|
-
self.hidden_size,
|
460
|
-
self.kv_lora_rank + self.qk_rope_head_dim,
|
461
|
-
bias=False,
|
462
|
-
quant_config=quant_config,
|
463
|
-
prefix=add_prefix("kv_a_proj_with_mqa", prefix),
|
464
|
-
)
|
465
|
-
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
466
|
-
self.kv_b_proj = ColumnParallelLinear(
|
467
|
-
self.kv_lora_rank,
|
468
|
-
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
469
|
-
bias=False,
|
470
|
-
quant_config=quant_config,
|
471
|
-
prefix=add_prefix("kv_b_proj", prefix),
|
472
|
-
)
|
473
|
-
# O projection.
|
474
|
-
self.o_proj = RowParallelLinear(
|
475
|
-
self.num_heads * self.v_head_dim,
|
476
|
-
self.hidden_size,
|
477
|
-
bias=False,
|
478
|
-
quant_config=quant_config,
|
479
|
-
prefix=add_prefix("o_proj", prefix),
|
480
|
-
reduce_results=reduce_results,
|
481
|
-
tp_rank=attn_tp_rank,
|
482
|
-
tp_size=attn_tp_size,
|
483
|
-
)
|
484
|
-
rope_scaling["rope_type"] = "deepseek_yarn"
|
485
|
-
self.rotary_emb = get_rope_wrapper(
|
486
|
-
qk_rope_head_dim,
|
487
|
-
rotary_dim=qk_rope_head_dim,
|
488
|
-
max_position=max_position_embeddings,
|
489
|
-
base=rope_theta,
|
490
|
-
rope_scaling=rope_scaling,
|
491
|
-
is_neox_style=False,
|
492
|
-
device=global_server_args_dict["device"],
|
493
|
-
)
|
494
|
-
|
495
|
-
if rope_scaling:
|
496
|
-
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
|
497
|
-
scaling_factor = rope_scaling["factor"]
|
498
|
-
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
499
|
-
self.scaling = self.scaling * mscale * mscale
|
500
|
-
|
501
|
-
# TODO, support head_size 192
|
502
|
-
self.attn = RadixAttention(
|
503
|
-
self.num_local_heads,
|
504
|
-
256,
|
505
|
-
self.scaling,
|
506
|
-
num_kv_heads=self.num_local_heads,
|
507
|
-
layer_id=layer_id,
|
508
|
-
quant_config=quant_config,
|
509
|
-
prefix=add_prefix("attn", prefix),
|
510
|
-
)
|
511
|
-
|
512
|
-
def forward(
|
513
|
-
self,
|
514
|
-
positions: torch.Tensor,
|
515
|
-
hidden_states: torch.Tensor,
|
516
|
-
forward_batch: ForwardBatch,
|
517
|
-
) -> torch.Tensor:
|
518
|
-
if hidden_states.shape[0] == 0:
|
519
|
-
assert (
|
520
|
-
not self.o_proj.reduce_results
|
521
|
-
), "short-circuiting allreduce will lead to hangs"
|
522
|
-
return hidden_states
|
523
|
-
|
524
|
-
if self.q_lora_rank is not None:
|
525
|
-
q = self.q_a_proj(hidden_states)[0]
|
526
|
-
q = self.q_a_layernorm(q)
|
527
|
-
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
528
|
-
else:
|
529
|
-
q = self.q_proj(hidden_states)[0].view(
|
530
|
-
-1, self.num_local_heads, self.qk_head_dim
|
531
|
-
)
|
532
|
-
_, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
533
|
-
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
534
|
-
kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
535
|
-
latent_cache = latent_cache.unsqueeze(1)
|
536
|
-
kv_a = self.kv_a_layernorm(kv_a.contiguous())
|
537
|
-
kv = self.kv_b_proj(kv_a)[0]
|
538
|
-
kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
|
539
|
-
k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
540
|
-
k_pe = latent_cache[:, :, self.kv_lora_rank :]
|
541
|
-
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
542
|
-
q[..., self.qk_nope_head_dim :] = q_pe
|
543
|
-
k = torch.empty_like(q)
|
544
|
-
k[..., : self.qk_nope_head_dim] = k_nope
|
545
|
-
k[..., self.qk_nope_head_dim :] = k_pe
|
546
|
-
q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim], value=0).view(
|
547
|
-
-1, self.num_local_heads * 256
|
548
|
-
)
|
549
|
-
k = torch.nn.functional.pad(k, [0, 256 - self.qk_head_dim], value=0).view(
|
550
|
-
-1, self.num_local_heads * 256
|
551
|
-
)
|
552
|
-
v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim], value=0).view(
|
553
|
-
-1, self.num_local_heads * 256
|
554
|
-
)
|
555
|
-
attn_output = self.attn(q, k, v, forward_batch)
|
556
|
-
attn_output = attn_output.view(-1, self.num_local_heads, 256)[
|
557
|
-
..., : self.v_head_dim
|
558
|
-
].reshape(-1, self.num_local_heads * self.v_head_dim)
|
559
|
-
output, _ = self.o_proj(attn_output)
|
560
|
-
return output
|
561
|
-
|
562
|
-
|
563
396
|
class DeepseekV2AttentionMLA(nn.Module):
|
564
397
|
|
565
398
|
def __init__(
|
@@ -705,6 +538,10 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
705
538
|
self.w_vc = None
|
706
539
|
self.w_scale = None
|
707
540
|
|
541
|
+
self.w_scale_k = None
|
542
|
+
self.w_scale_v = None
|
543
|
+
self.use_deep_gemm_bmm = False
|
544
|
+
|
708
545
|
self.flashinfer_mla_disable_ragged = global_server_args_dict[
|
709
546
|
"flashinfer_mla_disable_ragged"
|
710
547
|
]
|
@@ -762,6 +599,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
762
599
|
positions: torch.Tensor,
|
763
600
|
hidden_states: torch.Tensor,
|
764
601
|
forward_batch: ForwardBatch,
|
602
|
+
zero_allocator: BumpAllocator,
|
765
603
|
) -> torch.Tensor:
|
766
604
|
if hidden_states.shape[0] == 0:
|
767
605
|
assert (
|
@@ -787,9 +625,13 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
787
625
|
positions, hidden_states, forward_batch
|
788
626
|
)
|
789
627
|
else:
|
790
|
-
return self.forward_absorb(
|
628
|
+
return self.forward_absorb(
|
629
|
+
positions, hidden_states, forward_batch, zero_allocator
|
630
|
+
)
|
791
631
|
else:
|
792
|
-
return self.forward_absorb(
|
632
|
+
return self.forward_absorb(
|
633
|
+
positions, hidden_states, forward_batch, zero_allocator
|
634
|
+
)
|
793
635
|
|
794
636
|
def forward_normal(
|
795
637
|
self,
|
@@ -838,6 +680,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
838
680
|
positions: torch.Tensor,
|
839
681
|
hidden_states: torch.Tensor,
|
840
682
|
forward_batch: ForwardBatch,
|
683
|
+
zero_allocator: BumpAllocator,
|
841
684
|
) -> torch.Tensor:
|
842
685
|
q_len = hidden_states.shape[0]
|
843
686
|
q_input = hidden_states.new_empty(
|
@@ -853,7 +696,24 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
853
696
|
)
|
854
697
|
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
855
698
|
|
856
|
-
if self.
|
699
|
+
if self.use_deep_gemm_bmm:
|
700
|
+
q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
|
701
|
+
per_tensor_quant_mla_deep_gemm_masked_fp8(
|
702
|
+
q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn
|
703
|
+
)
|
704
|
+
)
|
705
|
+
q_nope_out = q_nope.new_empty(
|
706
|
+
(self.num_local_heads, aligned_m, self.kv_lora_rank)
|
707
|
+
)
|
708
|
+
deep_gemm_grouped_gemm_nt_f8f8bf16_masked(
|
709
|
+
(q_nope_val, q_nope_scale),
|
710
|
+
(self.w_kc, self.w_scale_k),
|
711
|
+
q_nope_out,
|
712
|
+
masked_m,
|
713
|
+
expected_m,
|
714
|
+
)
|
715
|
+
q_nope_out = q_nope_out[:, :expected_m, :]
|
716
|
+
elif self.w_kc.dtype == torch.float8_e4m3fnuz:
|
857
717
|
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
|
858
718
|
q_nope_out = torch.bmm(
|
859
719
|
q_nope.to(torch.bfloat16).transpose(0, 1),
|
@@ -861,7 +721,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
861
721
|
)
|
862
722
|
elif self.w_kc.dtype == torch.float8_e4m3fn:
|
863
723
|
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
|
864
|
-
q_nope.transpose(0, 1),
|
724
|
+
q_nope.transpose(0, 1),
|
725
|
+
zero_allocator.allocate(1),
|
865
726
|
)
|
866
727
|
q_nope_out = bmm_fp8(
|
867
728
|
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
|
@@ -884,7 +745,24 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
884
745
|
attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
|
885
746
|
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
886
747
|
|
887
|
-
if self.
|
748
|
+
if self.use_deep_gemm_bmm:
|
749
|
+
attn_output_val, attn_output_scale, masked_m, expected_m, aligned_m = (
|
750
|
+
per_tensor_quant_mla_deep_gemm_masked_fp8(
|
751
|
+
attn_output.transpose(0, 1), dtype=torch.float8_e4m3fn
|
752
|
+
)
|
753
|
+
)
|
754
|
+
attn_bmm_output = attn_output.new_empty(
|
755
|
+
(self.num_local_heads, aligned_m, self.v_head_dim)
|
756
|
+
)
|
757
|
+
deep_gemm_grouped_gemm_nt_f8f8bf16_masked(
|
758
|
+
(attn_output_val, attn_output_scale),
|
759
|
+
(self.w_vc, self.w_scale_v),
|
760
|
+
attn_bmm_output,
|
761
|
+
masked_m,
|
762
|
+
expected_m,
|
763
|
+
)
|
764
|
+
attn_bmm_output = attn_bmm_output[:, :expected_m, :]
|
765
|
+
elif self.w_vc.dtype == torch.float8_e4m3fnuz:
|
888
766
|
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
|
889
767
|
attn_bmm_output = torch.bmm(
|
890
768
|
attn_output.to(torch.bfloat16).transpose(0, 1),
|
@@ -892,7 +770,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
892
770
|
)
|
893
771
|
elif self.w_vc.dtype == torch.float8_e4m3fn:
|
894
772
|
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
|
895
|
-
attn_output.transpose(0, 1),
|
773
|
+
attn_output.transpose(0, 1),
|
774
|
+
zero_allocator.allocate(1),
|
896
775
|
)
|
897
776
|
attn_bmm_output = bmm_fp8(
|
898
777
|
attn_output_val,
|
@@ -913,6 +792,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
913
792
|
positions: torch.Tensor,
|
914
793
|
hidden_states: torch.Tensor,
|
915
794
|
forward_batch: ForwardBatch,
|
795
|
+
zero_allocator: BumpAllocator,
|
916
796
|
) -> torch.Tensor:
|
917
797
|
enable_rope_fusion = (
|
918
798
|
os.getenv("SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION", "1") == "1"
|
@@ -939,7 +819,9 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
939
819
|
)
|
940
820
|
elif self.w_kc.dtype == torch.float8_e4m3fn:
|
941
821
|
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
|
942
|
-
q_nope.transpose(0, 1),
|
822
|
+
q_nope.transpose(0, 1),
|
823
|
+
zero_allocator.allocate(1),
|
824
|
+
dtype=torch.float8_e4m3fn,
|
943
825
|
)
|
944
826
|
q_nope_out = bmm_fp8(
|
945
827
|
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
|
@@ -1035,7 +917,9 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1035
917
|
)
|
1036
918
|
elif self.w_vc.dtype == torch.float8_e4m3fn:
|
1037
919
|
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
|
1038
|
-
attn_output.transpose(0, 1),
|
920
|
+
attn_output.transpose(0, 1),
|
921
|
+
zero_allocator.allocate(1),
|
922
|
+
dtype=torch.float8_e4m3fn,
|
1039
923
|
)
|
1040
924
|
attn_bmm_output = bmm_fp8(
|
1041
925
|
attn_output_val,
|
@@ -1173,6 +1057,19 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1173
1057
|
return output
|
1174
1058
|
|
1175
1059
|
|
1060
|
+
class _FFNInputMode(Enum):
|
1061
|
+
# The MLP sublayer requires 1/tp_size tokens as input
|
1062
|
+
SCATTERED = auto()
|
1063
|
+
# The MLP sublayer requires all tokens as input
|
1064
|
+
FULL = auto()
|
1065
|
+
|
1066
|
+
|
1067
|
+
@dataclass
|
1068
|
+
class _DecoderLayerInfo:
|
1069
|
+
is_sparse: bool
|
1070
|
+
ffn_input_mode: _FFNInputMode
|
1071
|
+
|
1072
|
+
|
1176
1073
|
class DeepseekV2DecoderLayer(nn.Module):
|
1177
1074
|
|
1178
1075
|
def __init__(
|
@@ -1183,14 +1080,6 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1183
1080
|
is_nextn: bool = False,
|
1184
1081
|
prefix: str = "",
|
1185
1082
|
) -> None:
|
1186
|
-
|
1187
|
-
def is_sparse_layer(l: int):
|
1188
|
-
return (
|
1189
|
-
config.n_routed_experts is not None
|
1190
|
-
and l >= config.first_k_dense_replace
|
1191
|
-
and l % config.moe_layer_freq == 0
|
1192
|
-
)
|
1193
|
-
|
1194
1083
|
super().__init__()
|
1195
1084
|
self.hidden_size = config.hidden_size
|
1196
1085
|
rope_theta = getattr(config, "rope_theta", 10000)
|
@@ -1201,68 +1090,54 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1201
1090
|
self.dp_size = get_attention_dp_size()
|
1202
1091
|
self.attn_tp_size = get_attention_tp_size()
|
1203
1092
|
self.attn_tp_rank = get_attention_tp_rank()
|
1093
|
+
self.self_attn = DeepseekV2AttentionMLA(
|
1094
|
+
config=config,
|
1095
|
+
hidden_size=self.hidden_size,
|
1096
|
+
num_heads=config.num_attention_heads,
|
1097
|
+
qk_nope_head_dim=config.qk_nope_head_dim,
|
1098
|
+
qk_rope_head_dim=config.qk_rope_head_dim,
|
1099
|
+
v_head_dim=config.v_head_dim,
|
1100
|
+
q_lora_rank=(
|
1101
|
+
config.q_lora_rank if hasattr(config, "q_lora_rank") else None
|
1102
|
+
),
|
1103
|
+
kv_lora_rank=config.kv_lora_rank,
|
1104
|
+
rope_theta=rope_theta,
|
1105
|
+
rope_scaling=rope_scaling,
|
1106
|
+
max_position_embeddings=max_position_embeddings,
|
1107
|
+
quant_config=quant_config,
|
1108
|
+
layer_id=layer_id,
|
1109
|
+
reduce_results=False,
|
1110
|
+
prefix=add_prefix("self_attn", prefix),
|
1111
|
+
)
|
1204
1112
|
|
1205
|
-
|
1206
|
-
|
1207
|
-
|
1208
|
-
|
1209
|
-
num_heads=config.num_attention_heads,
|
1210
|
-
qk_nope_head_dim=config.qk_nope_head_dim,
|
1211
|
-
qk_rope_head_dim=config.qk_rope_head_dim,
|
1212
|
-
v_head_dim=config.v_head_dim,
|
1213
|
-
q_lora_rank=(
|
1214
|
-
config.q_lora_rank if hasattr(config, "q_lora_rank") else None
|
1215
|
-
),
|
1216
|
-
kv_lora_rank=config.kv_lora_rank,
|
1217
|
-
rope_theta=rope_theta,
|
1218
|
-
rope_scaling=rope_scaling,
|
1219
|
-
max_position_embeddings=max_position_embeddings,
|
1220
|
-
quant_config=quant_config,
|
1221
|
-
layer_id=layer_id,
|
1222
|
-
reduce_results=False,
|
1223
|
-
prefix=add_prefix("self_attn", prefix),
|
1224
|
-
)
|
1225
|
-
else:
|
1226
|
-
self.self_attn = DeepseekV2Attention(
|
1227
|
-
config=config,
|
1228
|
-
hidden_size=self.hidden_size,
|
1229
|
-
num_heads=config.num_attention_heads,
|
1230
|
-
qk_nope_head_dim=config.qk_nope_head_dim,
|
1231
|
-
qk_rope_head_dim=config.qk_rope_head_dim,
|
1232
|
-
v_head_dim=config.v_head_dim,
|
1233
|
-
q_lora_rank=(
|
1234
|
-
config.q_lora_rank if hasattr(config, "q_lora_rank") else None
|
1235
|
-
),
|
1236
|
-
kv_lora_rank=config.kv_lora_rank,
|
1237
|
-
rope_theta=rope_theta,
|
1238
|
-
rope_scaling=rope_scaling,
|
1239
|
-
max_position_embeddings=max_position_embeddings,
|
1240
|
-
quant_config=quant_config,
|
1241
|
-
layer_id=layer_id,
|
1242
|
-
reduce_results=False,
|
1243
|
-
prefix=add_prefix("self_attn", prefix),
|
1244
|
-
)
|
1113
|
+
self.info = self._compute_info(config, layer_id=layer_id, is_nextn=is_nextn)
|
1114
|
+
previous_layer_info = self._compute_info(
|
1115
|
+
config, layer_id=layer_id - 1, is_nextn=False
|
1116
|
+
)
|
1245
1117
|
|
1246
|
-
if
|
1118
|
+
if self.info.is_sparse:
|
1247
1119
|
self.mlp = DeepseekV2MoE(
|
1248
1120
|
config=config,
|
1249
1121
|
quant_config=quant_config,
|
1250
1122
|
prefix=add_prefix("mlp", prefix),
|
1251
1123
|
)
|
1252
|
-
self.is_sparse = True
|
1253
1124
|
else:
|
1125
|
+
if self._enable_moe_dense_fully_dp():
|
1126
|
+
mlp_tp_rank, mlp_tp_size = 0, 1
|
1127
|
+
else:
|
1128
|
+
mlp_tp_rank, mlp_tp_size = None, None
|
1254
1129
|
self.mlp = DeepseekV2MLP(
|
1255
1130
|
hidden_size=config.hidden_size,
|
1256
1131
|
intermediate_size=config.intermediate_size,
|
1257
1132
|
hidden_act=config.hidden_act,
|
1258
1133
|
quant_config=quant_config,
|
1259
1134
|
prefix=add_prefix("mlp", prefix),
|
1135
|
+
tp_rank=mlp_tp_rank,
|
1136
|
+
tp_size=mlp_tp_size,
|
1260
1137
|
)
|
1261
|
-
self.is_sparse = False
|
1262
1138
|
|
1263
1139
|
self.input_is_scattered = (
|
1264
|
-
|
1265
|
-
and global_server_args_dict["enable_deepep_moe"]
|
1140
|
+
previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
|
1266
1141
|
)
|
1267
1142
|
self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
|
1268
1143
|
|
@@ -1271,28 +1146,51 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1271
1146
|
config.hidden_size, eps=config.rms_norm_eps
|
1272
1147
|
)
|
1273
1148
|
|
1149
|
+
@staticmethod
|
1150
|
+
def _enable_moe_dense_fully_dp():
|
1151
|
+
return global_server_args_dict["moe_dense_tp_size"] == 1
|
1152
|
+
|
1153
|
+
@staticmethod
|
1154
|
+
def _compute_info(config: PretrainedConfig, layer_id: int, is_nextn: bool):
|
1155
|
+
is_sparse = is_nextn or (
|
1156
|
+
config.n_routed_experts is not None
|
1157
|
+
and layer_id >= config.first_k_dense_replace
|
1158
|
+
and layer_id % config.moe_layer_freq == 0
|
1159
|
+
)
|
1160
|
+
ffn_input_mode = (
|
1161
|
+
_FFNInputMode.SCATTERED
|
1162
|
+
if (global_server_args_dict["enable_deepep_moe"] and is_sparse)
|
1163
|
+
or (DeepseekV2DecoderLayer._enable_moe_dense_fully_dp() and not is_sparse)
|
1164
|
+
else _FFNInputMode.FULL
|
1165
|
+
)
|
1166
|
+
return _DecoderLayerInfo(is_sparse=is_sparse, ffn_input_mode=ffn_input_mode)
|
1167
|
+
|
1274
1168
|
def forward(
|
1275
1169
|
self,
|
1276
1170
|
positions: torch.Tensor,
|
1277
1171
|
hidden_states: torch.Tensor,
|
1278
1172
|
forward_batch: ForwardBatch,
|
1279
1173
|
residual: Optional[torch.Tensor],
|
1174
|
+
zero_allocator: BumpAllocator,
|
1280
1175
|
) -> torch.Tensor:
|
1281
|
-
if
|
1282
|
-
return self.
|
1283
|
-
positions, hidden_states, forward_batch, residual
|
1176
|
+
if self.info.ffn_input_mode == _FFNInputMode.SCATTERED:
|
1177
|
+
return self.forward_ffn_with_scattered_input(
|
1178
|
+
positions, hidden_states, forward_batch, residual, zero_allocator
|
1284
1179
|
)
|
1285
|
-
|
1286
|
-
return self.
|
1287
|
-
positions, hidden_states, forward_batch, residual
|
1180
|
+
elif self.info.ffn_input_mode == _FFNInputMode.FULL:
|
1181
|
+
return self.forward_ffn_with_full_input(
|
1182
|
+
positions, hidden_states, forward_batch, residual, zero_allocator
|
1288
1183
|
)
|
1184
|
+
else:
|
1185
|
+
raise NotImplementedError
|
1289
1186
|
|
1290
|
-
def
|
1187
|
+
def forward_ffn_with_full_input(
|
1291
1188
|
self,
|
1292
1189
|
positions: torch.Tensor,
|
1293
1190
|
hidden_states: torch.Tensor,
|
1294
1191
|
forward_batch: ForwardBatch,
|
1295
1192
|
residual: Optional[torch.Tensor],
|
1193
|
+
zero_allocator: BumpAllocator,
|
1296
1194
|
) -> torch.Tensor:
|
1297
1195
|
|
1298
1196
|
if hidden_states.shape[0] == 0:
|
@@ -1313,6 +1211,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1313
1211
|
positions=positions,
|
1314
1212
|
hidden_states=hidden_states,
|
1315
1213
|
forward_batch=forward_batch,
|
1214
|
+
zero_allocator=zero_allocator,
|
1316
1215
|
)
|
1317
1216
|
|
1318
1217
|
# Gather
|
@@ -1354,12 +1253,13 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1354
1253
|
|
1355
1254
|
return hidden_states, residual
|
1356
1255
|
|
1357
|
-
def
|
1256
|
+
def forward_ffn_with_scattered_input(
|
1358
1257
|
self,
|
1359
1258
|
positions: torch.Tensor,
|
1360
1259
|
hidden_states: torch.Tensor,
|
1361
1260
|
forward_batch: ForwardBatch,
|
1362
1261
|
residual: Optional[torch.Tensor],
|
1262
|
+
zero_allocator: BumpAllocator,
|
1363
1263
|
) -> torch.Tensor:
|
1364
1264
|
|
1365
1265
|
if hidden_states.shape[0] == 0:
|
@@ -1385,6 +1285,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1385
1285
|
positions=positions,
|
1386
1286
|
hidden_states=hidden_states,
|
1387
1287
|
forward_batch=forward_batch,
|
1288
|
+
zero_allocator=zero_allocator,
|
1388
1289
|
)
|
1389
1290
|
|
1390
1291
|
if self.attn_tp_size != 1:
|
@@ -1410,7 +1311,13 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1410
1311
|
hidden_states, residual = self.post_attention_layernorm(
|
1411
1312
|
hidden_states, residual
|
1412
1313
|
)
|
1413
|
-
|
1314
|
+
|
1315
|
+
if not (
|
1316
|
+
self._enable_moe_dense_fully_dp()
|
1317
|
+
and (not self.info.is_sparse)
|
1318
|
+
and hidden_states.shape[0] == 0
|
1319
|
+
):
|
1320
|
+
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
|
1414
1321
|
|
1415
1322
|
if self.is_last_layer and self.attn_tp_size != 1:
|
1416
1323
|
hidden_states += residual
|
@@ -1466,6 +1373,14 @@ class DeepseekV2Model(nn.Module):
|
|
1466
1373
|
forward_batch: ForwardBatch,
|
1467
1374
|
input_embeds: torch.Tensor = None,
|
1468
1375
|
) -> torch.Tensor:
|
1376
|
+
zero_allocator = BumpAllocator(
|
1377
|
+
# TODO for two-batch-overlap, we need a larger buffer size
|
1378
|
+
buffer_size=len(self.layers) * 2,
|
1379
|
+
dtype=torch.float32,
|
1380
|
+
device=(
|
1381
|
+
input_embeds.device if input_embeds is not None else input_ids.device
|
1382
|
+
),
|
1383
|
+
)
|
1469
1384
|
|
1470
1385
|
if input_embeds is None:
|
1471
1386
|
hidden_states = self.embed_tokens(input_ids)
|
@@ -1477,7 +1392,7 @@ class DeepseekV2Model(nn.Module):
|
|
1477
1392
|
expert_distribution_recorder.set_current_layer(i)
|
1478
1393
|
layer = self.layers[i]
|
1479
1394
|
hidden_states, residual = layer(
|
1480
|
-
positions, hidden_states, forward_batch, residual
|
1395
|
+
positions, hidden_states, forward_batch, residual, zero_allocator
|
1481
1396
|
)
|
1482
1397
|
if not forward_batch.forward_mode.is_idle():
|
1483
1398
|
if residual is None:
|
@@ -1500,24 +1415,33 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1500
1415
|
self.tp_size = get_tensor_model_parallel_world_size()
|
1501
1416
|
self.quant_config = quant_config
|
1502
1417
|
self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
|
1503
|
-
|
1504
|
-
|
1505
|
-
|
1506
|
-
|
1507
|
-
|
1508
|
-
|
1509
|
-
|
1510
|
-
|
1511
|
-
|
1512
|
-
|
1513
|
-
|
1514
|
-
|
1515
|
-
|
1516
|
-
|
1517
|
-
|
1518
|
-
|
1519
|
-
|
1520
|
-
|
1418
|
+
if self.n_share_experts_fusion > 0:
|
1419
|
+
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
|
1420
|
+
if (
|
1421
|
+
self.config.architectures[0] != "DeepseekV3ForCausalLM"
|
1422
|
+
or self.config.n_routed_experts != 256
|
1423
|
+
):
|
1424
|
+
self.n_share_experts_fusion = 0
|
1425
|
+
global_server_args_dict["n_share_experts_fusion"] = 0
|
1426
|
+
logger.info(
|
1427
|
+
"Only Deepseek V3/R1 can use shared experts fusion optimization. Shared experts fusion optimization is disabled."
|
1428
|
+
)
|
1429
|
+
else:
|
1430
|
+
assert (
|
1431
|
+
self.n_share_experts_fusion == self.tp_size
|
1432
|
+
), f"Shared experts fusion optimization is enabled in DeepSeek V3/R1, set it to {self.tp_size} can get best optimized performace."
|
1433
|
+
elif self.n_share_experts_fusion == 0:
|
1434
|
+
if (
|
1435
|
+
torch.cuda.get_device_capability("cuda") >= (9, 0)
|
1436
|
+
and self.config.architectures[0] == "DeepseekV3ForCausalLM"
|
1437
|
+
and self.config.n_routed_experts == 256
|
1438
|
+
and (not global_server_args_dict["enable_deepep_moe"])
|
1439
|
+
):
|
1440
|
+
self.n_share_experts_fusion = self.tp_size
|
1441
|
+
global_server_args_dict["n_share_experts_fusion"] = self.tp_size
|
1442
|
+
logger.info(
|
1443
|
+
"Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled."
|
1444
|
+
)
|
1521
1445
|
|
1522
1446
|
self.model = DeepseekV2Model(
|
1523
1447
|
config, quant_config, prefix=add_prefix("model", prefix)
|
@@ -1552,78 +1476,92 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1552
1476
|
def post_load_weights(self):
|
1553
1477
|
|
1554
1478
|
# Perform post-processing after loading weights
|
1555
|
-
|
1556
|
-
|
1557
|
-
|
1558
|
-
|
1559
|
-
if
|
1560
|
-
|
1561
|
-
|
1562
|
-
|
1563
|
-
|
1564
|
-
|
1565
|
-
self_attn.kv_b_proj.qzeros,
|
1566
|
-
).T
|
1567
|
-
else:
|
1568
|
-
w = ops.awq_dequantize(
|
1569
|
-
self_attn.kv_b_proj.qweight,
|
1570
|
-
self_attn.kv_b_proj.scales,
|
1571
|
-
self_attn.kv_b_proj.qzeros,
|
1572
|
-
0,
|
1573
|
-
0,
|
1574
|
-
0,
|
1575
|
-
).T
|
1479
|
+
for layer_id in range(self.config.num_hidden_layers):
|
1480
|
+
self_attn = self.model.layers[layer_id].self_attn
|
1481
|
+
if hasattr(self_attn.kv_b_proj, "qweight"):
|
1482
|
+
# AWQ compatible
|
1483
|
+
if _is_cuda:
|
1484
|
+
w = awq_dequantize(
|
1485
|
+
self_attn.kv_b_proj.qweight,
|
1486
|
+
self_attn.kv_b_proj.scales,
|
1487
|
+
self_attn.kv_b_proj.qzeros,
|
1488
|
+
).T
|
1576
1489
|
else:
|
1577
|
-
w =
|
1578
|
-
|
1579
|
-
|
1580
|
-
|
1581
|
-
|
1582
|
-
|
1583
|
-
|
1584
|
-
|
1585
|
-
|
1586
|
-
|
1587
|
-
|
1588
|
-
|
1589
|
-
|
1590
|
-
|
1591
|
-
|
1592
|
-
|
1593
|
-
|
1594
|
-
|
1595
|
-
|
1596
|
-
|
1490
|
+
w = awq_dequantize(
|
1491
|
+
self_attn.kv_b_proj.qweight,
|
1492
|
+
self_attn.kv_b_proj.scales,
|
1493
|
+
self_attn.kv_b_proj.qzeros,
|
1494
|
+
0,
|
1495
|
+
0,
|
1496
|
+
0,
|
1497
|
+
).T
|
1498
|
+
else:
|
1499
|
+
w = self_attn.kv_b_proj.weight
|
1500
|
+
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
|
1501
|
+
# This may affect the accuracy of fp8 model.
|
1502
|
+
# Fix deepseek v3 blockwise bmm by using deep_gemm
|
1503
|
+
use_deep_gemm_bmm = False
|
1504
|
+
model_dtype = torch.get_default_dtype()
|
1505
|
+
|
1506
|
+
if w.dtype in (
|
1507
|
+
torch.float8_e4m3fn,
|
1508
|
+
torch.float8_e4m3fnuz,
|
1509
|
+
):
|
1510
|
+
if hasattr(self.quant_config, "weight_block_size"):
|
1511
|
+
weight_block_size = self.quant_config.weight_block_size
|
1512
|
+
if weight_block_size is not None:
|
1513
|
+
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
1514
|
+
if _is_hip:
|
1515
|
+
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
1516
|
+
weight=w,
|
1517
|
+
weight_scale=self_attn.kv_b_proj.weight_scale_inv,
|
1518
|
+
input_scale=None,
|
1519
|
+
)
|
1520
|
+
else:
|
1521
|
+
weight = w
|
1522
|
+
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
1597
1523
|
|
1524
|
+
if (
|
1525
|
+
_is_cuda
|
1526
|
+
and _ENABLE_JIT_DEEPGEMM
|
1527
|
+
and weight_block_size[0] == 128
|
1528
|
+
and weight_block_size[1] == 128
|
1529
|
+
and model_dtype == torch.bfloat16
|
1530
|
+
):
|
1531
|
+
block_scale = weight_scale
|
1532
|
+
use_deep_gemm_bmm = True
|
1533
|
+
else:
|
1598
1534
|
w, scale = block_quant_to_tensor_quant(
|
1599
1535
|
weight, weight_scale, weight_block_size
|
1600
1536
|
)
|
1601
1537
|
self_attn.w_scale = scale
|
1602
|
-
|
1538
|
+
else:
|
1539
|
+
weight = w
|
1540
|
+
weight_scale = self_attn.kv_b_proj.weight_scale
|
1541
|
+
w, scale = channel_quant_to_tensor_quant(weight, weight_scale)
|
1542
|
+
self_attn.w_scale = scale
|
1543
|
+
|
1544
|
+
if w.dtype == torch.int8:
|
1545
|
+
if hasattr(self.quant_config, "weight_block_size"):
|
1546
|
+
# block-wise int8 need it
|
1547
|
+
weight_block_size = self.quant_config.weight_block_size
|
1548
|
+
if weight_block_size is not None:
|
1549
|
+
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
1603
1550
|
weight = w
|
1604
|
-
weight_scale = self_attn.kv_b_proj.
|
1605
|
-
w
|
1606
|
-
|
1607
|
-
|
1608
|
-
|
1609
|
-
|
1610
|
-
|
1611
|
-
|
1612
|
-
|
1613
|
-
|
1614
|
-
|
1615
|
-
|
1616
|
-
|
1617
|
-
|
1618
|
-
).to(torch.bfloat16)
|
1619
|
-
else:
|
1620
|
-
# channel-wise int8 need it
|
1621
|
-
w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
|
1622
|
-
torch.bfloat16
|
1623
|
-
)
|
1624
|
-
w_kc, w_vc = w.unflatten(
|
1625
|
-
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
1626
|
-
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
1551
|
+
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
1552
|
+
w = int8_block_dequant(
|
1553
|
+
weight, weight_scale, weight_block_size
|
1554
|
+
).to(torch.bfloat16)
|
1555
|
+
else:
|
1556
|
+
# channel-wise int8 need it
|
1557
|
+
w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
|
1558
|
+
torch.bfloat16
|
1559
|
+
)
|
1560
|
+
|
1561
|
+
w_kc, w_vc = w.unflatten(
|
1562
|
+
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
1563
|
+
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
1564
|
+
if not use_deep_gemm_bmm:
|
1627
1565
|
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
1628
1566
|
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
|
1629
1567
|
if (
|
@@ -1633,6 +1571,17 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1633
1571
|
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
|
1634
1572
|
if _is_hip:
|
1635
1573
|
self_attn.w_scale *= 2.0
|
1574
|
+
else:
|
1575
|
+
num_tiles_k = self_attn.qk_nope_head_dim // weight_block_size[1]
|
1576
|
+
num_tiles_n = self_attn.v_head_dim // weight_block_size[0]
|
1577
|
+
ws_kc, ws_vc = block_scale.unflatten(
|
1578
|
+
0, (-1, (num_tiles_k + num_tiles_n))
|
1579
|
+
).split([num_tiles_k, num_tiles_n], dim=1)
|
1580
|
+
self_attn.w_scale_k = ws_kc.transpose(1, 2).contiguous()
|
1581
|
+
self_attn.w_scale_v = ws_vc.contiguous()
|
1582
|
+
self_attn.w_kc = w_kc.transpose(1, 2).contiguous()
|
1583
|
+
self_attn.w_vc = w_vc.contiguous()
|
1584
|
+
self_attn.use_deep_gemm_bmm = True
|
1636
1585
|
|
1637
1586
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
1638
1587
|
stacked_params_mapping = [
|
@@ -1640,7 +1589,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1640
1589
|
("gate_up_proj", "gate_proj", 0),
|
1641
1590
|
("gate_up_proj", "up_proj", 1),
|
1642
1591
|
]
|
1643
|
-
if self.n_share_experts_fusion
|
1592
|
+
if self.n_share_experts_fusion > 0:
|
1644
1593
|
weights_list = list(weights)
|
1645
1594
|
weights_dict = dict(weights_list)
|
1646
1595
|
if self.quant_config.get_name() == "w8a8_int8":
|
@@ -1682,7 +1631,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1682
1631
|
f"mlp.experts."
|
1683
1632
|
f"{self.config.n_routed_experts + num_repeat}"
|
1684
1633
|
f".{suffix}",
|
1685
|
-
weights_dict[shared_expert_weight_name]
|
1634
|
+
weights_dict[shared_expert_weight_name],
|
1686
1635
|
)
|
1687
1636
|
)
|
1688
1637
|
names_to_remove += [shared_expert_weight_name]
|
@@ -1699,12 +1648,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1699
1648
|
ckpt_gate_proj_name="gate_proj",
|
1700
1649
|
ckpt_down_proj_name="down_proj",
|
1701
1650
|
ckpt_up_proj_name="up_proj",
|
1702
|
-
num_experts=self.config.n_routed_experts
|
1703
|
-
+ (
|
1704
|
-
self.n_share_experts_fusion
|
1705
|
-
if self.n_share_experts_fusion is not None
|
1706
|
-
else 0
|
1707
|
-
),
|
1651
|
+
num_experts=self.config.n_routed_experts + self.n_share_experts_fusion,
|
1708
1652
|
)
|
1709
1653
|
|
1710
1654
|
params_dict = dict(self.named_parameters())
|