sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_offline_throughput.py +19 -0
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +123 -79
- sglang/global_config.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/ir.py +1 -1
- sglang/srt/_custom_ops.py +83 -91
- sglang/srt/configs/load_config.py +4 -1
- sglang/srt/configs/model_config.py +48 -2
- sglang/srt/configs/qwen2_5_vl_config.py +5 -2
- sglang/srt/constrained/base_grammar_backend.py +117 -15
- sglang/srt/constrained/llguidance_backend.py +151 -0
- sglang/srt/constrained/outlines_backend.py +24 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -38
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
- sglang/srt/distributed/parallel_state.py +48 -3
- sglang/srt/entrypoints/engine.py +67 -9
- sglang/srt/entrypoints/http_server.py +190 -41
- sglang/srt/entrypoints/verl_engine.py +147 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/activation.py +11 -0
- sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +302 -414
- sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
- sglang/srt/layers/attention/torch_native_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +13 -8
- sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
- sglang/srt/layers/attention/utils.py +39 -0
- sglang/srt/layers/attention/vision.py +60 -63
- sglang/srt/layers/dp_attention.py +142 -1
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +3 -1
- sglang/srt/layers/logits_processor.py +281 -45
- sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
- sglang/srt/layers/moe/ep_moe/layer.py +140 -28
- sglang/srt/layers/moe/fused_moe_native.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
- sglang/srt/layers/moe/topk.py +13 -4
- sglang/srt/layers/quantization/__init__.py +111 -7
- sglang/srt/layers/quantization/blockwise_int8.py +409 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +69 -28
- sglang/srt/layers/quantization/fp8_utils.py +17 -1
- sglang/srt/layers/quantization/gptq.py +416 -0
- sglang/srt/layers/quantization/int8_kernel.py +327 -0
- sglang/srt/layers/quantization/int8_utils.py +73 -0
- sglang/srt/layers/quantization/modelopt_quant.py +18 -1
- sglang/srt/layers/radix_attention.py +1 -0
- sglang/srt/layers/rotary_embedding.py +0 -1
- sglang/srt/layers/sampler.py +76 -31
- sglang/srt/layers/vocab_parallel_embedding.py +14 -13
- sglang/srt/lora/lora.py +17 -1
- sglang/srt/lora/lora_config.py +5 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/cache_controller.py +193 -62
- sglang/srt/managers/configure_logging.py +2 -1
- sglang/srt/managers/data_parallel_controller.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +124 -102
- sglang/srt/managers/image_processor.py +2 -1
- sglang/srt/managers/io_struct.py +144 -6
- sglang/srt/managers/schedule_batch.py +237 -197
- sglang/srt/managers/schedule_policy.py +29 -29
- sglang/srt/managers/scheduler.py +773 -334
- sglang/srt/managers/session_controller.py +6 -2
- sglang/srt/managers/tokenizer_manager.py +225 -68
- sglang/srt/managers/tp_worker.py +15 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/chunk_cache.py +18 -11
- sglang/srt/mem_cache/hiradix_cache.py +394 -0
- sglang/srt/mem_cache/memory_pool.py +68 -37
- sglang/srt/mem_cache/radix_cache.py +58 -47
- sglang/srt/metrics/collector.py +102 -36
- sglang/srt/model_executor/cuda_graph_runner.py +56 -31
- sglang/srt/model_executor/forward_batch_info.py +49 -16
- sglang/srt/model_executor/model_runner.py +280 -81
- sglang/srt/model_loader/loader.py +3 -3
- sglang/srt/model_loader/weight_utils.py +36 -14
- sglang/srt/models/baichuan.py +31 -6
- sglang/srt/models/chatglm.py +39 -7
- sglang/srt/models/commandr.py +29 -5
- sglang/srt/models/dbrx.py +31 -5
- sglang/srt/models/deepseek.py +43 -6
- sglang/srt/models/deepseek_nextn.py +32 -19
- sglang/srt/models/deepseek_v2.py +265 -32
- sglang/srt/models/exaone.py +19 -9
- sglang/srt/models/gemma.py +22 -8
- sglang/srt/models/gemma2.py +25 -12
- sglang/srt/models/gemma2_reward.py +5 -1
- sglang/srt/models/gpt2.py +28 -13
- sglang/srt/models/gpt_bigcode.py +27 -5
- sglang/srt/models/granite.py +21 -9
- sglang/srt/models/grok.py +21 -4
- sglang/srt/models/internlm2.py +36 -6
- sglang/srt/models/internlm2_reward.py +5 -1
- sglang/srt/models/llama.py +26 -9
- sglang/srt/models/llama_classification.py +5 -1
- sglang/srt/models/llama_eagle.py +17 -4
- sglang/srt/models/llama_embedding.py +5 -1
- sglang/srt/models/llama_reward.py +7 -2
- sglang/srt/models/llava.py +19 -3
- sglang/srt/models/llavavid.py +10 -1
- sglang/srt/models/minicpm.py +26 -2
- sglang/srt/models/minicpm3.py +39 -3
- sglang/srt/models/minicpmv.py +45 -14
- sglang/srt/models/mixtral.py +20 -9
- sglang/srt/models/mixtral_quant.py +50 -8
- sglang/srt/models/mllama.py +57 -11
- sglang/srt/models/olmo.py +34 -6
- sglang/srt/models/olmo2.py +34 -13
- sglang/srt/models/olmoe.py +26 -4
- sglang/srt/models/phi3_small.py +29 -10
- sglang/srt/models/qwen.py +26 -3
- sglang/srt/models/qwen2.py +26 -4
- sglang/srt/models/qwen2_5_vl.py +46 -8
- sglang/srt/models/qwen2_eagle.py +17 -5
- sglang/srt/models/qwen2_moe.py +44 -6
- sglang/srt/models/qwen2_rm.py +78 -0
- sglang/srt/models/qwen2_vl.py +39 -8
- sglang/srt/models/stablelm.py +32 -5
- sglang/srt/models/torch_native_llama.py +5 -2
- sglang/srt/models/xverse.py +21 -9
- sglang/srt/models/xverse_moe.py +45 -7
- sglang/srt/models/yivl.py +2 -1
- sglang/srt/openai_api/adapter.py +109 -24
- sglang/srt/openai_api/protocol.py +17 -1
- sglang/srt/reasoning_parser.py +154 -0
- sglang/srt/sampling/penaltylib/__init__.py +4 -6
- sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
- sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
- sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
- sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
- sglang/srt/sampling/sampling_batch_info.py +79 -157
- sglang/srt/sampling/sampling_params.py +16 -13
- sglang/srt/server_args.py +135 -60
- sglang/srt/speculative/build_eagle_tree.py +8 -9
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -12
- sglang/srt/speculative/eagle_utils.py +92 -57
- sglang/srt/speculative/eagle_worker.py +238 -111
- sglang/srt/speculative/spec_info.py +1 -13
- sglang/srt/utils.py +43 -17
- sglang/srt/warmup.py +47 -0
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/runners.py +389 -126
- sglang/test/send_one.py +88 -0
- sglang/test/test_block_fp8_ep.py +361 -0
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +138 -84
- sglang/utils.py +50 -60
- sglang/version.py +1 -1
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/METADATA +22 -15
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/RECORD +200 -166
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/WHEEL +1 -1
- sglang/bench_latency.py +0 -1
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
- sglang/test/srt/sampling/penaltylib/utils.py +0 -344
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/top_level.txt +0 -0
sglang/srt/models/qwen2_moe.py
CHANGED
@@ -46,6 +46,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
46
46
|
)
|
47
47
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
48
48
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
49
|
+
from sglang.srt.utils import add_prefix
|
49
50
|
|
50
51
|
|
51
52
|
class Qwen2MoeMLP(nn.Module):
|
@@ -56,10 +57,15 @@ class Qwen2MoeMLP(nn.Module):
|
|
56
57
|
hidden_act: str,
|
57
58
|
quant_config: Optional[QuantizationConfig] = None,
|
58
59
|
reduce_results: bool = True,
|
60
|
+
prefix: str = "",
|
59
61
|
) -> None:
|
60
62
|
super().__init__()
|
61
63
|
self.gate_up_proj = MergedColumnParallelLinear(
|
62
|
-
hidden_size,
|
64
|
+
hidden_size,
|
65
|
+
[intermediate_size] * 2,
|
66
|
+
bias=False,
|
67
|
+
quant_config=quant_config,
|
68
|
+
prefix=add_prefix("gate_up_proj", prefix),
|
63
69
|
)
|
64
70
|
self.down_proj = RowParallelLinear(
|
65
71
|
intermediate_size,
|
@@ -67,6 +73,7 @@ class Qwen2MoeMLP(nn.Module):
|
|
67
73
|
bias=False,
|
68
74
|
quant_config=quant_config,
|
69
75
|
reduce_results=reduce_results,
|
76
|
+
prefix=add_prefix("down_proj", prefix),
|
70
77
|
)
|
71
78
|
if hidden_act != "silu":
|
72
79
|
raise ValueError(
|
@@ -87,6 +94,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|
87
94
|
self,
|
88
95
|
config: PretrainedConfig,
|
89
96
|
quant_config: Optional[QuantizationConfig] = None,
|
97
|
+
prefix: str = "",
|
90
98
|
):
|
91
99
|
super().__init__()
|
92
100
|
self.tp_size = get_tensor_model_parallel_world_size()
|
@@ -105,10 +113,15 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|
105
113
|
reduce_results=False,
|
106
114
|
renormalize=config.norm_topk_prob,
|
107
115
|
quant_config=quant_config,
|
116
|
+
prefix=add_prefix("experts", prefix),
|
108
117
|
)
|
109
118
|
|
110
119
|
self.gate = ReplicatedLinear(
|
111
|
-
config.hidden_size,
|
120
|
+
config.hidden_size,
|
121
|
+
config.num_experts,
|
122
|
+
bias=False,
|
123
|
+
quant_config=None,
|
124
|
+
prefix=add_prefix("gate", prefix),
|
112
125
|
)
|
113
126
|
if config.shared_expert_intermediate_size > 0:
|
114
127
|
self.shared_expert = Qwen2MoeMLP(
|
@@ -117,6 +130,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|
117
130
|
hidden_act=config.hidden_act,
|
118
131
|
quant_config=quant_config,
|
119
132
|
reduce_results=False,
|
133
|
+
prefix=add_prefix("shared_expert", prefix),
|
120
134
|
)
|
121
135
|
else:
|
122
136
|
self.shared_expert = None
|
@@ -157,6 +171,7 @@ class Qwen2MoeAttention(nn.Module):
|
|
157
171
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
158
172
|
max_position_embeddings: int = 8192,
|
159
173
|
quant_config: Optional[QuantizationConfig] = None,
|
174
|
+
prefix: str = "",
|
160
175
|
) -> None:
|
161
176
|
super().__init__()
|
162
177
|
self.hidden_size = hidden_size
|
@@ -188,6 +203,7 @@ class Qwen2MoeAttention(nn.Module):
|
|
188
203
|
self.total_num_kv_heads,
|
189
204
|
bias=True,
|
190
205
|
quant_config=quant_config,
|
206
|
+
prefix=add_prefix("qkv_proj", prefix),
|
191
207
|
)
|
192
208
|
|
193
209
|
self.o_proj = RowParallelLinear(
|
@@ -195,6 +211,7 @@ class Qwen2MoeAttention(nn.Module):
|
|
195
211
|
hidden_size,
|
196
212
|
bias=False,
|
197
213
|
quant_config=quant_config,
|
214
|
+
prefix=add_prefix("o_proj", prefix),
|
198
215
|
)
|
199
216
|
|
200
217
|
self.rotary_emb = get_rope(
|
@@ -210,6 +227,7 @@ class Qwen2MoeAttention(nn.Module):
|
|
210
227
|
self.scaling,
|
211
228
|
num_kv_heads=self.num_kv_heads,
|
212
229
|
layer_id=layer_id,
|
230
|
+
prefix=add_prefix("attn", prefix),
|
213
231
|
)
|
214
232
|
|
215
233
|
def forward(
|
@@ -232,6 +250,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
|
232
250
|
config: PretrainedConfig,
|
233
251
|
layer_id: int,
|
234
252
|
quant_config: Optional[QuantizationConfig] = None,
|
253
|
+
prefix: str = "",
|
235
254
|
) -> None:
|
236
255
|
super().__init__()
|
237
256
|
self.hidden_size = config.hidden_size
|
@@ -247,6 +266,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
|
247
266
|
rope_scaling=rope_scaling,
|
248
267
|
max_position_embeddings=max_position_embeddings,
|
249
268
|
quant_config=quant_config,
|
269
|
+
prefix=add_prefix("self_attn", prefix),
|
250
270
|
)
|
251
271
|
|
252
272
|
# Note: Qwen/Qwen2-57B-A14B-Instruct does not have
|
@@ -257,13 +277,18 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
|
257
277
|
if (layer_id not in mlp_only_layers) and (
|
258
278
|
config.num_experts > 0 and (layer_id + 1) % config.decoder_sparse_step == 0
|
259
279
|
):
|
260
|
-
self.mlp = Qwen2MoeSparseMoeBlock(
|
280
|
+
self.mlp = Qwen2MoeSparseMoeBlock(
|
281
|
+
config=config,
|
282
|
+
quant_config=quant_config,
|
283
|
+
prefix=add_prefix("mlp", prefix),
|
284
|
+
)
|
261
285
|
else:
|
262
286
|
self.mlp = Qwen2MoeMLP(
|
263
287
|
hidden_size=config.hidden_size,
|
264
288
|
intermediate_size=config.intermediate_size,
|
265
289
|
hidden_act=config.hidden_act,
|
266
290
|
quant_config=quant_config,
|
291
|
+
prefix=add_prefix("mlp", prefix),
|
267
292
|
)
|
268
293
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
269
294
|
self.post_attention_layernorm = RMSNorm(
|
@@ -300,6 +325,7 @@ class Qwen2MoeModel(nn.Module):
|
|
300
325
|
self,
|
301
326
|
config: PretrainedConfig,
|
302
327
|
quant_config: Optional[QuantizationConfig] = None,
|
328
|
+
prefix: str = "",
|
303
329
|
) -> None:
|
304
330
|
super().__init__()
|
305
331
|
self.padding_idx = config.pad_token_id
|
@@ -308,10 +334,16 @@ class Qwen2MoeModel(nn.Module):
|
|
308
334
|
self.embed_tokens = VocabParallelEmbedding(
|
309
335
|
config.vocab_size,
|
310
336
|
config.hidden_size,
|
337
|
+
prefix=add_prefix("embed_tokens", prefix),
|
311
338
|
)
|
312
339
|
self.layers = nn.ModuleList(
|
313
340
|
[
|
314
|
-
Qwen2MoeDecoderLayer(
|
341
|
+
Qwen2MoeDecoderLayer(
|
342
|
+
config,
|
343
|
+
layer_id,
|
344
|
+
quant_config=quant_config,
|
345
|
+
prefix=add_prefix(f"layers.{layer_id}", prefix),
|
346
|
+
)
|
315
347
|
for layer_id in range(config.num_hidden_layers)
|
316
348
|
]
|
317
349
|
)
|
@@ -346,13 +378,19 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|
346
378
|
self,
|
347
379
|
config: PretrainedConfig,
|
348
380
|
quant_config: Optional[QuantizationConfig] = None,
|
381
|
+
prefix: str = "",
|
349
382
|
) -> None:
|
350
383
|
super().__init__()
|
351
384
|
self.config = config
|
352
385
|
self.quant_config = quant_config
|
353
|
-
self.model = Qwen2MoeModel(
|
386
|
+
self.model = Qwen2MoeModel(
|
387
|
+
config, quant_config, prefix=add_prefix("model", prefix)
|
388
|
+
)
|
354
389
|
self.lm_head = ParallelLMHead(
|
355
|
-
config.vocab_size,
|
390
|
+
config.vocab_size,
|
391
|
+
config.hidden_size,
|
392
|
+
quant_config=quant_config,
|
393
|
+
prefix=add_prefix("lm_head", prefix),
|
356
394
|
)
|
357
395
|
self.logits_processor = LogitsProcessor(config)
|
358
396
|
|
@@ -0,0 +1,78 @@
|
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
|
15
|
+
from typing import Iterable, Optional, Tuple
|
16
|
+
|
17
|
+
import torch
|
18
|
+
from torch import nn
|
19
|
+
from transformers import Qwen2Config
|
20
|
+
|
21
|
+
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
|
22
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
23
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
24
|
+
from sglang.srt.models.qwen2 import Qwen2ForCausalLM, Qwen2Model
|
25
|
+
from sglang.srt.utils import add_prefix
|
26
|
+
|
27
|
+
|
28
|
+
class Qwen2ForRewardModel(nn.Module):
|
29
|
+
def __init__(
|
30
|
+
self,
|
31
|
+
config: Qwen2Config,
|
32
|
+
quant_config: Optional[QuantizationConfig] = None,
|
33
|
+
prefix: str = "",
|
34
|
+
) -> None:
|
35
|
+
super().__init__()
|
36
|
+
self.config = config
|
37
|
+
self.quant_config = quant_config
|
38
|
+
self.num_labels = 1
|
39
|
+
self.model = Qwen2Model(
|
40
|
+
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
41
|
+
)
|
42
|
+
self.score = nn.Sequential(
|
43
|
+
nn.Linear(config.hidden_size, config.hidden_size),
|
44
|
+
nn.ReLU(),
|
45
|
+
nn.Linear(config.hidden_size, self.num_labels),
|
46
|
+
)
|
47
|
+
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False)
|
48
|
+
|
49
|
+
self.eos_token_id = config.eos_token_id
|
50
|
+
|
51
|
+
@torch.no_grad()
|
52
|
+
def forward(
|
53
|
+
self,
|
54
|
+
input_ids: torch.Tensor,
|
55
|
+
positions: torch.Tensor,
|
56
|
+
forward_batch: ForwardBatch,
|
57
|
+
input_embeds: torch.Tensor = None,
|
58
|
+
get_embedding: bool = True,
|
59
|
+
) -> EmbeddingPoolerOutput:
|
60
|
+
assert get_embedding, "Qwen2ForRewardModel is only used for embedding"
|
61
|
+
|
62
|
+
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
63
|
+
logits = self.score(hidden_states)
|
64
|
+
pooled_logits = self.pooler(logits, forward_batch).embeddings
|
65
|
+
|
66
|
+
return EmbeddingPoolerOutput(pooled_logits)
|
67
|
+
|
68
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
69
|
+
# Filter out lm_head weights of Qwen2ForCausalLM
|
70
|
+
filtered_weights = [
|
71
|
+
(name, w) for name, w in weights if not name.startswith("lm_head")
|
72
|
+
]
|
73
|
+
return Qwen2ForCausalLM.load_weights(self, filtered_weights)
|
74
|
+
|
75
|
+
|
76
|
+
EntryClass = [
|
77
|
+
Qwen2ForRewardModel,
|
78
|
+
]
|
sglang/srt/models/qwen2_vl.py
CHANGED
@@ -46,6 +46,7 @@ from sglang.srt.managers.schedule_batch import ImageInputs
|
|
46
46
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
47
47
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
48
48
|
from sglang.srt.models.qwen2 import Qwen2Model
|
49
|
+
from sglang.srt.utils import add_prefix
|
49
50
|
|
50
51
|
logger = logging.getLogger(__name__)
|
51
52
|
|
@@ -91,14 +92,21 @@ class Qwen2VisionMLP(nn.Module):
|
|
91
92
|
hidden_features: int = None,
|
92
93
|
act_layer: Type[nn.Module] = QuickGELU,
|
93
94
|
quant_config: Optional[QuantizationConfig] = None,
|
95
|
+
prefix: str = "",
|
94
96
|
):
|
95
97
|
super().__init__()
|
96
98
|
self.fc1 = ColumnParallelLinear(
|
97
|
-
in_features,
|
99
|
+
in_features,
|
100
|
+
hidden_features,
|
101
|
+
quant_config=quant_config,
|
102
|
+
prefix=add_prefix("fc1", prefix),
|
98
103
|
)
|
99
104
|
self.act = act_layer()
|
100
105
|
self.fc2 = RowParallelLinear(
|
101
|
-
hidden_features,
|
106
|
+
hidden_features,
|
107
|
+
in_features,
|
108
|
+
quant_config=quant_config,
|
109
|
+
prefix=add_prefix("fc2", prefix),
|
102
110
|
)
|
103
111
|
|
104
112
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
@@ -119,6 +127,7 @@ class Qwen2VisionBlock(nn.Module):
|
|
119
127
|
norm_layer: Type[nn.Module] = None,
|
120
128
|
attn_implementation: Optional[str] = "sdpa",
|
121
129
|
quant_config: Optional[QuantizationConfig] = None,
|
130
|
+
prefix: str = "",
|
122
131
|
) -> None:
|
123
132
|
super().__init__()
|
124
133
|
if norm_layer is None:
|
@@ -145,9 +154,14 @@ class Qwen2VisionBlock(nn.Module):
|
|
145
154
|
use_full_precision_softmax=use_full_precision_softmax,
|
146
155
|
flatten_batch=True,
|
147
156
|
quant_config=quant_config,
|
157
|
+
prefix=add_prefix("attn", prefix),
|
148
158
|
)
|
149
159
|
self.mlp = Qwen2VisionMLP(
|
150
|
-
dim,
|
160
|
+
dim,
|
161
|
+
mlp_hidden_dim,
|
162
|
+
act_layer=act_layer,
|
163
|
+
quant_config=quant_config,
|
164
|
+
prefix=add_prefix("mlp", prefix),
|
151
165
|
)
|
152
166
|
|
153
167
|
def forward(
|
@@ -199,6 +213,7 @@ class Qwen2VisionPatchMerger(nn.Module):
|
|
199
213
|
norm_layer: Type[nn.Module] = None,
|
200
214
|
spatial_merge_size: int = 2,
|
201
215
|
quant_config: Optional[QuantizationConfig] = None,
|
216
|
+
prefix: str = "",
|
202
217
|
) -> None:
|
203
218
|
super().__init__()
|
204
219
|
self.hidden_size = context_dim * (spatial_merge_size**2)
|
@@ -212,10 +227,15 @@ class Qwen2VisionPatchMerger(nn.Module):
|
|
212
227
|
self.hidden_size,
|
213
228
|
bias=True,
|
214
229
|
quant_config=quant_config,
|
230
|
+
prefix=add_prefix("mlp.0", prefix),
|
215
231
|
),
|
216
232
|
nn.GELU(),
|
217
233
|
RowParallelLinear(
|
218
|
-
self.hidden_size,
|
234
|
+
self.hidden_size,
|
235
|
+
d_model,
|
236
|
+
bias=True,
|
237
|
+
quant_config=quant_config,
|
238
|
+
prefix=add_prefix("mlp.2", prefix),
|
219
239
|
),
|
220
240
|
]
|
221
241
|
)
|
@@ -273,6 +293,7 @@ class Qwen2VisionTransformer(nn.Module):
|
|
273
293
|
vision_config: Qwen2VLVisionConfig,
|
274
294
|
norm_eps: float = 1e-6,
|
275
295
|
quant_config: Optional[QuantizationConfig] = None,
|
296
|
+
prefix: str = "",
|
276
297
|
) -> None:
|
277
298
|
super().__init__()
|
278
299
|
|
@@ -307,8 +328,9 @@ class Qwen2VisionTransformer(nn.Module):
|
|
307
328
|
norm_layer=norm_layer,
|
308
329
|
attn_implementation="sdpa",
|
309
330
|
quant_config=quant_config,
|
331
|
+
prefix=add_prefix(f"blocks.{i}", prefix),
|
310
332
|
)
|
311
|
-
for
|
333
|
+
for i in range(depth)
|
312
334
|
]
|
313
335
|
)
|
314
336
|
self.merger = Qwen2VisionPatchMerger(
|
@@ -316,6 +338,7 @@ class Qwen2VisionTransformer(nn.Module):
|
|
316
338
|
context_dim=embed_dim,
|
317
339
|
norm_layer=norm_layer,
|
318
340
|
quant_config=quant_config,
|
341
|
+
prefix=add_prefix("merger", prefix),
|
319
342
|
)
|
320
343
|
|
321
344
|
@property
|
@@ -440,6 +463,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|
440
463
|
self,
|
441
464
|
config: Qwen2VLConfig,
|
442
465
|
quant_config: Optional[QuantizationConfig] = None,
|
466
|
+
prefix: str = "",
|
443
467
|
) -> None:
|
444
468
|
super().__init__()
|
445
469
|
|
@@ -450,15 +474,21 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|
450
474
|
# NOTE: Qwen2-VL vision encoder does not support any
|
451
475
|
# quantization method now.
|
452
476
|
quant_config=None,
|
477
|
+
prefix=add_prefix("visual", prefix),
|
453
478
|
)
|
454
479
|
|
455
|
-
self.model = Qwen2Model(
|
480
|
+
self.model = Qwen2Model(
|
481
|
+
config, quant_config, prefix=add_prefix("model", prefix)
|
482
|
+
)
|
456
483
|
|
457
484
|
if config.tie_word_embeddings:
|
458
485
|
self.lm_head = self.model.embed_tokens
|
459
486
|
else:
|
460
487
|
self.lm_head = ParallelLMHead(
|
461
|
-
config.vocab_size,
|
488
|
+
config.vocab_size,
|
489
|
+
config.hidden_size,
|
490
|
+
quant_config=quant_config,
|
491
|
+
prefix=add_prefix("lm_head", prefix),
|
462
492
|
)
|
463
493
|
|
464
494
|
self.logits_processor = LogitsProcessor(config)
|
@@ -559,7 +589,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|
559
589
|
]
|
560
590
|
image_embeds_offset += num_image_tokens
|
561
591
|
|
562
|
-
input_ids = None
|
563
592
|
hidden_states = self.model(
|
564
593
|
input_ids=input_ids,
|
565
594
|
positions=positions,
|
@@ -587,6 +616,8 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|
587
616
|
for name, loaded_weight in weights:
|
588
617
|
if "rotary_emb.inv_freq" in name:
|
589
618
|
continue
|
619
|
+
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
620
|
+
continue
|
590
621
|
|
591
622
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
592
623
|
if weight_name not in name:
|
sglang/srt/models/stablelm.py
CHANGED
@@ -42,6 +42,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
42
42
|
)
|
43
43
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
44
44
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
45
|
+
from sglang.srt.utils import add_prefix
|
45
46
|
|
46
47
|
|
47
48
|
class StablelmMLP(nn.Module):
|
@@ -49,6 +50,7 @@ class StablelmMLP(nn.Module):
|
|
49
50
|
self,
|
50
51
|
config: PretrainedConfig,
|
51
52
|
quant_config: Optional[QuantizationConfig] = None,
|
53
|
+
prefix: str = "",
|
52
54
|
) -> None:
|
53
55
|
super().__init__()
|
54
56
|
self.config = config
|
@@ -59,12 +61,14 @@ class StablelmMLP(nn.Module):
|
|
59
61
|
[config.intermediate_size] * 2,
|
60
62
|
bias=False,
|
61
63
|
quant_config=quant_config,
|
64
|
+
prefix=add_prefix("gate_up_proj", prefix),
|
62
65
|
)
|
63
66
|
self.down_proj = RowParallelLinear(
|
64
67
|
config.intermediate_size,
|
65
68
|
config.hidden_size,
|
66
69
|
bias=False,
|
67
70
|
quant_config=quant_config,
|
71
|
+
prefix=add_prefix("down_proj", prefix),
|
68
72
|
)
|
69
73
|
self.act_fn = SiluAndMul()
|
70
74
|
|
@@ -81,6 +85,7 @@ class StablelmAttention(nn.Module):
|
|
81
85
|
config: PretrainedConfig,
|
82
86
|
layer_id: int = 0,
|
83
87
|
quant_config: Optional[QuantizationConfig] = None,
|
88
|
+
prefix: str = "",
|
84
89
|
) -> None:
|
85
90
|
super().__init__()
|
86
91
|
self.config = config
|
@@ -122,11 +127,15 @@ class StablelmAttention(nn.Module):
|
|
122
127
|
self.total_num_heads,
|
123
128
|
self.total_num_key_value_heads,
|
124
129
|
self.qkv_bias,
|
130
|
+
quant_config=quant_config,
|
131
|
+
prefix=add_prefix("qkv_proj", prefix),
|
125
132
|
)
|
126
133
|
self.o_proj = RowParallelLinear(
|
127
134
|
self.total_num_heads * self.head_dim,
|
128
135
|
self.hidden_size,
|
129
136
|
bias=False,
|
137
|
+
quant_config=quant_config,
|
138
|
+
prefix=add_prefix("o_proj", prefix),
|
130
139
|
)
|
131
140
|
self.rotary_emb = get_rope(
|
132
141
|
self.head_dim,
|
@@ -140,6 +149,7 @@ class StablelmAttention(nn.Module):
|
|
140
149
|
self.scaling,
|
141
150
|
num_kv_heads=self.num_key_value_heads,
|
142
151
|
layer_id=layer_id,
|
152
|
+
prefix=add_prefix("attn", prefix),
|
143
153
|
)
|
144
154
|
|
145
155
|
def forward(
|
@@ -162,10 +172,15 @@ class StablelmDecoderLayer(nn.Module):
|
|
162
172
|
config: PretrainedConfig,
|
163
173
|
layer_id: int = 0,
|
164
174
|
quant_config: Optional[QuantizationConfig] = None,
|
175
|
+
prefix: str = "",
|
165
176
|
) -> None:
|
166
177
|
super().__init__()
|
167
|
-
self.self_attn = StablelmAttention(
|
168
|
-
|
178
|
+
self.self_attn = StablelmAttention(
|
179
|
+
config, layer_id=layer_id, prefix=add_prefix("self_attn", prefix)
|
180
|
+
)
|
181
|
+
self.mlp = StablelmMLP(
|
182
|
+
config, quant_config=quant_config, prefix=add_prefix("mlp", prefix)
|
183
|
+
)
|
169
184
|
norm_eps = getattr(config, "norm_eps", getattr(config, "layer_norm_eps", 1e-05))
|
170
185
|
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
|
171
186
|
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
|
@@ -200,15 +215,22 @@ class StableLMEpochModel(nn.Module):
|
|
200
215
|
self,
|
201
216
|
config: PretrainedConfig,
|
202
217
|
quant_config: Optional[QuantizationConfig] = None,
|
218
|
+
prefix: str = "",
|
203
219
|
) -> None:
|
204
220
|
super().__init__()
|
205
221
|
self.embed_tokens = VocabParallelEmbedding(
|
206
222
|
config.vocab_size,
|
207
223
|
config.hidden_size,
|
224
|
+
prefix=add_prefix("embed_tokens", prefix),
|
208
225
|
)
|
209
226
|
self.layers = nn.ModuleList(
|
210
227
|
[
|
211
|
-
StablelmDecoderLayer(
|
228
|
+
StablelmDecoderLayer(
|
229
|
+
config,
|
230
|
+
i,
|
231
|
+
quant_config=quant_config,
|
232
|
+
prefix=add_prefix(f"layers.{i}", prefix),
|
233
|
+
)
|
212
234
|
for i in range(config.num_hidden_layers)
|
213
235
|
]
|
214
236
|
)
|
@@ -242,12 +264,17 @@ class StableLmForCausalLM(nn.Module):
|
|
242
264
|
self,
|
243
265
|
config: PretrainedConfig,
|
244
266
|
quant_config: Optional[QuantizationConfig] = None,
|
267
|
+
prefix: str = "",
|
245
268
|
) -> None:
|
246
269
|
super().__init__()
|
247
270
|
self.config = config
|
248
271
|
self.quant_config = quant_config
|
249
|
-
self.model = StableLMEpochModel(
|
250
|
-
|
272
|
+
self.model = StableLMEpochModel(
|
273
|
+
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
274
|
+
)
|
275
|
+
self.lm_head = ParallelLMHead(
|
276
|
+
config.vocab_size, config.hidden_size, prefix=add_prefix("lm_head", prefix)
|
277
|
+
)
|
251
278
|
self.logits_processor = LogitsProcessor(config)
|
252
279
|
|
253
280
|
@torch.no_grad()
|
@@ -64,6 +64,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
64
64
|
)
|
65
65
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
66
66
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
67
|
+
from sglang.srt.utils import add_prefix
|
67
68
|
|
68
69
|
tp_size = get_tensor_model_parallel_world_size()
|
69
70
|
tp_rank = get_tensor_model_parallel_rank()
|
@@ -294,14 +295,14 @@ class LlamaDecoderLayer(nn.Module):
|
|
294
295
|
rope_is_neox_style=rope_is_neox_style,
|
295
296
|
max_position_embeddings=max_position_embeddings,
|
296
297
|
quant_config=quant_config,
|
297
|
-
prefix=
|
298
|
+
prefix=add_prefix("self_attn", prefix),
|
298
299
|
)
|
299
300
|
self.mlp = LlamaMLP(
|
300
301
|
hidden_size=self.hidden_size,
|
301
302
|
intermediate_size=config.intermediate_size,
|
302
303
|
hidden_act=config.hidden_act,
|
303
304
|
quant_config=quant_config,
|
304
|
-
prefix=
|
305
|
+
prefix=add_prefix("mlp", prefix),
|
305
306
|
)
|
306
307
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
307
308
|
self.post_attention_layernorm = RMSNorm(
|
@@ -486,6 +487,8 @@ class TorchNativeLlamaForCausalLM(nn.Module):
|
|
486
487
|
continue
|
487
488
|
if name.startswith("model.vision_tower") and name not in params_dict:
|
488
489
|
continue
|
490
|
+
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
491
|
+
continue
|
489
492
|
|
490
493
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
491
494
|
if weight_name not in name:
|
sglang/srt/models/xverse.py
CHANGED
@@ -40,6 +40,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
40
40
|
)
|
41
41
|
from sglang.srt.model_executor.model_runner import ForwardBatch
|
42
42
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
43
|
+
from sglang.srt.utils import add_prefix
|
43
44
|
|
44
45
|
|
45
46
|
class XverseMLP(nn.Module):
|
@@ -57,14 +58,14 @@ class XverseMLP(nn.Module):
|
|
57
58
|
[intermediate_size] * 2,
|
58
59
|
bias=False,
|
59
60
|
quant_config=quant_config,
|
60
|
-
prefix=
|
61
|
+
prefix=add_prefix("gate_up_proj", prefix),
|
61
62
|
)
|
62
63
|
self.down_proj = RowParallelLinear(
|
63
64
|
intermediate_size,
|
64
65
|
hidden_size,
|
65
66
|
bias=False,
|
66
67
|
quant_config=quant_config,
|
67
|
-
prefix=
|
68
|
+
prefix=add_prefix("down_proj", prefix),
|
68
69
|
)
|
69
70
|
if hidden_act != "silu":
|
70
71
|
raise ValueError(
|
@@ -128,14 +129,14 @@ class XverseAttention(nn.Module):
|
|
128
129
|
self.total_num_kv_heads,
|
129
130
|
bias=False,
|
130
131
|
quant_config=quant_config,
|
131
|
-
prefix=
|
132
|
+
prefix=add_prefix("qkv_proj", prefix),
|
132
133
|
)
|
133
134
|
self.o_proj = RowParallelLinear(
|
134
135
|
self.total_num_heads * self.head_dim,
|
135
136
|
hidden_size,
|
136
137
|
bias=False,
|
137
138
|
quant_config=quant_config,
|
138
|
-
prefix=
|
139
|
+
prefix=add_prefix("o_proj", prefix),
|
139
140
|
)
|
140
141
|
|
141
142
|
self.rotary_emb = get_rope(
|
@@ -152,6 +153,7 @@ class XverseAttention(nn.Module):
|
|
152
153
|
self.scaling,
|
153
154
|
num_kv_heads=self.num_kv_heads,
|
154
155
|
layer_id=layer_id,
|
156
|
+
prefix=add_prefix("attn", prefix),
|
155
157
|
)
|
156
158
|
|
157
159
|
def forward(
|
@@ -202,14 +204,14 @@ class XverseDecoderLayer(nn.Module):
|
|
202
204
|
rope_is_neox_style=rope_is_neox_style,
|
203
205
|
max_position_embeddings=max_position_embeddings,
|
204
206
|
quant_config=quant_config,
|
205
|
-
prefix=
|
207
|
+
prefix=add_prefix("self_attn", prefix),
|
206
208
|
)
|
207
209
|
self.mlp = XverseMLP(
|
208
210
|
hidden_size=self.hidden_size,
|
209
211
|
intermediate_size=config.intermediate_size,
|
210
212
|
hidden_act=config.hidden_act,
|
211
213
|
quant_config=quant_config,
|
212
|
-
prefix=
|
214
|
+
prefix=add_prefix("mlp", prefix),
|
213
215
|
)
|
214
216
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
215
217
|
self.post_attention_layernorm = RMSNorm(
|
@@ -246,6 +248,7 @@ class XverseModel(nn.Module):
|
|
246
248
|
self,
|
247
249
|
config: LlamaConfig,
|
248
250
|
quant_config: Optional[QuantizationConfig] = None,
|
251
|
+
prefix: str = "",
|
249
252
|
) -> None:
|
250
253
|
super().__init__()
|
251
254
|
self.config = config
|
@@ -254,11 +257,15 @@ class XverseModel(nn.Module):
|
|
254
257
|
self.embed_tokens = VocabParallelEmbedding(
|
255
258
|
config.vocab_size,
|
256
259
|
config.hidden_size,
|
260
|
+
prefix=add_prefix("embed_tokens", prefix),
|
257
261
|
)
|
258
262
|
self.layers = nn.ModuleList(
|
259
263
|
[
|
260
264
|
XverseDecoderLayer(
|
261
|
-
config,
|
265
|
+
config,
|
266
|
+
i,
|
267
|
+
quant_config=quant_config,
|
268
|
+
prefix=add_prefix(f"layers.{i}", prefix),
|
262
269
|
)
|
263
270
|
for i in range(config.num_hidden_layers)
|
264
271
|
]
|
@@ -295,12 +302,17 @@ class XverseForCausalLM(nn.Module):
|
|
295
302
|
self,
|
296
303
|
config: LlamaConfig,
|
297
304
|
quant_config: Optional[QuantizationConfig] = None,
|
305
|
+
prefix: str = "",
|
298
306
|
) -> None:
|
299
307
|
super().__init__()
|
300
308
|
self.config = config
|
301
309
|
self.quant_config = quant_config
|
302
|
-
self.model = XverseModel(
|
303
|
-
|
310
|
+
self.model = XverseModel(
|
311
|
+
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
312
|
+
)
|
313
|
+
self.lm_head = ParallelLMHead(
|
314
|
+
config.vocab_size, config.hidden_size, prefix=add_prefix("lm_head", prefix)
|
315
|
+
)
|
304
316
|
self.logits_processor = LogitsProcessor(config)
|
305
317
|
|
306
318
|
@torch.no_grad()
|