sglang 0.4.5.post1__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -4
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +0 -4
- sglang/lang/backend/anthropic.py +0 -4
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/openai.py +1 -1
- sglang/lang/backend/vertexai.py +0 -1
- sglang/lang/compiler.py +1 -7
- sglang/lang/tracer.py +3 -7
- sglang/srt/_custom_ops.py +0 -2
- sglang/srt/constrained/outlines_jump_forward.py +14 -1
- sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
- sglang/srt/constrained/xgrammar_backend.py +26 -4
- sglang/srt/custom_op.py +0 -62
- sglang/srt/disaggregation/decode.py +62 -6
- sglang/srt/disaggregation/mini_lb.py +5 -1
- sglang/srt/disaggregation/mooncake/conn.py +32 -62
- sglang/srt/disaggregation/mooncake/transfer_engine.py +30 -61
- sglang/srt/disaggregation/prefill.py +40 -4
- sglang/srt/disaggregation/utils.py +15 -0
- sglang/srt/entrypoints/verl_engine.py +7 -5
- sglang/srt/layers/activation.py +6 -8
- sglang/srt/layers/attention/flashattention_backend.py +114 -71
- sglang/srt/layers/attention/flashinfer_backend.py +5 -2
- sglang/srt/layers/attention/torch_native_backend.py +6 -1
- sglang/srt/layers/attention/triton_backend.py +6 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +17 -3
- sglang/srt/layers/moe/ep_moe/layer.py +15 -29
- sglang/srt/layers/moe/fused_moe_native.py +4 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +14 -19
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/topk.py +27 -30
- sglang/srt/layers/parameter.py +0 -2
- sglang/srt/layers/quantization/__init__.py +1 -0
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +8 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +16 -44
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
- sglang/srt/layers/quantization/fp8.py +115 -132
- sglang/srt/layers/quantization/fp8_kernel.py +213 -57
- sglang/srt/layers/quantization/fp8_utils.py +187 -262
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/utils.py +5 -11
- sglang/srt/layers/quantization/w8a8_fp8.py +2 -0
- sglang/srt/layers/quantization/w8a8_int8.py +7 -7
- sglang/srt/layers/radix_attention.py +15 -0
- sglang/srt/layers/rotary_embedding.py +3 -2
- sglang/srt/layers/sampler.py +5 -10
- sglang/srt/lora/backend/base_backend.py +18 -2
- sglang/srt/lora/backend/flashinfer_backend.py +1 -1
- sglang/srt/lora/backend/triton_backend.py +1 -1
- sglang/srt/lora/layers.py +1 -1
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/lora/lora_manager.py +1 -1
- sglang/srt/managers/detokenizer_manager.py +0 -1
- sglang/srt/managers/io_struct.py +1 -0
- sglang/srt/managers/mm_utils.py +4 -3
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +3 -2
- sglang/srt/managers/schedule_batch.py +2 -4
- sglang/srt/managers/scheduler.py +12 -71
- sglang/srt/managers/tokenizer_manager.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +5 -1
- sglang/srt/mem_cache/memory_pool.py +7 -2
- sglang/srt/model_executor/cuda_graph_runner.py +2 -2
- sglang/srt/model_executor/model_runner.py +20 -27
- sglang/srt/models/bert.py +398 -0
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_nextn.py +74 -70
- sglang/srt/models/deepseek_v2.py +289 -348
- sglang/srt/models/llama.py +5 -5
- sglang/srt/models/minicpm3.py +29 -201
- sglang/srt/models/qwen2.py +4 -1
- sglang/srt/models/qwen2_moe.py +14 -13
- sglang/srt/models/qwen3.py +335 -0
- sglang/srt/models/qwen3_moe.py +423 -0
- sglang/srt/reasoning_parser.py +0 -1
- sglang/srt/sampling/sampling_batch_info.py +2 -3
- sglang/srt/server_args.py +34 -32
- sglang/srt/speculative/eagle_worker.py +4 -7
- sglang/srt/utils.py +16 -1
- sglang/test/runners.py +5 -1
- sglang/test/test_block_fp8.py +167 -0
- sglang/test/test_custom_ops.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +3 -3
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +92 -91
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
- sglang/lang/__init__.py +0 -0
- sglang/srt/lora/backend/__init__.py +0 -25
- sglang/srt/server.py +0 -18
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -18,7 +18,8 @@
|
|
18
18
|
|
19
19
|
import logging
|
20
20
|
import os
|
21
|
-
from
|
21
|
+
from dataclasses import dataclass
|
22
|
+
from enum import Enum, IntEnum, auto
|
22
23
|
from typing import Any, Dict, Iterable, Optional, Tuple
|
23
24
|
|
24
25
|
import torch
|
@@ -28,6 +29,7 @@ from tqdm import tqdm
|
|
28
29
|
from transformers import PretrainedConfig
|
29
30
|
|
30
31
|
from sglang.srt.distributed import (
|
32
|
+
get_tensor_model_parallel_rank,
|
31
33
|
get_tensor_model_parallel_world_size,
|
32
34
|
parallel_state,
|
33
35
|
tensor_model_parallel_all_reduce,
|
@@ -51,10 +53,15 @@ from sglang.srt.layers.linear import (
|
|
51
53
|
)
|
52
54
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
53
55
|
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE
|
56
|
+
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
54
57
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
55
58
|
from sglang.srt.layers.moe.topk import select_experts
|
56
59
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
57
|
-
from sglang.srt.layers.quantization.fp8_kernel import
|
60
|
+
from sglang.srt.layers.quantization.fp8_kernel import (
|
61
|
+
_enable_jit_deepgemm_bmm,
|
62
|
+
per_tensor_quant_mla_deep_gemm_masked_fp8,
|
63
|
+
per_tensor_quant_mla_fp8,
|
64
|
+
)
|
58
65
|
from sglang.srt.layers.quantization.fp8_utils import (
|
59
66
|
block_quant_to_tensor_quant,
|
60
67
|
channel_quant_to_tensor_quant,
|
@@ -73,17 +80,16 @@ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
|
73
80
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
74
81
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
75
82
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
76
|
-
from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_hip
|
83
|
+
from sglang.srt.utils import BumpAllocator, DeepEPMode, add_prefix, is_cuda, is_hip
|
77
84
|
|
78
85
|
_is_hip = is_hip()
|
79
86
|
_is_cuda = is_cuda()
|
80
87
|
|
81
88
|
if _is_cuda:
|
89
|
+
from deep_gemm import m_grouped_gemm_fp8_fp8_bf16_nt_masked
|
82
90
|
from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
|
83
|
-
|
84
|
-
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
85
91
|
else:
|
86
|
-
from vllm import
|
92
|
+
from vllm._custom_ops import awq_dequantize
|
87
93
|
|
88
94
|
if _is_hip:
|
89
95
|
from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
|
@@ -96,7 +102,6 @@ logger = logging.getLogger(__name__)
|
|
96
102
|
|
97
103
|
|
98
104
|
class AttnForwardMethod(IntEnum):
|
99
|
-
|
100
105
|
# Use multi-head attention
|
101
106
|
MHA = auto()
|
102
107
|
|
@@ -147,7 +152,7 @@ class DeepseekV2MLP(nn.Module):
|
|
147
152
|
)
|
148
153
|
self.act_fn = SiluAndMul()
|
149
154
|
|
150
|
-
def forward(self, x):
|
155
|
+
def forward(self, x, forward_mode: Optional[ForwardMode] = None):
|
151
156
|
gate_up, _ = self.gate_up_proj(x)
|
152
157
|
x = self.act_fn(gate_up)
|
153
158
|
x, _ = self.down_proj(x)
|
@@ -188,11 +193,7 @@ class DeepseekV2MoE(nn.Module):
|
|
188
193
|
self.tp_size = get_tensor_model_parallel_world_size()
|
189
194
|
self.routed_scaling_factor = config.routed_scaling_factor
|
190
195
|
self.n_shared_experts = config.n_shared_experts
|
191
|
-
self.n_share_experts_fusion =
|
192
|
-
global_server_args_dict["n_share_experts_fusion"]
|
193
|
-
if global_server_args_dict["n_share_experts_fusion"] is not None
|
194
|
-
else 0
|
195
|
-
)
|
196
|
+
self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
|
196
197
|
|
197
198
|
if self.tp_size > config.n_routed_experts:
|
198
199
|
raise ValueError(
|
@@ -225,6 +226,7 @@ class DeepseekV2MoE(nn.Module):
|
|
225
226
|
num_expert_group=config.n_group,
|
226
227
|
topk_group=config.topk_group,
|
227
228
|
correction_bias=self.gate.e_score_correction_bias,
|
229
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
228
230
|
prefix=add_prefix("experts", prefix),
|
229
231
|
**(
|
230
232
|
dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
|
@@ -333,6 +335,7 @@ class DeepseekV2MoE(nn.Module):
|
|
333
335
|
topk_group=self.topk_group,
|
334
336
|
num_expert_group=self.num_expert_group,
|
335
337
|
correction_bias=self.correction_bias,
|
338
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
336
339
|
)
|
337
340
|
if self.ep_size > 1:
|
338
341
|
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
|
@@ -373,7 +376,7 @@ class DeepseekV2MoE(nn.Module):
|
|
373
376
|
return final_hidden_states
|
374
377
|
|
375
378
|
def _forward_shared_experts(self, hidden_states):
|
376
|
-
if self.
|
379
|
+
if self.n_share_experts_fusion == 0:
|
377
380
|
return self.shared_experts(hidden_states)
|
378
381
|
else:
|
379
382
|
return None
|
@@ -387,179 +390,6 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
|
387
390
|
return 0.1 * mscale * math.log(scale) + 1.0
|
388
391
|
|
389
392
|
|
390
|
-
class DeepseekV2Attention(nn.Module):
|
391
|
-
|
392
|
-
def __init__(
|
393
|
-
self,
|
394
|
-
config: PretrainedConfig,
|
395
|
-
hidden_size: int,
|
396
|
-
num_heads: int,
|
397
|
-
qk_nope_head_dim: int,
|
398
|
-
qk_rope_head_dim: int,
|
399
|
-
v_head_dim: int,
|
400
|
-
q_lora_rank: int,
|
401
|
-
kv_lora_rank: int,
|
402
|
-
rope_theta: float = 10000,
|
403
|
-
rope_scaling: Optional[Dict[str, Any]] = None,
|
404
|
-
max_position_embeddings: int = 8192,
|
405
|
-
quant_config: Optional[QuantizationConfig] = None,
|
406
|
-
layer_id=None,
|
407
|
-
reduce_results: bool = True,
|
408
|
-
prefix: str = "",
|
409
|
-
) -> None:
|
410
|
-
super().__init__()
|
411
|
-
self.layer_id = layer_id
|
412
|
-
self.hidden_size = hidden_size
|
413
|
-
self.qk_nope_head_dim = qk_nope_head_dim
|
414
|
-
self.qk_rope_head_dim = qk_rope_head_dim
|
415
|
-
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
416
|
-
self.v_head_dim = v_head_dim
|
417
|
-
self.q_lora_rank = q_lora_rank
|
418
|
-
self.kv_lora_rank = kv_lora_rank
|
419
|
-
|
420
|
-
self.dp_size = get_attention_dp_size()
|
421
|
-
attn_tp_rank = get_attention_tp_rank()
|
422
|
-
attn_tp_size = get_attention_tp_size()
|
423
|
-
|
424
|
-
self.num_heads = num_heads
|
425
|
-
assert num_heads % attn_tp_size == 0
|
426
|
-
self.num_local_heads = num_heads // attn_tp_size
|
427
|
-
self.scaling = self.qk_head_dim**-0.5
|
428
|
-
self.rope_theta = rope_theta
|
429
|
-
self.max_position_embeddings = max_position_embeddings
|
430
|
-
|
431
|
-
if self.q_lora_rank is not None:
|
432
|
-
self.q_a_proj = ReplicatedLinear(
|
433
|
-
self.hidden_size,
|
434
|
-
self.q_lora_rank,
|
435
|
-
bias=False,
|
436
|
-
quant_config=quant_config,
|
437
|
-
prefix=add_prefix("q_a_proj", prefix),
|
438
|
-
)
|
439
|
-
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
440
|
-
self.q_b_proj = ColumnParallelLinear(
|
441
|
-
q_lora_rank,
|
442
|
-
self.num_heads * self.qk_head_dim,
|
443
|
-
bias=False,
|
444
|
-
quant_config=quant_config,
|
445
|
-
prefix=add_prefix("q_b_proj", prefix),
|
446
|
-
)
|
447
|
-
else:
|
448
|
-
self.q_proj = ColumnParallelLinear(
|
449
|
-
self.hidden_size,
|
450
|
-
self.num_heads * self.qk_head_dim,
|
451
|
-
bias=False,
|
452
|
-
quant_config=quant_config,
|
453
|
-
prefix=add_prefix("q_proj", prefix),
|
454
|
-
tp_rank=attn_tp_rank,
|
455
|
-
tp_size=attn_tp_size,
|
456
|
-
)
|
457
|
-
|
458
|
-
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
459
|
-
self.hidden_size,
|
460
|
-
self.kv_lora_rank + self.qk_rope_head_dim,
|
461
|
-
bias=False,
|
462
|
-
quant_config=quant_config,
|
463
|
-
prefix=add_prefix("kv_a_proj_with_mqa", prefix),
|
464
|
-
)
|
465
|
-
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
466
|
-
self.kv_b_proj = ColumnParallelLinear(
|
467
|
-
self.kv_lora_rank,
|
468
|
-
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
469
|
-
bias=False,
|
470
|
-
quant_config=quant_config,
|
471
|
-
prefix=add_prefix("kv_b_proj", prefix),
|
472
|
-
)
|
473
|
-
# O projection.
|
474
|
-
self.o_proj = RowParallelLinear(
|
475
|
-
self.num_heads * self.v_head_dim,
|
476
|
-
self.hidden_size,
|
477
|
-
bias=False,
|
478
|
-
quant_config=quant_config,
|
479
|
-
prefix=add_prefix("o_proj", prefix),
|
480
|
-
reduce_results=reduce_results,
|
481
|
-
tp_rank=attn_tp_rank,
|
482
|
-
tp_size=attn_tp_size,
|
483
|
-
)
|
484
|
-
rope_scaling["rope_type"] = "deepseek_yarn"
|
485
|
-
self.rotary_emb = get_rope_wrapper(
|
486
|
-
qk_rope_head_dim,
|
487
|
-
rotary_dim=qk_rope_head_dim,
|
488
|
-
max_position=max_position_embeddings,
|
489
|
-
base=rope_theta,
|
490
|
-
rope_scaling=rope_scaling,
|
491
|
-
is_neox_style=False,
|
492
|
-
device=global_server_args_dict["device"],
|
493
|
-
)
|
494
|
-
|
495
|
-
if rope_scaling:
|
496
|
-
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
|
497
|
-
scaling_factor = rope_scaling["factor"]
|
498
|
-
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
499
|
-
self.scaling = self.scaling * mscale * mscale
|
500
|
-
|
501
|
-
# TODO, support head_size 192
|
502
|
-
self.attn = RadixAttention(
|
503
|
-
self.num_local_heads,
|
504
|
-
256,
|
505
|
-
self.scaling,
|
506
|
-
num_kv_heads=self.num_local_heads,
|
507
|
-
layer_id=layer_id,
|
508
|
-
quant_config=quant_config,
|
509
|
-
prefix=add_prefix("attn", prefix),
|
510
|
-
)
|
511
|
-
|
512
|
-
def forward(
|
513
|
-
self,
|
514
|
-
positions: torch.Tensor,
|
515
|
-
hidden_states: torch.Tensor,
|
516
|
-
forward_batch: ForwardBatch,
|
517
|
-
) -> torch.Tensor:
|
518
|
-
if hidden_states.shape[0] == 0:
|
519
|
-
assert (
|
520
|
-
not self.o_proj.reduce_results
|
521
|
-
), "short-circuiting allreduce will lead to hangs"
|
522
|
-
return hidden_states
|
523
|
-
|
524
|
-
if self.q_lora_rank is not None:
|
525
|
-
q = self.q_a_proj(hidden_states)[0]
|
526
|
-
q = self.q_a_layernorm(q)
|
527
|
-
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
528
|
-
else:
|
529
|
-
q = self.q_proj(hidden_states)[0].view(
|
530
|
-
-1, self.num_local_heads, self.qk_head_dim
|
531
|
-
)
|
532
|
-
_, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
533
|
-
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
534
|
-
kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
535
|
-
latent_cache = latent_cache.unsqueeze(1)
|
536
|
-
kv_a = self.kv_a_layernorm(kv_a.contiguous())
|
537
|
-
kv = self.kv_b_proj(kv_a)[0]
|
538
|
-
kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
|
539
|
-
k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
540
|
-
k_pe = latent_cache[:, :, self.kv_lora_rank :]
|
541
|
-
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
542
|
-
q[..., self.qk_nope_head_dim :] = q_pe
|
543
|
-
k = torch.empty_like(q)
|
544
|
-
k[..., : self.qk_nope_head_dim] = k_nope
|
545
|
-
k[..., self.qk_nope_head_dim :] = k_pe
|
546
|
-
q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim], value=0).view(
|
547
|
-
-1, self.num_local_heads * 256
|
548
|
-
)
|
549
|
-
k = torch.nn.functional.pad(k, [0, 256 - self.qk_head_dim], value=0).view(
|
550
|
-
-1, self.num_local_heads * 256
|
551
|
-
)
|
552
|
-
v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim], value=0).view(
|
553
|
-
-1, self.num_local_heads * 256
|
554
|
-
)
|
555
|
-
attn_output = self.attn(q, k, v, forward_batch)
|
556
|
-
attn_output = attn_output.view(-1, self.num_local_heads, 256)[
|
557
|
-
..., : self.v_head_dim
|
558
|
-
].reshape(-1, self.num_local_heads * self.v_head_dim)
|
559
|
-
output, _ = self.o_proj(attn_output)
|
560
|
-
return output
|
561
|
-
|
562
|
-
|
563
393
|
class DeepseekV2AttentionMLA(nn.Module):
|
564
394
|
|
565
395
|
def __init__(
|
@@ -705,6 +535,10 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
705
535
|
self.w_vc = None
|
706
536
|
self.w_scale = None
|
707
537
|
|
538
|
+
self.w_scale_k = None
|
539
|
+
self.w_scale_v = None
|
540
|
+
self.use_deep_gemm_bmm = False
|
541
|
+
|
708
542
|
self.flashinfer_mla_disable_ragged = global_server_args_dict[
|
709
543
|
"flashinfer_mla_disable_ragged"
|
710
544
|
]
|
@@ -762,6 +596,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
762
596
|
positions: torch.Tensor,
|
763
597
|
hidden_states: torch.Tensor,
|
764
598
|
forward_batch: ForwardBatch,
|
599
|
+
zero_allocator: BumpAllocator,
|
765
600
|
) -> torch.Tensor:
|
766
601
|
if hidden_states.shape[0] == 0:
|
767
602
|
assert (
|
@@ -787,9 +622,13 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
787
622
|
positions, hidden_states, forward_batch
|
788
623
|
)
|
789
624
|
else:
|
790
|
-
return self.forward_absorb(
|
625
|
+
return self.forward_absorb(
|
626
|
+
positions, hidden_states, forward_batch, zero_allocator
|
627
|
+
)
|
791
628
|
else:
|
792
|
-
return self.forward_absorb(
|
629
|
+
return self.forward_absorb(
|
630
|
+
positions, hidden_states, forward_batch, zero_allocator
|
631
|
+
)
|
793
632
|
|
794
633
|
def forward_normal(
|
795
634
|
self,
|
@@ -838,6 +677,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
838
677
|
positions: torch.Tensor,
|
839
678
|
hidden_states: torch.Tensor,
|
840
679
|
forward_batch: ForwardBatch,
|
680
|
+
zero_allocator: BumpAllocator,
|
841
681
|
) -> torch.Tensor:
|
842
682
|
q_len = hidden_states.shape[0]
|
843
683
|
q_input = hidden_states.new_empty(
|
@@ -853,7 +693,24 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
853
693
|
)
|
854
694
|
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
855
695
|
|
856
|
-
if self.
|
696
|
+
if self.use_deep_gemm_bmm:
|
697
|
+
q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
|
698
|
+
per_tensor_quant_mla_deep_gemm_masked_fp8(
|
699
|
+
q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn
|
700
|
+
)
|
701
|
+
)
|
702
|
+
q_nope_out = q_nope.new_empty(
|
703
|
+
(self.num_local_heads, aligned_m, self.kv_lora_rank)
|
704
|
+
)
|
705
|
+
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
706
|
+
(q_nope_val, q_nope_scale),
|
707
|
+
(self.w_kc, self.w_scale_k),
|
708
|
+
q_nope_out,
|
709
|
+
masked_m,
|
710
|
+
expected_m,
|
711
|
+
)
|
712
|
+
q_nope_out = q_nope_out[:, :expected_m, :]
|
713
|
+
elif self.w_kc.dtype == torch.float8_e4m3fnuz:
|
857
714
|
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
|
858
715
|
q_nope_out = torch.bmm(
|
859
716
|
q_nope.to(torch.bfloat16).transpose(0, 1),
|
@@ -861,7 +718,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
861
718
|
)
|
862
719
|
elif self.w_kc.dtype == torch.float8_e4m3fn:
|
863
720
|
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
|
864
|
-
q_nope.transpose(0, 1),
|
721
|
+
q_nope.transpose(0, 1),
|
722
|
+
zero_allocator.allocate(1),
|
865
723
|
)
|
866
724
|
q_nope_out = bmm_fp8(
|
867
725
|
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
|
@@ -884,7 +742,24 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
884
742
|
attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
|
885
743
|
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
886
744
|
|
887
|
-
if self.
|
745
|
+
if self.use_deep_gemm_bmm:
|
746
|
+
attn_output_val, attn_output_scale, masked_m, expected_m, aligned_m = (
|
747
|
+
per_tensor_quant_mla_deep_gemm_masked_fp8(
|
748
|
+
attn_output.transpose(0, 1), dtype=torch.float8_e4m3fn
|
749
|
+
)
|
750
|
+
)
|
751
|
+
attn_bmm_output = attn_output.new_empty(
|
752
|
+
(self.num_local_heads, aligned_m, self.v_head_dim)
|
753
|
+
)
|
754
|
+
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
755
|
+
(attn_output_val, attn_output_scale),
|
756
|
+
(self.w_vc, self.w_scale_v),
|
757
|
+
attn_bmm_output,
|
758
|
+
masked_m,
|
759
|
+
expected_m,
|
760
|
+
)
|
761
|
+
attn_bmm_output = attn_bmm_output[:, :expected_m, :]
|
762
|
+
elif self.w_vc.dtype == torch.float8_e4m3fnuz:
|
888
763
|
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
|
889
764
|
attn_bmm_output = torch.bmm(
|
890
765
|
attn_output.to(torch.bfloat16).transpose(0, 1),
|
@@ -892,7 +767,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
892
767
|
)
|
893
768
|
elif self.w_vc.dtype == torch.float8_e4m3fn:
|
894
769
|
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
|
895
|
-
attn_output.transpose(0, 1),
|
770
|
+
attn_output.transpose(0, 1),
|
771
|
+
zero_allocator.allocate(1),
|
896
772
|
)
|
897
773
|
attn_bmm_output = bmm_fp8(
|
898
774
|
attn_output_val,
|
@@ -913,6 +789,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
913
789
|
positions: torch.Tensor,
|
914
790
|
hidden_states: torch.Tensor,
|
915
791
|
forward_batch: ForwardBatch,
|
792
|
+
zero_allocator: BumpAllocator,
|
916
793
|
) -> torch.Tensor:
|
917
794
|
enable_rope_fusion = (
|
918
795
|
os.getenv("SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION", "1") == "1"
|
@@ -939,7 +816,9 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
939
816
|
)
|
940
817
|
elif self.w_kc.dtype == torch.float8_e4m3fn:
|
941
818
|
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
|
942
|
-
q_nope.transpose(0, 1),
|
819
|
+
q_nope.transpose(0, 1),
|
820
|
+
zero_allocator.allocate(1),
|
821
|
+
dtype=torch.float8_e4m3fn,
|
943
822
|
)
|
944
823
|
q_nope_out = bmm_fp8(
|
945
824
|
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
|
@@ -1035,7 +914,9 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1035
914
|
)
|
1036
915
|
elif self.w_vc.dtype == torch.float8_e4m3fn:
|
1037
916
|
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
|
1038
|
-
attn_output.transpose(0, 1),
|
917
|
+
attn_output.transpose(0, 1),
|
918
|
+
zero_allocator.allocate(1),
|
919
|
+
dtype=torch.float8_e4m3fn,
|
1039
920
|
)
|
1040
921
|
attn_bmm_output = bmm_fp8(
|
1041
922
|
attn_output_val,
|
@@ -1173,6 +1054,19 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1173
1054
|
return output
|
1174
1055
|
|
1175
1056
|
|
1057
|
+
class _FFNInputMode(Enum):
|
1058
|
+
# The MLP sublayer requires 1/tp_size tokens as input
|
1059
|
+
SCATTERED = auto()
|
1060
|
+
# The MLP sublayer requires all tokens as input
|
1061
|
+
FULL = auto()
|
1062
|
+
|
1063
|
+
|
1064
|
+
@dataclass
|
1065
|
+
class _DecoderLayerInfo:
|
1066
|
+
is_sparse: bool
|
1067
|
+
ffn_input_mode: _FFNInputMode
|
1068
|
+
|
1069
|
+
|
1176
1070
|
class DeepseekV2DecoderLayer(nn.Module):
|
1177
1071
|
|
1178
1072
|
def __init__(
|
@@ -1183,14 +1077,6 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1183
1077
|
is_nextn: bool = False,
|
1184
1078
|
prefix: str = "",
|
1185
1079
|
) -> None:
|
1186
|
-
|
1187
|
-
def is_sparse_layer(l: int):
|
1188
|
-
return (
|
1189
|
-
config.n_routed_experts is not None
|
1190
|
-
and l >= config.first_k_dense_replace
|
1191
|
-
and l % config.moe_layer_freq == 0
|
1192
|
-
)
|
1193
|
-
|
1194
1080
|
super().__init__()
|
1195
1081
|
self.hidden_size = config.hidden_size
|
1196
1082
|
rope_theta = getattr(config, "rope_theta", 10000)
|
@@ -1201,68 +1087,54 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1201
1087
|
self.dp_size = get_attention_dp_size()
|
1202
1088
|
self.attn_tp_size = get_attention_tp_size()
|
1203
1089
|
self.attn_tp_rank = get_attention_tp_rank()
|
1090
|
+
self.self_attn = DeepseekV2AttentionMLA(
|
1091
|
+
config=config,
|
1092
|
+
hidden_size=self.hidden_size,
|
1093
|
+
num_heads=config.num_attention_heads,
|
1094
|
+
qk_nope_head_dim=config.qk_nope_head_dim,
|
1095
|
+
qk_rope_head_dim=config.qk_rope_head_dim,
|
1096
|
+
v_head_dim=config.v_head_dim,
|
1097
|
+
q_lora_rank=(
|
1098
|
+
config.q_lora_rank if hasattr(config, "q_lora_rank") else None
|
1099
|
+
),
|
1100
|
+
kv_lora_rank=config.kv_lora_rank,
|
1101
|
+
rope_theta=rope_theta,
|
1102
|
+
rope_scaling=rope_scaling,
|
1103
|
+
max_position_embeddings=max_position_embeddings,
|
1104
|
+
quant_config=quant_config,
|
1105
|
+
layer_id=layer_id,
|
1106
|
+
reduce_results=False,
|
1107
|
+
prefix=add_prefix("self_attn", prefix),
|
1108
|
+
)
|
1204
1109
|
|
1205
|
-
|
1206
|
-
|
1207
|
-
|
1208
|
-
|
1209
|
-
num_heads=config.num_attention_heads,
|
1210
|
-
qk_nope_head_dim=config.qk_nope_head_dim,
|
1211
|
-
qk_rope_head_dim=config.qk_rope_head_dim,
|
1212
|
-
v_head_dim=config.v_head_dim,
|
1213
|
-
q_lora_rank=(
|
1214
|
-
config.q_lora_rank if hasattr(config, "q_lora_rank") else None
|
1215
|
-
),
|
1216
|
-
kv_lora_rank=config.kv_lora_rank,
|
1217
|
-
rope_theta=rope_theta,
|
1218
|
-
rope_scaling=rope_scaling,
|
1219
|
-
max_position_embeddings=max_position_embeddings,
|
1220
|
-
quant_config=quant_config,
|
1221
|
-
layer_id=layer_id,
|
1222
|
-
reduce_results=False,
|
1223
|
-
prefix=add_prefix("self_attn", prefix),
|
1224
|
-
)
|
1225
|
-
else:
|
1226
|
-
self.self_attn = DeepseekV2Attention(
|
1227
|
-
config=config,
|
1228
|
-
hidden_size=self.hidden_size,
|
1229
|
-
num_heads=config.num_attention_heads,
|
1230
|
-
qk_nope_head_dim=config.qk_nope_head_dim,
|
1231
|
-
qk_rope_head_dim=config.qk_rope_head_dim,
|
1232
|
-
v_head_dim=config.v_head_dim,
|
1233
|
-
q_lora_rank=(
|
1234
|
-
config.q_lora_rank if hasattr(config, "q_lora_rank") else None
|
1235
|
-
),
|
1236
|
-
kv_lora_rank=config.kv_lora_rank,
|
1237
|
-
rope_theta=rope_theta,
|
1238
|
-
rope_scaling=rope_scaling,
|
1239
|
-
max_position_embeddings=max_position_embeddings,
|
1240
|
-
quant_config=quant_config,
|
1241
|
-
layer_id=layer_id,
|
1242
|
-
reduce_results=False,
|
1243
|
-
prefix=add_prefix("self_attn", prefix),
|
1244
|
-
)
|
1110
|
+
self.info = self._compute_info(config, layer_id=layer_id, is_nextn=is_nextn)
|
1111
|
+
previous_layer_info = self._compute_info(
|
1112
|
+
config, layer_id=layer_id - 1, is_nextn=False
|
1113
|
+
)
|
1245
1114
|
|
1246
|
-
if
|
1115
|
+
if self.info.is_sparse:
|
1247
1116
|
self.mlp = DeepseekV2MoE(
|
1248
1117
|
config=config,
|
1249
1118
|
quant_config=quant_config,
|
1250
1119
|
prefix=add_prefix("mlp", prefix),
|
1251
1120
|
)
|
1252
|
-
self.is_sparse = True
|
1253
1121
|
else:
|
1122
|
+
if self._enable_moe_dense_fully_dp():
|
1123
|
+
mlp_tp_rank, mlp_tp_size = 0, 1
|
1124
|
+
else:
|
1125
|
+
mlp_tp_rank, mlp_tp_size = None, None
|
1254
1126
|
self.mlp = DeepseekV2MLP(
|
1255
1127
|
hidden_size=config.hidden_size,
|
1256
1128
|
intermediate_size=config.intermediate_size,
|
1257
1129
|
hidden_act=config.hidden_act,
|
1258
1130
|
quant_config=quant_config,
|
1259
1131
|
prefix=add_prefix("mlp", prefix),
|
1132
|
+
tp_rank=mlp_tp_rank,
|
1133
|
+
tp_size=mlp_tp_size,
|
1260
1134
|
)
|
1261
|
-
self.is_sparse = False
|
1262
1135
|
|
1263
1136
|
self.input_is_scattered = (
|
1264
|
-
|
1265
|
-
and global_server_args_dict["enable_deepep_moe"]
|
1137
|
+
previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
|
1266
1138
|
)
|
1267
1139
|
self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
|
1268
1140
|
|
@@ -1271,28 +1143,51 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1271
1143
|
config.hidden_size, eps=config.rms_norm_eps
|
1272
1144
|
)
|
1273
1145
|
|
1146
|
+
@staticmethod
|
1147
|
+
def _enable_moe_dense_fully_dp():
|
1148
|
+
return global_server_args_dict["moe_dense_tp_size"] == 1
|
1149
|
+
|
1150
|
+
@staticmethod
|
1151
|
+
def _compute_info(config: PretrainedConfig, layer_id: int, is_nextn: bool):
|
1152
|
+
is_sparse = is_nextn or (
|
1153
|
+
config.n_routed_experts is not None
|
1154
|
+
and layer_id >= config.first_k_dense_replace
|
1155
|
+
and layer_id % config.moe_layer_freq == 0
|
1156
|
+
)
|
1157
|
+
ffn_input_mode = (
|
1158
|
+
_FFNInputMode.SCATTERED
|
1159
|
+
if (global_server_args_dict["enable_deepep_moe"] and is_sparse)
|
1160
|
+
or (DeepseekV2DecoderLayer._enable_moe_dense_fully_dp() and not is_sparse)
|
1161
|
+
else _FFNInputMode.FULL
|
1162
|
+
)
|
1163
|
+
return _DecoderLayerInfo(is_sparse=is_sparse, ffn_input_mode=ffn_input_mode)
|
1164
|
+
|
1274
1165
|
def forward(
|
1275
1166
|
self,
|
1276
1167
|
positions: torch.Tensor,
|
1277
1168
|
hidden_states: torch.Tensor,
|
1278
1169
|
forward_batch: ForwardBatch,
|
1279
1170
|
residual: Optional[torch.Tensor],
|
1171
|
+
zero_allocator: BumpAllocator,
|
1280
1172
|
) -> torch.Tensor:
|
1281
|
-
if
|
1282
|
-
return self.
|
1283
|
-
positions, hidden_states, forward_batch, residual
|
1173
|
+
if self.info.ffn_input_mode == _FFNInputMode.SCATTERED:
|
1174
|
+
return self.forward_ffn_with_scattered_input(
|
1175
|
+
positions, hidden_states, forward_batch, residual, zero_allocator
|
1284
1176
|
)
|
1285
|
-
|
1286
|
-
return self.
|
1287
|
-
positions, hidden_states, forward_batch, residual
|
1177
|
+
elif self.info.ffn_input_mode == _FFNInputMode.FULL:
|
1178
|
+
return self.forward_ffn_with_full_input(
|
1179
|
+
positions, hidden_states, forward_batch, residual, zero_allocator
|
1288
1180
|
)
|
1181
|
+
else:
|
1182
|
+
raise NotImplementedError
|
1289
1183
|
|
1290
|
-
def
|
1184
|
+
def forward_ffn_with_full_input(
|
1291
1185
|
self,
|
1292
1186
|
positions: torch.Tensor,
|
1293
1187
|
hidden_states: torch.Tensor,
|
1294
1188
|
forward_batch: ForwardBatch,
|
1295
1189
|
residual: Optional[torch.Tensor],
|
1190
|
+
zero_allocator: BumpAllocator,
|
1296
1191
|
) -> torch.Tensor:
|
1297
1192
|
|
1298
1193
|
if hidden_states.shape[0] == 0:
|
@@ -1313,6 +1208,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1313
1208
|
positions=positions,
|
1314
1209
|
hidden_states=hidden_states,
|
1315
1210
|
forward_batch=forward_batch,
|
1211
|
+
zero_allocator=zero_allocator,
|
1316
1212
|
)
|
1317
1213
|
|
1318
1214
|
# Gather
|
@@ -1354,12 +1250,13 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1354
1250
|
|
1355
1251
|
return hidden_states, residual
|
1356
1252
|
|
1357
|
-
def
|
1253
|
+
def forward_ffn_with_scattered_input(
|
1358
1254
|
self,
|
1359
1255
|
positions: torch.Tensor,
|
1360
1256
|
hidden_states: torch.Tensor,
|
1361
1257
|
forward_batch: ForwardBatch,
|
1362
1258
|
residual: Optional[torch.Tensor],
|
1259
|
+
zero_allocator: BumpAllocator,
|
1363
1260
|
) -> torch.Tensor:
|
1364
1261
|
|
1365
1262
|
if hidden_states.shape[0] == 0:
|
@@ -1385,6 +1282,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1385
1282
|
positions=positions,
|
1386
1283
|
hidden_states=hidden_states,
|
1387
1284
|
forward_batch=forward_batch,
|
1285
|
+
zero_allocator=zero_allocator,
|
1388
1286
|
)
|
1389
1287
|
|
1390
1288
|
if self.attn_tp_size != 1:
|
@@ -1410,7 +1308,13 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1410
1308
|
hidden_states, residual = self.post_attention_layernorm(
|
1411
1309
|
hidden_states, residual
|
1412
1310
|
)
|
1413
|
-
|
1311
|
+
|
1312
|
+
if not (
|
1313
|
+
self._enable_moe_dense_fully_dp()
|
1314
|
+
and (not self.info.is_sparse)
|
1315
|
+
and hidden_states.shape[0] == 0
|
1316
|
+
):
|
1317
|
+
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
|
1414
1318
|
|
1415
1319
|
if self.is_last_layer and self.attn_tp_size != 1:
|
1416
1320
|
hidden_states += residual
|
@@ -1466,6 +1370,14 @@ class DeepseekV2Model(nn.Module):
|
|
1466
1370
|
forward_batch: ForwardBatch,
|
1467
1371
|
input_embeds: torch.Tensor = None,
|
1468
1372
|
) -> torch.Tensor:
|
1373
|
+
zero_allocator = BumpAllocator(
|
1374
|
+
# TODO for two-batch-overlap, we need a larger buffer size
|
1375
|
+
buffer_size=len(self.layers) * 2,
|
1376
|
+
dtype=torch.float32,
|
1377
|
+
device=(
|
1378
|
+
input_embeds.device if input_embeds is not None else input_ids.device
|
1379
|
+
),
|
1380
|
+
)
|
1469
1381
|
|
1470
1382
|
if input_embeds is None:
|
1471
1383
|
hidden_states = self.embed_tokens(input_ids)
|
@@ -1477,7 +1389,7 @@ class DeepseekV2Model(nn.Module):
|
|
1477
1389
|
expert_distribution_recorder.set_current_layer(i)
|
1478
1390
|
layer = self.layers[i]
|
1479
1391
|
hidden_states, residual = layer(
|
1480
|
-
positions, hidden_states, forward_batch, residual
|
1392
|
+
positions, hidden_states, forward_batch, residual, zero_allocator
|
1481
1393
|
)
|
1482
1394
|
if not forward_batch.forward_mode.is_idle():
|
1483
1395
|
if residual is None:
|
@@ -1500,24 +1412,33 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1500
1412
|
self.tp_size = get_tensor_model_parallel_world_size()
|
1501
1413
|
self.quant_config = quant_config
|
1502
1414
|
self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
|
1503
|
-
|
1504
|
-
|
1505
|
-
|
1506
|
-
|
1507
|
-
|
1508
|
-
|
1509
|
-
|
1510
|
-
|
1511
|
-
|
1512
|
-
|
1513
|
-
|
1514
|
-
|
1515
|
-
|
1516
|
-
|
1517
|
-
|
1518
|
-
|
1519
|
-
|
1520
|
-
|
1415
|
+
if self.n_share_experts_fusion > 0:
|
1416
|
+
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
|
1417
|
+
if (
|
1418
|
+
self.config.architectures[0] != "DeepseekV3ForCausalLM"
|
1419
|
+
or self.config.n_routed_experts != 256
|
1420
|
+
):
|
1421
|
+
self.n_share_experts_fusion = 0
|
1422
|
+
global_server_args_dict["n_share_experts_fusion"] = 0
|
1423
|
+
logger.info(
|
1424
|
+
"Only Deepseek V3/R1 can use shared experts fusion optimization. Shared experts fusion optimization is disabled."
|
1425
|
+
)
|
1426
|
+
else:
|
1427
|
+
assert (
|
1428
|
+
self.n_share_experts_fusion == self.tp_size
|
1429
|
+
), f"Shared experts fusion optimization is enabled in DeepSeek V3/R1, set it to {self.tp_size} can get best optimized performace."
|
1430
|
+
elif self.n_share_experts_fusion == 0:
|
1431
|
+
if (
|
1432
|
+
torch.cuda.get_device_capability("cuda") >= (9, 0)
|
1433
|
+
and self.config.architectures[0] == "DeepseekV3ForCausalLM"
|
1434
|
+
and self.config.n_routed_experts == 256
|
1435
|
+
and (not global_server_args_dict["enable_deepep_moe"])
|
1436
|
+
):
|
1437
|
+
self.n_share_experts_fusion = self.tp_size
|
1438
|
+
global_server_args_dict["n_share_experts_fusion"] = self.tp_size
|
1439
|
+
logger.info(
|
1440
|
+
"Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled."
|
1441
|
+
)
|
1521
1442
|
|
1522
1443
|
self.model = DeepseekV2Model(
|
1523
1444
|
config, quant_config, prefix=add_prefix("model", prefix)
|
@@ -1552,78 +1473,92 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1552
1473
|
def post_load_weights(self):
|
1553
1474
|
|
1554
1475
|
# Perform post-processing after loading weights
|
1555
|
-
|
1556
|
-
|
1557
|
-
|
1558
|
-
|
1559
|
-
if
|
1560
|
-
|
1561
|
-
|
1562
|
-
|
1563
|
-
|
1564
|
-
|
1565
|
-
self_attn.kv_b_proj.qzeros,
|
1566
|
-
).T
|
1567
|
-
else:
|
1568
|
-
w = ops.awq_dequantize(
|
1569
|
-
self_attn.kv_b_proj.qweight,
|
1570
|
-
self_attn.kv_b_proj.scales,
|
1571
|
-
self_attn.kv_b_proj.qzeros,
|
1572
|
-
0,
|
1573
|
-
0,
|
1574
|
-
0,
|
1575
|
-
).T
|
1476
|
+
for layer_id in range(self.config.num_hidden_layers):
|
1477
|
+
self_attn = self.model.layers[layer_id].self_attn
|
1478
|
+
if hasattr(self_attn.kv_b_proj, "qweight"):
|
1479
|
+
# AWQ compatible
|
1480
|
+
if _is_cuda:
|
1481
|
+
w = awq_dequantize(
|
1482
|
+
self_attn.kv_b_proj.qweight,
|
1483
|
+
self_attn.kv_b_proj.scales,
|
1484
|
+
self_attn.kv_b_proj.qzeros,
|
1485
|
+
).T
|
1576
1486
|
else:
|
1577
|
-
w =
|
1578
|
-
|
1579
|
-
|
1580
|
-
|
1581
|
-
|
1582
|
-
|
1583
|
-
|
1584
|
-
|
1585
|
-
|
1586
|
-
|
1587
|
-
|
1588
|
-
|
1589
|
-
|
1590
|
-
|
1591
|
-
|
1592
|
-
|
1593
|
-
|
1594
|
-
|
1595
|
-
|
1596
|
-
|
1487
|
+
w = awq_dequantize(
|
1488
|
+
self_attn.kv_b_proj.qweight,
|
1489
|
+
self_attn.kv_b_proj.scales,
|
1490
|
+
self_attn.kv_b_proj.qzeros,
|
1491
|
+
0,
|
1492
|
+
0,
|
1493
|
+
0,
|
1494
|
+
).T
|
1495
|
+
else:
|
1496
|
+
w = self_attn.kv_b_proj.weight
|
1497
|
+
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
|
1498
|
+
# This may affect the accuracy of fp8 model.
|
1499
|
+
# Fix deepseek v3 blockwise bmm by using deep_gemm
|
1500
|
+
use_deep_gemm_bmm = False
|
1501
|
+
model_dtype = torch.get_default_dtype()
|
1502
|
+
|
1503
|
+
if w.dtype in (
|
1504
|
+
torch.float8_e4m3fn,
|
1505
|
+
torch.float8_e4m3fnuz,
|
1506
|
+
):
|
1507
|
+
if hasattr(self.quant_config, "weight_block_size"):
|
1508
|
+
weight_block_size = self.quant_config.weight_block_size
|
1509
|
+
if weight_block_size is not None:
|
1510
|
+
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
1511
|
+
if _is_hip:
|
1512
|
+
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
1513
|
+
weight=w,
|
1514
|
+
weight_scale=self_attn.kv_b_proj.weight_scale_inv,
|
1515
|
+
input_scale=None,
|
1516
|
+
)
|
1517
|
+
else:
|
1518
|
+
weight = w
|
1519
|
+
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
1597
1520
|
|
1521
|
+
if (
|
1522
|
+
_is_cuda
|
1523
|
+
and _enable_jit_deepgemm_bmm
|
1524
|
+
and weight_block_size[0] == 128
|
1525
|
+
and weight_block_size[1] == 128
|
1526
|
+
and model_dtype == torch.bfloat16
|
1527
|
+
):
|
1528
|
+
block_scale = weight_scale
|
1529
|
+
use_deep_gemm_bmm = True
|
1530
|
+
else:
|
1598
1531
|
w, scale = block_quant_to_tensor_quant(
|
1599
1532
|
weight, weight_scale, weight_block_size
|
1600
1533
|
)
|
1601
1534
|
self_attn.w_scale = scale
|
1602
|
-
|
1535
|
+
else:
|
1536
|
+
weight = w
|
1537
|
+
weight_scale = self_attn.kv_b_proj.weight_scale
|
1538
|
+
w, scale = channel_quant_to_tensor_quant(weight, weight_scale)
|
1539
|
+
self_attn.w_scale = scale
|
1540
|
+
|
1541
|
+
if w.dtype == torch.int8:
|
1542
|
+
if hasattr(self.quant_config, "weight_block_size"):
|
1543
|
+
# block-wise int8 need it
|
1544
|
+
weight_block_size = self.quant_config.weight_block_size
|
1545
|
+
if weight_block_size is not None:
|
1546
|
+
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
1603
1547
|
weight = w
|
1604
|
-
weight_scale = self_attn.kv_b_proj.
|
1605
|
-
w
|
1606
|
-
|
1607
|
-
|
1608
|
-
|
1609
|
-
|
1610
|
-
|
1611
|
-
|
1612
|
-
|
1613
|
-
|
1614
|
-
|
1615
|
-
|
1616
|
-
|
1617
|
-
|
1618
|
-
).to(torch.bfloat16)
|
1619
|
-
else:
|
1620
|
-
# channel-wise int8 need it
|
1621
|
-
w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
|
1622
|
-
torch.bfloat16
|
1623
|
-
)
|
1624
|
-
w_kc, w_vc = w.unflatten(
|
1625
|
-
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
1626
|
-
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
1548
|
+
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
1549
|
+
w = int8_block_dequant(
|
1550
|
+
weight, weight_scale, weight_block_size
|
1551
|
+
).to(torch.bfloat16)
|
1552
|
+
else:
|
1553
|
+
# channel-wise int8 need it
|
1554
|
+
w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
|
1555
|
+
torch.bfloat16
|
1556
|
+
)
|
1557
|
+
|
1558
|
+
w_kc, w_vc = w.unflatten(
|
1559
|
+
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
1560
|
+
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
1561
|
+
if not use_deep_gemm_bmm:
|
1627
1562
|
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
1628
1563
|
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
|
1629
1564
|
if (
|
@@ -1633,6 +1568,17 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1633
1568
|
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
|
1634
1569
|
if _is_hip:
|
1635
1570
|
self_attn.w_scale *= 2.0
|
1571
|
+
else:
|
1572
|
+
num_tiles_k = self_attn.qk_nope_head_dim // weight_block_size[1]
|
1573
|
+
num_tiles_n = self_attn.v_head_dim // weight_block_size[0]
|
1574
|
+
ws_kc, ws_vc = block_scale.unflatten(
|
1575
|
+
0, (-1, (num_tiles_k + num_tiles_n))
|
1576
|
+
).split([num_tiles_k, num_tiles_n], dim=1)
|
1577
|
+
self_attn.w_scale_k = ws_kc.transpose(1, 2).contiguous()
|
1578
|
+
self_attn.w_scale_v = ws_vc.contiguous()
|
1579
|
+
self_attn.w_kc = w_kc.transpose(1, 2).contiguous()
|
1580
|
+
self_attn.w_vc = w_vc.contiguous()
|
1581
|
+
self_attn.use_deep_gemm_bmm = True
|
1636
1582
|
|
1637
1583
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
1638
1584
|
stacked_params_mapping = [
|
@@ -1640,7 +1586,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1640
1586
|
("gate_up_proj", "gate_proj", 0),
|
1641
1587
|
("gate_up_proj", "up_proj", 1),
|
1642
1588
|
]
|
1643
|
-
if self.n_share_experts_fusion
|
1589
|
+
if self.n_share_experts_fusion > 0:
|
1644
1590
|
weights_list = list(weights)
|
1645
1591
|
weights_dict = dict(weights_list)
|
1646
1592
|
if self.quant_config.get_name() == "w8a8_int8":
|
@@ -1699,12 +1645,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1699
1645
|
ckpt_gate_proj_name="gate_proj",
|
1700
1646
|
ckpt_down_proj_name="down_proj",
|
1701
1647
|
ckpt_up_proj_name="up_proj",
|
1702
|
-
num_experts=self.config.n_routed_experts
|
1703
|
-
+ (
|
1704
|
-
self.n_share_experts_fusion
|
1705
|
-
if self.n_share_experts_fusion is not None
|
1706
|
-
else 0
|
1707
|
-
),
|
1648
|
+
num_experts=self.config.n_routed_experts + self.n_share_experts_fusion,
|
1708
1649
|
)
|
1709
1650
|
|
1710
1651
|
params_dict = dict(self.named_parameters())
|