sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_offline_throughput.py +19 -0
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +123 -79
- sglang/global_config.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/ir.py +1 -1
- sglang/srt/_custom_ops.py +83 -91
- sglang/srt/configs/load_config.py +4 -1
- sglang/srt/configs/model_config.py +48 -2
- sglang/srt/configs/qwen2_5_vl_config.py +5 -2
- sglang/srt/constrained/base_grammar_backend.py +117 -15
- sglang/srt/constrained/llguidance_backend.py +151 -0
- sglang/srt/constrained/outlines_backend.py +24 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -38
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
- sglang/srt/distributed/parallel_state.py +48 -3
- sglang/srt/entrypoints/engine.py +67 -9
- sglang/srt/entrypoints/http_server.py +190 -41
- sglang/srt/entrypoints/verl_engine.py +147 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/activation.py +11 -0
- sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +302 -414
- sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
- sglang/srt/layers/attention/torch_native_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +13 -8
- sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
- sglang/srt/layers/attention/utils.py +39 -0
- sglang/srt/layers/attention/vision.py +60 -63
- sglang/srt/layers/dp_attention.py +142 -1
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +3 -1
- sglang/srt/layers/logits_processor.py +281 -45
- sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
- sglang/srt/layers/moe/ep_moe/layer.py +140 -28
- sglang/srt/layers/moe/fused_moe_native.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
- sglang/srt/layers/moe/topk.py +13 -4
- sglang/srt/layers/quantization/__init__.py +111 -7
- sglang/srt/layers/quantization/blockwise_int8.py +409 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +69 -28
- sglang/srt/layers/quantization/fp8_utils.py +17 -1
- sglang/srt/layers/quantization/gptq.py +416 -0
- sglang/srt/layers/quantization/int8_kernel.py +327 -0
- sglang/srt/layers/quantization/int8_utils.py +73 -0
- sglang/srt/layers/quantization/modelopt_quant.py +18 -1
- sglang/srt/layers/radix_attention.py +1 -0
- sglang/srt/layers/rotary_embedding.py +0 -1
- sglang/srt/layers/sampler.py +76 -31
- sglang/srt/layers/vocab_parallel_embedding.py +14 -13
- sglang/srt/lora/lora.py +17 -1
- sglang/srt/lora/lora_config.py +5 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/cache_controller.py +193 -62
- sglang/srt/managers/configure_logging.py +2 -1
- sglang/srt/managers/data_parallel_controller.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +124 -102
- sglang/srt/managers/image_processor.py +2 -1
- sglang/srt/managers/io_struct.py +144 -6
- sglang/srt/managers/schedule_batch.py +237 -197
- sglang/srt/managers/schedule_policy.py +29 -29
- sglang/srt/managers/scheduler.py +773 -334
- sglang/srt/managers/session_controller.py +6 -2
- sglang/srt/managers/tokenizer_manager.py +225 -68
- sglang/srt/managers/tp_worker.py +15 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/chunk_cache.py +18 -11
- sglang/srt/mem_cache/hiradix_cache.py +394 -0
- sglang/srt/mem_cache/memory_pool.py +68 -37
- sglang/srt/mem_cache/radix_cache.py +58 -47
- sglang/srt/metrics/collector.py +102 -36
- sglang/srt/model_executor/cuda_graph_runner.py +56 -31
- sglang/srt/model_executor/forward_batch_info.py +49 -16
- sglang/srt/model_executor/model_runner.py +280 -81
- sglang/srt/model_loader/loader.py +3 -3
- sglang/srt/model_loader/weight_utils.py +36 -14
- sglang/srt/models/baichuan.py +31 -6
- sglang/srt/models/chatglm.py +39 -7
- sglang/srt/models/commandr.py +29 -5
- sglang/srt/models/dbrx.py +31 -5
- sglang/srt/models/deepseek.py +43 -6
- sglang/srt/models/deepseek_nextn.py +32 -19
- sglang/srt/models/deepseek_v2.py +265 -32
- sglang/srt/models/exaone.py +19 -9
- sglang/srt/models/gemma.py +22 -8
- sglang/srt/models/gemma2.py +25 -12
- sglang/srt/models/gemma2_reward.py +5 -1
- sglang/srt/models/gpt2.py +28 -13
- sglang/srt/models/gpt_bigcode.py +27 -5
- sglang/srt/models/granite.py +21 -9
- sglang/srt/models/grok.py +21 -4
- sglang/srt/models/internlm2.py +36 -6
- sglang/srt/models/internlm2_reward.py +5 -1
- sglang/srt/models/llama.py +26 -9
- sglang/srt/models/llama_classification.py +5 -1
- sglang/srt/models/llama_eagle.py +17 -4
- sglang/srt/models/llama_embedding.py +5 -1
- sglang/srt/models/llama_reward.py +7 -2
- sglang/srt/models/llava.py +19 -3
- sglang/srt/models/llavavid.py +10 -1
- sglang/srt/models/minicpm.py +26 -2
- sglang/srt/models/minicpm3.py +39 -3
- sglang/srt/models/minicpmv.py +45 -14
- sglang/srt/models/mixtral.py +20 -9
- sglang/srt/models/mixtral_quant.py +50 -8
- sglang/srt/models/mllama.py +57 -11
- sglang/srt/models/olmo.py +34 -6
- sglang/srt/models/olmo2.py +34 -13
- sglang/srt/models/olmoe.py +26 -4
- sglang/srt/models/phi3_small.py +29 -10
- sglang/srt/models/qwen.py +26 -3
- sglang/srt/models/qwen2.py +26 -4
- sglang/srt/models/qwen2_5_vl.py +46 -8
- sglang/srt/models/qwen2_eagle.py +17 -5
- sglang/srt/models/qwen2_moe.py +44 -6
- sglang/srt/models/qwen2_rm.py +78 -0
- sglang/srt/models/qwen2_vl.py +39 -8
- sglang/srt/models/stablelm.py +32 -5
- sglang/srt/models/torch_native_llama.py +5 -2
- sglang/srt/models/xverse.py +21 -9
- sglang/srt/models/xverse_moe.py +45 -7
- sglang/srt/models/yivl.py +2 -1
- sglang/srt/openai_api/adapter.py +109 -24
- sglang/srt/openai_api/protocol.py +17 -1
- sglang/srt/reasoning_parser.py +154 -0
- sglang/srt/sampling/penaltylib/__init__.py +4 -6
- sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
- sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
- sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
- sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
- sglang/srt/sampling/sampling_batch_info.py +79 -157
- sglang/srt/sampling/sampling_params.py +16 -13
- sglang/srt/server_args.py +135 -60
- sglang/srt/speculative/build_eagle_tree.py +8 -9
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -12
- sglang/srt/speculative/eagle_utils.py +92 -57
- sglang/srt/speculative/eagle_worker.py +238 -111
- sglang/srt/speculative/spec_info.py +1 -13
- sglang/srt/utils.py +43 -17
- sglang/srt/warmup.py +47 -0
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/runners.py +389 -126
- sglang/test/send_one.py +88 -0
- sglang/test/test_block_fp8_ep.py +361 -0
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +138 -84
- sglang/utils.py +50 -60
- sglang/version.py +1 -1
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/METADATA +22 -15
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/RECORD +200 -166
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/WHEEL +1 -1
- sglang/bench_latency.py +0 -1
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
- sglang/test/srt/sampling/penaltylib/utils.py +0 -344
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/top_level.txt +0 -0
sglang/srt/models/phi3_small.py
CHANGED
@@ -24,7 +24,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
24
24
|
)
|
25
25
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
26
26
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
27
|
-
from sglang.srt.utils import make_layers
|
27
|
+
from sglang.srt.utils import add_prefix, make_layers
|
28
28
|
|
29
29
|
|
30
30
|
@torch.jit.script
|
@@ -70,13 +70,14 @@ class Phi3SmallMLP(nn.Module):
|
|
70
70
|
2 * [self.intermediate_size],
|
71
71
|
bias=True,
|
72
72
|
quant_config=quant_config,
|
73
|
-
prefix=
|
73
|
+
prefix=add_prefix("up_proj", prefix),
|
74
74
|
)
|
75
75
|
self.down_proj = RowParallelLinear(
|
76
76
|
self.intermediate_size,
|
77
77
|
self.hidden_size,
|
78
78
|
bias=True,
|
79
79
|
quant_config=quant_config,
|
80
|
+
prefix=add_prefix("down_proj", prefix),
|
80
81
|
)
|
81
82
|
|
82
83
|
def forward(self, x):
|
@@ -140,7 +141,7 @@ class Phi3SmallSelfAttention(nn.Module):
|
|
140
141
|
self.num_key_value_heads,
|
141
142
|
bias=True,
|
142
143
|
quant_config=quant_config,
|
143
|
-
prefix=
|
144
|
+
prefix=add_prefix("qkv_proj", prefix),
|
144
145
|
)
|
145
146
|
|
146
147
|
self.dense = RowParallelLinear(
|
@@ -148,7 +149,7 @@ class Phi3SmallSelfAttention(nn.Module):
|
|
148
149
|
self.hidden_size,
|
149
150
|
bias=True,
|
150
151
|
quant_config=quant_config,
|
151
|
-
prefix=
|
152
|
+
prefix=add_prefix("o_proj", prefix),
|
152
153
|
)
|
153
154
|
|
154
155
|
if getattr(self.config, "rope_scaling", None) is not None:
|
@@ -201,6 +202,7 @@ class Phi3SmallSelfAttention(nn.Module):
|
|
201
202
|
self.scale,
|
202
203
|
num_kv_heads=self.num_kv_heads_per_partion,
|
203
204
|
layer_id=layer_id,
|
205
|
+
prefix=add_prefix("attn", prefix),
|
204
206
|
)
|
205
207
|
|
206
208
|
def forward(
|
@@ -234,13 +236,21 @@ class Phi3SmallDecoderLayer(nn.Module):
|
|
234
236
|
config: PretrainedConfig,
|
235
237
|
layer_id: int,
|
236
238
|
quant_config: Optional[QuantizationConfig] = None,
|
239
|
+
prefix: str = "",
|
237
240
|
):
|
238
241
|
super().__init__()
|
239
242
|
self.hidden_size = config.hidden_size
|
240
243
|
self.self_attn = Phi3SmallSelfAttention(
|
241
|
-
config,
|
244
|
+
config,
|
245
|
+
layer_id,
|
246
|
+
quant_config=quant_config,
|
247
|
+
prefix=add_prefix("self_attn", prefix),
|
248
|
+
)
|
249
|
+
self.mlp = Phi3SmallMLP(
|
250
|
+
config,
|
251
|
+
quant_config,
|
252
|
+
prefix=add_prefix("mlp", prefix),
|
242
253
|
)
|
243
|
-
self.mlp = Phi3SmallMLP(config, quant_config)
|
244
254
|
|
245
255
|
self.input_layernorm = nn.LayerNorm(
|
246
256
|
config.hidden_size, eps=config.layer_norm_epsilon
|
@@ -284,15 +294,20 @@ class Phi3SmallModel(nn.Module):
|
|
284
294
|
|
285
295
|
self.config = config
|
286
296
|
self.embed_tokens = VocabParallelEmbedding(
|
287
|
-
config.vocab_size,
|
297
|
+
config.vocab_size,
|
298
|
+
config.hidden_size,
|
299
|
+
prefix=add_prefix("embed_tokens", prefix),
|
288
300
|
)
|
289
301
|
self.mup_embedding_multiplier = config.mup_embedding_multiplier
|
290
302
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
291
303
|
config.num_hidden_layers,
|
292
304
|
lambda prefix: Phi3SmallDecoderLayer(
|
293
|
-
config,
|
305
|
+
config,
|
306
|
+
int(prefix.split(".")[-1]),
|
307
|
+
quant_config,
|
308
|
+
prefix=prefix,
|
294
309
|
),
|
295
|
-
prefix=
|
310
|
+
prefix=add_prefix("layers", prefix),
|
296
311
|
)
|
297
312
|
|
298
313
|
self.final_layernorm = nn.LayerNorm(
|
@@ -335,6 +350,7 @@ class Phi3SmallForCausalLM(nn.Module):
|
|
335
350
|
self,
|
336
351
|
config: Phi3Config,
|
337
352
|
quant_config: Optional[QuantizationConfig] = None,
|
353
|
+
prefix: str = "",
|
338
354
|
):
|
339
355
|
|
340
356
|
super().__init__()
|
@@ -344,7 +360,7 @@ class Phi3SmallForCausalLM(nn.Module):
|
|
344
360
|
self.model = Phi3SmallModel(
|
345
361
|
config=config,
|
346
362
|
quant_config=quant_config,
|
347
|
-
prefix="model",
|
363
|
+
prefix=add_prefix("model", prefix),
|
348
364
|
)
|
349
365
|
self.vocab_size = config.vocab_size
|
350
366
|
self.mup_width_multiplier = config.mup_width_multiplier
|
@@ -354,6 +370,7 @@ class Phi3SmallForCausalLM(nn.Module):
|
|
354
370
|
org_num_embeddings=config.vocab_size,
|
355
371
|
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
356
372
|
quant_config=quant_config,
|
373
|
+
prefix=add_prefix("lm_head", prefix),
|
357
374
|
)
|
358
375
|
if self.config.tie_word_embeddings:
|
359
376
|
self.lm_head.weight = self.model.embed_tokens.weight
|
@@ -433,6 +450,8 @@ class Phi3SmallForCausalLM(nn.Module):
|
|
433
450
|
continue
|
434
451
|
if name.endswith(".bias") and name not in params_dict:
|
435
452
|
continue
|
453
|
+
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
454
|
+
continue
|
436
455
|
|
437
456
|
param = params_dict[name]
|
438
457
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
sglang/srt/models/qwen.py
CHANGED
@@ -39,6 +39,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
39
39
|
)
|
40
40
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
41
41
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
42
|
+
from sglang.srt.utils import add_prefix
|
42
43
|
|
43
44
|
|
44
45
|
class QWenMLP(nn.Module):
|
@@ -48,6 +49,7 @@ class QWenMLP(nn.Module):
|
|
48
49
|
intermediate_size: int,
|
49
50
|
hidden_act: str = "silu",
|
50
51
|
quant_config: Optional[QuantizationConfig] = None,
|
52
|
+
prefix: str = "",
|
51
53
|
):
|
52
54
|
super().__init__()
|
53
55
|
self.gate_up_proj = MergedColumnParallelLinear(
|
@@ -56,6 +58,7 @@ class QWenMLP(nn.Module):
|
|
56
58
|
bias=False,
|
57
59
|
gather_output=False,
|
58
60
|
quant_config=quant_config,
|
61
|
+
prefix=add_prefix("gate_up_proj", prefix),
|
59
62
|
)
|
60
63
|
self.c_proj = RowParallelLinear(
|
61
64
|
intermediate_size,
|
@@ -63,6 +66,7 @@ class QWenMLP(nn.Module):
|
|
63
66
|
bias=False,
|
64
67
|
input_is_parallel=True,
|
65
68
|
quant_config=quant_config,
|
69
|
+
prefix=add_prefix("c_proj", prefix),
|
66
70
|
)
|
67
71
|
if hidden_act != "silu":
|
68
72
|
raise ValueError(
|
@@ -88,6 +92,7 @@ class QWenAttention(nn.Module):
|
|
88
92
|
rope_theta: float = 10000,
|
89
93
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
90
94
|
quant_config: Optional[QuantizationConfig] = None,
|
95
|
+
prefix: str = "",
|
91
96
|
):
|
92
97
|
super().__init__()
|
93
98
|
self.hidden_size = hidden_size
|
@@ -104,6 +109,7 @@ class QWenAttention(nn.Module):
|
|
104
109
|
self.total_num_heads,
|
105
110
|
bias=True,
|
106
111
|
quant_config=quant_config,
|
112
|
+
prefix=add_prefix("c_attn", prefix),
|
107
113
|
)
|
108
114
|
self.c_proj = RowParallelLinear(
|
109
115
|
self.total_num_heads * self.head_dim,
|
@@ -111,6 +117,7 @@ class QWenAttention(nn.Module):
|
|
111
117
|
bias=False,
|
112
118
|
input_is_parallel=True,
|
113
119
|
quant_config=quant_config,
|
120
|
+
prefix=add_prefix("c_proj", prefix),
|
114
121
|
)
|
115
122
|
self.rotary_emb = get_rope(
|
116
123
|
self.head_dim,
|
@@ -126,6 +133,7 @@ class QWenAttention(nn.Module):
|
|
126
133
|
self.scaling,
|
127
134
|
num_kv_heads=self.num_heads,
|
128
135
|
layer_id=layer_id,
|
136
|
+
prefix=add_prefix("attn", prefix),
|
129
137
|
)
|
130
138
|
|
131
139
|
def forward(
|
@@ -148,6 +156,7 @@ class QWenBlock(nn.Module):
|
|
148
156
|
config: PretrainedConfig,
|
149
157
|
layer_id,
|
150
158
|
quant_config: Optional[QuantizationConfig] = None,
|
159
|
+
prefix: str = "",
|
151
160
|
):
|
152
161
|
super().__init__()
|
153
162
|
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
@@ -162,6 +171,7 @@ class QWenBlock(nn.Module):
|
|
162
171
|
rope_scaling=rope_scaling,
|
163
172
|
layer_id=layer_id,
|
164
173
|
quant_config=quant_config,
|
174
|
+
prefix=add_prefix("attn", prefix),
|
165
175
|
)
|
166
176
|
|
167
177
|
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
@@ -170,6 +180,7 @@ class QWenBlock(nn.Module):
|
|
170
180
|
config.hidden_size,
|
171
181
|
config.intermediate_size // 2,
|
172
182
|
quant_config=quant_config,
|
183
|
+
prefix=add_prefix("mlp", prefix),
|
173
184
|
)
|
174
185
|
|
175
186
|
def forward(
|
@@ -201,6 +212,7 @@ class QWenModel(nn.Module):
|
|
201
212
|
self,
|
202
213
|
config: PretrainedConfig,
|
203
214
|
quant_config: Optional[QuantizationConfig] = None,
|
215
|
+
prefix: str = "",
|
204
216
|
):
|
205
217
|
super().__init__()
|
206
218
|
self.config = config
|
@@ -210,10 +222,16 @@ class QWenModel(nn.Module):
|
|
210
222
|
self.wte = VocabParallelEmbedding(
|
211
223
|
vocab_size,
|
212
224
|
config.hidden_size,
|
225
|
+
prefix=add_prefix("wte", prefix),
|
213
226
|
)
|
214
227
|
self.h = nn.ModuleList(
|
215
228
|
[
|
216
|
-
QWenBlock(
|
229
|
+
QWenBlock(
|
230
|
+
config,
|
231
|
+
i,
|
232
|
+
quant_config=quant_config,
|
233
|
+
prefix=add_prefix(f"h.{i}", prefix),
|
234
|
+
)
|
217
235
|
for i in range(config.num_hidden_layers)
|
218
236
|
]
|
219
237
|
)
|
@@ -242,12 +260,17 @@ class QWenLMHeadModel(nn.Module):
|
|
242
260
|
self,
|
243
261
|
config: PretrainedConfig,
|
244
262
|
quant_config: Optional[QuantizationConfig] = None,
|
263
|
+
prefix: str = "",
|
245
264
|
):
|
246
265
|
super().__init__()
|
247
266
|
self.config = config
|
248
|
-
self.transformer = QWenModel(
|
267
|
+
self.transformer = QWenModel(
|
268
|
+
config, quant_config=quant_config, prefix=add_prefix("transformer", prefix)
|
269
|
+
)
|
249
270
|
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
250
|
-
self.lm_head = ParallelLMHead(
|
271
|
+
self.lm_head = ParallelLMHead(
|
272
|
+
vocab_size, config.hidden_size, prefix=add_prefix("lm_head", prefix)
|
273
|
+
)
|
251
274
|
self.logits_processor = LogitsProcessor(config)
|
252
275
|
|
253
276
|
@torch.no_grad()
|
sglang/srt/models/qwen2.py
CHANGED
@@ -15,7 +15,7 @@
|
|
15
15
|
# Adapted from llama2.py
|
16
16
|
# Modify details for the adaptation of Qwen2 model.
|
17
17
|
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
|
18
|
-
|
18
|
+
from readline import add_history
|
19
19
|
from typing import Any, Dict, Iterable, Optional, Tuple
|
20
20
|
|
21
21
|
import torch
|
@@ -46,7 +46,7 @@ from sglang.srt.model_loader.weight_utils import (
|
|
46
46
|
default_weight_loader,
|
47
47
|
kv_cache_scales_loader,
|
48
48
|
)
|
49
|
-
from sglang.srt.utils import make_layers
|
49
|
+
from sglang.srt.utils import add_prefix, make_layers
|
50
50
|
|
51
51
|
Qwen2Config = None
|
52
52
|
|
@@ -58,6 +58,7 @@ class Qwen2MLP(nn.Module):
|
|
58
58
|
intermediate_size: int,
|
59
59
|
hidden_act: str,
|
60
60
|
quant_config: Optional[QuantizationConfig] = None,
|
61
|
+
prefix: str = "",
|
61
62
|
) -> None:
|
62
63
|
super().__init__()
|
63
64
|
self.gate_up_proj = MergedColumnParallelLinear(
|
@@ -65,12 +66,14 @@ class Qwen2MLP(nn.Module):
|
|
65
66
|
[intermediate_size] * 2,
|
66
67
|
bias=False,
|
67
68
|
quant_config=quant_config,
|
69
|
+
prefix=add_prefix("gate_up_proj", prefix),
|
68
70
|
)
|
69
71
|
self.down_proj = RowParallelLinear(
|
70
72
|
intermediate_size,
|
71
73
|
hidden_size,
|
72
74
|
bias=False,
|
73
75
|
quant_config=quant_config,
|
76
|
+
prefix=add_prefix("down_proj", prefix),
|
74
77
|
)
|
75
78
|
if hidden_act != "silu":
|
76
79
|
raise ValueError(
|
@@ -97,6 +100,7 @@ class Qwen2Attention(nn.Module):
|
|
97
100
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
98
101
|
max_position_embeddings: int = 32768,
|
99
102
|
quant_config: Optional[QuantizationConfig] = None,
|
103
|
+
prefix: str = "",
|
100
104
|
) -> None:
|
101
105
|
super().__init__()
|
102
106
|
self.hidden_size = hidden_size
|
@@ -128,12 +132,14 @@ class Qwen2Attention(nn.Module):
|
|
128
132
|
self.total_num_kv_heads,
|
129
133
|
bias=True,
|
130
134
|
quant_config=quant_config,
|
135
|
+
prefix=add_prefix("qkv_proj", prefix),
|
131
136
|
)
|
132
137
|
self.o_proj = RowParallelLinear(
|
133
138
|
self.total_num_heads * self.head_dim,
|
134
139
|
hidden_size,
|
135
140
|
bias=False,
|
136
141
|
quant_config=quant_config,
|
142
|
+
prefix=add_prefix("o_proj", prefix),
|
137
143
|
)
|
138
144
|
|
139
145
|
self.rotary_emb = get_rope(
|
@@ -149,6 +155,7 @@ class Qwen2Attention(nn.Module):
|
|
149
155
|
self.scaling,
|
150
156
|
num_kv_heads=self.num_kv_heads,
|
151
157
|
layer_id=layer_id,
|
158
|
+
prefix=add_prefix("attn", prefix),
|
152
159
|
)
|
153
160
|
|
154
161
|
def forward(
|
@@ -171,6 +178,7 @@ class Qwen2DecoderLayer(nn.Module):
|
|
171
178
|
config: Qwen2Config,
|
172
179
|
layer_id: int = 0,
|
173
180
|
quant_config: Optional[QuantizationConfig] = None,
|
181
|
+
prefix: str = "",
|
174
182
|
) -> None:
|
175
183
|
super().__init__()
|
176
184
|
self.hidden_size = config.hidden_size
|
@@ -186,12 +194,14 @@ class Qwen2DecoderLayer(nn.Module):
|
|
186
194
|
rope_scaling=rope_scaling,
|
187
195
|
max_position_embeddings=max_position_embeddings,
|
188
196
|
quant_config=quant_config,
|
197
|
+
prefix=add_prefix("self_attn", prefix),
|
189
198
|
)
|
190
199
|
self.mlp = Qwen2MLP(
|
191
200
|
hidden_size=self.hidden_size,
|
192
201
|
intermediate_size=config.intermediate_size,
|
193
202
|
hidden_act=config.hidden_act,
|
194
203
|
quant_config=quant_config,
|
204
|
+
prefix=add_prefix("mlp", prefix),
|
195
205
|
)
|
196
206
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
197
207
|
self.post_attention_layernorm = RMSNorm(
|
@@ -228,6 +238,7 @@ class Qwen2Model(nn.Module):
|
|
228
238
|
self,
|
229
239
|
config: Qwen2Config,
|
230
240
|
quant_config: Optional[QuantizationConfig] = None,
|
241
|
+
prefix: str = "",
|
231
242
|
) -> None:
|
232
243
|
super().__init__()
|
233
244
|
self.config = config
|
@@ -237,6 +248,7 @@ class Qwen2Model(nn.Module):
|
|
237
248
|
config.vocab_size,
|
238
249
|
config.hidden_size,
|
239
250
|
quant_config=quant_config,
|
251
|
+
prefix=add_prefix("embed_tokens", prefix),
|
240
252
|
)
|
241
253
|
self.layers = make_layers(
|
242
254
|
config.num_hidden_layers,
|
@@ -244,7 +256,9 @@ class Qwen2Model(nn.Module):
|
|
244
256
|
layer_id=idx,
|
245
257
|
config=config,
|
246
258
|
quant_config=quant_config,
|
259
|
+
prefix=prefix,
|
247
260
|
),
|
261
|
+
prefix=add_prefix("layers", prefix),
|
248
262
|
)
|
249
263
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
250
264
|
|
@@ -325,16 +339,22 @@ class Qwen2ForCausalLM(nn.Module):
|
|
325
339
|
self,
|
326
340
|
config: Qwen2Config,
|
327
341
|
quant_config: Optional[QuantizationConfig] = None,
|
342
|
+
prefix: str = "",
|
328
343
|
) -> None:
|
329
344
|
super().__init__()
|
330
345
|
self.config = config
|
331
346
|
self.quant_config = quant_config
|
332
|
-
self.model = Qwen2Model(
|
347
|
+
self.model = Qwen2Model(
|
348
|
+
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
349
|
+
)
|
333
350
|
if config.tie_word_embeddings:
|
334
351
|
self.lm_head = self.model.embed_tokens
|
335
352
|
else:
|
336
353
|
self.lm_head = ParallelLMHead(
|
337
|
-
config.vocab_size,
|
354
|
+
config.vocab_size,
|
355
|
+
config.hidden_size,
|
356
|
+
quant_config=quant_config,
|
357
|
+
prefix=add_prefix("lm_head", prefix),
|
338
358
|
)
|
339
359
|
self.logits_processor = LogitsProcessor(config)
|
340
360
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
@@ -377,6 +397,8 @@ class Qwen2ForCausalLM(nn.Module):
|
|
377
397
|
# Models trained using ColossalAI may include these tensors in
|
378
398
|
# the checkpoint. Skip them.
|
379
399
|
continue
|
400
|
+
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
401
|
+
continue
|
380
402
|
if name.startswith("model.vision_tower") and name not in params_dict:
|
381
403
|
continue
|
382
404
|
|
sglang/srt/models/qwen2_5_vl.py
CHANGED
@@ -52,6 +52,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
52
52
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
53
53
|
from sglang.srt.models.qwen2 import Qwen2Model
|
54
54
|
from sglang.srt.models.qwen2_vl import Qwen2VLImageInputs, Qwen2VLVideoInputs
|
55
|
+
from sglang.srt.utils import add_prefix
|
55
56
|
|
56
57
|
logger = logging.getLogger(__name__)
|
57
58
|
|
@@ -65,16 +66,29 @@ class Qwen2_5_VLMLP(nn.Module):
|
|
65
66
|
bias: bool = True,
|
66
67
|
hidden_act="silu",
|
67
68
|
quant_config: Optional[QuantizationConfig] = None,
|
69
|
+
prefix: str = "",
|
68
70
|
):
|
69
71
|
super().__init__()
|
70
72
|
self.gate_proj = ColumnParallelLinear(
|
71
|
-
in_features,
|
73
|
+
in_features,
|
74
|
+
hidden_features,
|
75
|
+
bias=bias,
|
76
|
+
quant_config=quant_config,
|
77
|
+
prefix=add_prefix("gate_proj", prefix),
|
72
78
|
)
|
73
79
|
self.up_proj = ColumnParallelLinear(
|
74
|
-
in_features,
|
80
|
+
in_features,
|
81
|
+
hidden_features,
|
82
|
+
bias=bias,
|
83
|
+
quant_config=quant_config,
|
84
|
+
prefix=add_prefix("up_proj", prefix),
|
75
85
|
)
|
76
86
|
self.down_proj = RowParallelLinear(
|
77
|
-
hidden_features,
|
87
|
+
hidden_features,
|
88
|
+
in_features,
|
89
|
+
bias=bias,
|
90
|
+
quant_config=quant_config,
|
91
|
+
prefix=add_prefix("down_proj", prefix),
|
78
92
|
)
|
79
93
|
self.act = ACT2FN[hidden_act]
|
80
94
|
|
@@ -98,6 +112,7 @@ class Qwen2_5_VisionBlock(nn.Module):
|
|
98
112
|
norm_layer: Type[nn.Module] = None,
|
99
113
|
attn_implementation: Optional[str] = "sdpa",
|
100
114
|
quant_config: Optional[QuantizationConfig] = None,
|
115
|
+
prefix: str = "",
|
101
116
|
) -> None:
|
102
117
|
super().__init__()
|
103
118
|
if norm_layer is None:
|
@@ -123,9 +138,14 @@ class Qwen2_5_VisionBlock(nn.Module):
|
|
123
138
|
use_full_precision_softmax=use_full_precision_softmax,
|
124
139
|
flatten_batch=True,
|
125
140
|
quant_config=quant_config,
|
141
|
+
prefix=add_prefix("attn", prefix),
|
126
142
|
)
|
127
143
|
self.mlp = Qwen2_5_VLMLP(
|
128
|
-
dim,
|
144
|
+
dim,
|
145
|
+
intermediate_dim,
|
146
|
+
hidden_act=hidden_act,
|
147
|
+
quant_config=quant_config,
|
148
|
+
prefix=add_prefix("mlp", prefix),
|
129
149
|
)
|
130
150
|
|
131
151
|
def forward(
|
@@ -178,6 +198,7 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
|
|
178
198
|
context_dim: int,
|
179
199
|
spatial_merge_size: int = 2,
|
180
200
|
quant_config: Optional[QuantizationConfig] = None,
|
201
|
+
prefix: str = "",
|
181
202
|
) -> None:
|
182
203
|
super().__init__()
|
183
204
|
self.hidden_size = context_dim * (spatial_merge_size**2)
|
@@ -189,10 +210,15 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
|
|
189
210
|
self.hidden_size,
|
190
211
|
bias=True,
|
191
212
|
quant_config=quant_config,
|
213
|
+
prefix=add_prefix("mlp.0", prefix),
|
192
214
|
),
|
193
215
|
nn.GELU(),
|
194
216
|
RowParallelLinear(
|
195
|
-
self.hidden_size,
|
217
|
+
self.hidden_size,
|
218
|
+
dim,
|
219
|
+
bias=True,
|
220
|
+
quant_config=quant_config,
|
221
|
+
prefix=add_prefix("mlp.2", prefix),
|
196
222
|
),
|
197
223
|
]
|
198
224
|
)
|
@@ -250,6 +276,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|
250
276
|
vision_config: Qwen2_5_VLVisionConfig,
|
251
277
|
norm_eps: float = 1e-6,
|
252
278
|
quant_config: Optional[QuantizationConfig] = None,
|
279
|
+
prefix: str = "",
|
253
280
|
) -> None:
|
254
281
|
super().__init__()
|
255
282
|
|
@@ -286,8 +313,9 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|
286
313
|
norm_layer=norm_layer,
|
287
314
|
attn_implementation="sdpa",
|
288
315
|
quant_config=quant_config,
|
316
|
+
prefix=add_prefix(f"blocks.{i}", prefix),
|
289
317
|
)
|
290
|
-
for
|
318
|
+
for i in range(depth)
|
291
319
|
]
|
292
320
|
)
|
293
321
|
self.merger = Qwen2_5_VisionPatchMerger(
|
@@ -295,6 +323,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|
295
323
|
context_dim=hidden_size,
|
296
324
|
spatial_merge_size=spatial_merge_size,
|
297
325
|
quant_config=quant_config,
|
326
|
+
prefix=add_prefix("merger", prefix),
|
298
327
|
)
|
299
328
|
|
300
329
|
def get_window_index(self, grid_thw):
|
@@ -447,6 +476,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|
447
476
|
self,
|
448
477
|
config: Qwen2VLConfig,
|
449
478
|
quant_config: Optional[QuantizationConfig] = None,
|
479
|
+
prefix: str = "",
|
450
480
|
) -> None:
|
451
481
|
super().__init__()
|
452
482
|
|
@@ -457,15 +487,23 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|
457
487
|
# NOTE: Qwen2-VL vision encoder does not support any
|
458
488
|
# quantization method now.
|
459
489
|
quant_config=None,
|
490
|
+
prefix=add_prefix("visual", prefix),
|
460
491
|
)
|
461
492
|
|
462
|
-
self.model = Qwen2Model(
|
493
|
+
self.model = Qwen2Model(
|
494
|
+
config,
|
495
|
+
quant_config,
|
496
|
+
prefix=add_prefix("model", prefix),
|
497
|
+
)
|
463
498
|
|
464
499
|
if config.tie_word_embeddings:
|
465
500
|
self.lm_head = self.model.embed_tokens
|
466
501
|
else:
|
467
502
|
self.lm_head = ParallelLMHead(
|
468
|
-
config.vocab_size,
|
503
|
+
config.vocab_size,
|
504
|
+
config.hidden_size,
|
505
|
+
quant_config=quant_config,
|
506
|
+
prefix=add_prefix("lm_head", prefix),
|
469
507
|
)
|
470
508
|
|
471
509
|
self.logits_processor = LogitsProcessor(config)
|
sglang/srt/models/qwen2_eagle.py
CHANGED
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
|
13
13
|
limitations under the License.
|
14
14
|
"""
|
15
15
|
|
16
|
+
from sglang.srt.utils import add_prefix
|
17
|
+
|
16
18
|
# Adapted from
|
17
19
|
# https://github.com/SafeAILab/EAGLE/blob/main/eagle/model/cnets.py
|
18
20
|
"""Inference-only LLaMA-EAGLE model compatible with HuggingFace weights."""
|
@@ -42,7 +44,7 @@ class Qwen2DecoderLayer(Qwen2DecoderLayer):
|
|
42
44
|
quant_config: Optional[QuantizationConfig] = None,
|
43
45
|
prefix: str = "",
|
44
46
|
) -> None:
|
45
|
-
super().__init__(config, layer_id, quant_config)
|
47
|
+
super().__init__(config, layer_id, quant_config, prefix=prefix)
|
46
48
|
|
47
49
|
# Skip the input_layernorm
|
48
50
|
# https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427
|
@@ -56,6 +58,7 @@ class Qwen2Model(nn.Module):
|
|
56
58
|
self,
|
57
59
|
config: Qwen2Config,
|
58
60
|
quant_config: Optional[QuantizationConfig] = None,
|
61
|
+
prefix: str = "",
|
59
62
|
) -> None:
|
60
63
|
super().__init__()
|
61
64
|
self.config = config
|
@@ -63,11 +66,15 @@ class Qwen2Model(nn.Module):
|
|
63
66
|
self.embed_tokens = VocabParallelEmbedding(
|
64
67
|
config.vocab_size,
|
65
68
|
config.hidden_size,
|
69
|
+
prefix=add_prefix("embed_tokens", prefix),
|
66
70
|
)
|
67
71
|
self.layers = nn.ModuleList(
|
68
72
|
[
|
69
73
|
Qwen2DecoderLayer(
|
70
|
-
config,
|
74
|
+
config,
|
75
|
+
i,
|
76
|
+
quant_config=quant_config,
|
77
|
+
prefix=add_prefix(f"layers.{i}", prefix),
|
71
78
|
)
|
72
79
|
for i in range(config.num_hidden_layers)
|
73
80
|
]
|
@@ -107,17 +114,22 @@ class Qwen2ForCausalLMEagle(Qwen2ForCausalLM):
|
|
107
114
|
self,
|
108
115
|
config: Qwen2Config,
|
109
116
|
quant_config: Optional[QuantizationConfig] = None,
|
110
|
-
|
117
|
+
prefix: str = "",
|
111
118
|
) -> None:
|
112
119
|
nn.Module.__init__(self)
|
113
120
|
self.config = config
|
114
121
|
self.quant_config = quant_config
|
115
|
-
self.model = Qwen2Model(
|
122
|
+
self.model = Qwen2Model(
|
123
|
+
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
124
|
+
)
|
116
125
|
if self.config.tie_word_embeddings:
|
117
126
|
self.lm_head = self.model.embed_tokens
|
118
127
|
else:
|
119
128
|
self.lm_head = ParallelLMHead(
|
120
|
-
config.vocab_size,
|
129
|
+
config.vocab_size,
|
130
|
+
config.hidden_size,
|
131
|
+
quant_config=quant_config,
|
132
|
+
prefix=add_prefix("lm_head", prefix),
|
121
133
|
)
|
122
134
|
self.logits_processor = LogitsProcessor(config)
|
123
135
|
|