sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +119 -17
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +42 -7
- sglang/srt/conversation.py +9 -5
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +14 -4
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
- sglang/srt/disaggregation/mooncake/conn.py +286 -160
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/disaggregation/prefill.py +2 -0
- sglang/srt/distributed/parallel_state.py +15 -11
- sglang/srt/entrypoints/context.py +227 -0
- sglang/srt/entrypoints/engine.py +15 -9
- sglang/srt/entrypoints/harmony_utils.py +372 -0
- sglang/srt/entrypoints/http_server.py +74 -4
- sglang/srt/entrypoints/openai/protocol.py +218 -1
- sglang/srt/entrypoints/openai/serving_chat.py +41 -11
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +175 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/hf_transformers_utils.py +30 -3
- sglang/srt/jinja_template_utils.py +14 -1
- sglang/srt/layers/attention/aiter_backend.py +375 -115
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +52 -13
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
- sglang/srt/layers/attention/vision.py +22 -6
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +29 -14
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +3 -7
- sglang/srt/layers/moe/cutlass_moe.py +12 -3
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +135 -73
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +16 -4
- sglang/srt/layers/moe/utils.py +16 -0
- sglang/srt/layers/quantization/__init__.py +27 -3
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +3 -6
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +51 -10
- sglang/srt/layers/quantization/modelopt_quant.py +258 -68
- sglang/srt/layers/quantization/mxfp4.py +654 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +21 -12
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +506 -3
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +8 -3
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +60 -114
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +82 -62
- sglang/srt/lora/lora_registry.py +23 -11
- sglang/srt/lora/mem_pool.py +63 -68
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +75 -58
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +20 -8
- sglang/srt/managers/mm_utils.py +6 -13
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +61 -25
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +41 -19
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +35 -1
- sglang/srt/managers/tokenizer_manager.py +47 -30
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/mem_cache/allocator.py +61 -87
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +80 -22
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +34 -36
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +29 -9
- sglang/srt/model_executor/forward_batch_info.py +61 -19
- sglang/srt/model_executor/model_runner.py +148 -37
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +137 -59
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +38 -0
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +28 -16
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +1251 -0
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +0 -25
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen2_moe.py +6 -0
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/qwen3_moe.py +32 -6
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/step3_vl.py +9 -0
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/reasoning_parser.py +332 -37
- sglang/srt/server_args.py +186 -75
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +169 -9
- sglang/srt/utils.py +41 -5
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/runners.py +2 -2
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/test/test_utils.py +1 -1
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,426 @@
|
|
1
|
+
# Copyright 2023-2025 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
|
15
|
+
""" Inference-only Ernie4.5 model compatible with baidu/ERNIE-4.5-*-PT weights. """
|
16
|
+
|
17
|
+
from typing import Iterable, List, Optional, Tuple, Union
|
18
|
+
|
19
|
+
import torch
|
20
|
+
import torch.nn.functional as F
|
21
|
+
from torch import nn
|
22
|
+
from transformers.models.ernie4_5_moe.configuration_ernie4_5_moe import (
|
23
|
+
Ernie4_5_MoeConfig,
|
24
|
+
)
|
25
|
+
|
26
|
+
from sglang.srt.distributed import (
|
27
|
+
get_tensor_model_parallel_world_size,
|
28
|
+
tensor_model_parallel_all_reduce,
|
29
|
+
)
|
30
|
+
from sglang.srt.layers.communicator import enable_moe_dense_fully_dp
|
31
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
32
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
33
|
+
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
34
|
+
from sglang.srt.layers.moe.topk import TopK
|
35
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
36
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
37
|
+
ParallelLMHead,
|
38
|
+
VocabParallelEmbedding,
|
39
|
+
)
|
40
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
41
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
42
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
43
|
+
from sglang.srt.models.deepseek_v2 import DeepseekV2MLP as Ernie4MLP
|
44
|
+
from sglang.srt.models.llama import LlamaAttention as Ernie4Attention
|
45
|
+
from sglang.srt.utils import add_prefix, make_layers
|
46
|
+
|
47
|
+
|
48
|
+
class MoEGate(nn.Module):
|
49
|
+
def __init__(
|
50
|
+
self,
|
51
|
+
config,
|
52
|
+
prefix: str = "",
|
53
|
+
):
|
54
|
+
super().__init__()
|
55
|
+
self.weight = nn.Parameter(
|
56
|
+
torch.empty((config.moe_num_experts, config.hidden_size))
|
57
|
+
)
|
58
|
+
self.e_score_correction_bias = nn.Parameter(
|
59
|
+
torch.empty((1, config.moe_num_experts))
|
60
|
+
)
|
61
|
+
|
62
|
+
def forward(self, hidden_states):
|
63
|
+
logits = F.linear(hidden_states, self.weight, None)
|
64
|
+
return logits
|
65
|
+
|
66
|
+
|
67
|
+
class Ernie4Moe(nn.Module):
|
68
|
+
def __init__(
|
69
|
+
self,
|
70
|
+
config: Ernie4_5_MoeConfig,
|
71
|
+
layer_id: int,
|
72
|
+
quant_config: Optional[QuantizationConfig] = None,
|
73
|
+
prefix: str = "",
|
74
|
+
):
|
75
|
+
super().__init__()
|
76
|
+
self.layer_id = layer_id
|
77
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
78
|
+
self.moe_num_shared_experts = getattr(config, "moe_num_shared_experts", 0)
|
79
|
+
|
80
|
+
if config.hidden_act != "silu":
|
81
|
+
raise ValueError(
|
82
|
+
f"Unsupported activation: {config.hidden_act}. "
|
83
|
+
"Only silu is supported for now."
|
84
|
+
)
|
85
|
+
|
86
|
+
self.gate = MoEGate(config=config, prefix=add_prefix("gate", prefix))
|
87
|
+
|
88
|
+
self.topk = TopK(
|
89
|
+
top_k=config.moe_k,
|
90
|
+
renormalize=True,
|
91
|
+
use_grouped_topk=False,
|
92
|
+
correction_bias=self.gate.e_score_correction_bias,
|
93
|
+
)
|
94
|
+
|
95
|
+
self.experts = get_moe_impl_class()(
|
96
|
+
num_experts=config.moe_num_experts,
|
97
|
+
top_k=config.moe_k,
|
98
|
+
hidden_size=config.hidden_size,
|
99
|
+
intermediate_size=config.moe_intermediate_size,
|
100
|
+
layer_id=self.layer_id,
|
101
|
+
quant_config=quant_config,
|
102
|
+
prefix=add_prefix("experts", prefix),
|
103
|
+
)
|
104
|
+
|
105
|
+
if self.moe_num_shared_experts > 0:
|
106
|
+
intermediate_size = (
|
107
|
+
config.moe_intermediate_size * config.moe_num_shared_experts
|
108
|
+
)
|
109
|
+
# disable tp for shared experts when enable deepep moe
|
110
|
+
self.shared_experts = Ernie4MLP(
|
111
|
+
hidden_size=config.hidden_size,
|
112
|
+
intermediate_size=intermediate_size,
|
113
|
+
hidden_act=config.hidden_act,
|
114
|
+
quant_config=quant_config,
|
115
|
+
reduce_results=False,
|
116
|
+
prefix=add_prefix("shared_experts", prefix),
|
117
|
+
)
|
118
|
+
|
119
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
120
|
+
return self.forward_normal(hidden_states)
|
121
|
+
|
122
|
+
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
123
|
+
shared_output = (
|
124
|
+
self.shared_experts(hidden_states)
|
125
|
+
if self.moe_num_shared_experts > 0
|
126
|
+
else None
|
127
|
+
)
|
128
|
+
# router_logits: (num_tokens, n_experts)
|
129
|
+
router_logits = self.gate(hidden_states)
|
130
|
+
topk_output = self.topk(hidden_states, router_logits)
|
131
|
+
final_hidden_states = self.experts(
|
132
|
+
hidden_states=hidden_states, topk_output=topk_output
|
133
|
+
)
|
134
|
+
if shared_output is not None:
|
135
|
+
final_hidden_states = final_hidden_states + shared_output
|
136
|
+
if self.tp_size > 1:
|
137
|
+
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
138
|
+
return final_hidden_states
|
139
|
+
|
140
|
+
|
141
|
+
class Ernie4DecoderLayer(nn.Module):
|
142
|
+
"""A single transformer layer.
|
143
|
+
|
144
|
+
Transformer layer takes input with size [s, b, h] and returns an
|
145
|
+
output of the same size.
|
146
|
+
"""
|
147
|
+
|
148
|
+
def __init__(
|
149
|
+
self,
|
150
|
+
config,
|
151
|
+
layer_id: int,
|
152
|
+
quant_config: Optional[QuantizationConfig] = None,
|
153
|
+
prefix: str = "",
|
154
|
+
is_mtp: bool = False,
|
155
|
+
):
|
156
|
+
super().__init__()
|
157
|
+
rope_theta = getattr(config, "rope_theta", 10000)
|
158
|
+
rope_scaling = getattr(config, "rope_scaling", None)
|
159
|
+
rope_is_neox_style = getattr(config, "rope_is_neox_style", False)
|
160
|
+
# Self attention.
|
161
|
+
self.self_attn = Ernie4Attention(
|
162
|
+
config=config,
|
163
|
+
hidden_size=config.hidden_size,
|
164
|
+
num_heads=config.num_attention_heads,
|
165
|
+
num_kv_heads=config.num_key_value_heads,
|
166
|
+
layer_id=layer_id,
|
167
|
+
rope_theta=rope_theta,
|
168
|
+
rope_scaling=rope_scaling,
|
169
|
+
rope_is_neox_style=rope_is_neox_style,
|
170
|
+
max_position_embeddings=config.max_position_embeddings,
|
171
|
+
quant_config=quant_config,
|
172
|
+
prefix=add_prefix("self_attn", prefix),
|
173
|
+
bias=config.use_bias,
|
174
|
+
)
|
175
|
+
moe_layer_start_index = getattr(
|
176
|
+
config, "moe_layer_start_index", config.num_hidden_layers
|
177
|
+
)
|
178
|
+
moe_layer_end_index = getattr(
|
179
|
+
config, "moe_layer_end_index", config.num_hidden_layers - 1
|
180
|
+
)
|
181
|
+
# MLP
|
182
|
+
if (not is_mtp) and (
|
183
|
+
moe_layer_start_index <= layer_id <= moe_layer_end_index
|
184
|
+
and (layer_id - moe_layer_start_index) % config.moe_layer_interval == 0
|
185
|
+
):
|
186
|
+
self.mlp = Ernie4Moe(
|
187
|
+
config=config,
|
188
|
+
layer_id=layer_id,
|
189
|
+
quant_config=quant_config,
|
190
|
+
prefix=add_prefix("mlp", prefix),
|
191
|
+
)
|
192
|
+
else:
|
193
|
+
if enable_moe_dense_fully_dp():
|
194
|
+
mlp_tp_rank, mlp_tp_size = 0, 1
|
195
|
+
else:
|
196
|
+
mlp_tp_rank, mlp_tp_size = None, None
|
197
|
+
self.mlp = Ernie4MLP(
|
198
|
+
hidden_size=config.hidden_size,
|
199
|
+
intermediate_size=config.intermediate_size,
|
200
|
+
hidden_act=config.hidden_act,
|
201
|
+
quant_config=quant_config,
|
202
|
+
prefix=add_prefix("mlp", prefix),
|
203
|
+
tp_rank=mlp_tp_rank,
|
204
|
+
tp_size=mlp_tp_size,
|
205
|
+
)
|
206
|
+
|
207
|
+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
208
|
+
self.post_attention_layernorm = RMSNorm(
|
209
|
+
config.hidden_size, eps=config.rms_norm_eps
|
210
|
+
)
|
211
|
+
|
212
|
+
def forward(
|
213
|
+
self,
|
214
|
+
positions: torch.Tensor,
|
215
|
+
hidden_states: torch.Tensor,
|
216
|
+
forward_batch: ForwardBatch,
|
217
|
+
residual: Optional[torch.Tensor],
|
218
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
219
|
+
# Self Attention
|
220
|
+
if residual is None:
|
221
|
+
residual = hidden_states
|
222
|
+
hidden_states = self.input_layernorm(hidden_states)
|
223
|
+
else:
|
224
|
+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
225
|
+
hidden_states = self.self_attn(
|
226
|
+
positions=positions,
|
227
|
+
hidden_states=hidden_states,
|
228
|
+
forward_batch=forward_batch,
|
229
|
+
)
|
230
|
+
|
231
|
+
# Fully Connected
|
232
|
+
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
233
|
+
hidden_states = self.mlp(hidden_states)
|
234
|
+
|
235
|
+
return hidden_states, residual
|
236
|
+
|
237
|
+
|
238
|
+
class Ernie4Model(nn.Module):
|
239
|
+
def __init__(
|
240
|
+
self,
|
241
|
+
config: Ernie4_5_MoeConfig,
|
242
|
+
quant_config: Optional[QuantizationConfig] = None,
|
243
|
+
prefix: str = "",
|
244
|
+
) -> None:
|
245
|
+
super().__init__()
|
246
|
+
self.config = config
|
247
|
+
self.embed_tokens = VocabParallelEmbedding(
|
248
|
+
config.vocab_size,
|
249
|
+
config.hidden_size,
|
250
|
+
quant_config=quant_config,
|
251
|
+
prefix=add_prefix("embed_tokens", prefix),
|
252
|
+
)
|
253
|
+
self.layers = make_layers(
|
254
|
+
config.num_hidden_layers,
|
255
|
+
lambda idx, prefix: Ernie4DecoderLayer(
|
256
|
+
config=config, layer_id=idx, quant_config=quant_config, prefix=prefix
|
257
|
+
),
|
258
|
+
prefix="model.layers",
|
259
|
+
)
|
260
|
+
|
261
|
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
262
|
+
|
263
|
+
@torch.no_grad()
|
264
|
+
def forward(
|
265
|
+
self,
|
266
|
+
input_ids: torch.Tensor,
|
267
|
+
positions: torch.Tensor,
|
268
|
+
forward_batch: ForwardBatch,
|
269
|
+
input_embeds: torch.Tensor = None,
|
270
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
|
271
|
+
if input_embeds is None:
|
272
|
+
hidden_states = self.embed_tokens(input_ids)
|
273
|
+
else:
|
274
|
+
hidden_states = input_embeds
|
275
|
+
residual = None
|
276
|
+
for layer in self.layers:
|
277
|
+
hidden_states, residual = layer(
|
278
|
+
positions,
|
279
|
+
hidden_states,
|
280
|
+
forward_batch,
|
281
|
+
residual,
|
282
|
+
)
|
283
|
+
hidden_states, _ = self.norm(hidden_states, residual)
|
284
|
+
|
285
|
+
return hidden_states
|
286
|
+
|
287
|
+
|
288
|
+
class Ernie4_5_ForCausalLM(nn.Module):
|
289
|
+
packed_modules_mapping = {
|
290
|
+
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
291
|
+
"gate_up_proj": ["gate_proj", "up_proj"],
|
292
|
+
}
|
293
|
+
stacked_params_mapping = [
|
294
|
+
# (param_name, weight_name, shard_id)
|
295
|
+
(".qkv_proj", ".q_proj", "q"),
|
296
|
+
(".qkv_proj", ".k_proj", "k"),
|
297
|
+
(".qkv_proj", ".v_proj", "v"),
|
298
|
+
(".gate_up_proj", ".gate_proj", 0),
|
299
|
+
(".gate_up_proj", ".up_proj", 1),
|
300
|
+
]
|
301
|
+
|
302
|
+
def __init__(
|
303
|
+
self,
|
304
|
+
config: Ernie4_5_MoeConfig,
|
305
|
+
quant_config: Optional[QuantizationConfig] = None,
|
306
|
+
prefix: str = "",
|
307
|
+
):
|
308
|
+
super().__init__()
|
309
|
+
self.config: Ernie4_5_MoeConfig = config
|
310
|
+
self.quant_config = quant_config
|
311
|
+
self.model = Ernie4Model(config, quant_config, add_prefix("model", prefix))
|
312
|
+
if config.tie_word_embeddings:
|
313
|
+
self.lm_head = self.model.embed_tokens
|
314
|
+
else:
|
315
|
+
self.lm_head = ParallelLMHead(
|
316
|
+
config.vocab_size,
|
317
|
+
config.hidden_size,
|
318
|
+
quant_config=quant_config,
|
319
|
+
prefix="lm_head",
|
320
|
+
)
|
321
|
+
self.logits_processor = LogitsProcessor(config)
|
322
|
+
|
323
|
+
@torch.no_grad()
|
324
|
+
def forward(
|
325
|
+
self,
|
326
|
+
input_ids: torch.Tensor,
|
327
|
+
positions: torch.Tensor,
|
328
|
+
forward_batch: ForwardBatch,
|
329
|
+
) -> torch.Tensor:
|
330
|
+
hidden_states = self.model(input_ids, positions, forward_batch)
|
331
|
+
return self.logits_processor(
|
332
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
333
|
+
)
|
334
|
+
|
335
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
336
|
+
params_dict = dict(self.named_parameters())
|
337
|
+
for name, loaded_weight in weights:
|
338
|
+
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
339
|
+
continue
|
340
|
+
for param_name, weight_name, shard_id in self.stacked_params_mapping:
|
341
|
+
if weight_name not in name:
|
342
|
+
continue
|
343
|
+
name = name.replace(weight_name, param_name)
|
344
|
+
param = params_dict[name]
|
345
|
+
weight_loader = param.weight_loader
|
346
|
+
weight_loader(param, loaded_weight, shard_id)
|
347
|
+
break
|
348
|
+
else:
|
349
|
+
if name in params_dict.keys():
|
350
|
+
param = params_dict[name]
|
351
|
+
weight_loader = getattr(
|
352
|
+
param, "weight_loader", default_weight_loader
|
353
|
+
)
|
354
|
+
weight_loader(param, loaded_weight)
|
355
|
+
else:
|
356
|
+
raise KeyError(f"Parameter '{name}' not found in model.")
|
357
|
+
|
358
|
+
def get_embed_and_head(self):
|
359
|
+
return self.model.embed_tokens.weight, self.lm_head.weight
|
360
|
+
|
361
|
+
|
362
|
+
class Ernie4_5_MoeForCausalLM(Ernie4_5_ForCausalLM):
|
363
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
364
|
+
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
|
365
|
+
ckpt_gate_proj_name="gate_proj",
|
366
|
+
ckpt_down_proj_name="down_proj",
|
367
|
+
ckpt_up_proj_name="up_proj",
|
368
|
+
num_experts=self.config.moe_num_experts,
|
369
|
+
)
|
370
|
+
params_dict = dict(self.named_parameters())
|
371
|
+
for name, loaded_weight in weights:
|
372
|
+
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
373
|
+
continue
|
374
|
+
if name.startswith("model.mtp_"):
|
375
|
+
continue
|
376
|
+
if "moe_statics.e_score_correction_bias" in name:
|
377
|
+
name = name.replace("moe_statics", "gate")
|
378
|
+
for param_name, weight_name, shard_id in self.stacked_params_mapping:
|
379
|
+
if weight_name not in name:
|
380
|
+
continue
|
381
|
+
# We have mlp.experts[0].gate_proj in the checkpoint.
|
382
|
+
# Since we handle the experts below in expert_params_mapping,
|
383
|
+
# we need to skip here BEFORE we update the name, otherwise
|
384
|
+
# name will be updated to mlp.experts[0].gate_up_proj, which
|
385
|
+
# will then be updated below in expert_params_mapping
|
386
|
+
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
387
|
+
if ("mlp.experts." in name) and name not in params_dict:
|
388
|
+
continue
|
389
|
+
name = name.replace(weight_name, param_name)
|
390
|
+
param = params_dict[name]
|
391
|
+
weight_loader = param.weight_loader
|
392
|
+
weight_loader(param, loaded_weight, shard_id)
|
393
|
+
break
|
394
|
+
else:
|
395
|
+
for mapping in expert_params_mapping:
|
396
|
+
param_name, weight_name, expert_id, shard_id = mapping
|
397
|
+
if weight_name not in name:
|
398
|
+
continue
|
399
|
+
name = name.replace(weight_name, param_name)
|
400
|
+
if name in params_dict.keys():
|
401
|
+
param = params_dict[name]
|
402
|
+
weight_loader = param.weight_loader
|
403
|
+
weight_loader(
|
404
|
+
param,
|
405
|
+
loaded_weight,
|
406
|
+
name,
|
407
|
+
shard_id=shard_id,
|
408
|
+
expert_id=expert_id,
|
409
|
+
)
|
410
|
+
else:
|
411
|
+
raise KeyError(
|
412
|
+
f"Parameter '{name}'(replaced) not found in model."
|
413
|
+
)
|
414
|
+
break
|
415
|
+
else:
|
416
|
+
if name in params_dict.keys():
|
417
|
+
param = params_dict[name]
|
418
|
+
weight_loader = getattr(
|
419
|
+
param, "weight_loader", default_weight_loader
|
420
|
+
)
|
421
|
+
weight_loader(param, loaded_weight)
|
422
|
+
else:
|
423
|
+
raise KeyError(f"Parameter '{name}' not found in model.")
|
424
|
+
|
425
|
+
|
426
|
+
EntryClass = [Ernie4_5_MoeForCausalLM, Ernie4_5_ForCausalLM]
|
@@ -0,0 +1,203 @@
|
|
1
|
+
# Copyright 2023-2025 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
|
15
|
+
""" Ernie4.5 MTP model compatible with baidu/ERNIE-4.5-*-PT weights. """
|
16
|
+
|
17
|
+
from typing import Iterable, Optional, Tuple
|
18
|
+
|
19
|
+
import torch
|
20
|
+
from torch import nn
|
21
|
+
from transformers.models.ernie4_5_moe.configuration_ernie4_5_moe import (
|
22
|
+
Ernie4_5_MoeConfig,
|
23
|
+
)
|
24
|
+
|
25
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
26
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
27
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
28
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
29
|
+
ParallelLMHead,
|
30
|
+
VocabParallelEmbedding,
|
31
|
+
)
|
32
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
33
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
34
|
+
from sglang.srt.models.ernie4 import Ernie4_5_ForCausalLM, Ernie4DecoderLayer
|
35
|
+
from sglang.srt.utils import add_prefix
|
36
|
+
|
37
|
+
|
38
|
+
class Ernie4ModelMTP(nn.Module):
|
39
|
+
def __init__(
|
40
|
+
self,
|
41
|
+
config: Ernie4_5_MoeConfig,
|
42
|
+
layer_id: int,
|
43
|
+
prefix: str,
|
44
|
+
quant_config: Optional[QuantizationConfig] = None,
|
45
|
+
) -> None:
|
46
|
+
super().__init__()
|
47
|
+
|
48
|
+
self.embed_tokens = VocabParallelEmbedding(
|
49
|
+
config.vocab_size,
|
50
|
+
config.hidden_size,
|
51
|
+
quant_config=quant_config,
|
52
|
+
prefix=add_prefix("embed_tokens", prefix),
|
53
|
+
)
|
54
|
+
self.mtp_emb_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
55
|
+
self.mtp_hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
56
|
+
self.mtp_linear_proj = nn.Linear(
|
57
|
+
config.hidden_size * 2, config.hidden_size, bias=config.use_bias
|
58
|
+
)
|
59
|
+
self.mtp_block = Ernie4DecoderLayer(
|
60
|
+
config=config,
|
61
|
+
layer_id=layer_id,
|
62
|
+
quant_config=quant_config,
|
63
|
+
prefix=add_prefix("mtp_block", prefix),
|
64
|
+
is_mtp=True,
|
65
|
+
)
|
66
|
+
|
67
|
+
def forward(
|
68
|
+
self,
|
69
|
+
input_ids: torch.Tensor,
|
70
|
+
positions: torch.Tensor,
|
71
|
+
forward_batch: ForwardBatch,
|
72
|
+
input_embeds: torch.Tensor = None,
|
73
|
+
) -> torch.Tensor:
|
74
|
+
if input_embeds is None:
|
75
|
+
hidden_states = self.embed_tokens(input_ids)
|
76
|
+
else:
|
77
|
+
hidden_states = input_embeds
|
78
|
+
# masking inputs at position 0, as not needed by MTP
|
79
|
+
hidden_states[positions == 0] = 0
|
80
|
+
|
81
|
+
hidden_states = self.mtp_linear_proj(
|
82
|
+
torch.cat(
|
83
|
+
(
|
84
|
+
self.mtp_emb_norm(hidden_states),
|
85
|
+
self.mtp_hidden_norm(forward_batch.spec_info.hidden_states),
|
86
|
+
),
|
87
|
+
dim=-1,
|
88
|
+
)
|
89
|
+
)
|
90
|
+
residual = None
|
91
|
+
hidden_states, residual = self.mtp_block(
|
92
|
+
positions=positions,
|
93
|
+
hidden_states=hidden_states,
|
94
|
+
forward_batch=forward_batch,
|
95
|
+
residual=residual,
|
96
|
+
)
|
97
|
+
hidden_states = residual + hidden_states
|
98
|
+
return hidden_states
|
99
|
+
|
100
|
+
|
101
|
+
class Ernie4_5_MoeForCausalLMMTP(nn.Module):
|
102
|
+
def __init__(
|
103
|
+
self,
|
104
|
+
config: Ernie4_5_MoeConfig,
|
105
|
+
quant_config: Optional[QuantizationConfig] = None,
|
106
|
+
prefix: str = "",
|
107
|
+
mtp_layer_id: int = 0,
|
108
|
+
) -> None:
|
109
|
+
nn.Module.__init__(self)
|
110
|
+
self.config = config
|
111
|
+
self.mtp_layer_id = mtp_layer_id
|
112
|
+
|
113
|
+
self.model = Ernie4ModelMTP(
|
114
|
+
config=config,
|
115
|
+
layer_id=self.mtp_layer_id,
|
116
|
+
quant_config=quant_config,
|
117
|
+
prefix=add_prefix("model", prefix),
|
118
|
+
)
|
119
|
+
|
120
|
+
if config.tie_word_embeddings:
|
121
|
+
self.lm_head = self.model.embed_tokens
|
122
|
+
else:
|
123
|
+
self.lm_head = ParallelLMHead(
|
124
|
+
config.vocab_size,
|
125
|
+
config.hidden_size,
|
126
|
+
quant_config=quant_config,
|
127
|
+
prefix="lm_head",
|
128
|
+
)
|
129
|
+
self.logits_processor = LogitsProcessor(config)
|
130
|
+
|
131
|
+
@torch.no_grad()
|
132
|
+
def forward(
|
133
|
+
self,
|
134
|
+
input_ids: torch.Tensor,
|
135
|
+
positions: torch.Tensor,
|
136
|
+
forward_batch: ForwardBatch,
|
137
|
+
) -> torch.Tensor:
|
138
|
+
hidden_states = self.model(input_ids, positions, forward_batch)
|
139
|
+
return self.logits_processor(
|
140
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
141
|
+
)
|
142
|
+
|
143
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
144
|
+
mtp_layer_found = False
|
145
|
+
mtp_weight_patterns = [
|
146
|
+
f"mtp_block.{self.mtp_layer_id}",
|
147
|
+
f"mtp_emb_norm.{self.mtp_layer_id}",
|
148
|
+
f"mtp_hidden_norm.{self.mtp_layer_id}",
|
149
|
+
f"mtp_linear_proj.{self.mtp_layer_id}",
|
150
|
+
]
|
151
|
+
params_dict = dict(self.named_parameters())
|
152
|
+
for name, loaded_weight in weights:
|
153
|
+
# Only name matched patterns should be loaded
|
154
|
+
for layer_pattern in mtp_weight_patterns:
|
155
|
+
if layer_pattern in name:
|
156
|
+
mtp_layer_found = True
|
157
|
+
break
|
158
|
+
else:
|
159
|
+
continue
|
160
|
+
# But strip mtp_layer_id before loading, because each MTP layer is a MTP model.
|
161
|
+
name = name.replace(f".{self.mtp_layer_id}.", ".")
|
162
|
+
for (
|
163
|
+
param_name,
|
164
|
+
weight_name,
|
165
|
+
shard_id,
|
166
|
+
) in Ernie4_5_ForCausalLM.stacked_params_mapping:
|
167
|
+
if weight_name not in name:
|
168
|
+
continue
|
169
|
+
name = name.replace(weight_name, param_name)
|
170
|
+
param = params_dict[name]
|
171
|
+
weight_loader = param.weight_loader
|
172
|
+
weight_loader(param, loaded_weight, shard_id)
|
173
|
+
break
|
174
|
+
else:
|
175
|
+
if name in params_dict.keys():
|
176
|
+
param = params_dict[name]
|
177
|
+
weight_loader = getattr(
|
178
|
+
param, "weight_loader", default_weight_loader
|
179
|
+
)
|
180
|
+
weight_loader(param, loaded_weight)
|
181
|
+
else:
|
182
|
+
raise KeyError(f"Parameter '{name}' not found in MTP model.")
|
183
|
+
if not mtp_layer_found:
|
184
|
+
raise KeyError(
|
185
|
+
f"MTP layers 'mtp_*.{self.mtp_layer_id}.*' not found in weights."
|
186
|
+
)
|
187
|
+
|
188
|
+
def get_embed_and_head(self):
|
189
|
+
return self.model.embed_tokens.weight, self.lm_head.weight
|
190
|
+
|
191
|
+
def set_embed_and_head(self, embed, head):
|
192
|
+
del self.model.embed_tokens.weight
|
193
|
+
self.model.embed_tokens.weight = embed
|
194
|
+
if self.config.tie_word_embeddings:
|
195
|
+
self.lm_head = self.model.embed_tokens
|
196
|
+
else:
|
197
|
+
del self.lm_head.weight
|
198
|
+
self.lm_head.weight = head
|
199
|
+
torch.cuda.empty_cache()
|
200
|
+
torch.cuda.synchronize()
|
201
|
+
|
202
|
+
|
203
|
+
EntryClass = [Ernie4_5_MoeForCausalLMMTP]
|
sglang/srt/models/gemma2.py
CHANGED
@@ -432,40 +432,6 @@ class Gemma2ForCausalLM(nn.Module):
|
|
432
432
|
|
433
433
|
return result
|
434
434
|
|
435
|
-
def get_hidden_dim(self, module_name):
|
436
|
-
# return input_dim, output_dim
|
437
|
-
if module_name in ["q_proj", "qkv_proj"]:
|
438
|
-
return (
|
439
|
-
self.config.hidden_size,
|
440
|
-
self.config.head_dim * self.config.num_attention_heads,
|
441
|
-
)
|
442
|
-
elif module_name in ["o_proj"]:
|
443
|
-
return (
|
444
|
-
self.config.head_dim * self.config.num_attention_heads,
|
445
|
-
self.config.hidden_size,
|
446
|
-
)
|
447
|
-
elif module_name in ["kv_proj"]:
|
448
|
-
return (
|
449
|
-
self.config.hidden_size,
|
450
|
-
self.config.head_dim * self.config.num_key_value_heads,
|
451
|
-
)
|
452
|
-
elif module_name == "gate_up_proj":
|
453
|
-
return self.config.hidden_size, self.config.intermediate_size
|
454
|
-
elif module_name == "down_proj":
|
455
|
-
return self.config.intermediate_size, self.config.hidden_size
|
456
|
-
else:
|
457
|
-
raise NotImplementedError()
|
458
|
-
|
459
|
-
def get_module_name(self, name):
|
460
|
-
params_mapping = {
|
461
|
-
"q_proj": "qkv_proj",
|
462
|
-
"k_proj": "qkv_proj",
|
463
|
-
"v_proj": "qkv_proj",
|
464
|
-
"gate_proj": "gate_up_proj",
|
465
|
-
"up_proj": "gate_up_proj",
|
466
|
-
}
|
467
|
-
return params_mapping.get(name, name)
|
468
|
-
|
469
435
|
def get_attention_sliding_window_size(self):
|
470
436
|
return get_attention_sliding_window_size(self.config)
|
471
437
|
|