sglang 0.4.3.post1__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_offline_throughput.py +19 -0
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +123 -79
- sglang/global_config.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/ir.py +1 -1
- sglang/srt/_custom_ops.py +83 -91
- sglang/srt/configs/load_config.py +4 -1
- sglang/srt/configs/model_config.py +48 -2
- sglang/srt/configs/qwen2_5_vl_config.py +5 -2
- sglang/srt/constrained/base_grammar_backend.py +117 -15
- sglang/srt/constrained/llguidance_backend.py +151 -0
- sglang/srt/constrained/outlines_backend.py +24 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -38
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
- sglang/srt/distributed/parallel_state.py +48 -3
- sglang/srt/entrypoints/engine.py +67 -9
- sglang/srt/entrypoints/http_server.py +190 -41
- sglang/srt/entrypoints/verl_engine.py +147 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/activation.py +11 -0
- sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +208 -295
- sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
- sglang/srt/layers/attention/torch_native_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +9 -6
- sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
- sglang/srt/layers/attention/utils.py +39 -0
- sglang/srt/layers/attention/vision.py +60 -63
- sglang/srt/layers/dp_attention.py +142 -1
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +3 -1
- sglang/srt/layers/logits_processor.py +281 -45
- sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
- sglang/srt/layers/moe/ep_moe/layer.py +140 -28
- sglang/srt/layers/moe/fused_moe_native.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
- sglang/srt/layers/moe/topk.py +13 -4
- sglang/srt/layers/quantization/__init__.py +111 -7
- sglang/srt/layers/quantization/blockwise_int8.py +409 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/fp8.py +69 -28
- sglang/srt/layers/quantization/fp8_utils.py +17 -1
- sglang/srt/layers/quantization/gptq.py +416 -0
- sglang/srt/layers/quantization/int8_kernel.py +327 -0
- sglang/srt/layers/quantization/int8_utils.py +73 -0
- sglang/srt/layers/quantization/modelopt_quant.py +18 -1
- sglang/srt/layers/radix_attention.py +1 -0
- sglang/srt/layers/rotary_embedding.py +0 -1
- sglang/srt/layers/sampler.py +76 -31
- sglang/srt/layers/vocab_parallel_embedding.py +14 -13
- sglang/srt/lora/lora.py +17 -1
- sglang/srt/lora/lora_config.py +5 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/cache_controller.py +193 -62
- sglang/srt/managers/configure_logging.py +2 -1
- sglang/srt/managers/data_parallel_controller.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +124 -102
- sglang/srt/managers/image_processor.py +2 -1
- sglang/srt/managers/io_struct.py +143 -6
- sglang/srt/managers/schedule_batch.py +238 -197
- sglang/srt/managers/schedule_policy.py +29 -29
- sglang/srt/managers/scheduler.py +681 -259
- sglang/srt/managers/session_controller.py +6 -2
- sglang/srt/managers/tokenizer_manager.py +224 -68
- sglang/srt/managers/tp_worker.py +15 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/chunk_cache.py +18 -11
- sglang/srt/mem_cache/hiradix_cache.py +394 -0
- sglang/srt/mem_cache/memory_pool.py +44 -18
- sglang/srt/mem_cache/radix_cache.py +58 -47
- sglang/srt/metrics/collector.py +94 -36
- sglang/srt/model_executor/cuda_graph_runner.py +55 -24
- sglang/srt/model_executor/forward_batch_info.py +49 -16
- sglang/srt/model_executor/model_runner.py +209 -28
- sglang/srt/model_loader/loader.py +3 -3
- sglang/srt/model_loader/weight_utils.py +36 -14
- sglang/srt/models/baichuan.py +31 -6
- sglang/srt/models/chatglm.py +39 -7
- sglang/srt/models/commandr.py +29 -5
- sglang/srt/models/dbrx.py +31 -5
- sglang/srt/models/deepseek.py +43 -6
- sglang/srt/models/deepseek_nextn.py +32 -19
- sglang/srt/models/deepseek_v2.py +265 -29
- sglang/srt/models/exaone.py +19 -9
- sglang/srt/models/gemma.py +22 -8
- sglang/srt/models/gemma2.py +25 -12
- sglang/srt/models/gemma2_reward.py +5 -1
- sglang/srt/models/gpt2.py +28 -13
- sglang/srt/models/gpt_bigcode.py +27 -5
- sglang/srt/models/granite.py +21 -9
- sglang/srt/models/grok.py +21 -4
- sglang/srt/models/internlm2.py +36 -6
- sglang/srt/models/internlm2_reward.py +5 -1
- sglang/srt/models/llama.py +26 -9
- sglang/srt/models/llama_classification.py +5 -1
- sglang/srt/models/llama_eagle.py +17 -4
- sglang/srt/models/llama_embedding.py +5 -1
- sglang/srt/models/llama_reward.py +7 -2
- sglang/srt/models/llava.py +19 -3
- sglang/srt/models/llavavid.py +10 -1
- sglang/srt/models/minicpm.py +26 -2
- sglang/srt/models/minicpm3.py +39 -3
- sglang/srt/models/minicpmv.py +45 -14
- sglang/srt/models/mixtral.py +20 -9
- sglang/srt/models/mixtral_quant.py +50 -8
- sglang/srt/models/mllama.py +57 -11
- sglang/srt/models/olmo.py +34 -6
- sglang/srt/models/olmo2.py +34 -13
- sglang/srt/models/olmoe.py +26 -4
- sglang/srt/models/phi3_small.py +29 -10
- sglang/srt/models/qwen.py +26 -3
- sglang/srt/models/qwen2.py +26 -4
- sglang/srt/models/qwen2_5_vl.py +46 -8
- sglang/srt/models/qwen2_eagle.py +17 -5
- sglang/srt/models/qwen2_moe.py +44 -6
- sglang/srt/models/qwen2_rm.py +78 -0
- sglang/srt/models/qwen2_vl.py +39 -8
- sglang/srt/models/stablelm.py +32 -5
- sglang/srt/models/torch_native_llama.py +5 -2
- sglang/srt/models/xverse.py +21 -9
- sglang/srt/models/xverse_moe.py +45 -7
- sglang/srt/models/yivl.py +2 -1
- sglang/srt/openai_api/adapter.py +109 -24
- sglang/srt/openai_api/protocol.py +17 -1
- sglang/srt/reasoning_parser.py +154 -0
- sglang/srt/sampling/penaltylib/__init__.py +4 -6
- sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
- sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
- sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
- sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
- sglang/srt/sampling/sampling_batch_info.py +79 -157
- sglang/srt/sampling/sampling_params.py +16 -13
- sglang/srt/server_args.py +136 -52
- sglang/srt/speculative/build_eagle_tree.py +2 -8
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
- sglang/srt/speculative/eagle_utils.py +92 -58
- sglang/srt/speculative/eagle_worker.py +186 -94
- sglang/srt/speculative/spec_info.py +1 -13
- sglang/srt/utils.py +43 -17
- sglang/srt/warmup.py +47 -0
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/runners.py +389 -126
- sglang/test/send_one.py +88 -0
- sglang/test/test_block_fp8_ep.py +361 -0
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +138 -84
- sglang/utils.py +50 -60
- sglang/version.py +1 -1
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +214 -166
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
- sglang/bench_latency.py +0 -1
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
- sglang/test/srt/sampling/penaltylib/utils.py +0 -344
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -16,6 +16,7 @@
|
|
16
16
|
# https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
|
17
17
|
"""Inference-only DeepseekV2 model."""
|
18
18
|
|
19
|
+
import os
|
19
20
|
from typing import Any, Dict, Iterable, Optional, Tuple
|
20
21
|
|
21
22
|
import torch
|
@@ -31,6 +32,9 @@ from sglang.srt.distributed import (
|
|
31
32
|
tensor_model_parallel_all_reduce,
|
32
33
|
)
|
33
34
|
from sglang.srt.layers.activation import SiluAndMul
|
35
|
+
from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
|
36
|
+
decode_attention_fwd_grouped_rope,
|
37
|
+
)
|
34
38
|
from sglang.srt.layers.layernorm import RMSNorm
|
35
39
|
from sglang.srt.layers.linear import (
|
36
40
|
ColumnParallelLinear,
|
@@ -47,6 +51,9 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
|
47
51
|
input_to_float8,
|
48
52
|
normalize_e4m3fn_to_e4m3fnuz,
|
49
53
|
)
|
54
|
+
from sglang.srt.layers.quantization.int8_utils import (
|
55
|
+
block_dequant as int8_block_dequant,
|
56
|
+
)
|
50
57
|
from sglang.srt.layers.radix_attention import RadixAttention
|
51
58
|
from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
|
52
59
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
@@ -56,7 +63,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
56
63
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
57
64
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
58
65
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
59
|
-
from sglang.srt.utils import is_cuda_available, is_hip
|
66
|
+
from sglang.srt.utils import add_prefix, is_cuda_available, is_hip
|
60
67
|
|
61
68
|
is_hip_ = is_hip()
|
62
69
|
|
@@ -72,10 +79,15 @@ class DeepseekV2MLP(nn.Module):
|
|
72
79
|
hidden_act: str,
|
73
80
|
quant_config: Optional[QuantizationConfig] = None,
|
74
81
|
reduce_results: bool = True,
|
82
|
+
prefix: str = "",
|
75
83
|
) -> None:
|
76
84
|
super().__init__()
|
77
85
|
self.gate_up_proj = MergedColumnParallelLinear(
|
78
|
-
hidden_size,
|
86
|
+
hidden_size,
|
87
|
+
[intermediate_size] * 2,
|
88
|
+
bias=False,
|
89
|
+
quant_config=quant_config,
|
90
|
+
prefix=add_prefix("gate_up_proj", prefix),
|
79
91
|
)
|
80
92
|
self.down_proj = RowParallelLinear(
|
81
93
|
intermediate_size,
|
@@ -83,6 +95,7 @@ class DeepseekV2MLP(nn.Module):
|
|
83
95
|
bias=False,
|
84
96
|
quant_config=quant_config,
|
85
97
|
reduce_results=reduce_results,
|
98
|
+
prefix=add_prefix("down_proj", prefix),
|
86
99
|
)
|
87
100
|
if hidden_act != "silu":
|
88
101
|
raise ValueError(
|
@@ -99,7 +112,11 @@ class DeepseekV2MLP(nn.Module):
|
|
99
112
|
|
100
113
|
|
101
114
|
class MoEGate(nn.Module):
|
102
|
-
def __init__(
|
115
|
+
def __init__(
|
116
|
+
self,
|
117
|
+
config,
|
118
|
+
prefix: str = "",
|
119
|
+
):
|
103
120
|
super().__init__()
|
104
121
|
self.weight = nn.Parameter(
|
105
122
|
torch.empty((config.n_routed_experts, config.hidden_size))
|
@@ -122,6 +139,7 @@ class DeepseekV2MoE(nn.Module):
|
|
122
139
|
self,
|
123
140
|
config: PretrainedConfig,
|
124
141
|
quant_config: Optional[QuantizationConfig] = None,
|
142
|
+
prefix: str = "",
|
125
143
|
):
|
126
144
|
super().__init__()
|
127
145
|
self.tp_size = get_tensor_model_parallel_world_size()
|
@@ -140,7 +158,7 @@ class DeepseekV2MoE(nn.Module):
|
|
140
158
|
"Only silu is supported for now."
|
141
159
|
)
|
142
160
|
|
143
|
-
self.gate = MoEGate(config=config)
|
161
|
+
self.gate = MoEGate(config=config, prefix=add_prefix("gate", prefix))
|
144
162
|
|
145
163
|
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
|
146
164
|
self.experts = MoEImpl(
|
@@ -154,6 +172,7 @@ class DeepseekV2MoE(nn.Module):
|
|
154
172
|
num_expert_group=config.n_group,
|
155
173
|
topk_group=config.topk_group,
|
156
174
|
correction_bias=self.gate.e_score_correction_bias,
|
175
|
+
prefix=add_prefix("experts", prefix),
|
157
176
|
)
|
158
177
|
|
159
178
|
if config.n_shared_experts is not None:
|
@@ -164,6 +183,7 @@ class DeepseekV2MoE(nn.Module):
|
|
164
183
|
hidden_act=config.hidden_act,
|
165
184
|
quant_config=quant_config,
|
166
185
|
reduce_results=False,
|
186
|
+
prefix=add_prefix("shared_experts", prefix),
|
167
187
|
)
|
168
188
|
|
169
189
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
@@ -210,6 +230,7 @@ class DeepseekV2Attention(nn.Module):
|
|
210
230
|
max_position_embeddings: int = 8192,
|
211
231
|
quant_config: Optional[QuantizationConfig] = None,
|
212
232
|
layer_id=None,
|
233
|
+
prefix: str = "",
|
213
234
|
) -> None:
|
214
235
|
super().__init__()
|
215
236
|
self.layer_id = layer_id
|
@@ -234,6 +255,7 @@ class DeepseekV2Attention(nn.Module):
|
|
234
255
|
self.q_lora_rank,
|
235
256
|
bias=False,
|
236
257
|
quant_config=quant_config,
|
258
|
+
prefix=add_prefix("q_a_proj", prefix),
|
237
259
|
)
|
238
260
|
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
239
261
|
self.q_b_proj = ColumnParallelLinear(
|
@@ -241,6 +263,7 @@ class DeepseekV2Attention(nn.Module):
|
|
241
263
|
self.num_heads * self.qk_head_dim,
|
242
264
|
bias=False,
|
243
265
|
quant_config=quant_config,
|
266
|
+
prefix=add_prefix("q_b_proj", prefix),
|
244
267
|
)
|
245
268
|
else:
|
246
269
|
self.q_proj = ColumnParallelLinear(
|
@@ -248,6 +271,7 @@ class DeepseekV2Attention(nn.Module):
|
|
248
271
|
self.num_heads * self.qk_head_dim,
|
249
272
|
bias=False,
|
250
273
|
quant_config=quant_config,
|
274
|
+
prefix=add_prefix("q_proj", prefix),
|
251
275
|
)
|
252
276
|
|
253
277
|
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
@@ -255,8 +279,7 @@ class DeepseekV2Attention(nn.Module):
|
|
255
279
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
256
280
|
bias=False,
|
257
281
|
quant_config=quant_config,
|
258
|
-
|
259
|
-
prefix=f"self_attn.kv_a_proj_with_mqa",
|
282
|
+
prefix=add_prefix("kv_a_proj_with_mqa", prefix),
|
260
283
|
)
|
261
284
|
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
262
285
|
self.kv_b_proj = ColumnParallelLinear(
|
@@ -264,6 +287,7 @@ class DeepseekV2Attention(nn.Module):
|
|
264
287
|
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
265
288
|
bias=False,
|
266
289
|
quant_config=quant_config,
|
290
|
+
prefix=add_prefix("kv_b_proj", prefix),
|
267
291
|
)
|
268
292
|
# O projection.
|
269
293
|
self.o_proj = RowParallelLinear(
|
@@ -271,6 +295,7 @@ class DeepseekV2Attention(nn.Module):
|
|
271
295
|
self.hidden_size,
|
272
296
|
bias=False,
|
273
297
|
quant_config=quant_config,
|
298
|
+
prefix=add_prefix("o_proj", prefix),
|
274
299
|
)
|
275
300
|
rope_scaling["rope_type"] = "deepseek_yarn"
|
276
301
|
self.rotary_emb = get_rope_wrapper(
|
@@ -296,6 +321,7 @@ class DeepseekV2Attention(nn.Module):
|
|
296
321
|
self.scaling,
|
297
322
|
num_kv_heads=self.num_local_heads,
|
298
323
|
layer_id=layer_id,
|
324
|
+
prefix=add_prefix("attn", prefix),
|
299
325
|
)
|
300
326
|
|
301
327
|
def forward(
|
@@ -361,6 +387,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
361
387
|
quant_config: Optional[QuantizationConfig] = None,
|
362
388
|
layer_id=None,
|
363
389
|
use_dp=False,
|
390
|
+
prefix: str = "",
|
364
391
|
) -> None:
|
365
392
|
super().__init__()
|
366
393
|
self.layer_id = layer_id
|
@@ -387,6 +414,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
387
414
|
self.q_lora_rank,
|
388
415
|
bias=False,
|
389
416
|
quant_config=quant_config,
|
417
|
+
prefix=add_prefix("q_a_proj", prefix),
|
390
418
|
)
|
391
419
|
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
392
420
|
self.q_b_proj = ReplicatedLinear(
|
@@ -394,6 +422,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
394
422
|
self.num_heads * self.qk_head_dim,
|
395
423
|
bias=False,
|
396
424
|
quant_config=quant_config,
|
425
|
+
prefix=add_prefix("q_b_proj", prefix),
|
397
426
|
)
|
398
427
|
else:
|
399
428
|
self.q_proj = ReplicatedLinear(
|
@@ -401,12 +430,14 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
401
430
|
self.num_heads * self.qk_head_dim,
|
402
431
|
bias=False,
|
403
432
|
quant_config=quant_config,
|
433
|
+
prefix=add_prefix("q_proj", prefix),
|
404
434
|
)
|
405
435
|
self.kv_b_proj = ReplicatedLinear(
|
406
436
|
self.kv_lora_rank,
|
407
437
|
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
408
438
|
bias=False,
|
409
439
|
quant_config=quant_config,
|
440
|
+
prefix=add_prefix("kv_b_proj", prefix),
|
410
441
|
)
|
411
442
|
# O projection.
|
412
443
|
self.o_proj = ReplicatedLinear(
|
@@ -414,6 +445,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
414
445
|
self.hidden_size,
|
415
446
|
bias=False,
|
416
447
|
quant_config=quant_config,
|
448
|
+
prefix=add_prefix("o_proj", prefix),
|
417
449
|
)
|
418
450
|
else:
|
419
451
|
# For tensor parallel attention
|
@@ -423,6 +455,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
423
455
|
self.q_lora_rank,
|
424
456
|
bias=False,
|
425
457
|
quant_config=quant_config,
|
458
|
+
prefix=add_prefix("q_a_proj", prefix),
|
426
459
|
)
|
427
460
|
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
428
461
|
self.q_b_proj = ColumnParallelLinear(
|
@@ -430,6 +463,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
430
463
|
self.num_heads * self.qk_head_dim,
|
431
464
|
bias=False,
|
432
465
|
quant_config=quant_config,
|
466
|
+
prefix=add_prefix("q_b_proj", prefix),
|
433
467
|
)
|
434
468
|
else:
|
435
469
|
self.q_proj = ColumnParallelLinear(
|
@@ -437,12 +471,14 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
437
471
|
self.num_heads * self.qk_head_dim,
|
438
472
|
bias=False,
|
439
473
|
quant_config=quant_config,
|
474
|
+
prefix=add_prefix("q_proj", prefix),
|
440
475
|
)
|
441
476
|
self.kv_b_proj = ColumnParallelLinear(
|
442
477
|
self.kv_lora_rank,
|
443
478
|
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
444
479
|
bias=False,
|
445
480
|
quant_config=quant_config,
|
481
|
+
prefix=add_prefix("kv_b_proj", prefix),
|
446
482
|
)
|
447
483
|
# O projection.
|
448
484
|
self.o_proj = RowParallelLinear(
|
@@ -450,6 +486,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
450
486
|
self.hidden_size,
|
451
487
|
bias=False,
|
452
488
|
quant_config=quant_config,
|
489
|
+
prefix=add_prefix("o_proj", prefix),
|
453
490
|
)
|
454
491
|
|
455
492
|
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
@@ -457,8 +494,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
457
494
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
458
495
|
bias=False,
|
459
496
|
quant_config=quant_config,
|
460
|
-
|
461
|
-
prefix=f"self_attn.kv_a_proj_with_mqa",
|
497
|
+
prefix=add_prefix("kv_a_proj_with_mqa", prefix),
|
462
498
|
)
|
463
499
|
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
464
500
|
|
@@ -489,6 +525,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
489
525
|
num_kv_heads=1,
|
490
526
|
layer_id=layer_id,
|
491
527
|
v_head_dim=self.kv_lora_rank,
|
528
|
+
prefix=add_prefix("attn_mqa", prefix),
|
492
529
|
)
|
493
530
|
|
494
531
|
self.attn_mha = RadixAttention(
|
@@ -498,6 +535,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
498
535
|
num_kv_heads=self.num_local_heads,
|
499
536
|
layer_id=layer_id,
|
500
537
|
v_head_dim=self.v_head_dim,
|
538
|
+
prefix=add_prefix("attn_mha", prefix),
|
501
539
|
)
|
502
540
|
|
503
541
|
self.w_kc = None
|
@@ -510,20 +548,37 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
510
548
|
hidden_states: torch.Tensor,
|
511
549
|
forward_batch: ForwardBatch,
|
512
550
|
) -> torch.Tensor:
|
513
|
-
|
514
|
-
|
515
|
-
|
551
|
+
|
552
|
+
def no_absorb() -> bool:
|
553
|
+
if global_server_args_dict["enable_flashinfer_mla"]:
|
554
|
+
# Flashinfer MLA: Do not absorb when enabling ragged prefill
|
555
|
+
return (
|
556
|
+
not global_server_args_dict["flashinfer_mla_disable_ragged"]
|
557
|
+
and forward_batch.forward_mode.is_extend()
|
558
|
+
and forward_batch.extend_prefix_lens.sum() == 0
|
559
|
+
)
|
516
560
|
else:
|
517
|
-
|
561
|
+
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
|
562
|
+
return (
|
563
|
+
forward_batch.forward_mode.is_extend()
|
564
|
+
and not forward_batch.forward_mode.is_target_verify()
|
565
|
+
and not forward_batch.forward_mode.is_draft_extend()
|
566
|
+
and forward_batch.extend_prefix_lens.sum() == 0
|
567
|
+
)
|
568
|
+
|
569
|
+
if no_absorb():
|
570
|
+
return self.forward_normal(positions, hidden_states, forward_batch)
|
518
571
|
else:
|
519
|
-
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
572
|
+
if is_hip_:
|
573
|
+
if (
|
574
|
+
os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
|
575
|
+
and forward_batch.forward_mode.is_decode()
|
576
|
+
):
|
577
|
+
return self.forward_absorb_fused_mla_rope(
|
578
|
+
positions, hidden_states, forward_batch
|
579
|
+
)
|
580
|
+
else:
|
581
|
+
return self.forward_absorb(positions, hidden_states, forward_batch)
|
527
582
|
else:
|
528
583
|
return self.forward_absorb(positions, hidden_states, forward_batch)
|
529
584
|
|
@@ -644,6 +699,149 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
644
699
|
|
645
700
|
return output
|
646
701
|
|
702
|
+
def forward_absorb_fused_mla_rope(
|
703
|
+
self,
|
704
|
+
positions: torch.Tensor,
|
705
|
+
hidden_states: torch.Tensor,
|
706
|
+
forward_batch: ForwardBatch,
|
707
|
+
) -> torch.Tensor:
|
708
|
+
enable_rope_fusion = (
|
709
|
+
os.getenv("SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION", "1") == "1"
|
710
|
+
)
|
711
|
+
q_len = hidden_states.shape[0]
|
712
|
+
q_input = hidden_states.new_empty(
|
713
|
+
q_len, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim
|
714
|
+
)
|
715
|
+
if self.q_lora_rank is not None:
|
716
|
+
q = self.q_a_proj(hidden_states)[0]
|
717
|
+
q = self.q_a_layernorm(q)
|
718
|
+
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
719
|
+
else:
|
720
|
+
q = self.q_proj(hidden_states)[0].view(
|
721
|
+
-1, self.num_local_heads, self.qk_head_dim
|
722
|
+
)
|
723
|
+
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
724
|
+
|
725
|
+
if self.w_kc.dtype == torch.float8_e4m3fnuz:
|
726
|
+
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
|
727
|
+
q_nope_out = torch.bmm(
|
728
|
+
q_nope.to(torch.bfloat16).transpose(0, 1),
|
729
|
+
self.w_kc.to(torch.bfloat16) * self.w_scale,
|
730
|
+
)
|
731
|
+
elif self.w_kc.dtype == torch.float8_e4m3fn:
|
732
|
+
q_nope_val, q_nope_scale = input_to_float8(
|
733
|
+
q_nope.transpose(0, 1), torch.float8_e4m3fn
|
734
|
+
)
|
735
|
+
q_nope_out = bmm_fp8(
|
736
|
+
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
|
737
|
+
)
|
738
|
+
else:
|
739
|
+
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
|
740
|
+
q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1)
|
741
|
+
|
742
|
+
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
743
|
+
v_input = latent_cache[..., : self.kv_lora_rank]
|
744
|
+
v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1)
|
745
|
+
k_input = latent_cache.unsqueeze(1)
|
746
|
+
k_input[..., : self.kv_lora_rank] = v_input
|
747
|
+
|
748
|
+
if not enable_rope_fusion:
|
749
|
+
k_pe = k_input[..., self.kv_lora_rank :]
|
750
|
+
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
751
|
+
q_input[..., self.kv_lora_rank :] = q_pe
|
752
|
+
k_input[..., self.kv_lora_rank :] = k_pe
|
753
|
+
k_pe_output = None
|
754
|
+
else:
|
755
|
+
k_pe_output = torch.empty_like(k_input[..., self.kv_lora_rank :])
|
756
|
+
|
757
|
+
q_input[..., self.kv_lora_rank :] = q_pe
|
758
|
+
|
759
|
+
# attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
|
760
|
+
# Use Fused ROPE with use_rope=OFF.
|
761
|
+
attn_output = torch.empty(
|
762
|
+
(q_len, self.num_local_heads, self.kv_lora_rank),
|
763
|
+
dtype=q.dtype,
|
764
|
+
device=q.device,
|
765
|
+
)
|
766
|
+
attn_logits, _, kv_indptr, kv_indices, _, _, _ = (
|
767
|
+
forward_batch.attn_backend.forward_metadata
|
768
|
+
)
|
769
|
+
cos_sin_cache = self.rotary_emb.cos_sin_cache
|
770
|
+
num_kv_split = forward_batch.attn_backend.num_kv_splits
|
771
|
+
sm_scale = self.attn_mqa.scaling
|
772
|
+
if attn_logits is None:
|
773
|
+
attn_logits = torch.empty(
|
774
|
+
(
|
775
|
+
forward_batch.batch_size,
|
776
|
+
self.num_local_heads,
|
777
|
+
num_kv_split,
|
778
|
+
self.kv_lora_rank + 1,
|
779
|
+
),
|
780
|
+
dtype=torch.float32,
|
781
|
+
device=q.device,
|
782
|
+
)
|
783
|
+
|
784
|
+
# save current latent cache.
|
785
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
786
|
+
self.attn_mqa, forward_batch.out_cache_loc, k_input, None
|
787
|
+
)
|
788
|
+
key_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
|
789
|
+
self.attn_mqa.layer_id
|
790
|
+
)
|
791
|
+
val_cache_buf = key_cache_buf[..., : self.kv_lora_rank]
|
792
|
+
|
793
|
+
decode_attention_fwd_grouped_rope(
|
794
|
+
q_input,
|
795
|
+
key_cache_buf,
|
796
|
+
val_cache_buf,
|
797
|
+
attn_output,
|
798
|
+
kv_indptr,
|
799
|
+
kv_indices,
|
800
|
+
k_pe_output,
|
801
|
+
self.kv_lora_rank,
|
802
|
+
self.rotary_emb.rotary_dim,
|
803
|
+
cos_sin_cache,
|
804
|
+
positions,
|
805
|
+
attn_logits,
|
806
|
+
num_kv_split,
|
807
|
+
sm_scale,
|
808
|
+
logit_cap=self.attn_mqa.logit_cap,
|
809
|
+
use_rope=enable_rope_fusion,
|
810
|
+
is_neox_style=self.rotary_emb.is_neox_style,
|
811
|
+
)
|
812
|
+
|
813
|
+
if enable_rope_fusion:
|
814
|
+
k_input[..., self.kv_lora_rank :] = k_pe_output
|
815
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
816
|
+
self.attn_mqa, forward_batch.out_cache_loc, k_input, None
|
817
|
+
)
|
818
|
+
|
819
|
+
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
820
|
+
|
821
|
+
if self.w_vc.dtype == torch.float8_e4m3fnuz:
|
822
|
+
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
|
823
|
+
attn_bmm_output = torch.bmm(
|
824
|
+
attn_output.to(torch.bfloat16).transpose(0, 1),
|
825
|
+
self.w_vc.to(torch.bfloat16) * self.w_scale,
|
826
|
+
)
|
827
|
+
elif self.w_vc.dtype == torch.float8_e4m3fn:
|
828
|
+
attn_output_val, attn_output_scale = input_to_float8(
|
829
|
+
attn_output.transpose(0, 1), torch.float8_e4m3fn
|
830
|
+
)
|
831
|
+
attn_bmm_output = bmm_fp8(
|
832
|
+
attn_output_val,
|
833
|
+
self.w_vc,
|
834
|
+
attn_output_scale,
|
835
|
+
self.w_scale,
|
836
|
+
torch.bfloat16,
|
837
|
+
)
|
838
|
+
else:
|
839
|
+
attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc)
|
840
|
+
attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
|
841
|
+
output, _ = self.o_proj(attn_output)
|
842
|
+
|
843
|
+
return output
|
844
|
+
|
647
845
|
|
648
846
|
def all_gather(
|
649
847
|
input_tensor: torch.Tensor, forward_batch: ForwardBatch, rank, world_size, group
|
@@ -651,16 +849,14 @@ def all_gather(
|
|
651
849
|
if world_size == 1:
|
652
850
|
return input_tensor
|
653
851
|
|
654
|
-
all_lens = forward_batch.
|
655
|
-
max_len = max(forward_batch.
|
852
|
+
all_lens = forward_batch.global_num_tokens_cpu
|
853
|
+
max_len = max(forward_batch.global_num_tokens_cpu)
|
656
854
|
|
657
855
|
padded_tensor = torch.nn.functional.pad(
|
658
856
|
input_tensor, (0, 0, 0, max_len - input_tensor.shape[0])
|
659
857
|
)
|
660
858
|
|
661
|
-
|
662
|
-
forward_batch.gathered_buffer, padded_tensor, group=group
|
663
|
-
)
|
859
|
+
group.all_gather_into_tensor(forward_batch.gathered_buffer, padded_tensor)
|
664
860
|
|
665
861
|
gathered_tensors = torch.concat(
|
666
862
|
[
|
@@ -683,6 +879,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
683
879
|
layer_id: int,
|
684
880
|
quant_config: Optional[QuantizationConfig] = None,
|
685
881
|
is_nextn: bool = False,
|
882
|
+
prefix: str = "",
|
686
883
|
) -> None:
|
687
884
|
super().__init__()
|
688
885
|
self.hidden_size = config.hidden_size
|
@@ -696,7 +893,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
696
893
|
if self.enable_dp_attention:
|
697
894
|
self.tp_rank = get_tensor_model_parallel_rank()
|
698
895
|
self.tp_size = get_tensor_model_parallel_world_size()
|
699
|
-
self.tp_group = get_tp_group()
|
896
|
+
self.tp_group = get_tp_group()
|
700
897
|
if not global_server_args_dict["disable_mla"]:
|
701
898
|
self.self_attn = DeepseekV2AttentionMLA(
|
702
899
|
config=config,
|
@@ -715,6 +912,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
715
912
|
quant_config=quant_config,
|
716
913
|
layer_id=layer_id,
|
717
914
|
use_dp=self.enable_dp_attention,
|
915
|
+
prefix=add_prefix("self_attn", prefix),
|
718
916
|
)
|
719
917
|
else:
|
720
918
|
self.self_attn = DeepseekV2Attention(
|
@@ -733,19 +931,25 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
733
931
|
max_position_embeddings=max_position_embeddings,
|
734
932
|
quant_config=quant_config,
|
735
933
|
layer_id=layer_id,
|
934
|
+
prefix=add_prefix("self_attn", prefix),
|
736
935
|
)
|
737
936
|
if is_nextn or (
|
738
937
|
config.n_routed_experts is not None
|
739
938
|
and layer_id >= config.first_k_dense_replace
|
740
939
|
and layer_id % config.moe_layer_freq == 0
|
741
940
|
):
|
742
|
-
self.mlp = DeepseekV2MoE(
|
941
|
+
self.mlp = DeepseekV2MoE(
|
942
|
+
config=config,
|
943
|
+
quant_config=quant_config,
|
944
|
+
prefix=add_prefix("mlp", prefix),
|
945
|
+
)
|
743
946
|
else:
|
744
947
|
self.mlp = DeepseekV2MLP(
|
745
948
|
hidden_size=config.hidden_size,
|
746
949
|
intermediate_size=config.intermediate_size,
|
747
950
|
hidden_act=config.hidden_act,
|
748
951
|
quant_config=quant_config,
|
952
|
+
prefix=add_prefix("mlp", prefix),
|
749
953
|
)
|
750
954
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
751
955
|
self.post_attention_layernorm = RMSNorm(
|
@@ -797,6 +1001,7 @@ class DeepseekV2Model(nn.Module):
|
|
797
1001
|
self,
|
798
1002
|
config: PretrainedConfig,
|
799
1003
|
quant_config: Optional[QuantizationConfig] = None,
|
1004
|
+
prefix: str = "",
|
800
1005
|
) -> None:
|
801
1006
|
super().__init__()
|
802
1007
|
self.padding_id = config.pad_token_id
|
@@ -813,6 +1018,7 @@ class DeepseekV2Model(nn.Module):
|
|
813
1018
|
config,
|
814
1019
|
layer_id,
|
815
1020
|
quant_config=quant_config,
|
1021
|
+
prefix=add_prefix(f"layers.{layer_id}", prefix),
|
816
1022
|
)
|
817
1023
|
for layer_id in range(config.num_hidden_layers)
|
818
1024
|
]
|
@@ -843,21 +1049,28 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
843
1049
|
self,
|
844
1050
|
config: PretrainedConfig,
|
845
1051
|
quant_config: Optional[QuantizationConfig] = None,
|
1052
|
+
prefix: str = "",
|
846
1053
|
) -> None:
|
847
1054
|
super().__init__()
|
848
1055
|
self.config = config
|
849
1056
|
self.quant_config = quant_config
|
850
|
-
self.model = DeepseekV2Model(
|
1057
|
+
self.model = DeepseekV2Model(
|
1058
|
+
config, quant_config, prefix=add_prefix("model", prefix)
|
1059
|
+
)
|
851
1060
|
if global_server_args_dict["enable_dp_attention"]:
|
852
1061
|
self.lm_head = ReplicatedLinear(
|
853
1062
|
config.hidden_size,
|
854
1063
|
config.vocab_size,
|
855
1064
|
bias=False,
|
1065
|
+
prefix=add_prefix("lm_head", prefix),
|
856
1066
|
)
|
857
1067
|
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
|
858
1068
|
else:
|
859
1069
|
self.lm_head = ParallelLMHead(
|
860
|
-
config.vocab_size,
|
1070
|
+
config.vocab_size,
|
1071
|
+
config.hidden_size,
|
1072
|
+
quant_config=quant_config,
|
1073
|
+
prefix=add_prefix("lm_head", prefix),
|
861
1074
|
)
|
862
1075
|
self.logits_processor = LogitsProcessor(config)
|
863
1076
|
|
@@ -989,6 +1202,18 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
989
1202
|
weight, weight_scale, weight_block_size
|
990
1203
|
)
|
991
1204
|
self_attn.w_scale = scale
|
1205
|
+
if (
|
1206
|
+
hasattr(self.quant_config, "weight_block_size")
|
1207
|
+
and w.dtype == torch.int8
|
1208
|
+
):
|
1209
|
+
weight_block_size = self.quant_config.weight_block_size
|
1210
|
+
if weight_block_size is not None:
|
1211
|
+
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
1212
|
+
weight = w
|
1213
|
+
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
1214
|
+
w = int8_block_dequant(
|
1215
|
+
weight, weight_scale, weight_block_size
|
1216
|
+
).to(torch.bfloat16)
|
992
1217
|
w_kc, w_vc = w.unflatten(
|
993
1218
|
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
994
1219
|
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
@@ -1002,6 +1227,17 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1002
1227
|
if is_hip_:
|
1003
1228
|
self_attn.w_scale *= 2.0
|
1004
1229
|
|
1230
|
+
def get_embed_and_head(self):
|
1231
|
+
return self.model.embed_tokens.weight, self.lm_head.weight
|
1232
|
+
|
1233
|
+
def set_embed_and_head(self, embed, head):
|
1234
|
+
del self.model.embed_tokens.weight
|
1235
|
+
del self.lm_head.weight
|
1236
|
+
self.model.embed_tokens.weight = embed
|
1237
|
+
self.lm_head.weight = head
|
1238
|
+
torch.cuda.empty_cache()
|
1239
|
+
torch.cuda.synchronize()
|
1240
|
+
|
1005
1241
|
|
1006
1242
|
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
|
1007
1243
|
pass
|