sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_offline_throughput.py +19 -0
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +123 -79
- sglang/global_config.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/ir.py +1 -1
- sglang/srt/_custom_ops.py +83 -91
- sglang/srt/configs/load_config.py +4 -1
- sglang/srt/configs/model_config.py +48 -2
- sglang/srt/configs/qwen2_5_vl_config.py +5 -2
- sglang/srt/constrained/base_grammar_backend.py +117 -15
- sglang/srt/constrained/llguidance_backend.py +151 -0
- sglang/srt/constrained/outlines_backend.py +24 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -38
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
- sglang/srt/distributed/parallel_state.py +48 -3
- sglang/srt/entrypoints/engine.py +67 -9
- sglang/srt/entrypoints/http_server.py +190 -41
- sglang/srt/entrypoints/verl_engine.py +147 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/activation.py +11 -0
- sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +220 -378
- sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
- sglang/srt/layers/attention/torch_native_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +9 -6
- sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
- sglang/srt/layers/attention/utils.py +39 -0
- sglang/srt/layers/attention/vision.py +60 -63
- sglang/srt/layers/dp_attention.py +142 -1
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +3 -1
- sglang/srt/layers/logits_processor.py +281 -45
- sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
- sglang/srt/layers/moe/ep_moe/layer.py +140 -28
- sglang/srt/layers/moe/fused_moe_native.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
- sglang/srt/layers/moe/topk.py +13 -4
- sglang/srt/layers/quantization/__init__.py +111 -7
- sglang/srt/layers/quantization/blockwise_int8.py +409 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +69 -28
- sglang/srt/layers/quantization/fp8_utils.py +17 -1
- sglang/srt/layers/quantization/gptq.py +416 -0
- sglang/srt/layers/quantization/int8_kernel.py +327 -0
- sglang/srt/layers/quantization/int8_utils.py +73 -0
- sglang/srt/layers/quantization/modelopt_quant.py +18 -1
- sglang/srt/layers/radix_attention.py +1 -0
- sglang/srt/layers/rotary_embedding.py +0 -1
- sglang/srt/layers/sampler.py +76 -31
- sglang/srt/layers/vocab_parallel_embedding.py +14 -13
- sglang/srt/lora/lora.py +17 -1
- sglang/srt/lora/lora_config.py +5 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/cache_controller.py +193 -62
- sglang/srt/managers/configure_logging.py +2 -1
- sglang/srt/managers/data_parallel_controller.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +124 -102
- sglang/srt/managers/image_processor.py +2 -1
- sglang/srt/managers/io_struct.py +143 -6
- sglang/srt/managers/schedule_batch.py +237 -197
- sglang/srt/managers/schedule_policy.py +29 -29
- sglang/srt/managers/scheduler.py +681 -259
- sglang/srt/managers/session_controller.py +6 -2
- sglang/srt/managers/tokenizer_manager.py +224 -68
- sglang/srt/managers/tp_worker.py +15 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/chunk_cache.py +18 -11
- sglang/srt/mem_cache/hiradix_cache.py +394 -0
- sglang/srt/mem_cache/memory_pool.py +44 -18
- sglang/srt/mem_cache/radix_cache.py +58 -47
- sglang/srt/metrics/collector.py +94 -36
- sglang/srt/model_executor/cuda_graph_runner.py +55 -24
- sglang/srt/model_executor/forward_batch_info.py +49 -16
- sglang/srt/model_executor/model_runner.py +208 -28
- sglang/srt/model_loader/loader.py +3 -3
- sglang/srt/model_loader/weight_utils.py +36 -14
- sglang/srt/models/baichuan.py +31 -6
- sglang/srt/models/chatglm.py +39 -7
- sglang/srt/models/commandr.py +29 -5
- sglang/srt/models/dbrx.py +31 -5
- sglang/srt/models/deepseek.py +43 -6
- sglang/srt/models/deepseek_nextn.py +32 -19
- sglang/srt/models/deepseek_v2.py +265 -32
- sglang/srt/models/exaone.py +19 -9
- sglang/srt/models/gemma.py +22 -8
- sglang/srt/models/gemma2.py +25 -12
- sglang/srt/models/gemma2_reward.py +5 -1
- sglang/srt/models/gpt2.py +28 -13
- sglang/srt/models/gpt_bigcode.py +27 -5
- sglang/srt/models/granite.py +21 -9
- sglang/srt/models/grok.py +21 -4
- sglang/srt/models/internlm2.py +36 -6
- sglang/srt/models/internlm2_reward.py +5 -1
- sglang/srt/models/llama.py +26 -9
- sglang/srt/models/llama_classification.py +5 -1
- sglang/srt/models/llama_eagle.py +17 -4
- sglang/srt/models/llama_embedding.py +5 -1
- sglang/srt/models/llama_reward.py +7 -2
- sglang/srt/models/llava.py +19 -3
- sglang/srt/models/llavavid.py +10 -1
- sglang/srt/models/minicpm.py +26 -2
- sglang/srt/models/minicpm3.py +39 -3
- sglang/srt/models/minicpmv.py +45 -14
- sglang/srt/models/mixtral.py +20 -9
- sglang/srt/models/mixtral_quant.py +50 -8
- sglang/srt/models/mllama.py +57 -11
- sglang/srt/models/olmo.py +34 -6
- sglang/srt/models/olmo2.py +34 -13
- sglang/srt/models/olmoe.py +26 -4
- sglang/srt/models/phi3_small.py +29 -10
- sglang/srt/models/qwen.py +26 -3
- sglang/srt/models/qwen2.py +26 -4
- sglang/srt/models/qwen2_5_vl.py +46 -8
- sglang/srt/models/qwen2_eagle.py +17 -5
- sglang/srt/models/qwen2_moe.py +44 -6
- sglang/srt/models/qwen2_rm.py +78 -0
- sglang/srt/models/qwen2_vl.py +39 -8
- sglang/srt/models/stablelm.py +32 -5
- sglang/srt/models/torch_native_llama.py +5 -2
- sglang/srt/models/xverse.py +21 -9
- sglang/srt/models/xverse_moe.py +45 -7
- sglang/srt/models/yivl.py +2 -1
- sglang/srt/openai_api/adapter.py +109 -24
- sglang/srt/openai_api/protocol.py +17 -1
- sglang/srt/reasoning_parser.py +154 -0
- sglang/srt/sampling/penaltylib/__init__.py +4 -6
- sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
- sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
- sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
- sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
- sglang/srt/sampling/sampling_batch_info.py +79 -157
- sglang/srt/sampling/sampling_params.py +16 -13
- sglang/srt/server_args.py +136 -52
- sglang/srt/speculative/build_eagle_tree.py +2 -8
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
- sglang/srt/speculative/eagle_utils.py +92 -58
- sglang/srt/speculative/eagle_worker.py +186 -94
- sglang/srt/speculative/spec_info.py +1 -13
- sglang/srt/utils.py +43 -17
- sglang/srt/warmup.py +47 -0
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/runners.py +389 -126
- sglang/test/send_one.py +88 -0
- sglang/test/test_block_fp8_ep.py +361 -0
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +138 -84
- sglang/utils.py +50 -60
- sglang/version.py +1 -1
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +200 -166
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
- sglang/bench_latency.py +0 -1
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
- sglang/test/srt/sampling/penaltylib/utils.py +0 -344
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
sglang/srt/models/mllama.py
CHANGED
@@ -36,6 +36,7 @@ from sglang.srt.managers.schedule_batch import ImageInputs
|
|
36
36
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
37
37
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
38
38
|
from sglang.srt.models.llama import LlamaDecoderLayer, LlamaMLP
|
39
|
+
from sglang.srt.utils import add_prefix
|
39
40
|
|
40
41
|
|
41
42
|
class ColumnParallelConv2dPatch(torch.nn.Module):
|
@@ -147,7 +148,12 @@ class MllamaPrecomputedPositionEmbedding(nn.Module):
|
|
147
148
|
|
148
149
|
|
149
150
|
class MllamaVisionMLP(nn.Module):
|
150
|
-
def __init__(
|
151
|
+
def __init__(
|
152
|
+
self,
|
153
|
+
config,
|
154
|
+
quant_config: Optional[QuantizationConfig] = None,
|
155
|
+
prefix: str = "",
|
156
|
+
):
|
151
157
|
super().__init__()
|
152
158
|
self.config = config
|
153
159
|
self.activation_fn = get_act_fn(config.hidden_act)
|
@@ -156,12 +162,14 @@ class MllamaVisionMLP(nn.Module):
|
|
156
162
|
config.intermediate_size,
|
157
163
|
bias=True,
|
158
164
|
quant_config=quant_config,
|
165
|
+
prefix=add_prefix("fc1", prefix),
|
159
166
|
)
|
160
167
|
self.fc2 = RowParallelLinear(
|
161
168
|
config.intermediate_size,
|
162
169
|
config.hidden_size,
|
163
170
|
bias=True,
|
164
171
|
quant_config=quant_config,
|
172
|
+
prefix=add_prefix("fc2", prefix),
|
165
173
|
)
|
166
174
|
|
167
175
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
@@ -174,7 +182,10 @@ class MllamaVisionMLP(nn.Module):
|
|
174
182
|
|
175
183
|
class MllamaVisionEncoderLayer(nn.Module):
|
176
184
|
def __init__(
|
177
|
-
self,
|
185
|
+
self,
|
186
|
+
config: config_mllama.MllamaVisionConfig,
|
187
|
+
is_gated: bool = False,
|
188
|
+
prefix: str = "",
|
178
189
|
):
|
179
190
|
super().__init__()
|
180
191
|
|
@@ -193,8 +204,9 @@ class MllamaVisionEncoderLayer(nn.Module):
|
|
193
204
|
use_context_forward=False,
|
194
205
|
use_full_precision_softmax=False,
|
195
206
|
flatten_batch=False,
|
207
|
+
prefix=add_prefix("self_attn", prefix),
|
196
208
|
)
|
197
|
-
self.mlp = MllamaVisionMLP(config)
|
209
|
+
self.mlp = MllamaVisionMLP(config, prefix=add_prefix("mlp", prefix))
|
198
210
|
|
199
211
|
self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps)
|
200
212
|
self.post_attention_layernorm = nn.LayerNorm(
|
@@ -235,11 +247,17 @@ class MllamaVisionEncoder(nn.Module):
|
|
235
247
|
num_layers=32,
|
236
248
|
is_gated=False,
|
237
249
|
output_hidden_states=None,
|
250
|
+
prefix: str = "",
|
238
251
|
):
|
239
252
|
super().__init__()
|
240
253
|
self.config = config
|
241
254
|
self.layers = nn.ModuleList(
|
242
|
-
[
|
255
|
+
[
|
256
|
+
MllamaVisionEncoderLayer(
|
257
|
+
config, is_gated, prefix=add_prefix(f"layers.{i}", prefix)
|
258
|
+
)
|
259
|
+
for i in range(num_layers)
|
260
|
+
]
|
243
261
|
)
|
244
262
|
self.output_hidden_states = output_hidden_states or []
|
245
263
|
|
@@ -265,7 +283,7 @@ class MllamaVisionEncoder(nn.Module):
|
|
265
283
|
|
266
284
|
|
267
285
|
class MllamaVisionModel(nn.Module):
|
268
|
-
def __init__(self, config: config_mllama.MllamaVisionConfig):
|
286
|
+
def __init__(self, config: config_mllama.MllamaVisionConfig, prefix: str = ""):
|
269
287
|
super().__init__()
|
270
288
|
self.image_size = config.image_size
|
271
289
|
self.patch_size = config.patch_size
|
@@ -305,9 +323,13 @@ class MllamaVisionModel(nn.Module):
|
|
305
323
|
config.num_hidden_layers,
|
306
324
|
is_gated=False,
|
307
325
|
output_hidden_states=config.intermediate_layers_indices,
|
326
|
+
prefix=add_prefix("transformer", prefix),
|
308
327
|
)
|
309
328
|
self.global_transformer = MllamaVisionEncoder(
|
310
|
-
config,
|
329
|
+
config,
|
330
|
+
config.num_global_layers,
|
331
|
+
is_gated=True,
|
332
|
+
prefix=add_prefix("global_transformer", prefix),
|
311
333
|
)
|
312
334
|
|
313
335
|
def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
@@ -464,6 +486,7 @@ class MllamaTextCrossAttention(nn.Module):
|
|
464
486
|
config: Optional[config_mllama.MllamaTextConfig] = None,
|
465
487
|
layer_id: Optional[int] = None,
|
466
488
|
quant_config: Optional[QuantizationConfig] = None,
|
489
|
+
prefix: str = "",
|
467
490
|
):
|
468
491
|
super().__init__()
|
469
492
|
self.config = config
|
@@ -489,6 +512,7 @@ class MllamaTextCrossAttention(nn.Module):
|
|
489
512
|
self.num_key_value_heads,
|
490
513
|
bias=False,
|
491
514
|
quant_config=quant_config,
|
515
|
+
prefix=add_prefix("qkv_proj", prefix),
|
492
516
|
)
|
493
517
|
self.o_proj = RowParallelLinear(
|
494
518
|
self.num_heads * self.head_dim,
|
@@ -496,6 +520,7 @@ class MllamaTextCrossAttention(nn.Module):
|
|
496
520
|
bias=False,
|
497
521
|
input_is_parallel=True,
|
498
522
|
quant_config=quant_config,
|
523
|
+
prefix=add_prefix("o_proj", prefix),
|
499
524
|
)
|
500
525
|
# vllm.model_executor.layers.layernorm.RMSNorm has precision issue,
|
501
526
|
# use huggingface's instead
|
@@ -510,6 +535,7 @@ class MllamaTextCrossAttention(nn.Module):
|
|
510
535
|
self.num_local_key_value_heads,
|
511
536
|
layer_id=layer_id,
|
512
537
|
is_cross_attention=True,
|
538
|
+
prefix=add_prefix("attn", prefix),
|
513
539
|
)
|
514
540
|
|
515
541
|
def forward(
|
@@ -551,6 +577,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
|
|
551
577
|
config: config_mllama.MllamaTextConfig,
|
552
578
|
layer_id: int,
|
553
579
|
quant_config: Optional[QuantizationConfig],
|
580
|
+
prefix: str = "",
|
554
581
|
) -> None:
|
555
582
|
super().__init__()
|
556
583
|
self.layer_id = layer_id
|
@@ -558,6 +585,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
|
|
558
585
|
config=config,
|
559
586
|
layer_id=layer_id,
|
560
587
|
quant_config=quant_config,
|
588
|
+
prefix=add_prefix("cross_attn", prefix),
|
561
589
|
)
|
562
590
|
|
563
591
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
@@ -568,6 +596,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
|
|
568
596
|
intermediate_size=config.intermediate_size,
|
569
597
|
hidden_act=config.hidden_act,
|
570
598
|
quant_config=quant_config,
|
599
|
+
prefix=add_prefix("mlp", prefix),
|
571
600
|
)
|
572
601
|
self.post_attention_layernorm = RMSNorm(
|
573
602
|
config.hidden_size, eps=config.rms_norm_eps
|
@@ -610,12 +639,15 @@ class MllamaTextModel(nn.Module):
|
|
610
639
|
self,
|
611
640
|
config: config_mllama.MllamaTextConfig,
|
612
641
|
quant_config: Optional[QuantizationConfig],
|
642
|
+
prefix: str = "",
|
613
643
|
):
|
614
644
|
super().__init__()
|
615
645
|
self.padding_id = config.pad_token_id
|
616
646
|
self.vocab_size = config.vocab_size
|
617
647
|
self.embed_tokens = VocabParallelEmbedding(
|
618
|
-
config.vocab_size + 8,
|
648
|
+
config.vocab_size + 8,
|
649
|
+
config.hidden_size,
|
650
|
+
prefix=add_prefix("embed_tokens", prefix),
|
619
651
|
)
|
620
652
|
self.cross_attention_layers = config.cross_attention_layers
|
621
653
|
|
@@ -624,14 +656,20 @@ class MllamaTextModel(nn.Module):
|
|
624
656
|
if layer_id in self.cross_attention_layers:
|
625
657
|
layers.append(
|
626
658
|
MllamaCrossAttentionDecoderLayer(
|
627
|
-
config,
|
659
|
+
config,
|
660
|
+
layer_id,
|
661
|
+
quant_config=quant_config,
|
662
|
+
prefix=add_prefix(f"layers.{layer_id}", prefix),
|
628
663
|
)
|
629
664
|
)
|
630
665
|
else:
|
631
666
|
# TODO: force LlamaDecoderLayer to config.attention_bias=False
|
632
667
|
layers.append(
|
633
668
|
LlamaDecoderLayer(
|
634
|
-
config,
|
669
|
+
config,
|
670
|
+
quant_config=quant_config,
|
671
|
+
layer_id=layer_id,
|
672
|
+
prefix=add_prefix(f"layers.{layer_id}", prefix),
|
635
673
|
)
|
636
674
|
)
|
637
675
|
|
@@ -687,16 +725,20 @@ class MllamaForCausalLM(nn.Module):
|
|
687
725
|
self,
|
688
726
|
config: config_mllama.MllamaTextConfig,
|
689
727
|
quant_config: Optional[QuantizationConfig],
|
728
|
+
prefix: str = "",
|
690
729
|
):
|
691
730
|
super().__init__()
|
692
731
|
self.vocab_size = config.vocab_size
|
693
|
-
self.model = MllamaTextModel(
|
732
|
+
self.model = MllamaTextModel(
|
733
|
+
config, quant_config, prefix=add_prefix("model", prefix)
|
734
|
+
)
|
694
735
|
self.lm_head = ParallelLMHead(
|
695
736
|
config.vocab_size,
|
696
737
|
config.hidden_size,
|
697
738
|
org_num_embeddings=config.vocab_size,
|
698
739
|
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
699
740
|
quant_config=quant_config,
|
741
|
+
prefix=add_prefix("lm_head", prefix),
|
700
742
|
)
|
701
743
|
|
702
744
|
def forward(
|
@@ -726,6 +768,7 @@ class MllamaForConditionalGeneration(nn.Module):
|
|
726
768
|
self,
|
727
769
|
config: config_mllama.MllamaConfig,
|
728
770
|
quant_config: Optional[QuantizationConfig] = None,
|
771
|
+
prefix: str = "",
|
729
772
|
):
|
730
773
|
super().__init__()
|
731
774
|
self.vocab_size = config.text_config.vocab_size
|
@@ -737,10 +780,13 @@ class MllamaForConditionalGeneration(nn.Module):
|
|
737
780
|
)
|
738
781
|
self.image_size = config.vision_config.image_size
|
739
782
|
|
740
|
-
self.vision_model = MllamaVisionModel(
|
783
|
+
self.vision_model = MllamaVisionModel(
|
784
|
+
config.vision_config, prefix=add_prefix("vision_model", prefix)
|
785
|
+
)
|
741
786
|
self.language_model = MllamaForCausalLM(
|
742
787
|
config.text_config,
|
743
788
|
quant_config=quant_config,
|
789
|
+
prefix=add_prefix("language_model", prefix),
|
744
790
|
)
|
745
791
|
self.multi_modal_projector = nn.Linear(
|
746
792
|
config.vision_config.vision_output_dim,
|
sglang/srt/models/olmo.py
CHANGED
@@ -38,7 +38,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
38
38
|
)
|
39
39
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
40
40
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
41
|
-
from sglang.srt.utils import make_layers
|
41
|
+
from sglang.srt.utils import add_prefix, make_layers
|
42
42
|
|
43
43
|
|
44
44
|
class OlmoAttention(nn.Module):
|
@@ -53,6 +53,7 @@ class OlmoAttention(nn.Module):
|
|
53
53
|
config: OlmoConfig,
|
54
54
|
layer_id: int = 0,
|
55
55
|
quant_config: Optional[QuantizationConfig] = None,
|
56
|
+
prefix: str = "",
|
56
57
|
):
|
57
58
|
super().__init__()
|
58
59
|
self.config = config
|
@@ -75,6 +76,7 @@ class OlmoAttention(nn.Module):
|
|
75
76
|
self.head_dim,
|
76
77
|
self.total_num_heads,
|
77
78
|
bias=config.attention_bias,
|
79
|
+
prefix=add_prefix("qkv_proj", prefix),
|
78
80
|
)
|
79
81
|
|
80
82
|
# Rotary embeddings.
|
@@ -91,6 +93,7 @@ class OlmoAttention(nn.Module):
|
|
91
93
|
self.scaling,
|
92
94
|
num_kv_heads=self.num_heads,
|
93
95
|
layer_id=layer_id,
|
96
|
+
prefix=add_prefix("attn", prefix),
|
94
97
|
)
|
95
98
|
|
96
99
|
# Attention output projection.
|
@@ -98,6 +101,7 @@ class OlmoAttention(nn.Module):
|
|
98
101
|
self.hidden_size,
|
99
102
|
self.hidden_size,
|
100
103
|
bias=config.attention_bias,
|
104
|
+
prefix=add_prefix("o_proj", prefix),
|
101
105
|
)
|
102
106
|
|
103
107
|
def forward(
|
@@ -127,6 +131,7 @@ class OlmoMLP(nn.Module):
|
|
127
131
|
self,
|
128
132
|
config: OlmoConfig,
|
129
133
|
quant_config: Optional[QuantizationConfig] = None,
|
134
|
+
prefix: str = "",
|
130
135
|
):
|
131
136
|
super().__init__()
|
132
137
|
self.config = config
|
@@ -139,6 +144,7 @@ class OlmoMLP(nn.Module):
|
|
139
144
|
[self.intermediate_size] * 2,
|
140
145
|
bias=False,
|
141
146
|
quant_config=quant_config,
|
147
|
+
prefix=add_prefix("gate_up_proj", prefix),
|
142
148
|
)
|
143
149
|
|
144
150
|
# Activation function.
|
@@ -150,6 +156,7 @@ class OlmoMLP(nn.Module):
|
|
150
156
|
self.hidden_size,
|
151
157
|
bias=False,
|
152
158
|
quant_config=quant_config,
|
159
|
+
prefix=add_prefix("down_proj", prefix),
|
153
160
|
)
|
154
161
|
|
155
162
|
def forward(
|
@@ -174,13 +181,23 @@ class OlmoDecoderLayer(nn.Module):
|
|
174
181
|
config: OlmoConfig,
|
175
182
|
layer_id: int = 0,
|
176
183
|
quant_config: Optional[QuantizationConfig] = None,
|
184
|
+
prefix: str = "",
|
177
185
|
):
|
178
186
|
super().__init__()
|
179
187
|
# Attention block.
|
180
|
-
self.self_attn = OlmoAttention(
|
188
|
+
self.self_attn = OlmoAttention(
|
189
|
+
config,
|
190
|
+
layer_id,
|
191
|
+
quant_config,
|
192
|
+
prefix=add_prefix("self_attn", prefix),
|
193
|
+
)
|
181
194
|
|
182
195
|
# MLP block.
|
183
|
-
self.mlp = OlmoMLP(
|
196
|
+
self.mlp = OlmoMLP(
|
197
|
+
config,
|
198
|
+
quant_config,
|
199
|
+
prefix=add_prefix("mlp", prefix),
|
200
|
+
)
|
184
201
|
|
185
202
|
# LayerNorm
|
186
203
|
self.input_layernorm = nn.LayerNorm(
|
@@ -213,13 +230,18 @@ class OlmoDecoderLayer(nn.Module):
|
|
213
230
|
class OlmoModel(nn.Module):
|
214
231
|
|
215
232
|
def __init__(
|
216
|
-
self,
|
233
|
+
self,
|
234
|
+
config: OlmoConfig,
|
235
|
+
quant_config: Optional[QuantizationConfig] = None,
|
236
|
+
prefix: str = "",
|
217
237
|
):
|
218
238
|
super().__init__()
|
219
239
|
self.config = config
|
220
240
|
|
221
241
|
self.embed_tokens = VocabParallelEmbedding(
|
222
|
-
config.vocab_size,
|
242
|
+
config.vocab_size,
|
243
|
+
config.hidden_size,
|
244
|
+
prefix=add_prefix("embed_tokens", prefix),
|
223
245
|
)
|
224
246
|
self.layers = make_layers(
|
225
247
|
config.num_hidden_layers,
|
@@ -227,7 +249,9 @@ class OlmoModel(nn.Module):
|
|
227
249
|
layer_id=idx,
|
228
250
|
config=config,
|
229
251
|
quant_config=quant_config,
|
252
|
+
prefix=prefix,
|
230
253
|
),
|
254
|
+
prefix=add_prefix("layers", prefix),
|
231
255
|
)
|
232
256
|
self.norm = nn.LayerNorm(
|
233
257
|
config.hidden_size, elementwise_affine=False, bias=False
|
@@ -275,10 +299,11 @@ class OlmoForCausalLM(nn.Module):
|
|
275
299
|
self,
|
276
300
|
config: OlmoConfig,
|
277
301
|
quant_config: Optional[QuantizationConfig] = None,
|
302
|
+
prefix: str = "",
|
278
303
|
):
|
279
304
|
super().__init__()
|
280
305
|
self.config = config
|
281
|
-
self.model = OlmoModel(config, quant_config)
|
306
|
+
self.model = OlmoModel(config, quant_config, prefix=add_prefix("model", prefix))
|
282
307
|
if config.tie_word_embeddings:
|
283
308
|
self.lm_head = self.model.embed_tokens
|
284
309
|
else:
|
@@ -288,6 +313,7 @@ class OlmoForCausalLM(nn.Module):
|
|
288
313
|
config.hidden_size,
|
289
314
|
org_num_embeddings=config.vocab_size,
|
290
315
|
quant_config=quant_config,
|
316
|
+
prefix=add_prefix("lm_head", prefix),
|
291
317
|
)
|
292
318
|
self.logits_processor = LogitsProcessor(config)
|
293
319
|
|
@@ -325,6 +351,8 @@ class OlmoForCausalLM(nn.Module):
|
|
325
351
|
# Models trained using ColossalAI may include these tensors in
|
326
352
|
# the checkpoint. Skip them.
|
327
353
|
continue
|
354
|
+
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
355
|
+
continue
|
328
356
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
329
357
|
if weight_name not in name:
|
330
358
|
continue
|
sglang/srt/models/olmo2.py
CHANGED
@@ -45,7 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
45
45
|
)
|
46
46
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
47
47
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
48
|
-
from sglang.srt.utils import make_layers
|
48
|
+
from sglang.srt.utils import add_prefix, make_layers
|
49
49
|
|
50
50
|
|
51
51
|
class Olmo2Attention(nn.Module):
|
@@ -60,28 +60,29 @@ class Olmo2Attention(nn.Module):
|
|
60
60
|
config: PretrainedConfig,
|
61
61
|
layer_id: int = 0,
|
62
62
|
quant_config: Optional[QuantizationConfig] = None,
|
63
|
+
prefix: str = "",
|
63
64
|
):
|
64
65
|
super().__init__()
|
65
66
|
self.config = config
|
66
67
|
self.hidden_size = config.hidden_size
|
67
|
-
tp_size = get_tensor_model_parallel_world_size()
|
68
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
68
69
|
self.total_num_heads = config.num_attention_heads
|
69
70
|
|
70
71
|
assert self.hidden_size % self.total_num_heads == 0
|
71
|
-
assert self.total_num_heads % tp_size == 0
|
72
|
+
assert self.total_num_heads % self.tp_size == 0
|
72
73
|
|
73
|
-
self.num_heads = self.total_num_heads // tp_size
|
74
|
+
self.num_heads = self.total_num_heads // self.tp_size
|
74
75
|
self.total_num_kv_heads = self.config.num_key_value_heads
|
75
76
|
|
76
|
-
if self.total_num_kv_heads >= tp_size:
|
77
|
+
if self.total_num_kv_heads >= self.tp_size:
|
77
78
|
# Number of KV heads is greater than TP size, so we partition
|
78
79
|
# the KV heads across multiple tensor parallel GPUs.
|
79
|
-
assert self.total_num_kv_heads % tp_size == 0
|
80
|
+
assert self.total_num_kv_heads % self.tp_size == 0
|
80
81
|
else:
|
81
82
|
# Number of KV heads is less than TP size, so we replicate
|
82
83
|
# the KV heads across multiple tensor parallel GPUs.
|
83
|
-
assert tp_size % self.total_num_kv_heads == 0
|
84
|
-
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
84
|
+
assert self.tp_size % self.total_num_kv_heads == 0
|
85
|
+
self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
|
85
86
|
|
86
87
|
self.head_dim = self.hidden_size // self.total_num_heads
|
87
88
|
self.max_position_embeddings = config.max_position_embeddings
|
@@ -93,6 +94,8 @@ class Olmo2Attention(nn.Module):
|
|
93
94
|
self.head_dim,
|
94
95
|
self.total_num_heads,
|
95
96
|
bias=config.attention_bias,
|
97
|
+
quant_config=quant_config,
|
98
|
+
prefix=add_prefix("qkv_proj", prefix),
|
96
99
|
)
|
97
100
|
self.tp_rank = get_tensor_model_parallel_rank()
|
98
101
|
|
@@ -115,6 +118,7 @@ class Olmo2Attention(nn.Module):
|
|
115
118
|
self.scaling,
|
116
119
|
num_kv_heads=self.num_kv_heads,
|
117
120
|
layer_id=layer_id,
|
121
|
+
prefix=add_prefix("attn", prefix),
|
118
122
|
)
|
119
123
|
|
120
124
|
# Attention output projection.
|
@@ -122,6 +126,8 @@ class Olmo2Attention(nn.Module):
|
|
122
126
|
self.head_dim * self.total_num_heads,
|
123
127
|
self.hidden_size,
|
124
128
|
bias=config.attention_bias,
|
129
|
+
quant_config=quant_config,
|
130
|
+
prefix=add_prefix("o_proj", prefix),
|
125
131
|
)
|
126
132
|
|
127
133
|
def _apply_qk_norm(
|
@@ -164,6 +170,7 @@ class Olmo2MLP(nn.Module):
|
|
164
170
|
self,
|
165
171
|
config: PretrainedConfig,
|
166
172
|
quant_config: Optional[QuantizationConfig] = None,
|
173
|
+
prefix: str = "",
|
167
174
|
):
|
168
175
|
super().__init__()
|
169
176
|
self.config = config
|
@@ -176,6 +183,7 @@ class Olmo2MLP(nn.Module):
|
|
176
183
|
[self.intermediate_size] * 2,
|
177
184
|
bias=False,
|
178
185
|
quant_config=quant_config,
|
186
|
+
prefix=add_prefix("gate_up_proj", prefix),
|
179
187
|
)
|
180
188
|
|
181
189
|
# Activation function.
|
@@ -187,6 +195,7 @@ class Olmo2MLP(nn.Module):
|
|
187
195
|
self.hidden_size,
|
188
196
|
bias=False,
|
189
197
|
quant_config=quant_config,
|
198
|
+
prefix=add_prefix("down_proj", prefix),
|
190
199
|
)
|
191
200
|
|
192
201
|
def forward(
|
@@ -211,13 +220,16 @@ class Olmo2DecoderLayer(nn.Module):
|
|
211
220
|
config: PretrainedConfig,
|
212
221
|
layer_id: int = 0,
|
213
222
|
quant_config: Optional[QuantizationConfig] = None,
|
223
|
+
prefix: str = "",
|
214
224
|
):
|
215
225
|
super().__init__()
|
216
226
|
# Attention block.
|
217
|
-
self.self_attn = Olmo2Attention(
|
227
|
+
self.self_attn = Olmo2Attention(
|
228
|
+
config, layer_id, quant_config, prefix=add_prefix("self_attn", prefix)
|
229
|
+
)
|
218
230
|
|
219
231
|
# MLP block.
|
220
|
-
self.mlp = Olmo2MLP(config, quant_config)
|
232
|
+
self.mlp = Olmo2MLP(config, quant_config, prefix=add_prefix("mlp", prefix))
|
221
233
|
|
222
234
|
# RMSNorm
|
223
235
|
self.post_attention_layernorm = RMSNorm(
|
@@ -254,12 +266,15 @@ class Olmo2Model(nn.Module):
|
|
254
266
|
self,
|
255
267
|
config: PretrainedConfig,
|
256
268
|
quant_config: Optional[QuantizationConfig] = None,
|
269
|
+
prefix: str = "",
|
257
270
|
):
|
258
271
|
super().__init__()
|
259
272
|
self.config = config
|
260
273
|
|
261
274
|
self.embed_tokens = VocabParallelEmbedding(
|
262
|
-
config.vocab_size,
|
275
|
+
config.vocab_size,
|
276
|
+
config.hidden_size,
|
277
|
+
prefix=add_prefix("embed_tokens", prefix),
|
263
278
|
)
|
264
279
|
self.layers = make_layers(
|
265
280
|
config.num_hidden_layers,
|
@@ -267,7 +282,9 @@ class Olmo2Model(nn.Module):
|
|
267
282
|
layer_id=idx,
|
268
283
|
config=config,
|
269
284
|
quant_config=quant_config,
|
285
|
+
prefix=prefix,
|
270
286
|
),
|
287
|
+
prefix=add_prefix("layers", prefix),
|
271
288
|
)
|
272
289
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
273
290
|
|
@@ -313,10 +330,13 @@ class Olmo2ForCausalLM(nn.Module):
|
|
313
330
|
self,
|
314
331
|
config: PretrainedConfig,
|
315
332
|
quant_config: Optional[QuantizationConfig] = None,
|
333
|
+
prefix: str = "",
|
316
334
|
):
|
317
335
|
super().__init__()
|
318
336
|
self.config = config
|
319
|
-
self.model = Olmo2Model(
|
337
|
+
self.model = Olmo2Model(
|
338
|
+
config, quant_config, prefix=add_prefix("model", prefix)
|
339
|
+
)
|
320
340
|
if config.tie_word_embeddings:
|
321
341
|
self.lm_head = self.model.embed_tokens
|
322
342
|
else:
|
@@ -326,6 +346,7 @@ class Olmo2ForCausalLM(nn.Module):
|
|
326
346
|
config.hidden_size,
|
327
347
|
org_num_embeddings=config.vocab_size,
|
328
348
|
quant_config=quant_config,
|
349
|
+
prefix=add_prefix("lm_head", prefix),
|
329
350
|
)
|
330
351
|
self.logits_processor = LogitsProcessor(config)
|
331
352
|
|
@@ -343,7 +364,7 @@ class Olmo2ForCausalLM(nn.Module):
|
|
343
364
|
input_embeds=input_embeds,
|
344
365
|
)
|
345
366
|
return self.logits_processor(
|
346
|
-
input_ids, hidden_states, self.lm_head
|
367
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
347
368
|
)
|
348
369
|
|
349
370
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
sglang/srt/models/olmoe.py
CHANGED
@@ -41,7 +41,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
41
41
|
)
|
42
42
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
43
43
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
44
|
-
from sglang.srt.utils import make_layers, print_warning_once
|
44
|
+
from sglang.srt.utils import add_prefix, make_layers, print_warning_once
|
45
45
|
|
46
46
|
|
47
47
|
class OlmoeMoE(nn.Module):
|
@@ -69,7 +69,11 @@ class OlmoeMoE(nn.Module):
|
|
69
69
|
|
70
70
|
# Gate always runs at half / full precision for now.
|
71
71
|
self.gate = ReplicatedLinear(
|
72
|
-
hidden_size,
|
72
|
+
hidden_size,
|
73
|
+
num_experts,
|
74
|
+
bias=False,
|
75
|
+
quant_config=None,
|
76
|
+
prefix=add_prefix("gate", prefix),
|
73
77
|
)
|
74
78
|
|
75
79
|
self.experts = FusedMoE(
|
@@ -81,6 +85,7 @@ class OlmoeMoE(nn.Module):
|
|
81
85
|
renormalize=False,
|
82
86
|
quant_config=quant_config,
|
83
87
|
tp_size=tp_size,
|
88
|
+
prefix=add_prefix("experts", prefix),
|
84
89
|
)
|
85
90
|
|
86
91
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
@@ -107,6 +112,7 @@ class OlmoeAttention(nn.Module):
|
|
107
112
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
108
113
|
max_position_embeddings: int = 4096,
|
109
114
|
quant_config: Optional[QuantizationConfig] = None,
|
115
|
+
prefix: str = "",
|
110
116
|
) -> None:
|
111
117
|
super().__init__()
|
112
118
|
self.hidden_size = hidden_size
|
@@ -138,6 +144,7 @@ class OlmoeAttention(nn.Module):
|
|
138
144
|
self.total_num_kv_heads,
|
139
145
|
bias=False,
|
140
146
|
quant_config=quant_config,
|
147
|
+
prefix=add_prefix("qkv_proj", prefix),
|
141
148
|
)
|
142
149
|
self.q_norm = RMSNorm(hidden_size, eps=1e-5)
|
143
150
|
self.k_norm = RMSNorm(hidden_size, eps=1e-5)
|
@@ -146,6 +153,7 @@ class OlmoeAttention(nn.Module):
|
|
146
153
|
hidden_size,
|
147
154
|
bias=False,
|
148
155
|
quant_config=quant_config,
|
156
|
+
prefix=add_prefix("o_proj", prefix),
|
149
157
|
)
|
150
158
|
|
151
159
|
self.rotary_emb = get_rope(
|
@@ -162,6 +170,7 @@ class OlmoeAttention(nn.Module):
|
|
162
170
|
self.scaling,
|
163
171
|
layer_id=layer_id,
|
164
172
|
num_kv_heads=self.num_kv_heads,
|
173
|
+
prefix=add_prefix("attn", prefix),
|
165
174
|
)
|
166
175
|
|
167
176
|
def forward(
|
@@ -186,6 +195,7 @@ class OlmoeDecoderLayer(nn.Module):
|
|
186
195
|
config: PretrainedConfig,
|
187
196
|
layer_id: int = 0,
|
188
197
|
quant_config: Optional[QuantizationConfig] = None,
|
198
|
+
prefix: str = "",
|
189
199
|
) -> None:
|
190
200
|
super().__init__()
|
191
201
|
self.hidden_size = config.hidden_size
|
@@ -202,6 +212,7 @@ class OlmoeDecoderLayer(nn.Module):
|
|
202
212
|
rope_scaling=rope_scaling,
|
203
213
|
max_position_embeddings=max_position_embeddings,
|
204
214
|
quant_config=quant_config,
|
215
|
+
prefix=add_prefix("self_attn", prefix),
|
205
216
|
)
|
206
217
|
|
207
218
|
self.mlp = OlmoeMoE(
|
@@ -210,6 +221,7 @@ class OlmoeDecoderLayer(nn.Module):
|
|
210
221
|
hidden_size=config.hidden_size,
|
211
222
|
intermediate_size=config.intermediate_size,
|
212
223
|
quant_config=quant_config,
|
224
|
+
prefix=add_prefix("mlp", prefix),
|
213
225
|
)
|
214
226
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
|
215
227
|
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
|
@@ -246,6 +258,7 @@ class OlmoeModel(nn.Module):
|
|
246
258
|
self,
|
247
259
|
config: PretrainedConfig,
|
248
260
|
quant_config: Optional[QuantizationConfig] = None,
|
261
|
+
prefix: str = "",
|
249
262
|
) -> None:
|
250
263
|
super().__init__()
|
251
264
|
self.padding_idx = config.pad_token_id
|
@@ -254,6 +267,7 @@ class OlmoeModel(nn.Module):
|
|
254
267
|
self.embed_tokens = VocabParallelEmbedding(
|
255
268
|
config.vocab_size,
|
256
269
|
config.hidden_size,
|
270
|
+
prefix=add_prefix("embed_tokens", prefix),
|
257
271
|
)
|
258
272
|
self.layers = make_layers(
|
259
273
|
config.num_hidden_layers,
|
@@ -261,7 +275,9 @@ class OlmoeModel(nn.Module):
|
|
261
275
|
config=config,
|
262
276
|
quant_config=quant_config,
|
263
277
|
layer_id=idx,
|
278
|
+
prefix=prefix,
|
264
279
|
),
|
280
|
+
prefix=add_prefix("layers", prefix),
|
265
281
|
)
|
266
282
|
self.norm = RMSNorm(config.hidden_size, eps=1e-5)
|
267
283
|
|
@@ -294,13 +310,19 @@ class OlmoeForCausalLM(nn.Module):
|
|
294
310
|
self,
|
295
311
|
config: PretrainedConfig,
|
296
312
|
quant_config: Optional[QuantizationConfig] = None,
|
313
|
+
prefix: str = "",
|
297
314
|
) -> None:
|
298
315
|
super().__init__()
|
299
316
|
self.config = config
|
300
317
|
self.quant_config = quant_config
|
301
|
-
self.model = OlmoeModel(
|
318
|
+
self.model = OlmoeModel(
|
319
|
+
config, quant_config, prefix=add_prefix("model", prefix)
|
320
|
+
)
|
302
321
|
self.lm_head = ParallelLMHead(
|
303
|
-
config.vocab_size,
|
322
|
+
config.vocab_size,
|
323
|
+
config.hidden_size,
|
324
|
+
quant_config=quant_config,
|
325
|
+
prefix=add_prefix("lm_head", prefix),
|
304
326
|
)
|
305
327
|
self.logits_processor = LogitsProcessor(config)
|
306
328
|
|