sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_offline_throughput.py +19 -0
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +123 -79
- sglang/global_config.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/ir.py +1 -1
- sglang/srt/_custom_ops.py +83 -91
- sglang/srt/configs/load_config.py +4 -1
- sglang/srt/configs/model_config.py +48 -2
- sglang/srt/configs/qwen2_5_vl_config.py +5 -2
- sglang/srt/constrained/base_grammar_backend.py +117 -15
- sglang/srt/constrained/llguidance_backend.py +151 -0
- sglang/srt/constrained/outlines_backend.py +24 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -38
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
- sglang/srt/distributed/parallel_state.py +48 -3
- sglang/srt/entrypoints/engine.py +67 -9
- sglang/srt/entrypoints/http_server.py +190 -41
- sglang/srt/entrypoints/verl_engine.py +147 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/activation.py +11 -0
- sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +220 -378
- sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
- sglang/srt/layers/attention/torch_native_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +9 -6
- sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
- sglang/srt/layers/attention/utils.py +39 -0
- sglang/srt/layers/attention/vision.py +60 -63
- sglang/srt/layers/dp_attention.py +142 -1
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +3 -1
- sglang/srt/layers/logits_processor.py +281 -45
- sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
- sglang/srt/layers/moe/ep_moe/layer.py +140 -28
- sglang/srt/layers/moe/fused_moe_native.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
- sglang/srt/layers/moe/topk.py +13 -4
- sglang/srt/layers/quantization/__init__.py +111 -7
- sglang/srt/layers/quantization/blockwise_int8.py +409 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +69 -28
- sglang/srt/layers/quantization/fp8_utils.py +17 -1
- sglang/srt/layers/quantization/gptq.py +416 -0
- sglang/srt/layers/quantization/int8_kernel.py +327 -0
- sglang/srt/layers/quantization/int8_utils.py +73 -0
- sglang/srt/layers/quantization/modelopt_quant.py +18 -1
- sglang/srt/layers/radix_attention.py +1 -0
- sglang/srt/layers/rotary_embedding.py +0 -1
- sglang/srt/layers/sampler.py +76 -31
- sglang/srt/layers/vocab_parallel_embedding.py +14 -13
- sglang/srt/lora/lora.py +17 -1
- sglang/srt/lora/lora_config.py +5 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/cache_controller.py +193 -62
- sglang/srt/managers/configure_logging.py +2 -1
- sglang/srt/managers/data_parallel_controller.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +124 -102
- sglang/srt/managers/image_processor.py +2 -1
- sglang/srt/managers/io_struct.py +143 -6
- sglang/srt/managers/schedule_batch.py +237 -197
- sglang/srt/managers/schedule_policy.py +29 -29
- sglang/srt/managers/scheduler.py +681 -259
- sglang/srt/managers/session_controller.py +6 -2
- sglang/srt/managers/tokenizer_manager.py +224 -68
- sglang/srt/managers/tp_worker.py +15 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/chunk_cache.py +18 -11
- sglang/srt/mem_cache/hiradix_cache.py +394 -0
- sglang/srt/mem_cache/memory_pool.py +44 -18
- sglang/srt/mem_cache/radix_cache.py +58 -47
- sglang/srt/metrics/collector.py +94 -36
- sglang/srt/model_executor/cuda_graph_runner.py +55 -24
- sglang/srt/model_executor/forward_batch_info.py +49 -16
- sglang/srt/model_executor/model_runner.py +208 -28
- sglang/srt/model_loader/loader.py +3 -3
- sglang/srt/model_loader/weight_utils.py +36 -14
- sglang/srt/models/baichuan.py +31 -6
- sglang/srt/models/chatglm.py +39 -7
- sglang/srt/models/commandr.py +29 -5
- sglang/srt/models/dbrx.py +31 -5
- sglang/srt/models/deepseek.py +43 -6
- sglang/srt/models/deepseek_nextn.py +32 -19
- sglang/srt/models/deepseek_v2.py +265 -32
- sglang/srt/models/exaone.py +19 -9
- sglang/srt/models/gemma.py +22 -8
- sglang/srt/models/gemma2.py +25 -12
- sglang/srt/models/gemma2_reward.py +5 -1
- sglang/srt/models/gpt2.py +28 -13
- sglang/srt/models/gpt_bigcode.py +27 -5
- sglang/srt/models/granite.py +21 -9
- sglang/srt/models/grok.py +21 -4
- sglang/srt/models/internlm2.py +36 -6
- sglang/srt/models/internlm2_reward.py +5 -1
- sglang/srt/models/llama.py +26 -9
- sglang/srt/models/llama_classification.py +5 -1
- sglang/srt/models/llama_eagle.py +17 -4
- sglang/srt/models/llama_embedding.py +5 -1
- sglang/srt/models/llama_reward.py +7 -2
- sglang/srt/models/llava.py +19 -3
- sglang/srt/models/llavavid.py +10 -1
- sglang/srt/models/minicpm.py +26 -2
- sglang/srt/models/minicpm3.py +39 -3
- sglang/srt/models/minicpmv.py +45 -14
- sglang/srt/models/mixtral.py +20 -9
- sglang/srt/models/mixtral_quant.py +50 -8
- sglang/srt/models/mllama.py +57 -11
- sglang/srt/models/olmo.py +34 -6
- sglang/srt/models/olmo2.py +34 -13
- sglang/srt/models/olmoe.py +26 -4
- sglang/srt/models/phi3_small.py +29 -10
- sglang/srt/models/qwen.py +26 -3
- sglang/srt/models/qwen2.py +26 -4
- sglang/srt/models/qwen2_5_vl.py +46 -8
- sglang/srt/models/qwen2_eagle.py +17 -5
- sglang/srt/models/qwen2_moe.py +44 -6
- sglang/srt/models/qwen2_rm.py +78 -0
- sglang/srt/models/qwen2_vl.py +39 -8
- sglang/srt/models/stablelm.py +32 -5
- sglang/srt/models/torch_native_llama.py +5 -2
- sglang/srt/models/xverse.py +21 -9
- sglang/srt/models/xverse_moe.py +45 -7
- sglang/srt/models/yivl.py +2 -1
- sglang/srt/openai_api/adapter.py +109 -24
- sglang/srt/openai_api/protocol.py +17 -1
- sglang/srt/reasoning_parser.py +154 -0
- sglang/srt/sampling/penaltylib/__init__.py +4 -6
- sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
- sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
- sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
- sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
- sglang/srt/sampling/sampling_batch_info.py +79 -157
- sglang/srt/sampling/sampling_params.py +16 -13
- sglang/srt/server_args.py +136 -52
- sglang/srt/speculative/build_eagle_tree.py +2 -8
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
- sglang/srt/speculative/eagle_utils.py +92 -58
- sglang/srt/speculative/eagle_worker.py +186 -94
- sglang/srt/speculative/spec_info.py +1 -13
- sglang/srt/utils.py +43 -17
- sglang/srt/warmup.py +47 -0
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/runners.py +389 -126
- sglang/test/send_one.py +88 -0
- sglang/test/test_block_fp8_ep.py +361 -0
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +138 -84
- sglang/utils.py +50 -60
- sglang/version.py +1 -1
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +200 -166
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
- sglang/bench_latency.py +0 -1
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
- sglang/test/srt/sampling/penaltylib/utils.py +0 -344
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
sglang/srt/models/exaone.py
CHANGED
@@ -39,6 +39,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
39
39
|
)
|
40
40
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
41
41
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
42
|
+
from sglang.srt.utils import add_prefix
|
42
43
|
|
43
44
|
|
44
45
|
class ExaoneGatedMLP(nn.Module):
|
@@ -56,14 +57,14 @@ class ExaoneGatedMLP(nn.Module):
|
|
56
57
|
[intermediate_size] * 2,
|
57
58
|
bias=False,
|
58
59
|
quant_config=quant_config,
|
59
|
-
prefix=
|
60
|
+
prefix=add_prefix("gate_up_proj", prefix),
|
60
61
|
)
|
61
62
|
self.c_proj = RowParallelLinear(
|
62
63
|
intermediate_size,
|
63
64
|
hidden_size,
|
64
65
|
bias=False,
|
65
66
|
quant_config=quant_config,
|
66
|
-
prefix=
|
67
|
+
prefix=add_prefix("c_proj", prefix),
|
67
68
|
)
|
68
69
|
if hidden_act != "silu":
|
69
70
|
raise ValueError(
|
@@ -130,14 +131,14 @@ class ExaoneAttention(nn.Module):
|
|
130
131
|
self.total_num_kv_heads,
|
131
132
|
bias=False,
|
132
133
|
quant_config=quant_config,
|
133
|
-
prefix=
|
134
|
+
prefix=add_prefix("qkv_proj", prefix),
|
134
135
|
)
|
135
136
|
self.out_proj = RowParallelLinear(
|
136
137
|
self.total_num_heads * self.head_dim,
|
137
138
|
hidden_size,
|
138
139
|
bias=False,
|
139
140
|
quant_config=quant_config,
|
140
|
-
prefix=
|
141
|
+
prefix=add_prefix("out_proj", prefix),
|
141
142
|
)
|
142
143
|
|
143
144
|
self.rotary_emb = get_rope(
|
@@ -201,14 +202,14 @@ class ExaoneDecoderLayer(nn.Module):
|
|
201
202
|
rope_is_neox_style=rope_is_neox_style,
|
202
203
|
max_position_embeddings=max_position_embeddings,
|
203
204
|
quant_config=quant_config,
|
204
|
-
prefix=
|
205
|
+
prefix=add_prefix("self_attn", prefix),
|
205
206
|
)
|
206
207
|
self.mlp = ExaoneGatedMLP(
|
207
208
|
hidden_size=self.hidden_size,
|
208
209
|
intermediate_size=config.intermediate_size,
|
209
210
|
hidden_act=config.activation_function,
|
210
211
|
quant_config=quant_config,
|
211
|
-
prefix=
|
212
|
+
prefix=add_prefix("mlp", prefix),
|
212
213
|
)
|
213
214
|
rms_norm_eps = config.layer_norm_epsilon
|
214
215
|
self.ln_1 = RMSNorm(config.hidden_size, eps=rms_norm_eps)
|
@@ -244,6 +245,7 @@ class ExaoneModel(nn.Module):
|
|
244
245
|
self,
|
245
246
|
config,
|
246
247
|
quant_config: Optional[QuantizationConfig] = None,
|
248
|
+
prefix: str = "",
|
247
249
|
) -> None:
|
248
250
|
super().__init__()
|
249
251
|
self.config = config
|
@@ -256,7 +258,10 @@ class ExaoneModel(nn.Module):
|
|
256
258
|
self.h = nn.ModuleList(
|
257
259
|
[
|
258
260
|
ExaoneDecoderLayer(
|
259
|
-
config,
|
261
|
+
config,
|
262
|
+
i,
|
263
|
+
quant_config=quant_config,
|
264
|
+
prefix=add_prefix(f"h.{i}", prefix),
|
260
265
|
)
|
261
266
|
for i in range(config.num_hidden_layers)
|
262
267
|
]
|
@@ -293,12 +298,17 @@ class ExaoneForCausalLM(nn.Module):
|
|
293
298
|
self,
|
294
299
|
config,
|
295
300
|
quant_config: Optional[QuantizationConfig] = None,
|
301
|
+
prefix: str = "",
|
296
302
|
) -> None:
|
297
303
|
super().__init__()
|
298
304
|
self.config = config
|
299
305
|
self.quant_config = quant_config
|
300
|
-
self.transformer = ExaoneModel(
|
301
|
-
|
306
|
+
self.transformer = ExaoneModel(
|
307
|
+
config, quant_config=quant_config, prefix=add_prefix("transformer", prefix)
|
308
|
+
)
|
309
|
+
self.lm_head = ParallelLMHead(
|
310
|
+
config.vocab_size, config.hidden_size, prefix=add_prefix("lm_head", prefix)
|
311
|
+
)
|
302
312
|
self.logits_processor = LogitsProcessor(config)
|
303
313
|
|
304
314
|
@torch.no_grad()
|
sglang/srt/models/gemma.py
CHANGED
@@ -37,6 +37,7 @@ from sglang.srt.layers.rotary_embedding import get_rope
|
|
37
37
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
38
38
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
39
39
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
40
|
+
from sglang.srt.utils import add_prefix
|
40
41
|
|
41
42
|
|
42
43
|
class GemmaMLP(nn.Module):
|
@@ -45,6 +46,7 @@ class GemmaMLP(nn.Module):
|
|
45
46
|
hidden_size: int,
|
46
47
|
intermediate_size: int,
|
47
48
|
quant_config: Optional[QuantizationConfig] = None,
|
49
|
+
prefix: str = "",
|
48
50
|
) -> None:
|
49
51
|
super().__init__()
|
50
52
|
self.gate_up_proj = MergedColumnParallelLinear(
|
@@ -52,12 +54,14 @@ class GemmaMLP(nn.Module):
|
|
52
54
|
[intermediate_size] * 2,
|
53
55
|
bias=False,
|
54
56
|
quant_config=quant_config,
|
57
|
+
prefix=add_prefix("gate_up_proj", prefix),
|
55
58
|
)
|
56
59
|
self.down_proj = RowParallelLinear(
|
57
60
|
intermediate_size,
|
58
61
|
hidden_size,
|
59
62
|
bias=False,
|
60
63
|
quant_config=quant_config,
|
64
|
+
prefix=add_prefix("down_proj", prefix),
|
61
65
|
)
|
62
66
|
self.act_fn = GeluAndMul("none")
|
63
67
|
|
@@ -79,6 +83,7 @@ class GemmaAttention(nn.Module):
|
|
79
83
|
max_position_embeddings: int = 8192,
|
80
84
|
rope_theta: float = 10000,
|
81
85
|
quant_config: Optional[QuantizationConfig] = None,
|
86
|
+
prefix: str = "",
|
82
87
|
) -> None:
|
83
88
|
super().__init__()
|
84
89
|
self.hidden_size = hidden_size
|
@@ -109,12 +114,14 @@ class GemmaAttention(nn.Module):
|
|
109
114
|
self.total_num_kv_heads,
|
110
115
|
bias=False,
|
111
116
|
quant_config=quant_config,
|
117
|
+
prefix=add_prefix("qkv_proj", prefix),
|
112
118
|
)
|
113
119
|
self.o_proj = RowParallelLinear(
|
114
120
|
self.total_num_heads * self.head_dim,
|
115
121
|
hidden_size,
|
116
122
|
bias=False,
|
117
123
|
quant_config=quant_config,
|
124
|
+
prefix=add_prefix("o_proj", prefix),
|
118
125
|
)
|
119
126
|
|
120
127
|
self.rotary_emb = get_rope(
|
@@ -130,6 +137,7 @@ class GemmaAttention(nn.Module):
|
|
130
137
|
self.scaling,
|
131
138
|
num_kv_heads=self.num_kv_heads,
|
132
139
|
layer_id=layer_id,
|
140
|
+
prefix=add_prefix("attn", prefix),
|
133
141
|
)
|
134
142
|
|
135
143
|
def forward(
|
@@ -152,6 +160,7 @@ class GemmaDecoderLayer(nn.Module):
|
|
152
160
|
config: PretrainedConfig,
|
153
161
|
layer_id: int = 0,
|
154
162
|
quant_config: Optional[QuantizationConfig] = None,
|
163
|
+
prefix: str = "",
|
155
164
|
) -> None:
|
156
165
|
super().__init__()
|
157
166
|
self.hidden_size = config.hidden_size
|
@@ -164,11 +173,13 @@ class GemmaDecoderLayer(nn.Module):
|
|
164
173
|
max_position_embeddings=config.max_position_embeddings,
|
165
174
|
rope_theta=config.rope_theta,
|
166
175
|
quant_config=quant_config,
|
176
|
+
prefix=add_prefix("self_attn", prefix),
|
167
177
|
)
|
168
178
|
self.mlp = GemmaMLP(
|
169
179
|
hidden_size=self.hidden_size,
|
170
180
|
intermediate_size=config.intermediate_size,
|
171
181
|
quant_config=quant_config,
|
182
|
+
prefix=add_prefix("mlp", prefix),
|
172
183
|
)
|
173
184
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
174
185
|
self.post_attention_layernorm = RMSNorm(
|
@@ -205,6 +216,7 @@ class GemmaModel(nn.Module):
|
|
205
216
|
self,
|
206
217
|
config: PretrainedConfig,
|
207
218
|
quant_config: Optional[QuantizationConfig] = None,
|
219
|
+
prefix: str = "",
|
208
220
|
) -> None:
|
209
221
|
super().__init__()
|
210
222
|
self.config = config
|
@@ -215,7 +227,12 @@ class GemmaModel(nn.Module):
|
|
215
227
|
)
|
216
228
|
self.layers = nn.ModuleList(
|
217
229
|
[
|
218
|
-
GemmaDecoderLayer(
|
230
|
+
GemmaDecoderLayer(
|
231
|
+
config,
|
232
|
+
i,
|
233
|
+
quant_config=quant_config,
|
234
|
+
prefix=add_prefix(f"layers.{i}", prefix),
|
235
|
+
)
|
219
236
|
for i in range(config.num_hidden_layers)
|
220
237
|
]
|
221
238
|
)
|
@@ -277,11 +294,14 @@ class GemmaForCausalLM(nn.Module):
|
|
277
294
|
self,
|
278
295
|
config: PretrainedConfig,
|
279
296
|
quant_config: Optional[QuantizationConfig] = None,
|
297
|
+
prefix: str = "",
|
280
298
|
) -> None:
|
281
299
|
super().__init__()
|
282
300
|
self.config = config
|
283
301
|
self.quant_config = quant_config
|
284
|
-
self.model = GemmaModel(
|
302
|
+
self.model = GemmaModel(
|
303
|
+
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
304
|
+
)
|
285
305
|
self.logits_processor = LogitsProcessor(config)
|
286
306
|
|
287
307
|
@torch.no_grad()
|
@@ -336,12 +356,6 @@ class GemmaForCausalLM(nn.Module):
|
|
336
356
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
337
357
|
weight_loader(param, loaded_weight)
|
338
358
|
loaded_params.add(name)
|
339
|
-
unloaded_params = params_dict.keys() - loaded_params
|
340
|
-
if unloaded_params:
|
341
|
-
raise RuntimeError(
|
342
|
-
"Some weights are not initialized from checkpoints: "
|
343
|
-
f"{unloaded_params}"
|
344
|
-
)
|
345
359
|
|
346
360
|
|
347
361
|
EntryClass = GemmaForCausalLM
|
sglang/srt/models/gemma2.py
CHANGED
@@ -39,7 +39,7 @@ from sglang.srt.model_loader.weight_utils import (
|
|
39
39
|
default_weight_loader,
|
40
40
|
maybe_remap_kv_scale_name,
|
41
41
|
)
|
42
|
-
from sglang.srt.utils import make_layers
|
42
|
+
from sglang.srt.utils import add_prefix, make_layers
|
43
43
|
|
44
44
|
|
45
45
|
# Aligned with HF's implementation, using sliding window inclusive with the last token
|
@@ -56,13 +56,22 @@ class Gemma2MLP(nn.Module):
|
|
56
56
|
hidden_act: str,
|
57
57
|
hidden_activation: str,
|
58
58
|
quant_config: Optional[QuantizationConfig] = None,
|
59
|
+
prefix: str = "",
|
59
60
|
) -> None:
|
60
61
|
super().__init__()
|
61
62
|
self.gate_up_proj = MergedColumnParallelLinear(
|
62
|
-
hidden_size,
|
63
|
+
hidden_size,
|
64
|
+
[intermediate_size] * 2,
|
65
|
+
bias=False,
|
66
|
+
quant_config=quant_config,
|
67
|
+
prefix=add_prefix("gate_up_proj", prefix),
|
63
68
|
)
|
64
69
|
self.down_proj = RowParallelLinear(
|
65
|
-
intermediate_size,
|
70
|
+
intermediate_size,
|
71
|
+
hidden_size,
|
72
|
+
bias=False,
|
73
|
+
quant_config=quant_config,
|
74
|
+
prefix=add_prefix("down_proj", prefix),
|
66
75
|
)
|
67
76
|
if not (hidden_act == hidden_activation == "gelu_pytorch_tanh"):
|
68
77
|
raise ValueError(
|
@@ -91,6 +100,7 @@ class Gemma2Attention(nn.Module):
|
|
91
100
|
max_position_embeddings: int,
|
92
101
|
rope_theta: float,
|
93
102
|
quant_config: Optional[QuantizationConfig] = None,
|
103
|
+
prefix: str = "",
|
94
104
|
) -> None:
|
95
105
|
super().__init__()
|
96
106
|
self.layer_id = layer_id
|
@@ -123,12 +133,14 @@ class Gemma2Attention(nn.Module):
|
|
123
133
|
self.total_num_kv_heads,
|
124
134
|
bias=config.attention_bias,
|
125
135
|
quant_config=quant_config,
|
136
|
+
prefix=add_prefix("qkv_proj", prefix),
|
126
137
|
)
|
127
138
|
self.o_proj = RowParallelLinear(
|
128
139
|
self.total_num_heads * self.head_dim,
|
129
140
|
hidden_size,
|
130
141
|
bias=config.attention_bias,
|
131
142
|
quant_config=quant_config,
|
143
|
+
prefix=add_prefix("o_proj", prefix),
|
132
144
|
)
|
133
145
|
self.rotary_emb = get_rope(
|
134
146
|
self.head_dim,
|
@@ -151,6 +163,7 @@ class Gemma2Attention(nn.Module):
|
|
151
163
|
if use_sliding_window
|
152
164
|
else None
|
153
165
|
),
|
166
|
+
prefix=add_prefix("attn", prefix),
|
154
167
|
)
|
155
168
|
|
156
169
|
def forward(
|
@@ -173,6 +186,7 @@ class Gemma2DecoderLayer(nn.Module):
|
|
173
186
|
layer_id: int,
|
174
187
|
config: PretrainedConfig,
|
175
188
|
quant_config: Optional[QuantizationConfig] = None,
|
189
|
+
prefix: str = "",
|
176
190
|
) -> None:
|
177
191
|
super().__init__()
|
178
192
|
self.hidden_size = config.hidden_size
|
@@ -186,6 +200,7 @@ class Gemma2DecoderLayer(nn.Module):
|
|
186
200
|
max_position_embeddings=config.max_position_embeddings,
|
187
201
|
rope_theta=config.rope_theta,
|
188
202
|
quant_config=quant_config,
|
203
|
+
prefix=add_prefix("self_attn", prefix),
|
189
204
|
)
|
190
205
|
self.hidden_size = config.hidden_size
|
191
206
|
self.mlp = Gemma2MLP(
|
@@ -194,6 +209,7 @@ class Gemma2DecoderLayer(nn.Module):
|
|
194
209
|
hidden_act=config.hidden_act,
|
195
210
|
hidden_activation=config.hidden_activation,
|
196
211
|
quant_config=quant_config,
|
212
|
+
prefix=add_prefix("mlp", prefix),
|
197
213
|
)
|
198
214
|
self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
199
215
|
self.post_attention_layernorm = GemmaRMSNorm(
|
@@ -238,6 +254,7 @@ class Gemma2Model(nn.Module):
|
|
238
254
|
self,
|
239
255
|
config: PretrainedConfig,
|
240
256
|
quant_config: Optional[QuantizationConfig] = None,
|
257
|
+
prefix: str = "",
|
241
258
|
) -> None:
|
242
259
|
super().__init__()
|
243
260
|
self.config = config
|
@@ -253,7 +270,7 @@ class Gemma2Model(nn.Module):
|
|
253
270
|
config=config,
|
254
271
|
quant_config=quant_config,
|
255
272
|
),
|
256
|
-
prefix="",
|
273
|
+
prefix=add_prefix("layers", prefix),
|
257
274
|
)
|
258
275
|
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
259
276
|
|
@@ -339,11 +356,14 @@ class Gemma2ForCausalLM(nn.Module):
|
|
339
356
|
self,
|
340
357
|
config: PretrainedConfig,
|
341
358
|
quant_config: Optional[QuantizationConfig] = None,
|
359
|
+
prefix: str = "",
|
342
360
|
) -> None:
|
343
361
|
super().__init__()
|
344
362
|
self.config = config
|
345
363
|
self.quant_config = quant_config
|
346
|
-
self.model = Gemma2Model(
|
364
|
+
self.model = Gemma2Model(
|
365
|
+
config, quant_config, prefix=add_prefix("model", prefix)
|
366
|
+
)
|
347
367
|
self.logits_processor = LogitsProcessor(config)
|
348
368
|
|
349
369
|
@torch.no_grad()
|
@@ -437,12 +457,5 @@ class Gemma2ForCausalLM(nn.Module):
|
|
437
457
|
weight_loader(param, loaded_weight)
|
438
458
|
loaded_params.add(name)
|
439
459
|
|
440
|
-
unloaded_params = params_dict.keys() - loaded_params
|
441
|
-
if unloaded_params:
|
442
|
-
raise RuntimeError(
|
443
|
-
"Some weights are not initialized from checkpoints: "
|
444
|
-
f"{unloaded_params}"
|
445
|
-
)
|
446
|
-
|
447
460
|
|
448
461
|
EntryClass = Gemma2ForCausalLM
|
@@ -22,6 +22,7 @@ from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
|
|
22
22
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
23
23
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
24
24
|
from sglang.srt.models.gemma2 import Gemma2ForCausalLM, Gemma2Model
|
25
|
+
from sglang.srt.utils import add_prefix
|
25
26
|
|
26
27
|
|
27
28
|
class Gemma2ForSequenceClassification(nn.Module):
|
@@ -29,12 +30,15 @@ class Gemma2ForSequenceClassification(nn.Module):
|
|
29
30
|
self,
|
30
31
|
config: Gemma2Config,
|
31
32
|
quant_config: Optional[QuantizationConfig] = None,
|
33
|
+
prefix: str = "",
|
32
34
|
) -> None:
|
33
35
|
super().__init__()
|
34
36
|
self.config = config
|
35
37
|
self.quant_config = quant_config
|
36
38
|
self.num_labels = config.num_labels
|
37
|
-
self.model = Gemma2Model(
|
39
|
+
self.model = Gemma2Model(
|
40
|
+
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
41
|
+
)
|
38
42
|
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
39
43
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False)
|
40
44
|
|
sglang/srt/models/gpt2.py
CHANGED
@@ -17,14 +17,14 @@
|
|
17
17
|
# See the License for the specific language governing permissions and
|
18
18
|
# limitations under the License.
|
19
19
|
"""Inference-only GPT-2 model compatible with HuggingFace weights."""
|
20
|
-
from typing import Iterable, Optional, Tuple
|
20
|
+
from typing import Iterable, Optional, Tuple, Type
|
21
21
|
|
22
22
|
import torch
|
23
23
|
from torch import nn
|
24
24
|
from transformers import GPT2Config
|
25
25
|
|
26
26
|
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_world_size
|
27
|
-
from sglang.srt.layers.activation import
|
27
|
+
from sglang.srt.layers.activation import NewGELU
|
28
28
|
from sglang.srt.layers.linear import (
|
29
29
|
ColumnParallelLinear,
|
30
30
|
QKVParallelLinear,
|
@@ -36,6 +36,7 @@ from sglang.srt.layers.radix_attention import RadixAttention
|
|
36
36
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
37
37
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
38
38
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
39
|
+
from sglang.srt.utils import add_prefix
|
39
40
|
|
40
41
|
|
41
42
|
class GPT2Attention(nn.Module):
|
@@ -62,14 +63,14 @@ class GPT2Attention(nn.Module):
|
|
62
63
|
total_num_heads,
|
63
64
|
bias=True,
|
64
65
|
quant_config=quant_config,
|
65
|
-
prefix=
|
66
|
+
prefix=add_prefix("c_attn", prefix),
|
66
67
|
)
|
67
68
|
self.c_proj = RowParallelLinear(
|
68
69
|
self.hidden_size,
|
69
70
|
self.hidden_size,
|
70
71
|
bias=True,
|
71
72
|
quant_config=quant_config,
|
72
|
-
prefix=
|
73
|
+
prefix=add_prefix("c_proj", prefix),
|
73
74
|
)
|
74
75
|
self.attn = RadixAttention(
|
75
76
|
self.num_heads,
|
@@ -97,6 +98,7 @@ class GPT2MLP(nn.Module):
|
|
97
98
|
self,
|
98
99
|
intermediate_size: int,
|
99
100
|
config: GPT2Config,
|
101
|
+
act_layer: Type[nn.Module] = NewGELU,
|
100
102
|
quant_config: Optional[QuantizationConfig] = None,
|
101
103
|
prefix: str = "",
|
102
104
|
):
|
@@ -107,18 +109,16 @@ class GPT2MLP(nn.Module):
|
|
107
109
|
intermediate_size,
|
108
110
|
bias=True,
|
109
111
|
quant_config=quant_config,
|
110
|
-
prefix=
|
112
|
+
prefix=add_prefix("c_fc", prefix),
|
111
113
|
)
|
112
114
|
self.c_proj = RowParallelLinear(
|
113
115
|
intermediate_size,
|
114
116
|
hidden_size,
|
115
117
|
bias=True,
|
116
118
|
quant_config=quant_config,
|
117
|
-
prefix=
|
118
|
-
)
|
119
|
-
self.act = get_act_fn(
|
120
|
-
config.activation_function, quant_config, intermediate_size
|
119
|
+
prefix=add_prefix("c_proj", prefix),
|
121
120
|
)
|
121
|
+
self.act = act_layer()
|
122
122
|
|
123
123
|
def forward(
|
124
124
|
self,
|
@@ -136,6 +136,7 @@ class GPT2Block(nn.Module):
|
|
136
136
|
self,
|
137
137
|
layer_id: int,
|
138
138
|
config: GPT2Config,
|
139
|
+
act_layer: Type[nn.Module] = NewGELU,
|
139
140
|
quant_config: Optional[QuantizationConfig] = None,
|
140
141
|
prefix: str = "",
|
141
142
|
):
|
@@ -145,10 +146,16 @@ class GPT2Block(nn.Module):
|
|
145
146
|
|
146
147
|
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
147
148
|
self.attn = GPT2Attention(
|
148
|
-
layer_id, config, quant_config, prefix=
|
149
|
+
layer_id, config, quant_config, prefix=add_prefix("attn", prefix)
|
149
150
|
)
|
150
151
|
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
151
|
-
self.mlp = GPT2MLP(
|
152
|
+
self.mlp = GPT2MLP(
|
153
|
+
inner_dim,
|
154
|
+
config,
|
155
|
+
act_layer=act_layer,
|
156
|
+
quant_config=quant_config,
|
157
|
+
prefix=add_prefix("mlp", prefix),
|
158
|
+
)
|
152
159
|
|
153
160
|
def forward(
|
154
161
|
self,
|
@@ -190,7 +197,12 @@ class GPT2Model(nn.Module):
|
|
190
197
|
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
191
198
|
self.h = nn.ModuleList(
|
192
199
|
[
|
193
|
-
GPT2Block(
|
200
|
+
GPT2Block(
|
201
|
+
i,
|
202
|
+
config,
|
203
|
+
quant_config=quant_config,
|
204
|
+
prefix=add_prefix(f"h.{i}", prefix),
|
205
|
+
)
|
194
206
|
for i in range(config.num_hidden_layers)
|
195
207
|
]
|
196
208
|
)
|
@@ -221,11 +233,14 @@ class GPT2LMHeadModel(nn.Module):
|
|
221
233
|
self,
|
222
234
|
config: GPT2Config,
|
223
235
|
quant_config: Optional[QuantizationConfig] = None,
|
236
|
+
prefix: str = "",
|
224
237
|
):
|
225
238
|
super().__init__()
|
226
239
|
self.config = config
|
227
240
|
self.quant_config = quant_config
|
228
|
-
self.transformer = GPT2Model(
|
241
|
+
self.transformer = GPT2Model(
|
242
|
+
config, quant_config, prefix=add_prefix("transformer", prefix)
|
243
|
+
)
|
229
244
|
self.lm_head = self.transformer.wte
|
230
245
|
|
231
246
|
self.logits_processor = LogitsProcessor(config)
|
sglang/srt/models/gpt_bigcode.py
CHANGED
@@ -35,6 +35,7 @@ from sglang.srt.layers.radix_attention import RadixAttention
|
|
35
35
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
36
36
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
37
37
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
38
|
+
from sglang.srt.utils import add_prefix
|
38
39
|
|
39
40
|
|
40
41
|
class GPTBigCodeAttention(nn.Module):
|
@@ -44,6 +45,7 @@ class GPTBigCodeAttention(nn.Module):
|
|
44
45
|
layer_id: int,
|
45
46
|
config: GPTBigCodeConfig,
|
46
47
|
quant_config: Optional[QuantizationConfig] = None,
|
48
|
+
prefix: str = "",
|
47
49
|
):
|
48
50
|
super().__init__()
|
49
51
|
self.hidden_size = config.hidden_size
|
@@ -69,6 +71,7 @@ class GPTBigCodeAttention(nn.Module):
|
|
69
71
|
total_num_kv_heads,
|
70
72
|
bias=True,
|
71
73
|
quant_config=quant_config,
|
74
|
+
prefix=add_prefix("c_attn", prefix),
|
72
75
|
)
|
73
76
|
|
74
77
|
self.c_proj = RowParallelLinear(
|
@@ -76,6 +79,7 @@ class GPTBigCodeAttention(nn.Module):
|
|
76
79
|
self.hidden_size,
|
77
80
|
bias=True,
|
78
81
|
quant_config=quant_config,
|
82
|
+
prefix=add_prefix("c_proj", prefix),
|
79
83
|
)
|
80
84
|
self.attn = RadixAttention(
|
81
85
|
self.num_heads,
|
@@ -83,6 +87,7 @@ class GPTBigCodeAttention(nn.Module):
|
|
83
87
|
scaling=self.scale,
|
84
88
|
num_kv_heads=self.num_kv_heads,
|
85
89
|
layer_id=layer_id,
|
90
|
+
prefix=add_prefix("attn", prefix),
|
86
91
|
)
|
87
92
|
|
88
93
|
def forward(
|
@@ -111,6 +116,7 @@ class GPTBigMLP(nn.Module):
|
|
111
116
|
intermediate_size: int,
|
112
117
|
config: GPTBigCodeConfig,
|
113
118
|
quant_config: Optional[QuantizationConfig] = None,
|
119
|
+
prefix: str = "",
|
114
120
|
):
|
115
121
|
super().__init__()
|
116
122
|
hidden_size = config.hidden_size
|
@@ -119,12 +125,14 @@ class GPTBigMLP(nn.Module):
|
|
119
125
|
intermediate_size,
|
120
126
|
bias=True,
|
121
127
|
quant_config=quant_config,
|
128
|
+
prefix=add_prefix("c_fc", prefix),
|
122
129
|
)
|
123
130
|
self.c_proj = RowParallelLinear(
|
124
131
|
intermediate_size,
|
125
132
|
hidden_size,
|
126
133
|
bias=True,
|
127
134
|
quant_config=quant_config,
|
135
|
+
prefix=add_prefix("c_proj", prefix),
|
128
136
|
)
|
129
137
|
self.act = get_act_fn(
|
130
138
|
config.activation_function, quant_config, intermediate_size
|
@@ -144,15 +152,20 @@ class GPTBigCodeBlock(nn.Module):
|
|
144
152
|
layer_id: int,
|
145
153
|
config: GPTBigCodeConfig,
|
146
154
|
quant_config: Optional[QuantizationConfig] = None,
|
155
|
+
prefix: str = "",
|
147
156
|
):
|
148
157
|
super().__init__()
|
149
158
|
hidden_size = config.hidden_size
|
150
159
|
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
|
151
160
|
|
152
161
|
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
153
|
-
self.attn = GPTBigCodeAttention(
|
162
|
+
self.attn = GPTBigCodeAttention(
|
163
|
+
layer_id, config, quant_config, prefix=add_prefix("attn", prefix)
|
164
|
+
)
|
154
165
|
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
155
|
-
self.mlp = GPTBigMLP(
|
166
|
+
self.mlp = GPTBigMLP(
|
167
|
+
inner_dim, config, quant_config, prefix=add_prefix("mlp", prefix)
|
168
|
+
)
|
156
169
|
|
157
170
|
def forward(
|
158
171
|
self,
|
@@ -181,6 +194,7 @@ class GPTBigCodeModel(nn.Module):
|
|
181
194
|
self,
|
182
195
|
config: GPTBigCodeConfig,
|
183
196
|
quant_config: Optional[QuantizationConfig] = None,
|
197
|
+
prefix: str = "",
|
184
198
|
):
|
185
199
|
super().__init__()
|
186
200
|
self.config = config
|
@@ -190,12 +204,17 @@ class GPTBigCodeModel(nn.Module):
|
|
190
204
|
lora_vocab = 0
|
191
205
|
self.vocab_size = config.vocab_size + lora_vocab
|
192
206
|
self.wte = VocabParallelEmbedding(
|
193
|
-
self.vocab_size,
|
207
|
+
self.vocab_size,
|
208
|
+
self.embed_dim,
|
209
|
+
org_num_embeddings=config.vocab_size,
|
210
|
+
prefix=add_prefix("wte", prefix),
|
194
211
|
)
|
195
212
|
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
196
213
|
self.h = nn.ModuleList(
|
197
214
|
[
|
198
|
-
GPTBigCodeBlock(
|
215
|
+
GPTBigCodeBlock(
|
216
|
+
i, config, quant_config, prefix=add_prefix(f"h.{i}", prefix)
|
217
|
+
)
|
199
218
|
for i in range(config.num_hidden_layers)
|
200
219
|
]
|
201
220
|
)
|
@@ -235,13 +254,16 @@ class GPTBigCodeForCausalLM(nn.Module):
|
|
235
254
|
self,
|
236
255
|
config: GPTBigCodeConfig,
|
237
256
|
quant_config: Optional[QuantizationConfig] = None,
|
257
|
+
prefix: str = "",
|
238
258
|
):
|
239
259
|
super().__init__()
|
240
260
|
|
241
261
|
self.config = config
|
242
262
|
|
243
263
|
self.quant_config = quant_config
|
244
|
-
self.transformer = GPTBigCodeModel(
|
264
|
+
self.transformer = GPTBigCodeModel(
|
265
|
+
config, quant_config, prefix=add_prefix("transformer", prefix)
|
266
|
+
)
|
245
267
|
self.lm_head = self.transformer.wte
|
246
268
|
self.unpadded_vocab_size = config.vocab_size
|
247
269
|
self.logits_processor = LogitsProcessor(config)
|