sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_offline_throughput.py +19 -0
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +123 -79
- sglang/global_config.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/ir.py +1 -1
- sglang/srt/_custom_ops.py +83 -91
- sglang/srt/configs/load_config.py +4 -1
- sglang/srt/configs/model_config.py +48 -2
- sglang/srt/configs/qwen2_5_vl_config.py +5 -2
- sglang/srt/constrained/base_grammar_backend.py +117 -15
- sglang/srt/constrained/llguidance_backend.py +151 -0
- sglang/srt/constrained/outlines_backend.py +24 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -38
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
- sglang/srt/distributed/parallel_state.py +48 -3
- sglang/srt/entrypoints/engine.py +67 -9
- sglang/srt/entrypoints/http_server.py +190 -41
- sglang/srt/entrypoints/verl_engine.py +147 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/activation.py +11 -0
- sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +302 -414
- sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
- sglang/srt/layers/attention/torch_native_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +13 -8
- sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
- sglang/srt/layers/attention/utils.py +39 -0
- sglang/srt/layers/attention/vision.py +60 -63
- sglang/srt/layers/dp_attention.py +142 -1
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +3 -1
- sglang/srt/layers/logits_processor.py +281 -45
- sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
- sglang/srt/layers/moe/ep_moe/layer.py +140 -28
- sglang/srt/layers/moe/fused_moe_native.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
- sglang/srt/layers/moe/topk.py +13 -4
- sglang/srt/layers/quantization/__init__.py +111 -7
- sglang/srt/layers/quantization/blockwise_int8.py +409 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +69 -28
- sglang/srt/layers/quantization/fp8_utils.py +17 -1
- sglang/srt/layers/quantization/gptq.py +416 -0
- sglang/srt/layers/quantization/int8_kernel.py +327 -0
- sglang/srt/layers/quantization/int8_utils.py +73 -0
- sglang/srt/layers/quantization/modelopt_quant.py +18 -1
- sglang/srt/layers/radix_attention.py +1 -0
- sglang/srt/layers/rotary_embedding.py +0 -1
- sglang/srt/layers/sampler.py +76 -31
- sglang/srt/layers/vocab_parallel_embedding.py +14 -13
- sglang/srt/lora/lora.py +17 -1
- sglang/srt/lora/lora_config.py +5 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/cache_controller.py +193 -62
- sglang/srt/managers/configure_logging.py +2 -1
- sglang/srt/managers/data_parallel_controller.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +124 -102
- sglang/srt/managers/image_processor.py +2 -1
- sglang/srt/managers/io_struct.py +144 -6
- sglang/srt/managers/schedule_batch.py +237 -197
- sglang/srt/managers/schedule_policy.py +29 -29
- sglang/srt/managers/scheduler.py +773 -334
- sglang/srt/managers/session_controller.py +6 -2
- sglang/srt/managers/tokenizer_manager.py +225 -68
- sglang/srt/managers/tp_worker.py +15 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/chunk_cache.py +18 -11
- sglang/srt/mem_cache/hiradix_cache.py +394 -0
- sglang/srt/mem_cache/memory_pool.py +68 -37
- sglang/srt/mem_cache/radix_cache.py +58 -47
- sglang/srt/metrics/collector.py +102 -36
- sglang/srt/model_executor/cuda_graph_runner.py +56 -31
- sglang/srt/model_executor/forward_batch_info.py +49 -16
- sglang/srt/model_executor/model_runner.py +280 -81
- sglang/srt/model_loader/loader.py +3 -3
- sglang/srt/model_loader/weight_utils.py +36 -14
- sglang/srt/models/baichuan.py +31 -6
- sglang/srt/models/chatglm.py +39 -7
- sglang/srt/models/commandr.py +29 -5
- sglang/srt/models/dbrx.py +31 -5
- sglang/srt/models/deepseek.py +43 -6
- sglang/srt/models/deepseek_nextn.py +32 -19
- sglang/srt/models/deepseek_v2.py +265 -32
- sglang/srt/models/exaone.py +19 -9
- sglang/srt/models/gemma.py +22 -8
- sglang/srt/models/gemma2.py +25 -12
- sglang/srt/models/gemma2_reward.py +5 -1
- sglang/srt/models/gpt2.py +28 -13
- sglang/srt/models/gpt_bigcode.py +27 -5
- sglang/srt/models/granite.py +21 -9
- sglang/srt/models/grok.py +21 -4
- sglang/srt/models/internlm2.py +36 -6
- sglang/srt/models/internlm2_reward.py +5 -1
- sglang/srt/models/llama.py +26 -9
- sglang/srt/models/llama_classification.py +5 -1
- sglang/srt/models/llama_eagle.py +17 -4
- sglang/srt/models/llama_embedding.py +5 -1
- sglang/srt/models/llama_reward.py +7 -2
- sglang/srt/models/llava.py +19 -3
- sglang/srt/models/llavavid.py +10 -1
- sglang/srt/models/minicpm.py +26 -2
- sglang/srt/models/minicpm3.py +39 -3
- sglang/srt/models/minicpmv.py +45 -14
- sglang/srt/models/mixtral.py +20 -9
- sglang/srt/models/mixtral_quant.py +50 -8
- sglang/srt/models/mllama.py +57 -11
- sglang/srt/models/olmo.py +34 -6
- sglang/srt/models/olmo2.py +34 -13
- sglang/srt/models/olmoe.py +26 -4
- sglang/srt/models/phi3_small.py +29 -10
- sglang/srt/models/qwen.py +26 -3
- sglang/srt/models/qwen2.py +26 -4
- sglang/srt/models/qwen2_5_vl.py +46 -8
- sglang/srt/models/qwen2_eagle.py +17 -5
- sglang/srt/models/qwen2_moe.py +44 -6
- sglang/srt/models/qwen2_rm.py +78 -0
- sglang/srt/models/qwen2_vl.py +39 -8
- sglang/srt/models/stablelm.py +32 -5
- sglang/srt/models/torch_native_llama.py +5 -2
- sglang/srt/models/xverse.py +21 -9
- sglang/srt/models/xverse_moe.py +45 -7
- sglang/srt/models/yivl.py +2 -1
- sglang/srt/openai_api/adapter.py +109 -24
- sglang/srt/openai_api/protocol.py +17 -1
- sglang/srt/reasoning_parser.py +154 -0
- sglang/srt/sampling/penaltylib/__init__.py +4 -6
- sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
- sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
- sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
- sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
- sglang/srt/sampling/sampling_batch_info.py +79 -157
- sglang/srt/sampling/sampling_params.py +16 -13
- sglang/srt/server_args.py +135 -60
- sglang/srt/speculative/build_eagle_tree.py +8 -9
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -12
- sglang/srt/speculative/eagle_utils.py +92 -57
- sglang/srt/speculative/eagle_worker.py +238 -111
- sglang/srt/speculative/spec_info.py +1 -13
- sglang/srt/utils.py +43 -17
- sglang/srt/warmup.py +47 -0
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/runners.py +389 -126
- sglang/test/send_one.py +88 -0
- sglang/test/test_block_fp8_ep.py +361 -0
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +138 -84
- sglang/utils.py +50 -60
- sglang/version.py +1 -1
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/METADATA +22 -15
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/RECORD +200 -166
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/WHEEL +1 -1
- sglang/bench_latency.py +0 -1
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
- sglang/test/srt/sampling/penaltylib/utils.py +0 -344
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/top_level.txt +0 -0
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -16,6 +16,7 @@
|
|
16
16
|
# https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
|
17
17
|
"""Inference-only DeepseekV2 model."""
|
18
18
|
|
19
|
+
import os
|
19
20
|
from typing import Any, Dict, Iterable, Optional, Tuple
|
20
21
|
|
21
22
|
import torch
|
@@ -31,6 +32,9 @@ from sglang.srt.distributed import (
|
|
31
32
|
tensor_model_parallel_all_reduce,
|
32
33
|
)
|
33
34
|
from sglang.srt.layers.activation import SiluAndMul
|
35
|
+
from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
|
36
|
+
decode_attention_fwd_grouped_rope,
|
37
|
+
)
|
34
38
|
from sglang.srt.layers.layernorm import RMSNorm
|
35
39
|
from sglang.srt.layers.linear import (
|
36
40
|
ColumnParallelLinear,
|
@@ -47,6 +51,9 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
|
47
51
|
input_to_float8,
|
48
52
|
normalize_e4m3fn_to_e4m3fnuz,
|
49
53
|
)
|
54
|
+
from sglang.srt.layers.quantization.int8_utils import (
|
55
|
+
block_dequant as int8_block_dequant,
|
56
|
+
)
|
50
57
|
from sglang.srt.layers.radix_attention import RadixAttention
|
51
58
|
from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
|
52
59
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
@@ -56,7 +63,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
56
63
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
57
64
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
58
65
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
59
|
-
from sglang.srt.utils import is_cuda_available, is_hip
|
66
|
+
from sglang.srt.utils import add_prefix, is_cuda_available, is_hip
|
60
67
|
|
61
68
|
is_hip_ = is_hip()
|
62
69
|
|
@@ -72,10 +79,15 @@ class DeepseekV2MLP(nn.Module):
|
|
72
79
|
hidden_act: str,
|
73
80
|
quant_config: Optional[QuantizationConfig] = None,
|
74
81
|
reduce_results: bool = True,
|
82
|
+
prefix: str = "",
|
75
83
|
) -> None:
|
76
84
|
super().__init__()
|
77
85
|
self.gate_up_proj = MergedColumnParallelLinear(
|
78
|
-
hidden_size,
|
86
|
+
hidden_size,
|
87
|
+
[intermediate_size] * 2,
|
88
|
+
bias=False,
|
89
|
+
quant_config=quant_config,
|
90
|
+
prefix=add_prefix("gate_up_proj", prefix),
|
79
91
|
)
|
80
92
|
self.down_proj = RowParallelLinear(
|
81
93
|
intermediate_size,
|
@@ -83,6 +95,7 @@ class DeepseekV2MLP(nn.Module):
|
|
83
95
|
bias=False,
|
84
96
|
quant_config=quant_config,
|
85
97
|
reduce_results=reduce_results,
|
98
|
+
prefix=add_prefix("down_proj", prefix),
|
86
99
|
)
|
87
100
|
if hidden_act != "silu":
|
88
101
|
raise ValueError(
|
@@ -99,7 +112,11 @@ class DeepseekV2MLP(nn.Module):
|
|
99
112
|
|
100
113
|
|
101
114
|
class MoEGate(nn.Module):
|
102
|
-
def __init__(
|
115
|
+
def __init__(
|
116
|
+
self,
|
117
|
+
config,
|
118
|
+
prefix: str = "",
|
119
|
+
):
|
103
120
|
super().__init__()
|
104
121
|
self.weight = nn.Parameter(
|
105
122
|
torch.empty((config.n_routed_experts, config.hidden_size))
|
@@ -122,6 +139,7 @@ class DeepseekV2MoE(nn.Module):
|
|
122
139
|
self,
|
123
140
|
config: PretrainedConfig,
|
124
141
|
quant_config: Optional[QuantizationConfig] = None,
|
142
|
+
prefix: str = "",
|
125
143
|
):
|
126
144
|
super().__init__()
|
127
145
|
self.tp_size = get_tensor_model_parallel_world_size()
|
@@ -140,7 +158,7 @@ class DeepseekV2MoE(nn.Module):
|
|
140
158
|
"Only silu is supported for now."
|
141
159
|
)
|
142
160
|
|
143
|
-
self.gate = MoEGate(config=config)
|
161
|
+
self.gate = MoEGate(config=config, prefix=add_prefix("gate", prefix))
|
144
162
|
|
145
163
|
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
|
146
164
|
self.experts = MoEImpl(
|
@@ -154,6 +172,7 @@ class DeepseekV2MoE(nn.Module):
|
|
154
172
|
num_expert_group=config.n_group,
|
155
173
|
topk_group=config.topk_group,
|
156
174
|
correction_bias=self.gate.e_score_correction_bias,
|
175
|
+
prefix=add_prefix("experts", prefix),
|
157
176
|
)
|
158
177
|
|
159
178
|
if config.n_shared_experts is not None:
|
@@ -164,6 +183,7 @@ class DeepseekV2MoE(nn.Module):
|
|
164
183
|
hidden_act=config.hidden_act,
|
165
184
|
quant_config=quant_config,
|
166
185
|
reduce_results=False,
|
186
|
+
prefix=add_prefix("shared_experts", prefix),
|
167
187
|
)
|
168
188
|
|
169
189
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
@@ -210,6 +230,7 @@ class DeepseekV2Attention(nn.Module):
|
|
210
230
|
max_position_embeddings: int = 8192,
|
211
231
|
quant_config: Optional[QuantizationConfig] = None,
|
212
232
|
layer_id=None,
|
233
|
+
prefix: str = "",
|
213
234
|
) -> None:
|
214
235
|
super().__init__()
|
215
236
|
self.layer_id = layer_id
|
@@ -234,6 +255,7 @@ class DeepseekV2Attention(nn.Module):
|
|
234
255
|
self.q_lora_rank,
|
235
256
|
bias=False,
|
236
257
|
quant_config=quant_config,
|
258
|
+
prefix=add_prefix("q_a_proj", prefix),
|
237
259
|
)
|
238
260
|
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
239
261
|
self.q_b_proj = ColumnParallelLinear(
|
@@ -241,6 +263,7 @@ class DeepseekV2Attention(nn.Module):
|
|
241
263
|
self.num_heads * self.qk_head_dim,
|
242
264
|
bias=False,
|
243
265
|
quant_config=quant_config,
|
266
|
+
prefix=add_prefix("q_b_proj", prefix),
|
244
267
|
)
|
245
268
|
else:
|
246
269
|
self.q_proj = ColumnParallelLinear(
|
@@ -248,6 +271,7 @@ class DeepseekV2Attention(nn.Module):
|
|
248
271
|
self.num_heads * self.qk_head_dim,
|
249
272
|
bias=False,
|
250
273
|
quant_config=quant_config,
|
274
|
+
prefix=add_prefix("q_proj", prefix),
|
251
275
|
)
|
252
276
|
|
253
277
|
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
@@ -255,8 +279,7 @@ class DeepseekV2Attention(nn.Module):
|
|
255
279
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
256
280
|
bias=False,
|
257
281
|
quant_config=quant_config,
|
258
|
-
|
259
|
-
prefix=f"self_attn.kv_a_proj_with_mqa",
|
282
|
+
prefix=add_prefix("kv_a_proj_with_mqa", prefix),
|
260
283
|
)
|
261
284
|
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
262
285
|
self.kv_b_proj = ColumnParallelLinear(
|
@@ -264,6 +287,7 @@ class DeepseekV2Attention(nn.Module):
|
|
264
287
|
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
265
288
|
bias=False,
|
266
289
|
quant_config=quant_config,
|
290
|
+
prefix=add_prefix("kv_b_proj", prefix),
|
267
291
|
)
|
268
292
|
# O projection.
|
269
293
|
self.o_proj = RowParallelLinear(
|
@@ -271,6 +295,7 @@ class DeepseekV2Attention(nn.Module):
|
|
271
295
|
self.hidden_size,
|
272
296
|
bias=False,
|
273
297
|
quant_config=quant_config,
|
298
|
+
prefix=add_prefix("o_proj", prefix),
|
274
299
|
)
|
275
300
|
rope_scaling["rope_type"] = "deepseek_yarn"
|
276
301
|
self.rotary_emb = get_rope_wrapper(
|
@@ -296,6 +321,7 @@ class DeepseekV2Attention(nn.Module):
|
|
296
321
|
self.scaling,
|
297
322
|
num_kv_heads=self.num_local_heads,
|
298
323
|
layer_id=layer_id,
|
324
|
+
prefix=add_prefix("attn", prefix),
|
299
325
|
)
|
300
326
|
|
301
327
|
def forward(
|
@@ -361,6 +387,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
361
387
|
quant_config: Optional[QuantizationConfig] = None,
|
362
388
|
layer_id=None,
|
363
389
|
use_dp=False,
|
390
|
+
prefix: str = "",
|
364
391
|
) -> None:
|
365
392
|
super().__init__()
|
366
393
|
self.layer_id = layer_id
|
@@ -387,6 +414,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
387
414
|
self.q_lora_rank,
|
388
415
|
bias=False,
|
389
416
|
quant_config=quant_config,
|
417
|
+
prefix=add_prefix("q_a_proj", prefix),
|
390
418
|
)
|
391
419
|
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
392
420
|
self.q_b_proj = ReplicatedLinear(
|
@@ -394,6 +422,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
394
422
|
self.num_heads * self.qk_head_dim,
|
395
423
|
bias=False,
|
396
424
|
quant_config=quant_config,
|
425
|
+
prefix=add_prefix("q_b_proj", prefix),
|
397
426
|
)
|
398
427
|
else:
|
399
428
|
self.q_proj = ReplicatedLinear(
|
@@ -401,12 +430,14 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
401
430
|
self.num_heads * self.qk_head_dim,
|
402
431
|
bias=False,
|
403
432
|
quant_config=quant_config,
|
433
|
+
prefix=add_prefix("q_proj", prefix),
|
404
434
|
)
|
405
435
|
self.kv_b_proj = ReplicatedLinear(
|
406
436
|
self.kv_lora_rank,
|
407
437
|
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
408
438
|
bias=False,
|
409
439
|
quant_config=quant_config,
|
440
|
+
prefix=add_prefix("kv_b_proj", prefix),
|
410
441
|
)
|
411
442
|
# O projection.
|
412
443
|
self.o_proj = ReplicatedLinear(
|
@@ -414,6 +445,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
414
445
|
self.hidden_size,
|
415
446
|
bias=False,
|
416
447
|
quant_config=quant_config,
|
448
|
+
prefix=add_prefix("o_proj", prefix),
|
417
449
|
)
|
418
450
|
else:
|
419
451
|
# For tensor parallel attention
|
@@ -423,6 +455,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
423
455
|
self.q_lora_rank,
|
424
456
|
bias=False,
|
425
457
|
quant_config=quant_config,
|
458
|
+
prefix=add_prefix("q_a_proj", prefix),
|
426
459
|
)
|
427
460
|
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
428
461
|
self.q_b_proj = ColumnParallelLinear(
|
@@ -430,6 +463,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
430
463
|
self.num_heads * self.qk_head_dim,
|
431
464
|
bias=False,
|
432
465
|
quant_config=quant_config,
|
466
|
+
prefix=add_prefix("q_b_proj", prefix),
|
433
467
|
)
|
434
468
|
else:
|
435
469
|
self.q_proj = ColumnParallelLinear(
|
@@ -437,12 +471,14 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
437
471
|
self.num_heads * self.qk_head_dim,
|
438
472
|
bias=False,
|
439
473
|
quant_config=quant_config,
|
474
|
+
prefix=add_prefix("q_proj", prefix),
|
440
475
|
)
|
441
476
|
self.kv_b_proj = ColumnParallelLinear(
|
442
477
|
self.kv_lora_rank,
|
443
478
|
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
444
479
|
bias=False,
|
445
480
|
quant_config=quant_config,
|
481
|
+
prefix=add_prefix("kv_b_proj", prefix),
|
446
482
|
)
|
447
483
|
# O projection.
|
448
484
|
self.o_proj = RowParallelLinear(
|
@@ -450,6 +486,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
450
486
|
self.hidden_size,
|
451
487
|
bias=False,
|
452
488
|
quant_config=quant_config,
|
489
|
+
prefix=add_prefix("o_proj", prefix),
|
453
490
|
)
|
454
491
|
|
455
492
|
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
@@ -457,8 +494,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
457
494
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
458
495
|
bias=False,
|
459
496
|
quant_config=quant_config,
|
460
|
-
|
461
|
-
prefix=f"self_attn.kv_a_proj_with_mqa",
|
497
|
+
prefix=add_prefix("kv_a_proj_with_mqa", prefix),
|
462
498
|
)
|
463
499
|
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
464
500
|
|
@@ -489,6 +525,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
489
525
|
num_kv_heads=1,
|
490
526
|
layer_id=layer_id,
|
491
527
|
v_head_dim=self.kv_lora_rank,
|
528
|
+
prefix=add_prefix("attn_mqa", prefix),
|
492
529
|
)
|
493
530
|
|
494
531
|
self.attn_mha = RadixAttention(
|
@@ -498,6 +535,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
498
535
|
num_kv_heads=self.num_local_heads,
|
499
536
|
layer_id=layer_id,
|
500
537
|
v_head_dim=self.v_head_dim,
|
538
|
+
prefix=add_prefix("attn_mha", prefix),
|
501
539
|
)
|
502
540
|
|
503
541
|
self.w_kc = None
|
@@ -510,23 +548,37 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
510
548
|
hidden_states: torch.Tensor,
|
511
549
|
forward_batch: ForwardBatch,
|
512
550
|
) -> torch.Tensor:
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
|
551
|
+
|
552
|
+
def no_absorb() -> bool:
|
553
|
+
if global_server_args_dict["enable_flashinfer_mla"]:
|
554
|
+
# Flashinfer MLA: Do not absorb when enabling ragged prefill
|
555
|
+
return (
|
556
|
+
not global_server_args_dict["flashinfer_mla_disable_ragged"]
|
557
|
+
and forward_batch.forward_mode.is_extend()
|
558
|
+
and forward_batch.extend_prefix_lens.sum() == 0
|
559
|
+
)
|
519
560
|
else:
|
520
|
-
|
561
|
+
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
|
562
|
+
return (
|
563
|
+
forward_batch.forward_mode.is_extend()
|
564
|
+
and not forward_batch.forward_mode.is_target_verify()
|
565
|
+
and not forward_batch.forward_mode.is_draft_extend()
|
566
|
+
and forward_batch.extend_prefix_lens.sum() == 0
|
567
|
+
)
|
568
|
+
|
569
|
+
if no_absorb():
|
570
|
+
return self.forward_normal(positions, hidden_states, forward_batch)
|
521
571
|
else:
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
|
572
|
+
if is_hip_:
|
573
|
+
if (
|
574
|
+
os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
|
575
|
+
and forward_batch.forward_mode.is_decode()
|
576
|
+
):
|
577
|
+
return self.forward_absorb_fused_mla_rope(
|
578
|
+
positions, hidden_states, forward_batch
|
579
|
+
)
|
580
|
+
else:
|
581
|
+
return self.forward_absorb(positions, hidden_states, forward_batch)
|
530
582
|
else:
|
531
583
|
return self.forward_absorb(positions, hidden_states, forward_batch)
|
532
584
|
|
@@ -647,6 +699,149 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
647
699
|
|
648
700
|
return output
|
649
701
|
|
702
|
+
def forward_absorb_fused_mla_rope(
|
703
|
+
self,
|
704
|
+
positions: torch.Tensor,
|
705
|
+
hidden_states: torch.Tensor,
|
706
|
+
forward_batch: ForwardBatch,
|
707
|
+
) -> torch.Tensor:
|
708
|
+
enable_rope_fusion = (
|
709
|
+
os.getenv("SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION", "1") == "1"
|
710
|
+
)
|
711
|
+
q_len = hidden_states.shape[0]
|
712
|
+
q_input = hidden_states.new_empty(
|
713
|
+
q_len, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim
|
714
|
+
)
|
715
|
+
if self.q_lora_rank is not None:
|
716
|
+
q = self.q_a_proj(hidden_states)[0]
|
717
|
+
q = self.q_a_layernorm(q)
|
718
|
+
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
719
|
+
else:
|
720
|
+
q = self.q_proj(hidden_states)[0].view(
|
721
|
+
-1, self.num_local_heads, self.qk_head_dim
|
722
|
+
)
|
723
|
+
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
724
|
+
|
725
|
+
if self.w_kc.dtype == torch.float8_e4m3fnuz:
|
726
|
+
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
|
727
|
+
q_nope_out = torch.bmm(
|
728
|
+
q_nope.to(torch.bfloat16).transpose(0, 1),
|
729
|
+
self.w_kc.to(torch.bfloat16) * self.w_scale,
|
730
|
+
)
|
731
|
+
elif self.w_kc.dtype == torch.float8_e4m3fn:
|
732
|
+
q_nope_val, q_nope_scale = input_to_float8(
|
733
|
+
q_nope.transpose(0, 1), torch.float8_e4m3fn
|
734
|
+
)
|
735
|
+
q_nope_out = bmm_fp8(
|
736
|
+
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
|
737
|
+
)
|
738
|
+
else:
|
739
|
+
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
|
740
|
+
q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1)
|
741
|
+
|
742
|
+
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
743
|
+
v_input = latent_cache[..., : self.kv_lora_rank]
|
744
|
+
v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1)
|
745
|
+
k_input = latent_cache.unsqueeze(1)
|
746
|
+
k_input[..., : self.kv_lora_rank] = v_input
|
747
|
+
|
748
|
+
if not enable_rope_fusion:
|
749
|
+
k_pe = k_input[..., self.kv_lora_rank :]
|
750
|
+
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
751
|
+
q_input[..., self.kv_lora_rank :] = q_pe
|
752
|
+
k_input[..., self.kv_lora_rank :] = k_pe
|
753
|
+
k_pe_output = None
|
754
|
+
else:
|
755
|
+
k_pe_output = torch.empty_like(k_input[..., self.kv_lora_rank :])
|
756
|
+
|
757
|
+
q_input[..., self.kv_lora_rank :] = q_pe
|
758
|
+
|
759
|
+
# attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
|
760
|
+
# Use Fused ROPE with use_rope=OFF.
|
761
|
+
attn_output = torch.empty(
|
762
|
+
(q_len, self.num_local_heads, self.kv_lora_rank),
|
763
|
+
dtype=q.dtype,
|
764
|
+
device=q.device,
|
765
|
+
)
|
766
|
+
attn_logits, _, kv_indptr, kv_indices, _, _, _ = (
|
767
|
+
forward_batch.attn_backend.forward_metadata
|
768
|
+
)
|
769
|
+
cos_sin_cache = self.rotary_emb.cos_sin_cache
|
770
|
+
num_kv_split = forward_batch.attn_backend.num_kv_splits
|
771
|
+
sm_scale = self.attn_mqa.scaling
|
772
|
+
if attn_logits is None:
|
773
|
+
attn_logits = torch.empty(
|
774
|
+
(
|
775
|
+
forward_batch.batch_size,
|
776
|
+
self.num_local_heads,
|
777
|
+
num_kv_split,
|
778
|
+
self.kv_lora_rank + 1,
|
779
|
+
),
|
780
|
+
dtype=torch.float32,
|
781
|
+
device=q.device,
|
782
|
+
)
|
783
|
+
|
784
|
+
# save current latent cache.
|
785
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
786
|
+
self.attn_mqa, forward_batch.out_cache_loc, k_input, None
|
787
|
+
)
|
788
|
+
key_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
|
789
|
+
self.attn_mqa.layer_id
|
790
|
+
)
|
791
|
+
val_cache_buf = key_cache_buf[..., : self.kv_lora_rank]
|
792
|
+
|
793
|
+
decode_attention_fwd_grouped_rope(
|
794
|
+
q_input,
|
795
|
+
key_cache_buf,
|
796
|
+
val_cache_buf,
|
797
|
+
attn_output,
|
798
|
+
kv_indptr,
|
799
|
+
kv_indices,
|
800
|
+
k_pe_output,
|
801
|
+
self.kv_lora_rank,
|
802
|
+
self.rotary_emb.rotary_dim,
|
803
|
+
cos_sin_cache,
|
804
|
+
positions,
|
805
|
+
attn_logits,
|
806
|
+
num_kv_split,
|
807
|
+
sm_scale,
|
808
|
+
logit_cap=self.attn_mqa.logit_cap,
|
809
|
+
use_rope=enable_rope_fusion,
|
810
|
+
is_neox_style=self.rotary_emb.is_neox_style,
|
811
|
+
)
|
812
|
+
|
813
|
+
if enable_rope_fusion:
|
814
|
+
k_input[..., self.kv_lora_rank :] = k_pe_output
|
815
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
816
|
+
self.attn_mqa, forward_batch.out_cache_loc, k_input, None
|
817
|
+
)
|
818
|
+
|
819
|
+
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
820
|
+
|
821
|
+
if self.w_vc.dtype == torch.float8_e4m3fnuz:
|
822
|
+
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
|
823
|
+
attn_bmm_output = torch.bmm(
|
824
|
+
attn_output.to(torch.bfloat16).transpose(0, 1),
|
825
|
+
self.w_vc.to(torch.bfloat16) * self.w_scale,
|
826
|
+
)
|
827
|
+
elif self.w_vc.dtype == torch.float8_e4m3fn:
|
828
|
+
attn_output_val, attn_output_scale = input_to_float8(
|
829
|
+
attn_output.transpose(0, 1), torch.float8_e4m3fn
|
830
|
+
)
|
831
|
+
attn_bmm_output = bmm_fp8(
|
832
|
+
attn_output_val,
|
833
|
+
self.w_vc,
|
834
|
+
attn_output_scale,
|
835
|
+
self.w_scale,
|
836
|
+
torch.bfloat16,
|
837
|
+
)
|
838
|
+
else:
|
839
|
+
attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc)
|
840
|
+
attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
|
841
|
+
output, _ = self.o_proj(attn_output)
|
842
|
+
|
843
|
+
return output
|
844
|
+
|
650
845
|
|
651
846
|
def all_gather(
|
652
847
|
input_tensor: torch.Tensor, forward_batch: ForwardBatch, rank, world_size, group
|
@@ -654,16 +849,14 @@ def all_gather(
|
|
654
849
|
if world_size == 1:
|
655
850
|
return input_tensor
|
656
851
|
|
657
|
-
all_lens = forward_batch.
|
658
|
-
max_len = max(forward_batch.
|
852
|
+
all_lens = forward_batch.global_num_tokens_cpu
|
853
|
+
max_len = max(forward_batch.global_num_tokens_cpu)
|
659
854
|
|
660
855
|
padded_tensor = torch.nn.functional.pad(
|
661
856
|
input_tensor, (0, 0, 0, max_len - input_tensor.shape[0])
|
662
857
|
)
|
663
858
|
|
664
|
-
|
665
|
-
forward_batch.gathered_buffer, padded_tensor, group=group
|
666
|
-
)
|
859
|
+
group.all_gather_into_tensor(forward_batch.gathered_buffer, padded_tensor)
|
667
860
|
|
668
861
|
gathered_tensors = torch.concat(
|
669
862
|
[
|
@@ -686,6 +879,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
686
879
|
layer_id: int,
|
687
880
|
quant_config: Optional[QuantizationConfig] = None,
|
688
881
|
is_nextn: bool = False,
|
882
|
+
prefix: str = "",
|
689
883
|
) -> None:
|
690
884
|
super().__init__()
|
691
885
|
self.hidden_size = config.hidden_size
|
@@ -699,7 +893,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
699
893
|
if self.enable_dp_attention:
|
700
894
|
self.tp_rank = get_tensor_model_parallel_rank()
|
701
895
|
self.tp_size = get_tensor_model_parallel_world_size()
|
702
|
-
self.tp_group = get_tp_group()
|
896
|
+
self.tp_group = get_tp_group()
|
703
897
|
if not global_server_args_dict["disable_mla"]:
|
704
898
|
self.self_attn = DeepseekV2AttentionMLA(
|
705
899
|
config=config,
|
@@ -718,6 +912,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
718
912
|
quant_config=quant_config,
|
719
913
|
layer_id=layer_id,
|
720
914
|
use_dp=self.enable_dp_attention,
|
915
|
+
prefix=add_prefix("self_attn", prefix),
|
721
916
|
)
|
722
917
|
else:
|
723
918
|
self.self_attn = DeepseekV2Attention(
|
@@ -736,19 +931,25 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
736
931
|
max_position_embeddings=max_position_embeddings,
|
737
932
|
quant_config=quant_config,
|
738
933
|
layer_id=layer_id,
|
934
|
+
prefix=add_prefix("self_attn", prefix),
|
739
935
|
)
|
740
936
|
if is_nextn or (
|
741
937
|
config.n_routed_experts is not None
|
742
938
|
and layer_id >= config.first_k_dense_replace
|
743
939
|
and layer_id % config.moe_layer_freq == 0
|
744
940
|
):
|
745
|
-
self.mlp = DeepseekV2MoE(
|
941
|
+
self.mlp = DeepseekV2MoE(
|
942
|
+
config=config,
|
943
|
+
quant_config=quant_config,
|
944
|
+
prefix=add_prefix("mlp", prefix),
|
945
|
+
)
|
746
946
|
else:
|
747
947
|
self.mlp = DeepseekV2MLP(
|
748
948
|
hidden_size=config.hidden_size,
|
749
949
|
intermediate_size=config.intermediate_size,
|
750
950
|
hidden_act=config.hidden_act,
|
751
951
|
quant_config=quant_config,
|
952
|
+
prefix=add_prefix("mlp", prefix),
|
752
953
|
)
|
753
954
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
754
955
|
self.post_attention_layernorm = RMSNorm(
|
@@ -800,6 +1001,7 @@ class DeepseekV2Model(nn.Module):
|
|
800
1001
|
self,
|
801
1002
|
config: PretrainedConfig,
|
802
1003
|
quant_config: Optional[QuantizationConfig] = None,
|
1004
|
+
prefix: str = "",
|
803
1005
|
) -> None:
|
804
1006
|
super().__init__()
|
805
1007
|
self.padding_id = config.pad_token_id
|
@@ -816,6 +1018,7 @@ class DeepseekV2Model(nn.Module):
|
|
816
1018
|
config,
|
817
1019
|
layer_id,
|
818
1020
|
quant_config=quant_config,
|
1021
|
+
prefix=add_prefix(f"layers.{layer_id}", prefix),
|
819
1022
|
)
|
820
1023
|
for layer_id in range(config.num_hidden_layers)
|
821
1024
|
]
|
@@ -846,21 +1049,28 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
846
1049
|
self,
|
847
1050
|
config: PretrainedConfig,
|
848
1051
|
quant_config: Optional[QuantizationConfig] = None,
|
1052
|
+
prefix: str = "",
|
849
1053
|
) -> None:
|
850
1054
|
super().__init__()
|
851
1055
|
self.config = config
|
852
1056
|
self.quant_config = quant_config
|
853
|
-
self.model = DeepseekV2Model(
|
1057
|
+
self.model = DeepseekV2Model(
|
1058
|
+
config, quant_config, prefix=add_prefix("model", prefix)
|
1059
|
+
)
|
854
1060
|
if global_server_args_dict["enable_dp_attention"]:
|
855
1061
|
self.lm_head = ReplicatedLinear(
|
856
1062
|
config.hidden_size,
|
857
1063
|
config.vocab_size,
|
858
1064
|
bias=False,
|
1065
|
+
prefix=add_prefix("lm_head", prefix),
|
859
1066
|
)
|
860
1067
|
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
|
861
1068
|
else:
|
862
1069
|
self.lm_head = ParallelLMHead(
|
863
|
-
config.vocab_size,
|
1070
|
+
config.vocab_size,
|
1071
|
+
config.hidden_size,
|
1072
|
+
quant_config=quant_config,
|
1073
|
+
prefix=add_prefix("lm_head", prefix),
|
864
1074
|
)
|
865
1075
|
self.logits_processor = LogitsProcessor(config)
|
866
1076
|
|
@@ -992,6 +1202,18 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
992
1202
|
weight, weight_scale, weight_block_size
|
993
1203
|
)
|
994
1204
|
self_attn.w_scale = scale
|
1205
|
+
if (
|
1206
|
+
hasattr(self.quant_config, "weight_block_size")
|
1207
|
+
and w.dtype == torch.int8
|
1208
|
+
):
|
1209
|
+
weight_block_size = self.quant_config.weight_block_size
|
1210
|
+
if weight_block_size is not None:
|
1211
|
+
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
1212
|
+
weight = w
|
1213
|
+
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
1214
|
+
w = int8_block_dequant(
|
1215
|
+
weight, weight_scale, weight_block_size
|
1216
|
+
).to(torch.bfloat16)
|
995
1217
|
w_kc, w_vc = w.unflatten(
|
996
1218
|
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
997
1219
|
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
@@ -1005,6 +1227,17 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1005
1227
|
if is_hip_:
|
1006
1228
|
self_attn.w_scale *= 2.0
|
1007
1229
|
|
1230
|
+
def get_embed_and_head(self):
|
1231
|
+
return self.model.embed_tokens.weight, self.lm_head.weight
|
1232
|
+
|
1233
|
+
def set_embed_and_head(self, embed, head):
|
1234
|
+
del self.model.embed_tokens.weight
|
1235
|
+
del self.lm_head.weight
|
1236
|
+
self.model.embed_tokens.weight = embed
|
1237
|
+
self.lm_head.weight = head
|
1238
|
+
torch.cuda.empty_cache()
|
1239
|
+
torch.cuda.synchronize()
|
1240
|
+
|
1008
1241
|
|
1009
1242
|
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
|
1010
1243
|
pass
|