sglang 0.4.5__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -4
- sglang/bench_one_batch.py +23 -2
- sglang/bench_serving.py +6 -4
- sglang/lang/backend/anthropic.py +0 -4
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/openai.py +1 -1
- sglang/lang/backend/vertexai.py +0 -1
- sglang/lang/compiler.py +1 -7
- sglang/lang/tracer.py +3 -7
- sglang/srt/_custom_ops.py +0 -2
- sglang/srt/configs/model_config.py +37 -5
- sglang/srt/constrained/base_grammar_backend.py +26 -5
- sglang/srt/constrained/llguidance_backend.py +1 -0
- sglang/srt/constrained/outlines_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +14 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
- sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
- sglang/srt/constrained/xgrammar_backend.py +27 -4
- sglang/srt/custom_op.py +0 -62
- sglang/srt/disaggregation/base/__init__.py +8 -0
- sglang/srt/disaggregation/base/conn.py +113 -0
- sglang/srt/disaggregation/decode.py +80 -11
- sglang/srt/disaggregation/mini_lb.py +58 -123
- sglang/srt/disaggregation/mooncake/__init__.py +6 -0
- sglang/srt/disaggregation/mooncake/conn.py +585 -0
- sglang/srt/disaggregation/mooncake/transfer_engine.py +77 -0
- sglang/srt/disaggregation/prefill.py +82 -22
- sglang/srt/disaggregation/utils.py +46 -0
- sglang/srt/entrypoints/EngineBase.py +53 -0
- sglang/srt/entrypoints/engine.py +36 -8
- sglang/srt/entrypoints/http_server.py +37 -8
- sglang/srt/entrypoints/http_server_engine.py +142 -0
- sglang/srt/entrypoints/verl_engine.py +42 -13
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +6 -8
- sglang/srt/layers/attention/flashattention_backend.py +430 -257
- sglang/srt/layers/attention/flashinfer_backend.py +18 -9
- sglang/srt/layers/attention/torch_native_backend.py +6 -1
- sglang/srt/layers/attention/triton_backend.py +6 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/dp_attention.py +2 -4
- sglang/srt/layers/elementwise.py +15 -2
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +18 -3
- sglang/srt/layers/moe/ep_moe/layer.py +15 -29
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
- sglang/srt/layers/moe/fused_moe_native.py +4 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +46 -34
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/router.py +7 -1
- sglang/srt/layers/moe/topk.py +63 -45
- sglang/srt/layers/parameter.py +0 -2
- sglang/srt/layers/quantization/__init__.py +13 -5
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +12 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -77
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
- sglang/srt/layers/quantization/fp8.py +131 -136
- sglang/srt/layers/quantization/fp8_kernel.py +328 -46
- sglang/srt/layers/quantization/fp8_utils.py +206 -253
- sglang/srt/layers/quantization/kv_cache.py +43 -52
- sglang/srt/layers/quantization/modelopt_quant.py +271 -4
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/utils.py +5 -11
- sglang/srt/layers/quantization/w8a8_fp8.py +156 -4
- sglang/srt/layers/quantization/w8a8_int8.py +8 -7
- sglang/srt/layers/radix_attention.py +28 -1
- sglang/srt/layers/rotary_embedding.py +15 -3
- sglang/srt/layers/sampler.py +5 -10
- sglang/srt/lora/backend/base_backend.py +18 -2
- sglang/srt/lora/backend/flashinfer_backend.py +1 -1
- sglang/srt/lora/backend/triton_backend.py +1 -1
- sglang/srt/lora/layers.py +1 -1
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/lora/lora_manager.py +1 -1
- sglang/srt/managers/detokenizer_manager.py +0 -1
- sglang/srt/managers/io_struct.py +255 -97
- sglang/srt/managers/mm_utils.py +7 -5
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +117 -79
- sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
- sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
- sglang/srt/managers/schedule_batch.py +64 -25
- sglang/srt/managers/scheduler.py +80 -82
- sglang/srt/managers/tokenizer_manager.py +18 -3
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +5 -1
- sglang/srt/mem_cache/memory_pool.py +21 -3
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +9 -6
- sglang/srt/model_executor/forward_batch_info.py +234 -15
- sglang/srt/model_executor/model_runner.py +67 -35
- sglang/srt/model_loader/loader.py +31 -4
- sglang/srt/model_loader/weight_utils.py +4 -2
- sglang/srt/models/baichuan.py +2 -0
- sglang/srt/models/bert.py +398 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/commandr.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +74 -70
- sglang/srt/models/deepseek_v2.py +494 -366
- sglang/srt/models/exaone.py +1 -0
- sglang/srt/models/gemma.py +1 -0
- sglang/srt/models/gemma2.py +1 -0
- sglang/srt/models/gemma3_causal.py +1 -0
- sglang/srt/models/gpt2.py +1 -0
- sglang/srt/models/gpt_bigcode.py +1 -0
- sglang/srt/models/granite.py +1 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +1 -0
- sglang/srt/models/llama.py +6 -5
- sglang/srt/models/llama4.py +101 -34
- sglang/srt/models/minicpm.py +1 -0
- sglang/srt/models/minicpm3.py +30 -200
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/mllama.py +51 -8
- sglang/srt/models/mllama4.py +102 -29
- sglang/srt/models/olmo.py +1 -0
- sglang/srt/models/olmo2.py +1 -0
- sglang/srt/models/olmoe.py +1 -0
- sglang/srt/models/phi3_small.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +5 -1
- sglang/srt/models/qwen2_5_vl.py +35 -70
- sglang/srt/models/qwen2_moe.py +15 -13
- sglang/srt/models/qwen2_vl.py +27 -25
- sglang/srt/models/qwen3.py +335 -0
- sglang/srt/models/qwen3_moe.py +423 -0
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/models/xverse.py +1 -0
- sglang/srt/models/xverse_moe.py +1 -0
- sglang/srt/openai_api/adapter.py +4 -1
- sglang/srt/patch_torch.py +11 -0
- sglang/srt/reasoning_parser.py +0 -1
- sglang/srt/sampling/sampling_batch_info.py +2 -3
- sglang/srt/server_args.py +55 -19
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
- sglang/srt/speculative/eagle_utils.py +1 -11
- sglang/srt/speculative/eagle_worker.py +10 -9
- sglang/srt/utils.py +136 -10
- sglang/test/attention/test_flashattn_backend.py +259 -221
- sglang/test/attention/test_flashattn_mla_backend.py +285 -0
- sglang/test/attention/test_prefix_chunk_info.py +224 -0
- sglang/test/runners.py +5 -1
- sglang/test/test_block_fp8.py +224 -0
- sglang/test/test_custom_ops.py +1 -1
- sglang/test/test_utils.py +19 -8
- sglang/version.py +1 -1
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +15 -5
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +162 -147
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
- sglang/lang/__init__.py +0 -0
- sglang/srt/disaggregation/conn.py +0 -81
- sglang/srt/lora/backend/__init__.py +0 -25
- sglang/srt/server.py +0 -18
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -18,6 +18,8 @@
|
|
18
18
|
|
19
19
|
import logging
|
20
20
|
import os
|
21
|
+
from dataclasses import dataclass
|
22
|
+
from enum import Enum, IntEnum, auto
|
21
23
|
from typing import Any, Dict, Iterable, Optional, Tuple
|
22
24
|
|
23
25
|
import torch
|
@@ -27,6 +29,7 @@ from tqdm import tqdm
|
|
27
29
|
from transformers import PretrainedConfig
|
28
30
|
|
29
31
|
from sglang.srt.distributed import (
|
32
|
+
get_tensor_model_parallel_rank,
|
30
33
|
get_tensor_model_parallel_world_size,
|
31
34
|
parallel_state,
|
32
35
|
tensor_model_parallel_all_reduce,
|
@@ -54,9 +57,14 @@ from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
|
54
57
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
55
58
|
from sglang.srt.layers.moe.topk import select_experts
|
56
59
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
60
|
+
from sglang.srt.layers.quantization.fp8_kernel import (
|
61
|
+
_enable_jit_deepgemm_bmm,
|
62
|
+
per_tensor_quant_mla_deep_gemm_masked_fp8,
|
63
|
+
per_tensor_quant_mla_fp8,
|
64
|
+
)
|
57
65
|
from sglang.srt.layers.quantization.fp8_utils import (
|
58
66
|
block_quant_to_tensor_quant,
|
59
|
-
|
67
|
+
channel_quant_to_tensor_quant,
|
60
68
|
normalize_e4m3fn_to_e4m3fnuz,
|
61
69
|
)
|
62
70
|
from sglang.srt.layers.quantization.int8_utils import (
|
@@ -72,15 +80,16 @@ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
|
72
80
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
73
81
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
74
82
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
75
|
-
from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_hip
|
83
|
+
from sglang.srt.utils import BumpAllocator, DeepEPMode, add_prefix, is_cuda, is_hip
|
76
84
|
|
77
85
|
_is_hip = is_hip()
|
78
86
|
_is_cuda = is_cuda()
|
79
87
|
|
80
88
|
if _is_cuda:
|
81
|
-
from
|
89
|
+
from deep_gemm import m_grouped_gemm_fp8_fp8_bf16_nt_masked
|
90
|
+
from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
|
82
91
|
else:
|
83
|
-
from vllm import
|
92
|
+
from vllm._custom_ops import awq_dequantize
|
84
93
|
|
85
94
|
if _is_hip:
|
86
95
|
from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
|
@@ -92,6 +101,18 @@ expert_distribution_recorder = ExpertDistributionRecorder()
|
|
92
101
|
logger = logging.getLogger(__name__)
|
93
102
|
|
94
103
|
|
104
|
+
class AttnForwardMethod(IntEnum):
|
105
|
+
# Use multi-head attention
|
106
|
+
MHA = auto()
|
107
|
+
|
108
|
+
# Use absorbed multi-latent attention
|
109
|
+
MLA = auto()
|
110
|
+
|
111
|
+
# Use multi-head attention, but with KV cache chunked.
|
112
|
+
# This method can avoid OOM when prefix lengths are long.
|
113
|
+
MHA_CHUNKED_KV = auto()
|
114
|
+
|
115
|
+
|
95
116
|
class DeepseekV2MLP(nn.Module):
|
96
117
|
def __init__(
|
97
118
|
self,
|
@@ -131,7 +152,7 @@ class DeepseekV2MLP(nn.Module):
|
|
131
152
|
)
|
132
153
|
self.act_fn = SiluAndMul()
|
133
154
|
|
134
|
-
def forward(self, x):
|
155
|
+
def forward(self, x, forward_mode: Optional[ForwardMode] = None):
|
135
156
|
gate_up, _ = self.gate_up_proj(x)
|
136
157
|
x = self.act_fn(gate_up)
|
137
158
|
x, _ = self.down_proj(x)
|
@@ -172,13 +193,8 @@ class DeepseekV2MoE(nn.Module):
|
|
172
193
|
self.tp_size = get_tensor_model_parallel_world_size()
|
173
194
|
self.routed_scaling_factor = config.routed_scaling_factor
|
174
195
|
self.n_shared_experts = config.n_shared_experts
|
175
|
-
self.n_share_experts_fusion =
|
176
|
-
global_server_args_dict["n_share_experts_fusion"]
|
177
|
-
if global_server_args_dict["n_share_experts_fusion"] is not None
|
178
|
-
else 0
|
179
|
-
)
|
196
|
+
self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
|
180
197
|
|
181
|
-
self.routed_scaling_factor = config.routed_scaling_factor
|
182
198
|
if self.tp_size > config.n_routed_experts:
|
183
199
|
raise ValueError(
|
184
200
|
f"Tensor parallel size {self.tp_size} is greater than "
|
@@ -210,6 +226,7 @@ class DeepseekV2MoE(nn.Module):
|
|
210
226
|
num_expert_group=config.n_group,
|
211
227
|
topk_group=config.topk_group,
|
212
228
|
correction_bias=self.gate.e_score_correction_bias,
|
229
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
213
230
|
prefix=add_prefix("experts", prefix),
|
214
231
|
**(
|
215
232
|
dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
|
@@ -278,10 +295,7 @@ class DeepseekV2MoE(nn.Module):
|
|
278
295
|
return self.forward_deepep(hidden_states, forward_mode)
|
279
296
|
|
280
297
|
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
281
|
-
|
282
|
-
shared_output = self.shared_experts(hidden_states)
|
283
|
-
else:
|
284
|
-
shared_output = None
|
298
|
+
shared_output = self._forward_shared_experts(hidden_states)
|
285
299
|
# router_logits: (num_tokens, n_experts)
|
286
300
|
router_logits = self.gate(hidden_states)
|
287
301
|
final_hidden_states = (
|
@@ -311,8 +325,7 @@ class DeepseekV2MoE(nn.Module):
|
|
311
325
|
):
|
312
326
|
# router_logits: (num_tokens, n_experts)
|
313
327
|
router_logits = self.gate(hidden_states)
|
314
|
-
|
315
|
-
shared_output = self.shared_experts(hidden_states)
|
328
|
+
shared_output = self._forward_shared_experts(hidden_states)
|
316
329
|
topk_weights, topk_idx = select_experts(
|
317
330
|
hidden_states=hidden_states,
|
318
331
|
router_logits=router_logits,
|
@@ -322,8 +335,10 @@ class DeepseekV2MoE(nn.Module):
|
|
322
335
|
topk_group=self.topk_group,
|
323
336
|
num_expert_group=self.num_expert_group,
|
324
337
|
correction_bias=self.correction_bias,
|
338
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
325
339
|
)
|
326
340
|
if self.ep_size > 1:
|
341
|
+
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
|
327
342
|
(
|
328
343
|
hidden_states,
|
329
344
|
topk_idx,
|
@@ -336,19 +351,15 @@ class DeepseekV2MoE(nn.Module):
|
|
336
351
|
hidden_states,
|
337
352
|
topk_idx,
|
338
353
|
topk_weights,
|
339
|
-
self.num_experts,
|
340
354
|
forward_mode=forward_mode,
|
341
355
|
)
|
342
|
-
final_hidden_states = (
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
forward_mode=forward_mode,
|
350
|
-
)
|
351
|
-
* self.routed_scaling_factor
|
356
|
+
final_hidden_states = self.experts(
|
357
|
+
hidden_states=hidden_states,
|
358
|
+
reorder_topk_ids=reorder_topk_ids,
|
359
|
+
seg_indptr=seg_indptr,
|
360
|
+
masked_m=masked_m,
|
361
|
+
expected_m=expected_m,
|
362
|
+
forward_mode=forward_mode,
|
352
363
|
)
|
353
364
|
if self.ep_size > 1:
|
354
365
|
final_hidden_states = self.deepep_dispatcher.combine(
|
@@ -357,11 +368,19 @@ class DeepseekV2MoE(nn.Module):
|
|
357
368
|
topk_weights,
|
358
369
|
forward_mode,
|
359
370
|
)
|
371
|
+
final_hidden_states *= self.routed_scaling_factor
|
372
|
+
|
360
373
|
if shared_output is not None:
|
361
374
|
final_hidden_states = final_hidden_states + shared_output
|
362
375
|
|
363
376
|
return final_hidden_states
|
364
377
|
|
378
|
+
def _forward_shared_experts(self, hidden_states):
|
379
|
+
if self.n_share_experts_fusion == 0:
|
380
|
+
return self.shared_experts(hidden_states)
|
381
|
+
else:
|
382
|
+
return None
|
383
|
+
|
365
384
|
|
366
385
|
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
367
386
|
import math
|
@@ -371,178 +390,6 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
|
371
390
|
return 0.1 * mscale * math.log(scale) + 1.0
|
372
391
|
|
373
392
|
|
374
|
-
class DeepseekV2Attention(nn.Module):
|
375
|
-
|
376
|
-
def __init__(
|
377
|
-
self,
|
378
|
-
config: PretrainedConfig,
|
379
|
-
hidden_size: int,
|
380
|
-
num_heads: int,
|
381
|
-
qk_nope_head_dim: int,
|
382
|
-
qk_rope_head_dim: int,
|
383
|
-
v_head_dim: int,
|
384
|
-
q_lora_rank: int,
|
385
|
-
kv_lora_rank: int,
|
386
|
-
rope_theta: float = 10000,
|
387
|
-
rope_scaling: Optional[Dict[str, Any]] = None,
|
388
|
-
max_position_embeddings: int = 8192,
|
389
|
-
quant_config: Optional[QuantizationConfig] = None,
|
390
|
-
layer_id=None,
|
391
|
-
reduce_results: bool = True,
|
392
|
-
prefix: str = "",
|
393
|
-
) -> None:
|
394
|
-
super().__init__()
|
395
|
-
self.layer_id = layer_id
|
396
|
-
self.hidden_size = hidden_size
|
397
|
-
self.qk_nope_head_dim = qk_nope_head_dim
|
398
|
-
self.qk_rope_head_dim = qk_rope_head_dim
|
399
|
-
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
400
|
-
self.v_head_dim = v_head_dim
|
401
|
-
self.q_lora_rank = q_lora_rank
|
402
|
-
self.kv_lora_rank = kv_lora_rank
|
403
|
-
|
404
|
-
self.dp_size = get_attention_dp_size()
|
405
|
-
attn_tp_rank = get_attention_tp_rank()
|
406
|
-
attn_tp_size = get_attention_tp_size()
|
407
|
-
|
408
|
-
self.num_heads = num_heads
|
409
|
-
assert num_heads % attn_tp_size == 0
|
410
|
-
self.num_local_heads = num_heads // attn_tp_size
|
411
|
-
self.scaling = self.qk_head_dim**-0.5
|
412
|
-
self.rope_theta = rope_theta
|
413
|
-
self.max_position_embeddings = max_position_embeddings
|
414
|
-
|
415
|
-
if self.q_lora_rank is not None:
|
416
|
-
self.q_a_proj = ReplicatedLinear(
|
417
|
-
self.hidden_size,
|
418
|
-
self.q_lora_rank,
|
419
|
-
bias=False,
|
420
|
-
quant_config=quant_config,
|
421
|
-
prefix=add_prefix("q_a_proj", prefix),
|
422
|
-
)
|
423
|
-
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
424
|
-
self.q_b_proj = ColumnParallelLinear(
|
425
|
-
q_lora_rank,
|
426
|
-
self.num_heads * self.qk_head_dim,
|
427
|
-
bias=False,
|
428
|
-
quant_config=quant_config,
|
429
|
-
prefix=add_prefix("q_b_proj", prefix),
|
430
|
-
)
|
431
|
-
else:
|
432
|
-
self.q_proj = ColumnParallelLinear(
|
433
|
-
self.hidden_size,
|
434
|
-
self.num_heads * self.qk_head_dim,
|
435
|
-
bias=False,
|
436
|
-
quant_config=quant_config,
|
437
|
-
prefix=add_prefix("q_proj", prefix),
|
438
|
-
tp_rank=attn_tp_rank,
|
439
|
-
tp_size=attn_tp_size,
|
440
|
-
)
|
441
|
-
|
442
|
-
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
443
|
-
self.hidden_size,
|
444
|
-
self.kv_lora_rank + self.qk_rope_head_dim,
|
445
|
-
bias=False,
|
446
|
-
quant_config=quant_config,
|
447
|
-
prefix=add_prefix("kv_a_proj_with_mqa", prefix),
|
448
|
-
)
|
449
|
-
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
450
|
-
self.kv_b_proj = ColumnParallelLinear(
|
451
|
-
self.kv_lora_rank,
|
452
|
-
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
453
|
-
bias=False,
|
454
|
-
quant_config=quant_config,
|
455
|
-
prefix=add_prefix("kv_b_proj", prefix),
|
456
|
-
)
|
457
|
-
# O projection.
|
458
|
-
self.o_proj = RowParallelLinear(
|
459
|
-
self.num_heads * self.v_head_dim,
|
460
|
-
self.hidden_size,
|
461
|
-
bias=False,
|
462
|
-
quant_config=quant_config,
|
463
|
-
prefix=add_prefix("o_proj", prefix),
|
464
|
-
reduce_results=reduce_results,
|
465
|
-
tp_rank=attn_tp_rank,
|
466
|
-
tp_size=attn_tp_size,
|
467
|
-
)
|
468
|
-
rope_scaling["rope_type"] = "deepseek_yarn"
|
469
|
-
self.rotary_emb = get_rope_wrapper(
|
470
|
-
qk_rope_head_dim,
|
471
|
-
rotary_dim=qk_rope_head_dim,
|
472
|
-
max_position=max_position_embeddings,
|
473
|
-
base=rope_theta,
|
474
|
-
rope_scaling=rope_scaling,
|
475
|
-
is_neox_style=False,
|
476
|
-
device=global_server_args_dict["device"],
|
477
|
-
)
|
478
|
-
|
479
|
-
if rope_scaling:
|
480
|
-
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
|
481
|
-
scaling_factor = rope_scaling["factor"]
|
482
|
-
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
483
|
-
self.scaling = self.scaling * mscale * mscale
|
484
|
-
|
485
|
-
# TODO, support head_size 192
|
486
|
-
self.attn = RadixAttention(
|
487
|
-
self.num_local_heads,
|
488
|
-
256,
|
489
|
-
self.scaling,
|
490
|
-
num_kv_heads=self.num_local_heads,
|
491
|
-
layer_id=layer_id,
|
492
|
-
prefix=add_prefix("attn", prefix),
|
493
|
-
)
|
494
|
-
|
495
|
-
def forward(
|
496
|
-
self,
|
497
|
-
positions: torch.Tensor,
|
498
|
-
hidden_states: torch.Tensor,
|
499
|
-
forward_batch: ForwardBatch,
|
500
|
-
) -> torch.Tensor:
|
501
|
-
if hidden_states.shape[0] == 0:
|
502
|
-
assert (
|
503
|
-
not self.o_proj.reduce_results
|
504
|
-
), "short-circuiting allreduce will lead to hangs"
|
505
|
-
return hidden_states
|
506
|
-
|
507
|
-
if self.q_lora_rank is not None:
|
508
|
-
q = self.q_a_proj(hidden_states)[0]
|
509
|
-
q = self.q_a_layernorm(q)
|
510
|
-
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
511
|
-
else:
|
512
|
-
q = self.q_proj(hidden_states)[0].view(
|
513
|
-
-1, self.num_local_heads, self.qk_head_dim
|
514
|
-
)
|
515
|
-
_, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
516
|
-
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
517
|
-
kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
518
|
-
latent_cache = latent_cache.unsqueeze(1)
|
519
|
-
kv_a = self.kv_a_layernorm(kv_a.contiguous())
|
520
|
-
kv = self.kv_b_proj(kv_a)[0]
|
521
|
-
kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
|
522
|
-
k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
523
|
-
k_pe = latent_cache[:, :, self.kv_lora_rank :]
|
524
|
-
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
525
|
-
q[..., self.qk_nope_head_dim :] = q_pe
|
526
|
-
k = torch.empty_like(q)
|
527
|
-
k[..., : self.qk_nope_head_dim] = k_nope
|
528
|
-
k[..., self.qk_nope_head_dim :] = k_pe
|
529
|
-
q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim], value=0).view(
|
530
|
-
-1, self.num_local_heads * 256
|
531
|
-
)
|
532
|
-
k = torch.nn.functional.pad(k, [0, 256 - self.qk_head_dim], value=0).view(
|
533
|
-
-1, self.num_local_heads * 256
|
534
|
-
)
|
535
|
-
v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim], value=0).view(
|
536
|
-
-1, self.num_local_heads * 256
|
537
|
-
)
|
538
|
-
attn_output = self.attn(q, k, v, forward_batch)
|
539
|
-
attn_output = attn_output.view(-1, self.num_local_heads, 256)[
|
540
|
-
..., : self.v_head_dim
|
541
|
-
].reshape(-1, self.num_local_heads * self.v_head_dim)
|
542
|
-
output, _ = self.o_proj(attn_output)
|
543
|
-
return output
|
544
|
-
|
545
|
-
|
546
393
|
class DeepseekV2AttentionMLA(nn.Module):
|
547
394
|
|
548
395
|
def __init__(
|
@@ -669,6 +516,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
669
516
|
num_kv_heads=1,
|
670
517
|
layer_id=layer_id,
|
671
518
|
v_head_dim=self.kv_lora_rank,
|
519
|
+
quant_config=quant_config,
|
672
520
|
prefix=add_prefix("attn_mqa", prefix),
|
673
521
|
)
|
674
522
|
|
@@ -679,6 +527,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
679
527
|
num_kv_heads=self.num_local_heads,
|
680
528
|
layer_id=layer_id,
|
681
529
|
v_head_dim=self.v_head_dim,
|
530
|
+
quant_config=quant_config,
|
682
531
|
prefix=add_prefix("attn_mha", prefix),
|
683
532
|
)
|
684
533
|
|
@@ -686,39 +535,68 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
686
535
|
self.w_vc = None
|
687
536
|
self.w_scale = None
|
688
537
|
|
538
|
+
self.w_scale_k = None
|
539
|
+
self.w_scale_v = None
|
540
|
+
self.use_deep_gemm_bmm = False
|
541
|
+
|
689
542
|
self.flashinfer_mla_disable_ragged = global_server_args_dict[
|
690
543
|
"flashinfer_mla_disable_ragged"
|
691
544
|
]
|
545
|
+
self.disable_chunked_prefix_cache = global_server_args_dict[
|
546
|
+
"disable_chunked_prefix_cache"
|
547
|
+
]
|
692
548
|
self.attention_backend = global_server_args_dict["attention_backend"]
|
693
549
|
self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
|
694
550
|
|
695
|
-
|
551
|
+
# TODO: Design a finer way to determine the threshold
|
552
|
+
self.chunked_prefix_cache_threshold = 8192
|
553
|
+
|
554
|
+
def dispatch_attn_forward_method(
|
555
|
+
self, forward_batch: ForwardBatch
|
556
|
+
) -> AttnForwardMethod:
|
696
557
|
if self.attention_backend == "flashinfer":
|
697
558
|
# Flashinfer MLA: Do not absorb when enabling ragged prefill
|
698
|
-
|
559
|
+
if (
|
699
560
|
not self.flashinfer_mla_disable_ragged
|
700
561
|
and forward_batch.forward_mode.is_extend()
|
701
562
|
and not forward_batch.forward_mode.is_target_verify()
|
702
563
|
and not forward_batch.forward_mode.is_draft_extend()
|
703
564
|
and sum(forward_batch.extend_prefix_lens_cpu) == 0
|
704
|
-
)
|
565
|
+
):
|
566
|
+
return AttnForwardMethod.MHA
|
567
|
+
else:
|
568
|
+
return AttnForwardMethod.MLA
|
705
569
|
elif self.attention_backend == "fa3":
|
706
|
-
# Flash Attention:
|
707
|
-
|
570
|
+
# Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences.
|
571
|
+
if (
|
572
|
+
forward_batch.forward_mode.is_extend()
|
573
|
+
and not self.disable_chunked_prefix_cache
|
574
|
+
and not forward_batch.forward_mode.is_target_verify()
|
575
|
+
and not forward_batch.forward_mode.is_draft_extend()
|
576
|
+
and sum(forward_batch.extend_prefix_lens_cpu)
|
577
|
+
>= self.chunked_prefix_cache_threshold
|
578
|
+
):
|
579
|
+
return AttnForwardMethod.MHA_CHUNKED_KV
|
580
|
+
else:
|
581
|
+
return AttnForwardMethod.MLA
|
708
582
|
else:
|
709
583
|
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
|
710
|
-
|
584
|
+
if (
|
711
585
|
forward_batch.forward_mode.is_extend()
|
712
586
|
and not forward_batch.forward_mode.is_target_verify()
|
713
587
|
and not forward_batch.forward_mode.is_draft_extend()
|
714
588
|
and sum(forward_batch.extend_prefix_lens_cpu) == 0
|
715
|
-
)
|
589
|
+
):
|
590
|
+
return AttnForwardMethod.MHA
|
591
|
+
else:
|
592
|
+
return AttnForwardMethod.MLA
|
716
593
|
|
717
594
|
def forward(
|
718
595
|
self,
|
719
596
|
positions: torch.Tensor,
|
720
597
|
hidden_states: torch.Tensor,
|
721
598
|
forward_batch: ForwardBatch,
|
599
|
+
zero_allocator: BumpAllocator,
|
722
600
|
) -> torch.Tensor:
|
723
601
|
if hidden_states.shape[0] == 0:
|
724
602
|
assert (
|
@@ -726,8 +604,14 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
726
604
|
), "short-circuiting allreduce will lead to hangs"
|
727
605
|
return hidden_states
|
728
606
|
|
729
|
-
|
607
|
+
attn_forward_method = self.dispatch_attn_forward_method(forward_batch)
|
608
|
+
|
609
|
+
if attn_forward_method == AttnForwardMethod.MHA:
|
730
610
|
return self.forward_normal(positions, hidden_states, forward_batch)
|
611
|
+
elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV:
|
612
|
+
return self.forward_normal_chunked_kv(
|
613
|
+
positions, hidden_states, forward_batch
|
614
|
+
)
|
731
615
|
else:
|
732
616
|
if _is_hip:
|
733
617
|
if (
|
@@ -738,9 +622,13 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
738
622
|
positions, hidden_states, forward_batch
|
739
623
|
)
|
740
624
|
else:
|
741
|
-
return self.forward_absorb(
|
625
|
+
return self.forward_absorb(
|
626
|
+
positions, hidden_states, forward_batch, zero_allocator
|
627
|
+
)
|
742
628
|
else:
|
743
|
-
return self.forward_absorb(
|
629
|
+
return self.forward_absorb(
|
630
|
+
positions, hidden_states, forward_batch, zero_allocator
|
631
|
+
)
|
744
632
|
|
745
633
|
def forward_normal(
|
746
634
|
self,
|
@@ -789,6 +677,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
789
677
|
positions: torch.Tensor,
|
790
678
|
hidden_states: torch.Tensor,
|
791
679
|
forward_batch: ForwardBatch,
|
680
|
+
zero_allocator: BumpAllocator,
|
792
681
|
) -> torch.Tensor:
|
793
682
|
q_len = hidden_states.shape[0]
|
794
683
|
q_input = hidden_states.new_empty(
|
@@ -804,15 +693,33 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
804
693
|
)
|
805
694
|
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
806
695
|
|
807
|
-
if self.
|
696
|
+
if self.use_deep_gemm_bmm:
|
697
|
+
q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
|
698
|
+
per_tensor_quant_mla_deep_gemm_masked_fp8(
|
699
|
+
q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn
|
700
|
+
)
|
701
|
+
)
|
702
|
+
q_nope_out = q_nope.new_empty(
|
703
|
+
(self.num_local_heads, aligned_m, self.kv_lora_rank)
|
704
|
+
)
|
705
|
+
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
706
|
+
(q_nope_val, q_nope_scale),
|
707
|
+
(self.w_kc, self.w_scale_k),
|
708
|
+
q_nope_out,
|
709
|
+
masked_m,
|
710
|
+
expected_m,
|
711
|
+
)
|
712
|
+
q_nope_out = q_nope_out[:, :expected_m, :]
|
713
|
+
elif self.w_kc.dtype == torch.float8_e4m3fnuz:
|
808
714
|
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
|
809
715
|
q_nope_out = torch.bmm(
|
810
716
|
q_nope.to(torch.bfloat16).transpose(0, 1),
|
811
717
|
self.w_kc.to(torch.bfloat16) * self.w_scale,
|
812
718
|
)
|
813
719
|
elif self.w_kc.dtype == torch.float8_e4m3fn:
|
814
|
-
q_nope_val, q_nope_scale =
|
815
|
-
q_nope.transpose(0, 1),
|
720
|
+
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
|
721
|
+
q_nope.transpose(0, 1),
|
722
|
+
zero_allocator.allocate(1),
|
816
723
|
)
|
817
724
|
q_nope_out = bmm_fp8(
|
818
725
|
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
|
@@ -835,15 +742,33 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
835
742
|
attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
|
836
743
|
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
837
744
|
|
838
|
-
if self.
|
745
|
+
if self.use_deep_gemm_bmm:
|
746
|
+
attn_output_val, attn_output_scale, masked_m, expected_m, aligned_m = (
|
747
|
+
per_tensor_quant_mla_deep_gemm_masked_fp8(
|
748
|
+
attn_output.transpose(0, 1), dtype=torch.float8_e4m3fn
|
749
|
+
)
|
750
|
+
)
|
751
|
+
attn_bmm_output = attn_output.new_empty(
|
752
|
+
(self.num_local_heads, aligned_m, self.v_head_dim)
|
753
|
+
)
|
754
|
+
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
755
|
+
(attn_output_val, attn_output_scale),
|
756
|
+
(self.w_vc, self.w_scale_v),
|
757
|
+
attn_bmm_output,
|
758
|
+
masked_m,
|
759
|
+
expected_m,
|
760
|
+
)
|
761
|
+
attn_bmm_output = attn_bmm_output[:, :expected_m, :]
|
762
|
+
elif self.w_vc.dtype == torch.float8_e4m3fnuz:
|
839
763
|
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
|
840
764
|
attn_bmm_output = torch.bmm(
|
841
765
|
attn_output.to(torch.bfloat16).transpose(0, 1),
|
842
766
|
self.w_vc.to(torch.bfloat16) * self.w_scale,
|
843
767
|
)
|
844
768
|
elif self.w_vc.dtype == torch.float8_e4m3fn:
|
845
|
-
attn_output_val, attn_output_scale =
|
846
|
-
attn_output.transpose(0, 1),
|
769
|
+
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
|
770
|
+
attn_output.transpose(0, 1),
|
771
|
+
zero_allocator.allocate(1),
|
847
772
|
)
|
848
773
|
attn_bmm_output = bmm_fp8(
|
849
774
|
attn_output_val,
|
@@ -864,6 +789,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
864
789
|
positions: torch.Tensor,
|
865
790
|
hidden_states: torch.Tensor,
|
866
791
|
forward_batch: ForwardBatch,
|
792
|
+
zero_allocator: BumpAllocator,
|
867
793
|
) -> torch.Tensor:
|
868
794
|
enable_rope_fusion = (
|
869
795
|
os.getenv("SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION", "1") == "1"
|
@@ -889,8 +815,10 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
889
815
|
self.w_kc.to(torch.bfloat16) * self.w_scale,
|
890
816
|
)
|
891
817
|
elif self.w_kc.dtype == torch.float8_e4m3fn:
|
892
|
-
q_nope_val, q_nope_scale =
|
893
|
-
q_nope.transpose(0, 1),
|
818
|
+
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
|
819
|
+
q_nope.transpose(0, 1),
|
820
|
+
zero_allocator.allocate(1),
|
821
|
+
dtype=torch.float8_e4m3fn,
|
894
822
|
)
|
895
823
|
q_nope_out = bmm_fp8(
|
896
824
|
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
|
@@ -985,8 +913,10 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
985
913
|
self.w_vc.to(torch.bfloat16) * self.w_scale,
|
986
914
|
)
|
987
915
|
elif self.w_vc.dtype == torch.float8_e4m3fn:
|
988
|
-
attn_output_val, attn_output_scale =
|
989
|
-
attn_output.transpose(0, 1),
|
916
|
+
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
|
917
|
+
attn_output.transpose(0, 1),
|
918
|
+
zero_allocator.allocate(1),
|
919
|
+
dtype=torch.float8_e4m3fn,
|
990
920
|
)
|
991
921
|
attn_bmm_output = bmm_fp8(
|
992
922
|
attn_output_val,
|
@@ -1002,6 +932,140 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1002
932
|
|
1003
933
|
return output
|
1004
934
|
|
935
|
+
def _chunked_prefix_attn_mha(
|
936
|
+
self,
|
937
|
+
q: torch.Tensor,
|
938
|
+
accum_output: torch.Tensor,
|
939
|
+
accum_lse: torch.Tensor,
|
940
|
+
forward_batch: ForwardBatch,
|
941
|
+
) -> torch.Tensor:
|
942
|
+
|
943
|
+
assert forward_batch.num_prefix_chunks is not None
|
944
|
+
for i in range(forward_batch.num_prefix_chunks):
|
945
|
+
forward_batch.set_prefix_chunk_idx(i)
|
946
|
+
|
947
|
+
# Fetch latent cache from memory pool with precomputed chunked kv indices
|
948
|
+
latent_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
|
949
|
+
self.attn_mha.layer_id
|
950
|
+
)
|
951
|
+
latent_cache = latent_cache_buf[
|
952
|
+
forward_batch.prefix_chunk_kv_indices[i]
|
953
|
+
].contiguous()
|
954
|
+
|
955
|
+
kv_a_normed, k_pe = latent_cache.split(
|
956
|
+
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
957
|
+
)
|
958
|
+
kv_a_normed = kv_a_normed.squeeze(1).contiguous()
|
959
|
+
kv = self.kv_b_proj(kv_a_normed)[0]
|
960
|
+
kv = kv.view(
|
961
|
+
-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim
|
962
|
+
)
|
963
|
+
v = kv[..., self.qk_nope_head_dim :]
|
964
|
+
k_nope = kv[..., : self.qk_nope_head_dim]
|
965
|
+
|
966
|
+
k = torch.empty(
|
967
|
+
(
|
968
|
+
k_nope.shape[0],
|
969
|
+
self.num_local_heads,
|
970
|
+
self.qk_nope_head_dim + self.qk_rope_head_dim,
|
971
|
+
),
|
972
|
+
dtype=v.dtype,
|
973
|
+
device=v.device,
|
974
|
+
)
|
975
|
+
k[..., : self.qk_nope_head_dim] = k_nope
|
976
|
+
k[..., self.qk_nope_head_dim :] = k_pe
|
977
|
+
|
978
|
+
output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
|
979
|
+
lse = torch.transpose(lse, 0, 1).contiguous()
|
980
|
+
tmp_output = torch.empty_like(accum_output)
|
981
|
+
tmp_lse = torch.empty_like(accum_lse)
|
982
|
+
merge_state_v2(output, lse, accum_output, accum_lse, tmp_output, tmp_lse)
|
983
|
+
accum_output, accum_lse = tmp_output, tmp_lse
|
984
|
+
|
985
|
+
return accum_output
|
986
|
+
|
987
|
+
def forward_normal_chunked_kv(
|
988
|
+
self,
|
989
|
+
positions: torch.Tensor,
|
990
|
+
hidden_states: torch.Tensor,
|
991
|
+
forward_batch: ForwardBatch,
|
992
|
+
) -> torch.Tensor:
|
993
|
+
# In normal mha, the k and v tensors will become overly large when the prefix length is long.
|
994
|
+
# To avoid this, we split the kv cache into chunks and process them one after another.
|
995
|
+
# Since mha is compute friendly, the for loop induced here will not introduce significant overhead.
|
996
|
+
# The top comments in https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/mla/common.py
|
997
|
+
# will be helpful for understanding the purpose of this function.
|
998
|
+
|
999
|
+
# First do normal mha forward to get output for extended part
|
1000
|
+
if self.q_lora_rank is not None:
|
1001
|
+
q = self.q_a_proj(hidden_states)[0]
|
1002
|
+
q = self.q_a_layernorm(q)
|
1003
|
+
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
1004
|
+
else:
|
1005
|
+
q = self.q_proj(hidden_states)[0].view(
|
1006
|
+
-1, self.num_local_heads, self.qk_head_dim
|
1007
|
+
)
|
1008
|
+
_, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
1009
|
+
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
1010
|
+
kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
1011
|
+
latent_cache = latent_cache.unsqueeze(1)
|
1012
|
+
kv_a = self.kv_a_layernorm(kv_a.contiguous())
|
1013
|
+
kv = self.kv_b_proj(kv_a)[0]
|
1014
|
+
kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
|
1015
|
+
k_nope = kv[..., : self.qk_nope_head_dim]
|
1016
|
+
v = kv[..., self.qk_nope_head_dim :]
|
1017
|
+
k_pe = latent_cache[:, :, self.kv_lora_rank :]
|
1018
|
+
|
1019
|
+
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
1020
|
+
q[..., self.qk_nope_head_dim :] = q_pe
|
1021
|
+
k = torch.empty_like(q)
|
1022
|
+
k[..., : self.qk_nope_head_dim] = k_nope
|
1023
|
+
k[..., self.qk_nope_head_dim :] = k_pe
|
1024
|
+
|
1025
|
+
latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
|
1026
|
+
latent_cache[:, :, self.kv_lora_rank :] = k_pe
|
1027
|
+
|
1028
|
+
# Save latent cache
|
1029
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
1030
|
+
self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
|
1031
|
+
)
|
1032
|
+
|
1033
|
+
# Do mha for extended part without prefix
|
1034
|
+
forward_batch.set_attn_attend_prefix_cache(False)
|
1035
|
+
attn_output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
|
1036
|
+
lse = torch.transpose(lse, 0, 1).contiguous()
|
1037
|
+
|
1038
|
+
# Do mha attention with chunked prefix cache if there are any sequence with prefix
|
1039
|
+
if any(forward_batch.extend_prefix_lens_cpu):
|
1040
|
+
# Only initialize the info once
|
1041
|
+
if forward_batch.num_prefix_chunks is None:
|
1042
|
+
forward_batch.prepare_chunked_prefix_cache_info(q.device)
|
1043
|
+
|
1044
|
+
forward_batch.set_attn_attend_prefix_cache(True)
|
1045
|
+
attn_output = self._chunked_prefix_attn_mha(
|
1046
|
+
q=q,
|
1047
|
+
accum_output=attn_output,
|
1048
|
+
accum_lse=lse,
|
1049
|
+
forward_batch=forward_batch,
|
1050
|
+
)
|
1051
|
+
|
1052
|
+
attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
|
1053
|
+
output, _ = self.o_proj(attn_output)
|
1054
|
+
return output
|
1055
|
+
|
1056
|
+
|
1057
|
+
class _FFNInputMode(Enum):
|
1058
|
+
# The MLP sublayer requires 1/tp_size tokens as input
|
1059
|
+
SCATTERED = auto()
|
1060
|
+
# The MLP sublayer requires all tokens as input
|
1061
|
+
FULL = auto()
|
1062
|
+
|
1063
|
+
|
1064
|
+
@dataclass
|
1065
|
+
class _DecoderLayerInfo:
|
1066
|
+
is_sparse: bool
|
1067
|
+
ffn_input_mode: _FFNInputMode
|
1068
|
+
|
1005
1069
|
|
1006
1070
|
class DeepseekV2DecoderLayer(nn.Module):
|
1007
1071
|
|
@@ -1013,14 +1077,6 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1013
1077
|
is_nextn: bool = False,
|
1014
1078
|
prefix: str = "",
|
1015
1079
|
) -> None:
|
1016
|
-
|
1017
|
-
def is_sparse_layer(l: int):
|
1018
|
-
return (
|
1019
|
-
config.n_routed_experts is not None
|
1020
|
-
and l >= config.first_k_dense_replace
|
1021
|
-
and l % config.moe_layer_freq == 0
|
1022
|
-
)
|
1023
|
-
|
1024
1080
|
super().__init__()
|
1025
1081
|
self.hidden_size = config.hidden_size
|
1026
1082
|
rope_theta = getattr(config, "rope_theta", 10000)
|
@@ -1031,68 +1087,54 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1031
1087
|
self.dp_size = get_attention_dp_size()
|
1032
1088
|
self.attn_tp_size = get_attention_tp_size()
|
1033
1089
|
self.attn_tp_rank = get_attention_tp_rank()
|
1090
|
+
self.self_attn = DeepseekV2AttentionMLA(
|
1091
|
+
config=config,
|
1092
|
+
hidden_size=self.hidden_size,
|
1093
|
+
num_heads=config.num_attention_heads,
|
1094
|
+
qk_nope_head_dim=config.qk_nope_head_dim,
|
1095
|
+
qk_rope_head_dim=config.qk_rope_head_dim,
|
1096
|
+
v_head_dim=config.v_head_dim,
|
1097
|
+
q_lora_rank=(
|
1098
|
+
config.q_lora_rank if hasattr(config, "q_lora_rank") else None
|
1099
|
+
),
|
1100
|
+
kv_lora_rank=config.kv_lora_rank,
|
1101
|
+
rope_theta=rope_theta,
|
1102
|
+
rope_scaling=rope_scaling,
|
1103
|
+
max_position_embeddings=max_position_embeddings,
|
1104
|
+
quant_config=quant_config,
|
1105
|
+
layer_id=layer_id,
|
1106
|
+
reduce_results=False,
|
1107
|
+
prefix=add_prefix("self_attn", prefix),
|
1108
|
+
)
|
1034
1109
|
|
1035
|
-
|
1036
|
-
|
1037
|
-
|
1038
|
-
|
1039
|
-
num_heads=config.num_attention_heads,
|
1040
|
-
qk_nope_head_dim=config.qk_nope_head_dim,
|
1041
|
-
qk_rope_head_dim=config.qk_rope_head_dim,
|
1042
|
-
v_head_dim=config.v_head_dim,
|
1043
|
-
q_lora_rank=(
|
1044
|
-
config.q_lora_rank if hasattr(config, "q_lora_rank") else None
|
1045
|
-
),
|
1046
|
-
kv_lora_rank=config.kv_lora_rank,
|
1047
|
-
rope_theta=rope_theta,
|
1048
|
-
rope_scaling=rope_scaling,
|
1049
|
-
max_position_embeddings=max_position_embeddings,
|
1050
|
-
quant_config=quant_config,
|
1051
|
-
layer_id=layer_id,
|
1052
|
-
reduce_results=False,
|
1053
|
-
prefix=add_prefix("self_attn", prefix),
|
1054
|
-
)
|
1055
|
-
else:
|
1056
|
-
self.self_attn = DeepseekV2Attention(
|
1057
|
-
config=config,
|
1058
|
-
hidden_size=self.hidden_size,
|
1059
|
-
num_heads=config.num_attention_heads,
|
1060
|
-
qk_nope_head_dim=config.qk_nope_head_dim,
|
1061
|
-
qk_rope_head_dim=config.qk_rope_head_dim,
|
1062
|
-
v_head_dim=config.v_head_dim,
|
1063
|
-
q_lora_rank=(
|
1064
|
-
config.q_lora_rank if hasattr(config, "q_lora_rank") else None
|
1065
|
-
),
|
1066
|
-
kv_lora_rank=config.kv_lora_rank,
|
1067
|
-
rope_theta=rope_theta,
|
1068
|
-
rope_scaling=rope_scaling,
|
1069
|
-
max_position_embeddings=max_position_embeddings,
|
1070
|
-
quant_config=quant_config,
|
1071
|
-
layer_id=layer_id,
|
1072
|
-
reduce_results=False,
|
1073
|
-
prefix=add_prefix("self_attn", prefix),
|
1074
|
-
)
|
1110
|
+
self.info = self._compute_info(config, layer_id=layer_id, is_nextn=is_nextn)
|
1111
|
+
previous_layer_info = self._compute_info(
|
1112
|
+
config, layer_id=layer_id - 1, is_nextn=False
|
1113
|
+
)
|
1075
1114
|
|
1076
|
-
if
|
1115
|
+
if self.info.is_sparse:
|
1077
1116
|
self.mlp = DeepseekV2MoE(
|
1078
1117
|
config=config,
|
1079
1118
|
quant_config=quant_config,
|
1080
1119
|
prefix=add_prefix("mlp", prefix),
|
1081
1120
|
)
|
1082
|
-
self.is_sparse = True
|
1083
1121
|
else:
|
1122
|
+
if self._enable_moe_dense_fully_dp():
|
1123
|
+
mlp_tp_rank, mlp_tp_size = 0, 1
|
1124
|
+
else:
|
1125
|
+
mlp_tp_rank, mlp_tp_size = None, None
|
1084
1126
|
self.mlp = DeepseekV2MLP(
|
1085
1127
|
hidden_size=config.hidden_size,
|
1086
1128
|
intermediate_size=config.intermediate_size,
|
1087
1129
|
hidden_act=config.hidden_act,
|
1088
1130
|
quant_config=quant_config,
|
1089
1131
|
prefix=add_prefix("mlp", prefix),
|
1132
|
+
tp_rank=mlp_tp_rank,
|
1133
|
+
tp_size=mlp_tp_size,
|
1090
1134
|
)
|
1091
|
-
self.is_sparse = False
|
1092
1135
|
|
1093
1136
|
self.input_is_scattered = (
|
1094
|
-
|
1095
|
-
and global_server_args_dict["enable_deepep_moe"]
|
1137
|
+
previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
|
1096
1138
|
)
|
1097
1139
|
self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
|
1098
1140
|
|
@@ -1101,28 +1143,51 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1101
1143
|
config.hidden_size, eps=config.rms_norm_eps
|
1102
1144
|
)
|
1103
1145
|
|
1146
|
+
@staticmethod
|
1147
|
+
def _enable_moe_dense_fully_dp():
|
1148
|
+
return global_server_args_dict["moe_dense_tp_size"] == 1
|
1149
|
+
|
1150
|
+
@staticmethod
|
1151
|
+
def _compute_info(config: PretrainedConfig, layer_id: int, is_nextn: bool):
|
1152
|
+
is_sparse = is_nextn or (
|
1153
|
+
config.n_routed_experts is not None
|
1154
|
+
and layer_id >= config.first_k_dense_replace
|
1155
|
+
and layer_id % config.moe_layer_freq == 0
|
1156
|
+
)
|
1157
|
+
ffn_input_mode = (
|
1158
|
+
_FFNInputMode.SCATTERED
|
1159
|
+
if (global_server_args_dict["enable_deepep_moe"] and is_sparse)
|
1160
|
+
or (DeepseekV2DecoderLayer._enable_moe_dense_fully_dp() and not is_sparse)
|
1161
|
+
else _FFNInputMode.FULL
|
1162
|
+
)
|
1163
|
+
return _DecoderLayerInfo(is_sparse=is_sparse, ffn_input_mode=ffn_input_mode)
|
1164
|
+
|
1104
1165
|
def forward(
|
1105
1166
|
self,
|
1106
1167
|
positions: torch.Tensor,
|
1107
1168
|
hidden_states: torch.Tensor,
|
1108
1169
|
forward_batch: ForwardBatch,
|
1109
1170
|
residual: Optional[torch.Tensor],
|
1171
|
+
zero_allocator: BumpAllocator,
|
1110
1172
|
) -> torch.Tensor:
|
1111
|
-
if
|
1112
|
-
return self.
|
1113
|
-
positions, hidden_states, forward_batch, residual
|
1173
|
+
if self.info.ffn_input_mode == _FFNInputMode.SCATTERED:
|
1174
|
+
return self.forward_ffn_with_scattered_input(
|
1175
|
+
positions, hidden_states, forward_batch, residual, zero_allocator
|
1114
1176
|
)
|
1115
|
-
|
1116
|
-
return self.
|
1117
|
-
positions, hidden_states, forward_batch, residual
|
1177
|
+
elif self.info.ffn_input_mode == _FFNInputMode.FULL:
|
1178
|
+
return self.forward_ffn_with_full_input(
|
1179
|
+
positions, hidden_states, forward_batch, residual, zero_allocator
|
1118
1180
|
)
|
1181
|
+
else:
|
1182
|
+
raise NotImplementedError
|
1119
1183
|
|
1120
|
-
def
|
1184
|
+
def forward_ffn_with_full_input(
|
1121
1185
|
self,
|
1122
1186
|
positions: torch.Tensor,
|
1123
1187
|
hidden_states: torch.Tensor,
|
1124
1188
|
forward_batch: ForwardBatch,
|
1125
1189
|
residual: Optional[torch.Tensor],
|
1190
|
+
zero_allocator: BumpAllocator,
|
1126
1191
|
) -> torch.Tensor:
|
1127
1192
|
|
1128
1193
|
if hidden_states.shape[0] == 0:
|
@@ -1143,6 +1208,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1143
1208
|
positions=positions,
|
1144
1209
|
hidden_states=hidden_states,
|
1145
1210
|
forward_batch=forward_batch,
|
1211
|
+
zero_allocator=zero_allocator,
|
1146
1212
|
)
|
1147
1213
|
|
1148
1214
|
# Gather
|
@@ -1184,12 +1250,13 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1184
1250
|
|
1185
1251
|
return hidden_states, residual
|
1186
1252
|
|
1187
|
-
def
|
1253
|
+
def forward_ffn_with_scattered_input(
|
1188
1254
|
self,
|
1189
1255
|
positions: torch.Tensor,
|
1190
1256
|
hidden_states: torch.Tensor,
|
1191
1257
|
forward_batch: ForwardBatch,
|
1192
1258
|
residual: Optional[torch.Tensor],
|
1259
|
+
zero_allocator: BumpAllocator,
|
1193
1260
|
) -> torch.Tensor:
|
1194
1261
|
|
1195
1262
|
if hidden_states.shape[0] == 0:
|
@@ -1215,6 +1282,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1215
1282
|
positions=positions,
|
1216
1283
|
hidden_states=hidden_states,
|
1217
1284
|
forward_batch=forward_batch,
|
1285
|
+
zero_allocator=zero_allocator,
|
1218
1286
|
)
|
1219
1287
|
|
1220
1288
|
if self.attn_tp_size != 1:
|
@@ -1240,7 +1308,13 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1240
1308
|
hidden_states, residual = self.post_attention_layernorm(
|
1241
1309
|
hidden_states, residual
|
1242
1310
|
)
|
1243
|
-
|
1311
|
+
|
1312
|
+
if not (
|
1313
|
+
self._enable_moe_dense_fully_dp()
|
1314
|
+
and (not self.info.is_sparse)
|
1315
|
+
and hidden_states.shape[0] == 0
|
1316
|
+
):
|
1317
|
+
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
|
1244
1318
|
|
1245
1319
|
if self.is_last_layer and self.attn_tp_size != 1:
|
1246
1320
|
hidden_states += residual
|
@@ -1296,6 +1370,14 @@ class DeepseekV2Model(nn.Module):
|
|
1296
1370
|
forward_batch: ForwardBatch,
|
1297
1371
|
input_embeds: torch.Tensor = None,
|
1298
1372
|
) -> torch.Tensor:
|
1373
|
+
zero_allocator = BumpAllocator(
|
1374
|
+
# TODO for two-batch-overlap, we need a larger buffer size
|
1375
|
+
buffer_size=len(self.layers) * 2,
|
1376
|
+
dtype=torch.float32,
|
1377
|
+
device=(
|
1378
|
+
input_embeds.device if input_embeds is not None else input_ids.device
|
1379
|
+
),
|
1380
|
+
)
|
1299
1381
|
|
1300
1382
|
if input_embeds is None:
|
1301
1383
|
hidden_states = self.embed_tokens(input_ids)
|
@@ -1307,7 +1389,7 @@ class DeepseekV2Model(nn.Module):
|
|
1307
1389
|
expert_distribution_recorder.set_current_layer(i)
|
1308
1390
|
layer = self.layers[i]
|
1309
1391
|
hidden_states, residual = layer(
|
1310
|
-
positions, hidden_states, forward_batch, residual
|
1392
|
+
positions, hidden_states, forward_batch, residual, zero_allocator
|
1311
1393
|
)
|
1312
1394
|
if not forward_batch.forward_mode.is_idle():
|
1313
1395
|
if residual is None:
|
@@ -1330,24 +1412,33 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1330
1412
|
self.tp_size = get_tensor_model_parallel_world_size()
|
1331
1413
|
self.quant_config = quant_config
|
1332
1414
|
self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
|
1333
|
-
|
1334
|
-
|
1335
|
-
|
1336
|
-
|
1337
|
-
|
1338
|
-
|
1339
|
-
|
1340
|
-
|
1341
|
-
|
1342
|
-
|
1343
|
-
|
1344
|
-
|
1345
|
-
|
1346
|
-
|
1347
|
-
|
1348
|
-
|
1349
|
-
|
1350
|
-
|
1415
|
+
if self.n_share_experts_fusion > 0:
|
1416
|
+
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
|
1417
|
+
if (
|
1418
|
+
self.config.architectures[0] != "DeepseekV3ForCausalLM"
|
1419
|
+
or self.config.n_routed_experts != 256
|
1420
|
+
):
|
1421
|
+
self.n_share_experts_fusion = 0
|
1422
|
+
global_server_args_dict["n_share_experts_fusion"] = 0
|
1423
|
+
logger.info(
|
1424
|
+
"Only Deepseek V3/R1 can use shared experts fusion optimization. Shared experts fusion optimization is disabled."
|
1425
|
+
)
|
1426
|
+
else:
|
1427
|
+
assert (
|
1428
|
+
self.n_share_experts_fusion == self.tp_size
|
1429
|
+
), f"Shared experts fusion optimization is enabled in DeepSeek V3/R1, set it to {self.tp_size} can get best optimized performace."
|
1430
|
+
elif self.n_share_experts_fusion == 0:
|
1431
|
+
if (
|
1432
|
+
torch.cuda.get_device_capability("cuda") >= (9, 0)
|
1433
|
+
and self.config.architectures[0] == "DeepseekV3ForCausalLM"
|
1434
|
+
and self.config.n_routed_experts == 256
|
1435
|
+
and (not global_server_args_dict["enable_deepep_moe"])
|
1436
|
+
):
|
1437
|
+
self.n_share_experts_fusion = self.tp_size
|
1438
|
+
global_server_args_dict["n_share_experts_fusion"] = self.tp_size
|
1439
|
+
logger.info(
|
1440
|
+
"Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled."
|
1441
|
+
)
|
1351
1442
|
|
1352
1443
|
self.model = DeepseekV2Model(
|
1353
1444
|
config, quant_config, prefix=add_prefix("model", prefix)
|
@@ -1382,35 +1473,38 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1382
1473
|
def post_load_weights(self):
|
1383
1474
|
|
1384
1475
|
# Perform post-processing after loading weights
|
1385
|
-
|
1386
|
-
|
1387
|
-
|
1388
|
-
|
1389
|
-
if
|
1390
|
-
|
1391
|
-
|
1392
|
-
|
1393
|
-
|
1394
|
-
|
1395
|
-
self_attn.kv_b_proj.qzeros,
|
1396
|
-
).T
|
1397
|
-
else:
|
1398
|
-
w = ops.awq_dequantize(
|
1399
|
-
self_attn.kv_b_proj.qweight,
|
1400
|
-
self_attn.kv_b_proj.scales,
|
1401
|
-
self_attn.kv_b_proj.qzeros,
|
1402
|
-
0,
|
1403
|
-
0,
|
1404
|
-
0,
|
1405
|
-
).T
|
1476
|
+
for layer_id in range(self.config.num_hidden_layers):
|
1477
|
+
self_attn = self.model.layers[layer_id].self_attn
|
1478
|
+
if hasattr(self_attn.kv_b_proj, "qweight"):
|
1479
|
+
# AWQ compatible
|
1480
|
+
if _is_cuda:
|
1481
|
+
w = awq_dequantize(
|
1482
|
+
self_attn.kv_b_proj.qweight,
|
1483
|
+
self_attn.kv_b_proj.scales,
|
1484
|
+
self_attn.kv_b_proj.qzeros,
|
1485
|
+
).T
|
1406
1486
|
else:
|
1407
|
-
w =
|
1408
|
-
|
1409
|
-
|
1410
|
-
|
1411
|
-
|
1412
|
-
|
1413
|
-
|
1487
|
+
w = awq_dequantize(
|
1488
|
+
self_attn.kv_b_proj.qweight,
|
1489
|
+
self_attn.kv_b_proj.scales,
|
1490
|
+
self_attn.kv_b_proj.qzeros,
|
1491
|
+
0,
|
1492
|
+
0,
|
1493
|
+
0,
|
1494
|
+
).T
|
1495
|
+
else:
|
1496
|
+
w = self_attn.kv_b_proj.weight
|
1497
|
+
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
|
1498
|
+
# This may affect the accuracy of fp8 model.
|
1499
|
+
# Fix deepseek v3 blockwise bmm by using deep_gemm
|
1500
|
+
use_deep_gemm_bmm = False
|
1501
|
+
model_dtype = torch.get_default_dtype()
|
1502
|
+
|
1503
|
+
if w.dtype in (
|
1504
|
+
torch.float8_e4m3fn,
|
1505
|
+
torch.float8_e4m3fnuz,
|
1506
|
+
):
|
1507
|
+
if hasattr(self.quant_config, "weight_block_size"):
|
1414
1508
|
weight_block_size = self.quant_config.weight_block_size
|
1415
1509
|
if weight_block_size is not None:
|
1416
1510
|
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
@@ -1424,29 +1518,47 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1424
1518
|
weight = w
|
1425
1519
|
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
1426
1520
|
|
1427
|
-
|
1428
|
-
|
1429
|
-
|
1430
|
-
|
1431
|
-
|
1432
|
-
|
1433
|
-
|
1434
|
-
|
1435
|
-
|
1436
|
-
|
1437
|
-
|
1438
|
-
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
1439
|
-
w = int8_block_dequant(
|
1521
|
+
if (
|
1522
|
+
_is_cuda
|
1523
|
+
and _enable_jit_deepgemm_bmm
|
1524
|
+
and weight_block_size[0] == 128
|
1525
|
+
and weight_block_size[1] == 128
|
1526
|
+
and model_dtype == torch.bfloat16
|
1527
|
+
):
|
1528
|
+
block_scale = weight_scale
|
1529
|
+
use_deep_gemm_bmm = True
|
1530
|
+
else:
|
1531
|
+
w, scale = block_quant_to_tensor_quant(
|
1440
1532
|
weight, weight_scale, weight_block_size
|
1441
|
-
)
|
1442
|
-
|
1443
|
-
|
1444
|
-
|
1445
|
-
|
1446
|
-
|
1447
|
-
|
1448
|
-
|
1449
|
-
|
1533
|
+
)
|
1534
|
+
self_attn.w_scale = scale
|
1535
|
+
else:
|
1536
|
+
weight = w
|
1537
|
+
weight_scale = self_attn.kv_b_proj.weight_scale
|
1538
|
+
w, scale = channel_quant_to_tensor_quant(weight, weight_scale)
|
1539
|
+
self_attn.w_scale = scale
|
1540
|
+
|
1541
|
+
if w.dtype == torch.int8:
|
1542
|
+
if hasattr(self.quant_config, "weight_block_size"):
|
1543
|
+
# block-wise int8 need it
|
1544
|
+
weight_block_size = self.quant_config.weight_block_size
|
1545
|
+
if weight_block_size is not None:
|
1546
|
+
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
1547
|
+
weight = w
|
1548
|
+
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
1549
|
+
w = int8_block_dequant(
|
1550
|
+
weight, weight_scale, weight_block_size
|
1551
|
+
).to(torch.bfloat16)
|
1552
|
+
else:
|
1553
|
+
# channel-wise int8 need it
|
1554
|
+
w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
|
1555
|
+
torch.bfloat16
|
1556
|
+
)
|
1557
|
+
|
1558
|
+
w_kc, w_vc = w.unflatten(
|
1559
|
+
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
1560
|
+
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
1561
|
+
if not use_deep_gemm_bmm:
|
1450
1562
|
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
1451
1563
|
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
|
1452
1564
|
if (
|
@@ -1456,6 +1568,17 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1456
1568
|
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
|
1457
1569
|
if _is_hip:
|
1458
1570
|
self_attn.w_scale *= 2.0
|
1571
|
+
else:
|
1572
|
+
num_tiles_k = self_attn.qk_nope_head_dim // weight_block_size[1]
|
1573
|
+
num_tiles_n = self_attn.v_head_dim // weight_block_size[0]
|
1574
|
+
ws_kc, ws_vc = block_scale.unflatten(
|
1575
|
+
0, (-1, (num_tiles_k + num_tiles_n))
|
1576
|
+
).split([num_tiles_k, num_tiles_n], dim=1)
|
1577
|
+
self_attn.w_scale_k = ws_kc.transpose(1, 2).contiguous()
|
1578
|
+
self_attn.w_scale_v = ws_vc.contiguous()
|
1579
|
+
self_attn.w_kc = w_kc.transpose(1, 2).contiguous()
|
1580
|
+
self_attn.w_vc = w_vc.contiguous()
|
1581
|
+
self_attn.use_deep_gemm_bmm = True
|
1459
1582
|
|
1460
1583
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
1461
1584
|
stacked_params_mapping = [
|
@@ -1463,17 +1586,27 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1463
1586
|
("gate_up_proj", "gate_proj", 0),
|
1464
1587
|
("gate_up_proj", "up_proj", 1),
|
1465
1588
|
]
|
1466
|
-
if self.n_share_experts_fusion
|
1589
|
+
if self.n_share_experts_fusion > 0:
|
1467
1590
|
weights_list = list(weights)
|
1468
1591
|
weights_dict = dict(weights_list)
|
1469
|
-
|
1470
|
-
|
1471
|
-
|
1472
|
-
|
1473
|
-
|
1474
|
-
|
1475
|
-
|
1476
|
-
|
1592
|
+
if self.quant_config.get_name() == "w8a8_int8":
|
1593
|
+
suffix_list = [
|
1594
|
+
"down_proj.weight",
|
1595
|
+
"down_proj.weight_scale",
|
1596
|
+
"gate_proj.weight",
|
1597
|
+
"gate_proj.weight_scale",
|
1598
|
+
"up_proj.weight",
|
1599
|
+
"up_proj.weight_scale",
|
1600
|
+
]
|
1601
|
+
else:
|
1602
|
+
suffix_list = [
|
1603
|
+
"down_proj.weight",
|
1604
|
+
"down_proj.weight_scale_inv",
|
1605
|
+
"gate_proj.weight",
|
1606
|
+
"gate_proj.weight_scale_inv",
|
1607
|
+
"up_proj.weight",
|
1608
|
+
"up_proj.weight_scale_inv",
|
1609
|
+
]
|
1477
1610
|
names_to_remove = []
|
1478
1611
|
for moe_layer in tqdm(
|
1479
1612
|
range(
|
@@ -1512,12 +1645,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1512
1645
|
ckpt_gate_proj_name="gate_proj",
|
1513
1646
|
ckpt_down_proj_name="down_proj",
|
1514
1647
|
ckpt_up_proj_name="up_proj",
|
1515
|
-
num_experts=self.config.n_routed_experts
|
1516
|
-
+ (
|
1517
|
-
self.n_share_experts_fusion
|
1518
|
-
if self.n_share_experts_fusion is not None
|
1519
|
-
else 0
|
1520
|
-
),
|
1648
|
+
num_experts=self.config.n_routed_experts + self.n_share_experts_fusion,
|
1521
1649
|
)
|
1522
1650
|
|
1523
1651
|
params_dict = dict(self.named_parameters())
|