sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +119 -17
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +42 -7
  6. sglang/srt/conversation.py +9 -5
  7. sglang/srt/disaggregation/base/conn.py +5 -2
  8. sglang/srt/disaggregation/decode.py +14 -4
  9. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
  10. sglang/srt/disaggregation/mooncake/conn.py +286 -160
  11. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  12. sglang/srt/disaggregation/prefill.py +2 -0
  13. sglang/srt/distributed/parallel_state.py +15 -11
  14. sglang/srt/entrypoints/context.py +227 -0
  15. sglang/srt/entrypoints/engine.py +15 -9
  16. sglang/srt/entrypoints/harmony_utils.py +372 -0
  17. sglang/srt/entrypoints/http_server.py +74 -4
  18. sglang/srt/entrypoints/openai/protocol.py +218 -1
  19. sglang/srt/entrypoints/openai/serving_chat.py +41 -11
  20. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  21. sglang/srt/entrypoints/openai/tool_server.py +175 -0
  22. sglang/srt/entrypoints/tool.py +87 -0
  23. sglang/srt/eplb/expert_location.py +5 -1
  24. sglang/srt/function_call/ebnf_composer.py +1 -0
  25. sglang/srt/function_call/function_call_parser.py +2 -0
  26. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  27. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  28. sglang/srt/function_call/kimik2_detector.py +3 -3
  29. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  30. sglang/srt/hf_transformers_utils.py +30 -3
  31. sglang/srt/jinja_template_utils.py +14 -1
  32. sglang/srt/layers/attention/aiter_backend.py +375 -115
  33. sglang/srt/layers/attention/ascend_backend.py +3 -0
  34. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  35. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  36. sglang/srt/layers/attention/flashinfer_backend.py +52 -13
  37. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  38. sglang/srt/layers/attention/triton_backend.py +85 -14
  39. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  41. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  42. sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
  43. sglang/srt/layers/attention/vision.py +22 -6
  44. sglang/srt/layers/attention/wave_backend.py +627 -0
  45. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  46. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  47. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  48. sglang/srt/layers/communicator.py +29 -14
  49. sglang/srt/layers/dp_attention.py +12 -0
  50. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  51. sglang/srt/layers/linear.py +3 -7
  52. sglang/srt/layers/moe/cutlass_moe.py +12 -3
  53. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  54. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  55. sglang/srt/layers/moe/ep_moe/layer.py +135 -73
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  59. sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
  60. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
  61. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  62. sglang/srt/layers/moe/topk.py +16 -4
  63. sglang/srt/layers/moe/utils.py +16 -0
  64. sglang/srt/layers/quantization/__init__.py +27 -3
  65. sglang/srt/layers/quantization/fp4.py +557 -0
  66. sglang/srt/layers/quantization/fp8.py +3 -6
  67. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  68. sglang/srt/layers/quantization/fp8_utils.py +51 -10
  69. sglang/srt/layers/quantization/modelopt_quant.py +258 -68
  70. sglang/srt/layers/quantization/mxfp4.py +654 -0
  71. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  72. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  73. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  74. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  75. sglang/srt/layers/quantization/quark/utils.py +107 -0
  76. sglang/srt/layers/quantization/unquant.py +60 -6
  77. sglang/srt/layers/quantization/w4afp8.py +21 -12
  78. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  79. sglang/srt/layers/rotary_embedding.py +506 -3
  80. sglang/srt/layers/utils.py +9 -0
  81. sglang/srt/layers/vocab_parallel_embedding.py +8 -3
  82. sglang/srt/lora/backend/base_backend.py +3 -23
  83. sglang/srt/lora/layers.py +60 -114
  84. sglang/srt/lora/lora.py +17 -62
  85. sglang/srt/lora/lora_manager.py +82 -62
  86. sglang/srt/lora/lora_registry.py +23 -11
  87. sglang/srt/lora/mem_pool.py +63 -68
  88. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  89. sglang/srt/lora/utils.py +25 -58
  90. sglang/srt/managers/cache_controller.py +75 -58
  91. sglang/srt/managers/detokenizer_manager.py +1 -1
  92. sglang/srt/managers/io_struct.py +20 -8
  93. sglang/srt/managers/mm_utils.py +6 -13
  94. sglang/srt/managers/multimodal_processor.py +1 -1
  95. sglang/srt/managers/schedule_batch.py +61 -25
  96. sglang/srt/managers/schedule_policy.py +6 -6
  97. sglang/srt/managers/scheduler.py +41 -19
  98. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  99. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  100. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  101. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  102. sglang/srt/managers/template_manager.py +35 -1
  103. sglang/srt/managers/tokenizer_manager.py +47 -30
  104. sglang/srt/managers/tp_worker.py +3 -0
  105. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  106. sglang/srt/mem_cache/allocator.py +61 -87
  107. sglang/srt/mem_cache/hicache_storage.py +1 -1
  108. sglang/srt/mem_cache/hiradix_cache.py +80 -22
  109. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  110. sglang/srt/mem_cache/memory_pool_host.py +34 -36
  111. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  112. sglang/srt/mem_cache/radix_cache.py +2 -5
  113. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  114. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  115. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  116. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  117. sglang/srt/model_executor/cuda_graph_runner.py +29 -9
  118. sglang/srt/model_executor/forward_batch_info.py +61 -19
  119. sglang/srt/model_executor/model_runner.py +148 -37
  120. sglang/srt/model_loader/loader.py +18 -6
  121. sglang/srt/model_loader/weight_utils.py +10 -0
  122. sglang/srt/models/bailing_moe.py +425 -0
  123. sglang/srt/models/deepseek_v2.py +137 -59
  124. sglang/srt/models/ernie4.py +426 -0
  125. sglang/srt/models/ernie4_eagle.py +203 -0
  126. sglang/srt/models/gemma2.py +0 -34
  127. sglang/srt/models/gemma3n_mm.py +38 -0
  128. sglang/srt/models/glm4.py +6 -0
  129. sglang/srt/models/glm4_moe.py +28 -16
  130. sglang/srt/models/glm4v.py +589 -0
  131. sglang/srt/models/glm4v_moe.py +400 -0
  132. sglang/srt/models/gpt_oss.py +1251 -0
  133. sglang/srt/models/granite.py +0 -25
  134. sglang/srt/models/llama.py +0 -25
  135. sglang/srt/models/llama4.py +1 -1
  136. sglang/srt/models/qwen2.py +6 -0
  137. sglang/srt/models/qwen2_5_vl.py +7 -3
  138. sglang/srt/models/qwen2_audio.py +10 -9
  139. sglang/srt/models/qwen2_moe.py +6 -0
  140. sglang/srt/models/qwen3.py +0 -24
  141. sglang/srt/models/qwen3_moe.py +32 -6
  142. sglang/srt/models/registry.py +1 -1
  143. sglang/srt/models/step3_vl.py +9 -0
  144. sglang/srt/models/torch_native_llama.py +0 -24
  145. sglang/srt/models/transformers.py +2 -5
  146. sglang/srt/multimodal/processors/base_processor.py +23 -13
  147. sglang/srt/multimodal/processors/glm4v.py +132 -0
  148. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  149. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  150. sglang/srt/reasoning_parser.py +332 -37
  151. sglang/srt/server_args.py +186 -75
  152. sglang/srt/speculative/eagle_worker.py +16 -0
  153. sglang/srt/two_batch_overlap.py +169 -9
  154. sglang/srt/utils.py +41 -5
  155. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  156. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  157. sglang/test/doc_patch.py +59 -0
  158. sglang/test/few_shot_gsm8k.py +1 -1
  159. sglang/test/few_shot_gsm8k_engine.py +1 -1
  160. sglang/test/run_eval.py +4 -1
  161. sglang/test/runners.py +2 -2
  162. sglang/test/simple_eval_common.py +6 -0
  163. sglang/test/simple_eval_gpqa.py +2 -0
  164. sglang/test/test_fp4_moe.py +118 -36
  165. sglang/test/test_utils.py +1 -1
  166. sglang/utils.py +1 -1
  167. sglang/version.py +1 -1
  168. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
  169. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
  170. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  171. /sglang/{api.py → lang/api.py} +0 -0
  172. /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
  173. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
  174. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  175. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,426 @@
1
+ # Copyright 2023-2025 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
15
+ """ Inference-only Ernie4.5 model compatible with baidu/ERNIE-4.5-*-PT weights. """
16
+
17
+ from typing import Iterable, List, Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.nn.functional as F
21
+ from torch import nn
22
+ from transformers.models.ernie4_5_moe.configuration_ernie4_5_moe import (
23
+ Ernie4_5_MoeConfig,
24
+ )
25
+
26
+ from sglang.srt.distributed import (
27
+ get_tensor_model_parallel_world_size,
28
+ tensor_model_parallel_all_reduce,
29
+ )
30
+ from sglang.srt.layers.communicator import enable_moe_dense_fully_dp
31
+ from sglang.srt.layers.layernorm import RMSNorm
32
+ from sglang.srt.layers.logits_processor import LogitsProcessor
33
+ from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
34
+ from sglang.srt.layers.moe.topk import TopK
35
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
36
+ from sglang.srt.layers.vocab_parallel_embedding import (
37
+ ParallelLMHead,
38
+ VocabParallelEmbedding,
39
+ )
40
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
41
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
42
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
43
+ from sglang.srt.models.deepseek_v2 import DeepseekV2MLP as Ernie4MLP
44
+ from sglang.srt.models.llama import LlamaAttention as Ernie4Attention
45
+ from sglang.srt.utils import add_prefix, make_layers
46
+
47
+
48
+ class MoEGate(nn.Module):
49
+ def __init__(
50
+ self,
51
+ config,
52
+ prefix: str = "",
53
+ ):
54
+ super().__init__()
55
+ self.weight = nn.Parameter(
56
+ torch.empty((config.moe_num_experts, config.hidden_size))
57
+ )
58
+ self.e_score_correction_bias = nn.Parameter(
59
+ torch.empty((1, config.moe_num_experts))
60
+ )
61
+
62
+ def forward(self, hidden_states):
63
+ logits = F.linear(hidden_states, self.weight, None)
64
+ return logits
65
+
66
+
67
+ class Ernie4Moe(nn.Module):
68
+ def __init__(
69
+ self,
70
+ config: Ernie4_5_MoeConfig,
71
+ layer_id: int,
72
+ quant_config: Optional[QuantizationConfig] = None,
73
+ prefix: str = "",
74
+ ):
75
+ super().__init__()
76
+ self.layer_id = layer_id
77
+ self.tp_size = get_tensor_model_parallel_world_size()
78
+ self.moe_num_shared_experts = getattr(config, "moe_num_shared_experts", 0)
79
+
80
+ if config.hidden_act != "silu":
81
+ raise ValueError(
82
+ f"Unsupported activation: {config.hidden_act}. "
83
+ "Only silu is supported for now."
84
+ )
85
+
86
+ self.gate = MoEGate(config=config, prefix=add_prefix("gate", prefix))
87
+
88
+ self.topk = TopK(
89
+ top_k=config.moe_k,
90
+ renormalize=True,
91
+ use_grouped_topk=False,
92
+ correction_bias=self.gate.e_score_correction_bias,
93
+ )
94
+
95
+ self.experts = get_moe_impl_class()(
96
+ num_experts=config.moe_num_experts,
97
+ top_k=config.moe_k,
98
+ hidden_size=config.hidden_size,
99
+ intermediate_size=config.moe_intermediate_size,
100
+ layer_id=self.layer_id,
101
+ quant_config=quant_config,
102
+ prefix=add_prefix("experts", prefix),
103
+ )
104
+
105
+ if self.moe_num_shared_experts > 0:
106
+ intermediate_size = (
107
+ config.moe_intermediate_size * config.moe_num_shared_experts
108
+ )
109
+ # disable tp for shared experts when enable deepep moe
110
+ self.shared_experts = Ernie4MLP(
111
+ hidden_size=config.hidden_size,
112
+ intermediate_size=intermediate_size,
113
+ hidden_act=config.hidden_act,
114
+ quant_config=quant_config,
115
+ reduce_results=False,
116
+ prefix=add_prefix("shared_experts", prefix),
117
+ )
118
+
119
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
120
+ return self.forward_normal(hidden_states)
121
+
122
+ def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
123
+ shared_output = (
124
+ self.shared_experts(hidden_states)
125
+ if self.moe_num_shared_experts > 0
126
+ else None
127
+ )
128
+ # router_logits: (num_tokens, n_experts)
129
+ router_logits = self.gate(hidden_states)
130
+ topk_output = self.topk(hidden_states, router_logits)
131
+ final_hidden_states = self.experts(
132
+ hidden_states=hidden_states, topk_output=topk_output
133
+ )
134
+ if shared_output is not None:
135
+ final_hidden_states = final_hidden_states + shared_output
136
+ if self.tp_size > 1:
137
+ final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
138
+ return final_hidden_states
139
+
140
+
141
+ class Ernie4DecoderLayer(nn.Module):
142
+ """A single transformer layer.
143
+
144
+ Transformer layer takes input with size [s, b, h] and returns an
145
+ output of the same size.
146
+ """
147
+
148
+ def __init__(
149
+ self,
150
+ config,
151
+ layer_id: int,
152
+ quant_config: Optional[QuantizationConfig] = None,
153
+ prefix: str = "",
154
+ is_mtp: bool = False,
155
+ ):
156
+ super().__init__()
157
+ rope_theta = getattr(config, "rope_theta", 10000)
158
+ rope_scaling = getattr(config, "rope_scaling", None)
159
+ rope_is_neox_style = getattr(config, "rope_is_neox_style", False)
160
+ # Self attention.
161
+ self.self_attn = Ernie4Attention(
162
+ config=config,
163
+ hidden_size=config.hidden_size,
164
+ num_heads=config.num_attention_heads,
165
+ num_kv_heads=config.num_key_value_heads,
166
+ layer_id=layer_id,
167
+ rope_theta=rope_theta,
168
+ rope_scaling=rope_scaling,
169
+ rope_is_neox_style=rope_is_neox_style,
170
+ max_position_embeddings=config.max_position_embeddings,
171
+ quant_config=quant_config,
172
+ prefix=add_prefix("self_attn", prefix),
173
+ bias=config.use_bias,
174
+ )
175
+ moe_layer_start_index = getattr(
176
+ config, "moe_layer_start_index", config.num_hidden_layers
177
+ )
178
+ moe_layer_end_index = getattr(
179
+ config, "moe_layer_end_index", config.num_hidden_layers - 1
180
+ )
181
+ # MLP
182
+ if (not is_mtp) and (
183
+ moe_layer_start_index <= layer_id <= moe_layer_end_index
184
+ and (layer_id - moe_layer_start_index) % config.moe_layer_interval == 0
185
+ ):
186
+ self.mlp = Ernie4Moe(
187
+ config=config,
188
+ layer_id=layer_id,
189
+ quant_config=quant_config,
190
+ prefix=add_prefix("mlp", prefix),
191
+ )
192
+ else:
193
+ if enable_moe_dense_fully_dp():
194
+ mlp_tp_rank, mlp_tp_size = 0, 1
195
+ else:
196
+ mlp_tp_rank, mlp_tp_size = None, None
197
+ self.mlp = Ernie4MLP(
198
+ hidden_size=config.hidden_size,
199
+ intermediate_size=config.intermediate_size,
200
+ hidden_act=config.hidden_act,
201
+ quant_config=quant_config,
202
+ prefix=add_prefix("mlp", prefix),
203
+ tp_rank=mlp_tp_rank,
204
+ tp_size=mlp_tp_size,
205
+ )
206
+
207
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
208
+ self.post_attention_layernorm = RMSNorm(
209
+ config.hidden_size, eps=config.rms_norm_eps
210
+ )
211
+
212
+ def forward(
213
+ self,
214
+ positions: torch.Tensor,
215
+ hidden_states: torch.Tensor,
216
+ forward_batch: ForwardBatch,
217
+ residual: Optional[torch.Tensor],
218
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
219
+ # Self Attention
220
+ if residual is None:
221
+ residual = hidden_states
222
+ hidden_states = self.input_layernorm(hidden_states)
223
+ else:
224
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
225
+ hidden_states = self.self_attn(
226
+ positions=positions,
227
+ hidden_states=hidden_states,
228
+ forward_batch=forward_batch,
229
+ )
230
+
231
+ # Fully Connected
232
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
233
+ hidden_states = self.mlp(hidden_states)
234
+
235
+ return hidden_states, residual
236
+
237
+
238
+ class Ernie4Model(nn.Module):
239
+ def __init__(
240
+ self,
241
+ config: Ernie4_5_MoeConfig,
242
+ quant_config: Optional[QuantizationConfig] = None,
243
+ prefix: str = "",
244
+ ) -> None:
245
+ super().__init__()
246
+ self.config = config
247
+ self.embed_tokens = VocabParallelEmbedding(
248
+ config.vocab_size,
249
+ config.hidden_size,
250
+ quant_config=quant_config,
251
+ prefix=add_prefix("embed_tokens", prefix),
252
+ )
253
+ self.layers = make_layers(
254
+ config.num_hidden_layers,
255
+ lambda idx, prefix: Ernie4DecoderLayer(
256
+ config=config, layer_id=idx, quant_config=quant_config, prefix=prefix
257
+ ),
258
+ prefix="model.layers",
259
+ )
260
+
261
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
262
+
263
+ @torch.no_grad()
264
+ def forward(
265
+ self,
266
+ input_ids: torch.Tensor,
267
+ positions: torch.Tensor,
268
+ forward_batch: ForwardBatch,
269
+ input_embeds: torch.Tensor = None,
270
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
271
+ if input_embeds is None:
272
+ hidden_states = self.embed_tokens(input_ids)
273
+ else:
274
+ hidden_states = input_embeds
275
+ residual = None
276
+ for layer in self.layers:
277
+ hidden_states, residual = layer(
278
+ positions,
279
+ hidden_states,
280
+ forward_batch,
281
+ residual,
282
+ )
283
+ hidden_states, _ = self.norm(hidden_states, residual)
284
+
285
+ return hidden_states
286
+
287
+
288
+ class Ernie4_5_ForCausalLM(nn.Module):
289
+ packed_modules_mapping = {
290
+ "qkv_proj": ["q_proj", "k_proj", "v_proj"],
291
+ "gate_up_proj": ["gate_proj", "up_proj"],
292
+ }
293
+ stacked_params_mapping = [
294
+ # (param_name, weight_name, shard_id)
295
+ (".qkv_proj", ".q_proj", "q"),
296
+ (".qkv_proj", ".k_proj", "k"),
297
+ (".qkv_proj", ".v_proj", "v"),
298
+ (".gate_up_proj", ".gate_proj", 0),
299
+ (".gate_up_proj", ".up_proj", 1),
300
+ ]
301
+
302
+ def __init__(
303
+ self,
304
+ config: Ernie4_5_MoeConfig,
305
+ quant_config: Optional[QuantizationConfig] = None,
306
+ prefix: str = "",
307
+ ):
308
+ super().__init__()
309
+ self.config: Ernie4_5_MoeConfig = config
310
+ self.quant_config = quant_config
311
+ self.model = Ernie4Model(config, quant_config, add_prefix("model", prefix))
312
+ if config.tie_word_embeddings:
313
+ self.lm_head = self.model.embed_tokens
314
+ else:
315
+ self.lm_head = ParallelLMHead(
316
+ config.vocab_size,
317
+ config.hidden_size,
318
+ quant_config=quant_config,
319
+ prefix="lm_head",
320
+ )
321
+ self.logits_processor = LogitsProcessor(config)
322
+
323
+ @torch.no_grad()
324
+ def forward(
325
+ self,
326
+ input_ids: torch.Tensor,
327
+ positions: torch.Tensor,
328
+ forward_batch: ForwardBatch,
329
+ ) -> torch.Tensor:
330
+ hidden_states = self.model(input_ids, positions, forward_batch)
331
+ return self.logits_processor(
332
+ input_ids, hidden_states, self.lm_head, forward_batch
333
+ )
334
+
335
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
336
+ params_dict = dict(self.named_parameters())
337
+ for name, loaded_weight in weights:
338
+ if self.config.tie_word_embeddings and "lm_head.weight" in name:
339
+ continue
340
+ for param_name, weight_name, shard_id in self.stacked_params_mapping:
341
+ if weight_name not in name:
342
+ continue
343
+ name = name.replace(weight_name, param_name)
344
+ param = params_dict[name]
345
+ weight_loader = param.weight_loader
346
+ weight_loader(param, loaded_weight, shard_id)
347
+ break
348
+ else:
349
+ if name in params_dict.keys():
350
+ param = params_dict[name]
351
+ weight_loader = getattr(
352
+ param, "weight_loader", default_weight_loader
353
+ )
354
+ weight_loader(param, loaded_weight)
355
+ else:
356
+ raise KeyError(f"Parameter '{name}' not found in model.")
357
+
358
+ def get_embed_and_head(self):
359
+ return self.model.embed_tokens.weight, self.lm_head.weight
360
+
361
+
362
+ class Ernie4_5_MoeForCausalLM(Ernie4_5_ForCausalLM):
363
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
364
+ expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
365
+ ckpt_gate_proj_name="gate_proj",
366
+ ckpt_down_proj_name="down_proj",
367
+ ckpt_up_proj_name="up_proj",
368
+ num_experts=self.config.moe_num_experts,
369
+ )
370
+ params_dict = dict(self.named_parameters())
371
+ for name, loaded_weight in weights:
372
+ if self.config.tie_word_embeddings and "lm_head.weight" in name:
373
+ continue
374
+ if name.startswith("model.mtp_"):
375
+ continue
376
+ if "moe_statics.e_score_correction_bias" in name:
377
+ name = name.replace("moe_statics", "gate")
378
+ for param_name, weight_name, shard_id in self.stacked_params_mapping:
379
+ if weight_name not in name:
380
+ continue
381
+ # We have mlp.experts[0].gate_proj in the checkpoint.
382
+ # Since we handle the experts below in expert_params_mapping,
383
+ # we need to skip here BEFORE we update the name, otherwise
384
+ # name will be updated to mlp.experts[0].gate_up_proj, which
385
+ # will then be updated below in expert_params_mapping
386
+ # for mlp.experts[0].gate_gate_up_proj, which breaks load.
387
+ if ("mlp.experts." in name) and name not in params_dict:
388
+ continue
389
+ name = name.replace(weight_name, param_name)
390
+ param = params_dict[name]
391
+ weight_loader = param.weight_loader
392
+ weight_loader(param, loaded_weight, shard_id)
393
+ break
394
+ else:
395
+ for mapping in expert_params_mapping:
396
+ param_name, weight_name, expert_id, shard_id = mapping
397
+ if weight_name not in name:
398
+ continue
399
+ name = name.replace(weight_name, param_name)
400
+ if name in params_dict.keys():
401
+ param = params_dict[name]
402
+ weight_loader = param.weight_loader
403
+ weight_loader(
404
+ param,
405
+ loaded_weight,
406
+ name,
407
+ shard_id=shard_id,
408
+ expert_id=expert_id,
409
+ )
410
+ else:
411
+ raise KeyError(
412
+ f"Parameter '{name}'(replaced) not found in model."
413
+ )
414
+ break
415
+ else:
416
+ if name in params_dict.keys():
417
+ param = params_dict[name]
418
+ weight_loader = getattr(
419
+ param, "weight_loader", default_weight_loader
420
+ )
421
+ weight_loader(param, loaded_weight)
422
+ else:
423
+ raise KeyError(f"Parameter '{name}' not found in model.")
424
+
425
+
426
+ EntryClass = [Ernie4_5_MoeForCausalLM, Ernie4_5_ForCausalLM]
@@ -0,0 +1,203 @@
1
+ # Copyright 2023-2025 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
15
+ """ Ernie4.5 MTP model compatible with baidu/ERNIE-4.5-*-PT weights. """
16
+
17
+ from typing import Iterable, Optional, Tuple
18
+
19
+ import torch
20
+ from torch import nn
21
+ from transformers.models.ernie4_5_moe.configuration_ernie4_5_moe import (
22
+ Ernie4_5_MoeConfig,
23
+ )
24
+
25
+ from sglang.srt.layers.layernorm import RMSNorm
26
+ from sglang.srt.layers.logits_processor import LogitsProcessor
27
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
28
+ from sglang.srt.layers.vocab_parallel_embedding import (
29
+ ParallelLMHead,
30
+ VocabParallelEmbedding,
31
+ )
32
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
33
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
34
+ from sglang.srt.models.ernie4 import Ernie4_5_ForCausalLM, Ernie4DecoderLayer
35
+ from sglang.srt.utils import add_prefix
36
+
37
+
38
+ class Ernie4ModelMTP(nn.Module):
39
+ def __init__(
40
+ self,
41
+ config: Ernie4_5_MoeConfig,
42
+ layer_id: int,
43
+ prefix: str,
44
+ quant_config: Optional[QuantizationConfig] = None,
45
+ ) -> None:
46
+ super().__init__()
47
+
48
+ self.embed_tokens = VocabParallelEmbedding(
49
+ config.vocab_size,
50
+ config.hidden_size,
51
+ quant_config=quant_config,
52
+ prefix=add_prefix("embed_tokens", prefix),
53
+ )
54
+ self.mtp_emb_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
55
+ self.mtp_hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
56
+ self.mtp_linear_proj = nn.Linear(
57
+ config.hidden_size * 2, config.hidden_size, bias=config.use_bias
58
+ )
59
+ self.mtp_block = Ernie4DecoderLayer(
60
+ config=config,
61
+ layer_id=layer_id,
62
+ quant_config=quant_config,
63
+ prefix=add_prefix("mtp_block", prefix),
64
+ is_mtp=True,
65
+ )
66
+
67
+ def forward(
68
+ self,
69
+ input_ids: torch.Tensor,
70
+ positions: torch.Tensor,
71
+ forward_batch: ForwardBatch,
72
+ input_embeds: torch.Tensor = None,
73
+ ) -> torch.Tensor:
74
+ if input_embeds is None:
75
+ hidden_states = self.embed_tokens(input_ids)
76
+ else:
77
+ hidden_states = input_embeds
78
+ # masking inputs at position 0, as not needed by MTP
79
+ hidden_states[positions == 0] = 0
80
+
81
+ hidden_states = self.mtp_linear_proj(
82
+ torch.cat(
83
+ (
84
+ self.mtp_emb_norm(hidden_states),
85
+ self.mtp_hidden_norm(forward_batch.spec_info.hidden_states),
86
+ ),
87
+ dim=-1,
88
+ )
89
+ )
90
+ residual = None
91
+ hidden_states, residual = self.mtp_block(
92
+ positions=positions,
93
+ hidden_states=hidden_states,
94
+ forward_batch=forward_batch,
95
+ residual=residual,
96
+ )
97
+ hidden_states = residual + hidden_states
98
+ return hidden_states
99
+
100
+
101
+ class Ernie4_5_MoeForCausalLMMTP(nn.Module):
102
+ def __init__(
103
+ self,
104
+ config: Ernie4_5_MoeConfig,
105
+ quant_config: Optional[QuantizationConfig] = None,
106
+ prefix: str = "",
107
+ mtp_layer_id: int = 0,
108
+ ) -> None:
109
+ nn.Module.__init__(self)
110
+ self.config = config
111
+ self.mtp_layer_id = mtp_layer_id
112
+
113
+ self.model = Ernie4ModelMTP(
114
+ config=config,
115
+ layer_id=self.mtp_layer_id,
116
+ quant_config=quant_config,
117
+ prefix=add_prefix("model", prefix),
118
+ )
119
+
120
+ if config.tie_word_embeddings:
121
+ self.lm_head = self.model.embed_tokens
122
+ else:
123
+ self.lm_head = ParallelLMHead(
124
+ config.vocab_size,
125
+ config.hidden_size,
126
+ quant_config=quant_config,
127
+ prefix="lm_head",
128
+ )
129
+ self.logits_processor = LogitsProcessor(config)
130
+
131
+ @torch.no_grad()
132
+ def forward(
133
+ self,
134
+ input_ids: torch.Tensor,
135
+ positions: torch.Tensor,
136
+ forward_batch: ForwardBatch,
137
+ ) -> torch.Tensor:
138
+ hidden_states = self.model(input_ids, positions, forward_batch)
139
+ return self.logits_processor(
140
+ input_ids, hidden_states, self.lm_head, forward_batch
141
+ )
142
+
143
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
144
+ mtp_layer_found = False
145
+ mtp_weight_patterns = [
146
+ f"mtp_block.{self.mtp_layer_id}",
147
+ f"mtp_emb_norm.{self.mtp_layer_id}",
148
+ f"mtp_hidden_norm.{self.mtp_layer_id}",
149
+ f"mtp_linear_proj.{self.mtp_layer_id}",
150
+ ]
151
+ params_dict = dict(self.named_parameters())
152
+ for name, loaded_weight in weights:
153
+ # Only name matched patterns should be loaded
154
+ for layer_pattern in mtp_weight_patterns:
155
+ if layer_pattern in name:
156
+ mtp_layer_found = True
157
+ break
158
+ else:
159
+ continue
160
+ # But strip mtp_layer_id before loading, because each MTP layer is a MTP model.
161
+ name = name.replace(f".{self.mtp_layer_id}.", ".")
162
+ for (
163
+ param_name,
164
+ weight_name,
165
+ shard_id,
166
+ ) in Ernie4_5_ForCausalLM.stacked_params_mapping:
167
+ if weight_name not in name:
168
+ continue
169
+ name = name.replace(weight_name, param_name)
170
+ param = params_dict[name]
171
+ weight_loader = param.weight_loader
172
+ weight_loader(param, loaded_weight, shard_id)
173
+ break
174
+ else:
175
+ if name in params_dict.keys():
176
+ param = params_dict[name]
177
+ weight_loader = getattr(
178
+ param, "weight_loader", default_weight_loader
179
+ )
180
+ weight_loader(param, loaded_weight)
181
+ else:
182
+ raise KeyError(f"Parameter '{name}' not found in MTP model.")
183
+ if not mtp_layer_found:
184
+ raise KeyError(
185
+ f"MTP layers 'mtp_*.{self.mtp_layer_id}.*' not found in weights."
186
+ )
187
+
188
+ def get_embed_and_head(self):
189
+ return self.model.embed_tokens.weight, self.lm_head.weight
190
+
191
+ def set_embed_and_head(self, embed, head):
192
+ del self.model.embed_tokens.weight
193
+ self.model.embed_tokens.weight = embed
194
+ if self.config.tie_word_embeddings:
195
+ self.lm_head = self.model.embed_tokens
196
+ else:
197
+ del self.lm_head.weight
198
+ self.lm_head.weight = head
199
+ torch.cuda.empty_cache()
200
+ torch.cuda.synchronize()
201
+
202
+
203
+ EntryClass = [Ernie4_5_MoeForCausalLMMTP]
@@ -432,40 +432,6 @@ class Gemma2ForCausalLM(nn.Module):
432
432
 
433
433
  return result
434
434
 
435
- def get_hidden_dim(self, module_name):
436
- # return input_dim, output_dim
437
- if module_name in ["q_proj", "qkv_proj"]:
438
- return (
439
- self.config.hidden_size,
440
- self.config.head_dim * self.config.num_attention_heads,
441
- )
442
- elif module_name in ["o_proj"]:
443
- return (
444
- self.config.head_dim * self.config.num_attention_heads,
445
- self.config.hidden_size,
446
- )
447
- elif module_name in ["kv_proj"]:
448
- return (
449
- self.config.hidden_size,
450
- self.config.head_dim * self.config.num_key_value_heads,
451
- )
452
- elif module_name == "gate_up_proj":
453
- return self.config.hidden_size, self.config.intermediate_size
454
- elif module_name == "down_proj":
455
- return self.config.intermediate_size, self.config.hidden_size
456
- else:
457
- raise NotImplementedError()
458
-
459
- def get_module_name(self, name):
460
- params_mapping = {
461
- "q_proj": "qkv_proj",
462
- "k_proj": "qkv_proj",
463
- "v_proj": "qkv_proj",
464
- "gate_proj": "gate_up_proj",
465
- "up_proj": "gate_up_proj",
466
- }
467
- return params_mapping.get(name, name)
468
-
469
435
  def get_attention_sliding_window_size(self):
470
436
  return get_attention_sliding_window_size(self.config)
471
437