sglang 0.4.9.post4__py3-none-any.whl → 0.4.9.post6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/lang/chat_template.py +21 -0
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/model_config.py +7 -0
- sglang/srt/constrained/base_grammar_backend.py +10 -2
- sglang/srt/constrained/xgrammar_backend.py +7 -5
- sglang/srt/conversation.py +16 -1
- sglang/srt/debug_utils/__init__.py +0 -0
- sglang/srt/debug_utils/dump_comparator.py +131 -0
- sglang/srt/debug_utils/dumper.py +108 -0
- sglang/srt/debug_utils/text_comparator.py +172 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
- sglang/srt/disaggregation/mooncake/conn.py +16 -0
- sglang/srt/disaggregation/prefill.py +13 -1
- sglang/srt/entrypoints/engine.py +4 -2
- sglang/srt/entrypoints/http_server.py +13 -1
- sglang/srt/entrypoints/openai/protocol.py +3 -1
- sglang/srt/entrypoints/openai/serving_base.py +5 -2
- sglang/srt/entrypoints/openai/serving_chat.py +132 -79
- sglang/srt/function_call/ebnf_composer.py +10 -3
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +164 -0
- sglang/srt/function_call/qwen3_coder_detector.py +1 -0
- sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
- sglang/srt/layers/attention/vision.py +56 -8
- sglang/srt/layers/layernorm.py +26 -1
- sglang/srt/layers/logits_processor.py +14 -3
- sglang/srt/layers/moe/ep_moe/layer.py +323 -242
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +83 -118
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
- sglang/srt/layers/moe/token_dispatcher/__init__.py +0 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +48 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +19 -0
- sglang/srt/layers/moe/topk.py +90 -24
- sglang/srt/layers/multimodal.py +11 -8
- sglang/srt/layers/quantization/fp8.py +25 -247
- sglang/srt/layers/quantization/fp8_kernel.py +78 -48
- sglang/srt/layers/quantization/modelopt_quant.py +27 -10
- sglang/srt/layers/quantization/unquant.py +24 -76
- sglang/srt/layers/quantization/w4afp8.py +68 -17
- sglang/srt/lora/lora_registry.py +93 -29
- sglang/srt/managers/cache_controller.py +9 -7
- sglang/srt/managers/data_parallel_controller.py +4 -0
- sglang/srt/managers/io_struct.py +12 -0
- sglang/srt/managers/mm_utils.py +154 -35
- sglang/srt/managers/multimodal_processor.py +3 -14
- sglang/srt/managers/schedule_batch.py +14 -8
- sglang/srt/managers/scheduler.py +64 -1
- sglang/srt/managers/scheduler_input_blocker.py +106 -0
- sglang/srt/managers/tokenizer_manager.py +80 -15
- sglang/srt/managers/tp_worker.py +8 -0
- sglang/srt/mem_cache/hiradix_cache.py +5 -2
- sglang/srt/model_executor/model_runner.py +83 -27
- sglang/srt/models/deepseek_v2.py +75 -84
- sglang/srt/models/glm4_moe.py +1035 -0
- sglang/srt/models/glm4_moe_nextn.py +167 -0
- sglang/srt/models/interns1.py +328 -0
- sglang/srt/models/internvl.py +143 -47
- sglang/srt/models/llava.py +9 -5
- sglang/srt/models/minicpmo.py +4 -1
- sglang/srt/models/qwen2_moe.py +2 -2
- sglang/srt/models/qwen3_moe.py +17 -71
- sglang/srt/multimodal/processors/base_processor.py +20 -6
- sglang/srt/multimodal/processors/clip.py +2 -2
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
- sglang/srt/multimodal/processors/gemma3.py +2 -2
- sglang/srt/multimodal/processors/gemma3n.py +2 -2
- sglang/srt/multimodal/processors/internvl.py +21 -8
- sglang/srt/multimodal/processors/janus_pro.py +2 -2
- sglang/srt/multimodal/processors/kimi_vl.py +2 -2
- sglang/srt/multimodal/processors/llava.py +4 -4
- sglang/srt/multimodal/processors/minicpm.py +2 -3
- sglang/srt/multimodal/processors/mlama.py +2 -2
- sglang/srt/multimodal/processors/mllama4.py +18 -111
- sglang/srt/multimodal/processors/phi4mm.py +2 -2
- sglang/srt/multimodal/processors/pixtral.py +2 -2
- sglang/srt/multimodal/processors/qwen_audio.py +2 -2
- sglang/srt/multimodal/processors/qwen_vl.py +2 -2
- sglang/srt/multimodal/processors/vila.py +3 -1
- sglang/srt/poll_based_barrier.py +31 -0
- sglang/srt/reasoning_parser.py +2 -1
- sglang/srt/server_args.py +65 -6
- sglang/srt/two_batch_overlap.py +8 -3
- sglang/srt/utils.py +96 -1
- sglang/srt/weight_sync/utils.py +119 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_utils.py +118 -5
- sglang/utils.py +19 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/METADATA +5 -4
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/RECORD +97 -80
- sglang/srt/debug_utils.py +0 -74
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1035 @@
|
|
1
|
+
# Copyright 2025-2026 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
|
15
|
+
"""Inference-only GLM-4.5 model compatible with HuggingFace weights"""
|
16
|
+
|
17
|
+
import logging
|
18
|
+
from typing import Any, Dict, Iterable, Optional, Tuple
|
19
|
+
|
20
|
+
import torch
|
21
|
+
import torch.nn.functional as F
|
22
|
+
from torch import nn
|
23
|
+
from transformers import PretrainedConfig
|
24
|
+
|
25
|
+
from sglang.srt.distributed import (
|
26
|
+
get_tensor_model_parallel_rank,
|
27
|
+
get_tensor_model_parallel_world_size,
|
28
|
+
parallel_state,
|
29
|
+
tensor_model_parallel_all_reduce,
|
30
|
+
)
|
31
|
+
from sglang.srt.layers.activation import SiluAndMul
|
32
|
+
from sglang.srt.layers.amx_utils import PackWeightMethod
|
33
|
+
from sglang.srt.layers.communicator import (
|
34
|
+
LayerCommunicator,
|
35
|
+
LayerScatterModes,
|
36
|
+
enable_moe_dense_fully_dp,
|
37
|
+
)
|
38
|
+
from sglang.srt.layers.dp_attention import (
|
39
|
+
get_attention_tp_rank,
|
40
|
+
get_attention_tp_size,
|
41
|
+
get_local_attention_dp_size,
|
42
|
+
)
|
43
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
44
|
+
from sglang.srt.layers.linear import (
|
45
|
+
ColumnParallelLinear,
|
46
|
+
MergedColumnParallelLinear,
|
47
|
+
QKVParallelLinear,
|
48
|
+
ReplicatedLinear,
|
49
|
+
RowParallelLinear,
|
50
|
+
)
|
51
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
52
|
+
from sglang.srt.layers.moe.ep_moe.layer import (
|
53
|
+
DeepEPMoE,
|
54
|
+
get_moe_impl_class,
|
55
|
+
use_flashinfer_trtllm_moe,
|
56
|
+
)
|
57
|
+
from sglang.srt.layers.moe.topk import TopK
|
58
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
59
|
+
from sglang.srt.layers.quantization.fp8_kernel import (
|
60
|
+
is_fp8_fnuz,
|
61
|
+
per_tensor_quant_mla_fp8,
|
62
|
+
per_token_group_quant_mla_deep_gemm_masked_fp8,
|
63
|
+
)
|
64
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
65
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
66
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
67
|
+
ParallelLMHead,
|
68
|
+
VocabParallelEmbedding,
|
69
|
+
)
|
70
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
71
|
+
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
72
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
73
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
74
|
+
from sglang.srt.models.deepseek_v2 import (
|
75
|
+
DeepseekV2DecoderLayer,
|
76
|
+
DeepseekV2ForCausalLM,
|
77
|
+
DeepseekV2Model,
|
78
|
+
DeepseekV2MoE,
|
79
|
+
)
|
80
|
+
from sglang.srt.two_batch_overlap import (
|
81
|
+
MaybeTboDeepEPDispatcher,
|
82
|
+
model_forward_maybe_tbo,
|
83
|
+
)
|
84
|
+
from sglang.srt.utils import (
|
85
|
+
BumpAllocator,
|
86
|
+
DeepEPMode,
|
87
|
+
LazyValue,
|
88
|
+
add_prefix,
|
89
|
+
bind_or_assign,
|
90
|
+
cpu_has_amx_support,
|
91
|
+
get_bool_env_var,
|
92
|
+
get_device_sm,
|
93
|
+
get_int_env_var,
|
94
|
+
is_cpu,
|
95
|
+
is_cuda,
|
96
|
+
is_flashinfer_available,
|
97
|
+
is_hip,
|
98
|
+
is_non_idle_and_non_empty,
|
99
|
+
log_info_on_rank0,
|
100
|
+
use_intel_amx_backend,
|
101
|
+
)
|
102
|
+
|
103
|
+
_is_hip = is_hip()
|
104
|
+
_is_cuda = is_cuda()
|
105
|
+
_is_fp8_fnuz = is_fp8_fnuz()
|
106
|
+
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
107
|
+
_is_cpu_amx_available = cpu_has_amx_support()
|
108
|
+
_is_cpu = is_cpu()
|
109
|
+
_device_sm = get_device_sm()
|
110
|
+
|
111
|
+
if _is_cuda:
|
112
|
+
from sgl_kernel import dsv3_router_gemm
|
113
|
+
elif _is_cpu and _is_cpu_amx_available:
|
114
|
+
pass
|
115
|
+
|
116
|
+
logger = logging.getLogger(__name__)
|
117
|
+
|
118
|
+
|
119
|
+
class Glm4MoeMLP(nn.Module):
|
120
|
+
def __init__(
|
121
|
+
self,
|
122
|
+
hidden_size: int,
|
123
|
+
intermediate_size: int,
|
124
|
+
hidden_act: str,
|
125
|
+
quant_config: Optional[QuantizationConfig] = None,
|
126
|
+
reduce_results: bool = True,
|
127
|
+
prefix: str = "",
|
128
|
+
tp_rank: Optional[int] = None,
|
129
|
+
tp_size: Optional[int] = None,
|
130
|
+
) -> None:
|
131
|
+
super().__init__()
|
132
|
+
self.tp_size = tp_size
|
133
|
+
|
134
|
+
self.gate_up_proj = MergedColumnParallelLinear(
|
135
|
+
hidden_size,
|
136
|
+
[intermediate_size] * 2,
|
137
|
+
bias=False,
|
138
|
+
quant_config=quant_config,
|
139
|
+
prefix=add_prefix("gate_up_proj", prefix),
|
140
|
+
tp_rank=tp_rank,
|
141
|
+
tp_size=tp_size,
|
142
|
+
)
|
143
|
+
self.down_proj = RowParallelLinear(
|
144
|
+
intermediate_size,
|
145
|
+
hidden_size,
|
146
|
+
bias=False,
|
147
|
+
quant_config=quant_config,
|
148
|
+
reduce_results=reduce_results,
|
149
|
+
prefix=add_prefix("down_proj", prefix),
|
150
|
+
tp_rank=tp_rank,
|
151
|
+
tp_size=tp_size,
|
152
|
+
)
|
153
|
+
if hidden_act != "silu":
|
154
|
+
raise ValueError(
|
155
|
+
f"Unsupported activation: {hidden_act}. "
|
156
|
+
"Only silu is supported for now."
|
157
|
+
)
|
158
|
+
self.act_fn = SiluAndMul()
|
159
|
+
|
160
|
+
def forward(self, x, forward_batch=None, can_fuse_mlp_allreduce=False):
|
161
|
+
if (self.tp_size == 1) and x.shape[0] == 0:
|
162
|
+
return x
|
163
|
+
|
164
|
+
gate_up, _ = self.gate_up_proj(x)
|
165
|
+
x = self.act_fn(gate_up)
|
166
|
+
x, _ = self.down_proj(x, can_fuse_mlp_allreduce=can_fuse_mlp_allreduce)
|
167
|
+
return x
|
168
|
+
|
169
|
+
|
170
|
+
class Glm4MoeAttention(nn.Module):
|
171
|
+
def __init__(
|
172
|
+
self,
|
173
|
+
hidden_size: int,
|
174
|
+
num_heads: int,
|
175
|
+
num_kv_heads: int,
|
176
|
+
layer_id: int = 0,
|
177
|
+
rope_theta: float = 10000,
|
178
|
+
partial_rotary_factor: float = 0.5,
|
179
|
+
rope_scaling: Optional[Dict[str, Any]] = None,
|
180
|
+
max_position_embeddings: int = 8192,
|
181
|
+
head_dim: Optional[int] = None,
|
182
|
+
rms_norm_eps: float = 1e-05,
|
183
|
+
attention_bias: bool = True,
|
184
|
+
quant_config: Optional[QuantizationConfig] = None,
|
185
|
+
use_qk_norm: bool = False,
|
186
|
+
prefix: str = "",
|
187
|
+
alt_stream: Optional[torch.cuda.Stream] = None,
|
188
|
+
) -> None:
|
189
|
+
super().__init__()
|
190
|
+
self.hidden_size = hidden_size
|
191
|
+
|
192
|
+
attn_tp_rank = get_attention_tp_rank()
|
193
|
+
attn_tp_size = get_attention_tp_size()
|
194
|
+
|
195
|
+
self.total_num_heads = num_heads
|
196
|
+
assert self.total_num_heads % attn_tp_size == 0
|
197
|
+
self.num_heads = self.total_num_heads // attn_tp_size
|
198
|
+
self.total_num_kv_heads = num_kv_heads
|
199
|
+
if self.total_num_kv_heads >= attn_tp_size:
|
200
|
+
# Number of KV heads is greater than TP size, so we partition
|
201
|
+
# the KV heads across multiple tensor parallel GPUs.
|
202
|
+
assert self.total_num_kv_heads % attn_tp_size == 0
|
203
|
+
else:
|
204
|
+
# Number of KV heads is less than TP size, so we replicate
|
205
|
+
# the KV heads across multiple tensor parallel GPUs.
|
206
|
+
assert attn_tp_size % self.total_num_kv_heads == 0
|
207
|
+
self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size)
|
208
|
+
self.head_dim = head_dim or hidden_size // self.total_num_heads
|
209
|
+
self.q_size = self.num_heads * self.head_dim
|
210
|
+
self.kv_size = self.num_kv_heads * self.head_dim
|
211
|
+
self.scaling = self.head_dim**-0.5
|
212
|
+
self.rope_theta = rope_theta
|
213
|
+
self.use_qk_norm = use_qk_norm
|
214
|
+
self.max_position_embeddings = max_position_embeddings
|
215
|
+
self.tp_rank = get_tensor_model_parallel_rank()
|
216
|
+
|
217
|
+
self.qkv_proj = QKVParallelLinear(
|
218
|
+
hidden_size,
|
219
|
+
self.head_dim,
|
220
|
+
self.total_num_heads,
|
221
|
+
self.total_num_kv_heads,
|
222
|
+
bias=attention_bias,
|
223
|
+
quant_config=quant_config,
|
224
|
+
tp_rank=attn_tp_rank,
|
225
|
+
tp_size=attn_tp_size,
|
226
|
+
prefix=add_prefix("qkv_proj", prefix),
|
227
|
+
)
|
228
|
+
|
229
|
+
self.o_proj = RowParallelLinear(
|
230
|
+
self.total_num_heads * self.head_dim,
|
231
|
+
hidden_size,
|
232
|
+
bias=False,
|
233
|
+
quant_config=quant_config,
|
234
|
+
tp_rank=attn_tp_rank,
|
235
|
+
tp_size=attn_tp_size,
|
236
|
+
reduce_results=False,
|
237
|
+
prefix=add_prefix("o_proj", prefix),
|
238
|
+
)
|
239
|
+
|
240
|
+
self.rotary_emb = get_rope(
|
241
|
+
self.head_dim,
|
242
|
+
rotary_dim=self.head_dim,
|
243
|
+
max_position=max_position_embeddings,
|
244
|
+
partial_rotary_factor=partial_rotary_factor,
|
245
|
+
base=rope_theta,
|
246
|
+
rope_scaling=rope_scaling,
|
247
|
+
)
|
248
|
+
self.attn = RadixAttention(
|
249
|
+
self.num_heads,
|
250
|
+
self.head_dim,
|
251
|
+
self.scaling,
|
252
|
+
num_kv_heads=self.num_kv_heads,
|
253
|
+
layer_id=layer_id,
|
254
|
+
prefix=add_prefix("attn", prefix),
|
255
|
+
)
|
256
|
+
|
257
|
+
if self.use_qk_norm:
|
258
|
+
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
259
|
+
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
260
|
+
self.alt_stream = alt_stream
|
261
|
+
|
262
|
+
def _apply_qk_norm(
|
263
|
+
self, q: torch.Tensor, k: torch.Tensor
|
264
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
265
|
+
# overlap qk norm
|
266
|
+
if self.alt_stream is not None and get_is_capture_mode():
|
267
|
+
current_stream = torch.cuda.current_stream()
|
268
|
+
self.alt_stream.wait_stream(current_stream)
|
269
|
+
q_by_head = q.reshape(-1, self.head_dim)
|
270
|
+
q_by_head = self.q_norm(q_by_head)
|
271
|
+
with torch.cuda.stream(self.alt_stream):
|
272
|
+
k_by_head = k.reshape(-1, self.head_dim)
|
273
|
+
k_by_head = self.k_norm(k_by_head)
|
274
|
+
current_stream.wait_stream(self.alt_stream)
|
275
|
+
else:
|
276
|
+
q_by_head = q.reshape(-1, self.head_dim)
|
277
|
+
q_by_head = self.q_norm(q_by_head)
|
278
|
+
k_by_head = k.reshape(-1, self.head_dim)
|
279
|
+
k_by_head = self.k_norm(k_by_head)
|
280
|
+
q = q_by_head.view(q.shape)
|
281
|
+
k = k_by_head.view(k.shape)
|
282
|
+
return q, k
|
283
|
+
|
284
|
+
def op_prepare(self, state):
|
285
|
+
state.attn_intermediate_state = self.forward_prepare(
|
286
|
+
positions=state.positions,
|
287
|
+
hidden_states=state.pop("hidden_states_after_comm_pre_attn"),
|
288
|
+
forward_batch=state.forward_batch,
|
289
|
+
)
|
290
|
+
|
291
|
+
def op_core(self, state):
|
292
|
+
state.hidden_states_after_attn = self.forward_core(
|
293
|
+
state.pop("attn_intermediate_state")
|
294
|
+
)
|
295
|
+
|
296
|
+
def forward_prepare(
|
297
|
+
self,
|
298
|
+
positions: torch.Tensor,
|
299
|
+
hidden_states: torch.Tensor,
|
300
|
+
forward_batch: ForwardBatch,
|
301
|
+
):
|
302
|
+
if hidden_states.shape[0] == 0:
|
303
|
+
return hidden_states, forward_batch, None
|
304
|
+
qkv, _ = self.qkv_proj(hidden_states)
|
305
|
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
306
|
+
if self.use_qk_norm:
|
307
|
+
q, k = self._apply_qk_norm(q, k)
|
308
|
+
q, k = self.rotary_emb(positions, q, k)
|
309
|
+
inner_state = q, k, v, forward_batch
|
310
|
+
return None, forward_batch, inner_state
|
311
|
+
|
312
|
+
def forward_core(self, intermediate_state):
|
313
|
+
hidden_states, forward_batch, inner_state = intermediate_state
|
314
|
+
if inner_state is None:
|
315
|
+
return hidden_states
|
316
|
+
attn_output = self.attn(*inner_state)
|
317
|
+
output, _ = self.o_proj(attn_output)
|
318
|
+
return output
|
319
|
+
|
320
|
+
def forward(
|
321
|
+
self,
|
322
|
+
positions: torch.Tensor,
|
323
|
+
hidden_states: torch.Tensor,
|
324
|
+
forward_batch: ForwardBatch,
|
325
|
+
) -> torch.Tensor:
|
326
|
+
s = self.forward_prepare(
|
327
|
+
positions=positions,
|
328
|
+
hidden_states=hidden_states,
|
329
|
+
forward_batch=forward_batch,
|
330
|
+
)
|
331
|
+
return self.forward_core(s)
|
332
|
+
|
333
|
+
|
334
|
+
class Glm4MoeGate(nn.Module):
|
335
|
+
def __init__(
|
336
|
+
self,
|
337
|
+
config,
|
338
|
+
prefix: str = "",
|
339
|
+
is_nextn: bool = False,
|
340
|
+
):
|
341
|
+
super().__init__()
|
342
|
+
self.is_nextn = is_nextn
|
343
|
+
self.weight = nn.Parameter(
|
344
|
+
torch.empty((config.n_routed_experts, config.hidden_size))
|
345
|
+
)
|
346
|
+
self.e_score_correction_bias = nn.Parameter(
|
347
|
+
torch.empty((config.n_routed_experts))
|
348
|
+
)
|
349
|
+
if _is_cpu and _is_cpu_amx_available:
|
350
|
+
self.quant_method = PackWeightMethod(weight_names=["weight"])
|
351
|
+
|
352
|
+
def forward(self, hidden_states):
|
353
|
+
if use_intel_amx_backend(self):
|
354
|
+
return torch.ops.sgl_kernel.weight_packed_linear(
|
355
|
+
hidden_states,
|
356
|
+
self.weight,
|
357
|
+
None, # bias
|
358
|
+
True, # is_vnni
|
359
|
+
)
|
360
|
+
|
361
|
+
# NOTE: For some unknown reason, router_gemm seems degrade accept length.
|
362
|
+
if (
|
363
|
+
_is_cuda
|
364
|
+
and not self.is_nextn
|
365
|
+
and hidden_states.shape[0] < 4
|
366
|
+
and hidden_states.shape[1] == 7168
|
367
|
+
and self.weight.shape[0] == 256
|
368
|
+
and _device_sm >= 90
|
369
|
+
):
|
370
|
+
logits = dsv3_router_gemm(hidden_states, self.weight).to(
|
371
|
+
hidden_states.dtype
|
372
|
+
)
|
373
|
+
else:
|
374
|
+
logits = F.linear(hidden_states, self.weight, None)
|
375
|
+
|
376
|
+
return logits
|
377
|
+
|
378
|
+
|
379
|
+
class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
380
|
+
def __init__(
|
381
|
+
self,
|
382
|
+
config: PretrainedConfig,
|
383
|
+
layer_id: int,
|
384
|
+
quant_config: Optional[QuantizationConfig] = None,
|
385
|
+
prefix: str = "",
|
386
|
+
alt_stream: Optional[torch.cuda.Stream] = None,
|
387
|
+
is_nextn: bool = False,
|
388
|
+
):
|
389
|
+
nn.Module.__init__(self)
|
390
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
391
|
+
self.routed_scaling_factor = config.routed_scaling_factor
|
392
|
+
self.n_shared_experts = config.n_shared_experts
|
393
|
+
self.num_fused_shared_experts = (
|
394
|
+
0
|
395
|
+
if global_server_args_dict["disable_shared_experts_fusion"]
|
396
|
+
else config.n_shared_experts
|
397
|
+
)
|
398
|
+
self.config = config
|
399
|
+
self.layer_id = layer_id
|
400
|
+
self.alt_stream = alt_stream
|
401
|
+
|
402
|
+
if self.tp_size > config.n_routed_experts:
|
403
|
+
raise ValueError(
|
404
|
+
f"Tensor parallel size {self.tp_size} is greater than "
|
405
|
+
f"the number of experts {config.n_routed_experts}."
|
406
|
+
)
|
407
|
+
|
408
|
+
if config.hidden_act != "silu":
|
409
|
+
raise ValueError(
|
410
|
+
f"Unsupported activation: {config.hidden_act}. "
|
411
|
+
"Only silu is supported for now."
|
412
|
+
)
|
413
|
+
|
414
|
+
self.gate = Glm4MoeGate(
|
415
|
+
config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
|
416
|
+
)
|
417
|
+
|
418
|
+
self.topk = (
|
419
|
+
TopK(
|
420
|
+
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
421
|
+
renormalize=config.norm_topk_prob,
|
422
|
+
use_grouped_topk=True,
|
423
|
+
num_expert_group=config.n_group,
|
424
|
+
num_fused_shared_experts=self.num_fused_shared_experts,
|
425
|
+
topk_group=config.topk_group,
|
426
|
+
correction_bias=self.gate.e_score_correction_bias,
|
427
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
428
|
+
)
|
429
|
+
if not use_flashinfer_trtllm_moe
|
430
|
+
else None
|
431
|
+
)
|
432
|
+
|
433
|
+
self.experts = get_moe_impl_class()(
|
434
|
+
num_experts=config.n_routed_experts
|
435
|
+
+ self.num_fused_shared_experts
|
436
|
+
+ global_server_args_dict["ep_num_redundant_experts"],
|
437
|
+
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
438
|
+
hidden_size=config.hidden_size,
|
439
|
+
intermediate_size=config.moe_intermediate_size,
|
440
|
+
layer_id=self.layer_id,
|
441
|
+
quant_config=quant_config,
|
442
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
443
|
+
prefix=add_prefix("experts", prefix),
|
444
|
+
**(
|
445
|
+
dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
|
446
|
+
if global_server_args_dict["enable_deepep_moe"]
|
447
|
+
else {}
|
448
|
+
),
|
449
|
+
# Additional args for FusedMoE
|
450
|
+
**(
|
451
|
+
dict(
|
452
|
+
enable_flashinfer_cutlass_moe=True,
|
453
|
+
enable_ep_moe=global_server_args_dict["enable_ep_moe"],
|
454
|
+
)
|
455
|
+
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
|
456
|
+
else {}
|
457
|
+
),
|
458
|
+
**(
|
459
|
+
dict(
|
460
|
+
renormalize=config.norm_topk_prob,
|
461
|
+
use_grouped_topk=True,
|
462
|
+
num_expert_group=config.n_group,
|
463
|
+
num_fused_shared_experts=self.num_fused_shared_experts,
|
464
|
+
topk_group=config.topk_group,
|
465
|
+
correction_bias=self.gate.e_score_correction_bias,
|
466
|
+
)
|
467
|
+
if use_flashinfer_trtllm_moe
|
468
|
+
else {}
|
469
|
+
),
|
470
|
+
)
|
471
|
+
|
472
|
+
self.shared_experts_is_int8 = False
|
473
|
+
self.shared_experts_is_fp8 = False
|
474
|
+
# self.shared_experts_weight_block_size = None
|
475
|
+
if config.n_shared_experts is not None and self.num_fused_shared_experts == 0:
|
476
|
+
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
477
|
+
self.shared_experts = Glm4MoeMLP(
|
478
|
+
hidden_size=config.hidden_size,
|
479
|
+
intermediate_size=intermediate_size,
|
480
|
+
hidden_act=config.hidden_act,
|
481
|
+
quant_config=quant_config,
|
482
|
+
reduce_results=False,
|
483
|
+
prefix=add_prefix("shared_experts", prefix),
|
484
|
+
**(
|
485
|
+
dict(tp_rank=0, tp_size=1)
|
486
|
+
if global_server_args_dict["enable_deepep_moe"]
|
487
|
+
else {}
|
488
|
+
),
|
489
|
+
)
|
490
|
+
is_packed_weight = hasattr(
|
491
|
+
self.shared_experts.gate_up_proj.quant_method, "quant_config"
|
492
|
+
)
|
493
|
+
self.shared_experts_is_int8 = (
|
494
|
+
not is_packed_weight
|
495
|
+
and self.shared_experts.gate_up_proj.weight.dtype == torch.int8
|
496
|
+
)
|
497
|
+
self.shared_experts_is_fp8 = (
|
498
|
+
not is_packed_weight
|
499
|
+
and self.shared_experts.gate_up_proj.weight.dtype == torch.float8_e4m3fn
|
500
|
+
)
|
501
|
+
|
502
|
+
self.top_k = config.num_experts_per_tok
|
503
|
+
|
504
|
+
if global_server_args_dict["enable_deepep_moe"]:
|
505
|
+
# TODO: we will support tp < ep in the future
|
506
|
+
self.ep_size = get_tensor_model_parallel_world_size()
|
507
|
+
self.num_experts = (
|
508
|
+
config.n_routed_experts
|
509
|
+
+ global_server_args_dict["ep_num_redundant_experts"]
|
510
|
+
)
|
511
|
+
self.renormalize = config.norm_topk_prob
|
512
|
+
self.topk_group = config.topk_group
|
513
|
+
self.num_expert_group = config.n_group
|
514
|
+
self.correction_bias = (
|
515
|
+
self.gate.e_score_correction_bias.data
|
516
|
+
if self.gate.e_score_correction_bias is not None
|
517
|
+
else None
|
518
|
+
)
|
519
|
+
|
520
|
+
self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
|
521
|
+
group=parallel_state.get_tp_group().device_group,
|
522
|
+
router_topk=self.top_k,
|
523
|
+
permute_fusion=True,
|
524
|
+
num_experts=self.num_experts,
|
525
|
+
num_local_experts=config.n_routed_experts // self.tp_size,
|
526
|
+
hidden_size=config.hidden_size,
|
527
|
+
params_dtype=config.torch_dtype,
|
528
|
+
deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
|
529
|
+
async_finish=True,
|
530
|
+
return_recv_hook=True,
|
531
|
+
)
|
532
|
+
|
533
|
+
self._enable_deepep_moe = global_server_args_dict["enable_deepep_moe"]
|
534
|
+
|
535
|
+
|
536
|
+
class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
|
537
|
+
def __init__(
|
538
|
+
self,
|
539
|
+
config: PretrainedConfig,
|
540
|
+
layer_id: int,
|
541
|
+
quant_config: Optional[QuantizationConfig] = None,
|
542
|
+
is_nextn: bool = False,
|
543
|
+
prefix: str = "",
|
544
|
+
alt_stream: Optional[torch.cuda.Stream] = None,
|
545
|
+
) -> None:
|
546
|
+
nn.Module.__init__(self)
|
547
|
+
self.hidden_size = config.hidden_size
|
548
|
+
self.config = config
|
549
|
+
rope_theta = getattr(config, "rope_theta", 10000)
|
550
|
+
rope_scaling = getattr(config, "rope_scaling", None)
|
551
|
+
partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5)
|
552
|
+
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
553
|
+
head_dim = getattr(
|
554
|
+
config, "head_dim", config.hidden_size // config.num_attention_heads
|
555
|
+
)
|
556
|
+
rms_norm_eps = config.rms_norm_eps
|
557
|
+
attention_bias = config.attention_bias
|
558
|
+
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
|
559
|
+
self.layer_id = layer_id
|
560
|
+
self.self_attn = Glm4MoeAttention(
|
561
|
+
hidden_size=self.hidden_size,
|
562
|
+
num_heads=config.num_attention_heads,
|
563
|
+
num_kv_heads=config.num_key_value_heads,
|
564
|
+
layer_id=layer_id,
|
565
|
+
rope_theta=rope_theta,
|
566
|
+
rope_scaling=rope_scaling,
|
567
|
+
partial_rotary_factor=partial_rotary_factor,
|
568
|
+
max_position_embeddings=max_position_embeddings,
|
569
|
+
head_dim=head_dim,
|
570
|
+
rms_norm_eps=rms_norm_eps,
|
571
|
+
attention_bias=attention_bias,
|
572
|
+
quant_config=quant_config,
|
573
|
+
prefix=add_prefix("self_attn", prefix),
|
574
|
+
use_qk_norm=config.use_qk_norm,
|
575
|
+
)
|
576
|
+
|
577
|
+
self.is_layer_sparse = self._is_layer_sparse(layer_id, is_nextn=is_nextn)
|
578
|
+
is_previous_layer_sparse = self._is_layer_sparse(layer_id - 1, is_nextn=False)
|
579
|
+
|
580
|
+
num_layers = 1 if is_nextn else config.num_hidden_layers
|
581
|
+
self.layer_scatter_modes = LayerScatterModes.init_new(
|
582
|
+
layer_id=layer_id,
|
583
|
+
num_layers=num_layers,
|
584
|
+
is_layer_sparse=self.is_layer_sparse,
|
585
|
+
is_previous_layer_sparse=is_previous_layer_sparse,
|
586
|
+
)
|
587
|
+
|
588
|
+
if self.is_layer_sparse:
|
589
|
+
self.mlp = Glm4MoeSparseMoeBlock(
|
590
|
+
config=config,
|
591
|
+
quant_config=quant_config,
|
592
|
+
prefix=add_prefix("mlp", prefix),
|
593
|
+
layer_id=self.layer_id,
|
594
|
+
)
|
595
|
+
else:
|
596
|
+
if enable_moe_dense_fully_dp():
|
597
|
+
mlp_tp_rank, mlp_tp_size = 0, 1
|
598
|
+
else:
|
599
|
+
mlp_tp_rank, mlp_tp_size = None, None
|
600
|
+
self.mlp = Glm4MoeMLP(
|
601
|
+
hidden_size=config.hidden_size,
|
602
|
+
intermediate_size=config.intermediate_size,
|
603
|
+
hidden_act=config.hidden_act,
|
604
|
+
quant_config=quant_config,
|
605
|
+
prefix=add_prefix("mlp", prefix),
|
606
|
+
tp_rank=mlp_tp_rank,
|
607
|
+
tp_size=mlp_tp_size,
|
608
|
+
)
|
609
|
+
|
610
|
+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
611
|
+
self.post_attention_layernorm = RMSNorm(
|
612
|
+
config.hidden_size, eps=config.rms_norm_eps
|
613
|
+
)
|
614
|
+
|
615
|
+
self.layer_communicator = LayerCommunicator(
|
616
|
+
layer_scatter_modes=self.layer_scatter_modes,
|
617
|
+
input_layernorm=self.input_layernorm,
|
618
|
+
post_attention_layernorm=self.post_attention_layernorm,
|
619
|
+
)
|
620
|
+
|
621
|
+
def forward(
|
622
|
+
self,
|
623
|
+
positions: torch.Tensor,
|
624
|
+
hidden_states: torch.Tensor,
|
625
|
+
forward_batch: ForwardBatch,
|
626
|
+
residual: Optional[torch.Tensor],
|
627
|
+
zero_allocator: BumpAllocator,
|
628
|
+
) -> torch.Tensor:
|
629
|
+
hidden_states, residual = self.layer_communicator.prepare_attn(
|
630
|
+
hidden_states, residual, forward_batch
|
631
|
+
)
|
632
|
+
|
633
|
+
hidden_states = self.self_attn(
|
634
|
+
positions=positions,
|
635
|
+
hidden_states=hidden_states,
|
636
|
+
forward_batch=forward_batch,
|
637
|
+
)
|
638
|
+
|
639
|
+
hidden_states, residual = self.layer_communicator.prepare_mlp(
|
640
|
+
hidden_states, residual, forward_batch
|
641
|
+
)
|
642
|
+
|
643
|
+
hidden_states = self.mlp(hidden_states, forward_batch)
|
644
|
+
|
645
|
+
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
646
|
+
hidden_states, residual, forward_batch
|
647
|
+
)
|
648
|
+
|
649
|
+
return hidden_states, residual
|
650
|
+
|
651
|
+
|
652
|
+
class Glm4MoeModel(DeepseekV2Model):
|
653
|
+
def __init__(
|
654
|
+
self,
|
655
|
+
config: PretrainedConfig,
|
656
|
+
quant_config: Optional[QuantizationConfig] = None,
|
657
|
+
prefix: str = "",
|
658
|
+
) -> None:
|
659
|
+
nn.Module.__init__(self)
|
660
|
+
self.padding_id = config.pad_token_id
|
661
|
+
self.vocab_size = config.vocab_size
|
662
|
+
self.first_k_dense_replace = config.first_k_dense_replace
|
663
|
+
|
664
|
+
self.embed_tokens = VocabParallelEmbedding(
|
665
|
+
config.vocab_size,
|
666
|
+
config.hidden_size,
|
667
|
+
enable_tp=not global_server_args_dict["enable_dp_attention"],
|
668
|
+
)
|
669
|
+
self.alt_stream = torch.cuda.Stream() if _is_cuda else None
|
670
|
+
self.layers = nn.ModuleList(
|
671
|
+
[
|
672
|
+
Glm4MoeDecoderLayer(
|
673
|
+
config,
|
674
|
+
layer_id,
|
675
|
+
quant_config=quant_config,
|
676
|
+
prefix=add_prefix(f"layers.{layer_id}", prefix),
|
677
|
+
alt_stream=self.alt_stream,
|
678
|
+
)
|
679
|
+
for layer_id in range(config.num_hidden_layers)
|
680
|
+
]
|
681
|
+
)
|
682
|
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
683
|
+
|
684
|
+
self.dp_size = get_local_attention_dp_size()
|
685
|
+
|
686
|
+
|
687
|
+
class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
|
688
|
+
|
689
|
+
def __init__(
|
690
|
+
self,
|
691
|
+
config: PretrainedConfig,
|
692
|
+
quant_config: Optional[QuantizationConfig] = None,
|
693
|
+
prefix: str = "",
|
694
|
+
) -> None:
|
695
|
+
nn.Module.__init__(self)
|
696
|
+
config.moe_layer_freq = 1
|
697
|
+
self.config = config
|
698
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
699
|
+
self.quant_config = quant_config
|
700
|
+
self.determine_num_fused_shared_experts("Glm4MoeForCausalLM")
|
701
|
+
self.model = Glm4MoeModel(
|
702
|
+
config, quant_config, prefix=add_prefix("model", prefix)
|
703
|
+
)
|
704
|
+
self.lm_head = ParallelLMHead(
|
705
|
+
config.vocab_size,
|
706
|
+
config.hidden_size,
|
707
|
+
quant_config=quant_config,
|
708
|
+
prefix=add_prefix("lm_head", prefix),
|
709
|
+
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
710
|
+
)
|
711
|
+
self.logits_processor = LogitsProcessor(config)
|
712
|
+
self.dp_size = get_local_attention_dp_size()
|
713
|
+
|
714
|
+
self._routed_experts_weights_of_layer = LazyValue(
|
715
|
+
lambda: {
|
716
|
+
layer_id: layer.mlp.get_moe_weights()
|
717
|
+
for layer_id, layer in enumerate(self.model.layers)
|
718
|
+
if isinstance(layer.mlp, DeepseekV2MoE)
|
719
|
+
}
|
720
|
+
)
|
721
|
+
|
722
|
+
def determine_num_fused_shared_experts(
|
723
|
+
self, architecture: str = "DeepseekV3ForCausalLM"
|
724
|
+
):
|
725
|
+
self.num_fused_shared_experts = 0
|
726
|
+
if global_server_args_dict["disable_shared_experts_fusion"]:
|
727
|
+
return
|
728
|
+
|
729
|
+
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
|
730
|
+
disable_reason = None
|
731
|
+
if (
|
732
|
+
not _is_cuda
|
733
|
+
or torch.cuda.get_device_capability("cuda") < (8, 0)
|
734
|
+
or self.config.architectures[0] != architecture
|
735
|
+
or self.config.n_routed_experts != 128
|
736
|
+
or self.config.n_shared_experts != 1
|
737
|
+
):
|
738
|
+
disable_reason = "Only GLM-4.5 on NV-platform with capability >= 80 can use shared experts fusion optimization."
|
739
|
+
elif (
|
740
|
+
global_server_args_dict["enable_deepep_moe"]
|
741
|
+
or global_server_args_dict["enable_ep_moe"]
|
742
|
+
):
|
743
|
+
disable_reason = "Deepseek GLM-4.5 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
|
744
|
+
|
745
|
+
if disable_reason is not None:
|
746
|
+
global_server_args_dict["disable_shared_experts_fusion"] = True
|
747
|
+
log_info_on_rank0(
|
748
|
+
logger,
|
749
|
+
f"{disable_reason} Shared experts fusion optimization is disabled.",
|
750
|
+
)
|
751
|
+
return
|
752
|
+
|
753
|
+
self.num_fused_shared_experts = self.config.n_shared_experts
|
754
|
+
|
755
|
+
def get_input_embeddings(self) -> nn.Embedding:
|
756
|
+
return self.model.embed_tokens
|
757
|
+
|
758
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
|
759
|
+
|
760
|
+
if is_nextn:
|
761
|
+
if hasattr(self.config, "num_nextn_predict_layers"):
|
762
|
+
num_nextn_layers = self.config.num_nextn_predict_layers
|
763
|
+
assert num_nextn_layers == 1, "Only 1 nextn layer is supported"
|
764
|
+
# compatible with old design
|
765
|
+
nextn_layer_id = (
|
766
|
+
0
|
767
|
+
if self.config.num_hidden_layers == 1
|
768
|
+
else self.config.num_hidden_layers
|
769
|
+
)
|
770
|
+
else:
|
771
|
+
raise ValueError("num_nextn_predict_layers is not in the config")
|
772
|
+
|
773
|
+
stacked_params_mapping = [
|
774
|
+
# (param_name, shard_name, shard_id)
|
775
|
+
("qkv_proj", "q_proj", "q"),
|
776
|
+
("qkv_proj", "k_proj", "k"),
|
777
|
+
("qkv_proj", "v_proj", "v"),
|
778
|
+
("gate_up_proj", "gate_proj", 0),
|
779
|
+
("gate_up_proj", "up_proj", 1),
|
780
|
+
]
|
781
|
+
if self.num_fused_shared_experts > 0:
|
782
|
+
assert self.num_fused_shared_experts == 1
|
783
|
+
weights_list = list(weights)
|
784
|
+
weights_dict = dict(weights_list)
|
785
|
+
if self.quant_config is not None:
|
786
|
+
if self.quant_config.get_name() == "w8a8_int8":
|
787
|
+
suffix_list = [
|
788
|
+
"down_proj.weight",
|
789
|
+
"down_proj.weight_scale",
|
790
|
+
"gate_proj.weight",
|
791
|
+
"gate_proj.weight_scale",
|
792
|
+
"up_proj.weight",
|
793
|
+
"up_proj.weight_scale",
|
794
|
+
]
|
795
|
+
elif (
|
796
|
+
self.quant_config.get_name() == "fp8"
|
797
|
+
or self.quant_config.get_name() == "blockwise_int8"
|
798
|
+
or self.quant_config.get_name() == "compressed_tensors"
|
799
|
+
):
|
800
|
+
suffix_list = [
|
801
|
+
"down_proj.weight",
|
802
|
+
"down_proj.weight_scale",
|
803
|
+
"gate_proj.weight",
|
804
|
+
"gate_proj.weight_scale",
|
805
|
+
"up_proj.weight",
|
806
|
+
"up_proj.weight_scale",
|
807
|
+
]
|
808
|
+
elif self.quant_config.get_name() == "awq":
|
809
|
+
suffix_list = [
|
810
|
+
"down_proj.qweight",
|
811
|
+
"down_proj.qzeros",
|
812
|
+
"down_proj.scales",
|
813
|
+
"gate_proj.qweight",
|
814
|
+
"gate_proj.qzeros",
|
815
|
+
"gate_proj.scales",
|
816
|
+
"up_proj.qweight",
|
817
|
+
"up_proj.qzeros",
|
818
|
+
"up_proj.scales",
|
819
|
+
]
|
820
|
+
elif self.quant_config.get_name() == "modelopt_fp4":
|
821
|
+
suffix_list = [
|
822
|
+
"down_proj.weight",
|
823
|
+
"down_proj.weight_scale",
|
824
|
+
"down_proj.weight_scale_2",
|
825
|
+
"down_proj.input_scale",
|
826
|
+
"gate_proj.weight",
|
827
|
+
"gate_proj.weight_scale",
|
828
|
+
"gate_proj.weight_scale_2",
|
829
|
+
"gate_proj.input_scale",
|
830
|
+
"up_proj.weight",
|
831
|
+
"up_proj.weight_scale",
|
832
|
+
"up_proj.weight_scale_2",
|
833
|
+
"up_proj.input_scale",
|
834
|
+
]
|
835
|
+
else:
|
836
|
+
raise ValueError(
|
837
|
+
f"Unsupported shared expert fusion for quantization: {self.quant_config.get_name()}."
|
838
|
+
)
|
839
|
+
else:
|
840
|
+
suffix_list = [
|
841
|
+
"down_proj.weight",
|
842
|
+
"gate_proj.weight",
|
843
|
+
"up_proj.weight",
|
844
|
+
]
|
845
|
+
names_to_remove = []
|
846
|
+
|
847
|
+
moe_layers = (
|
848
|
+
range(
|
849
|
+
self.config.first_k_dense_replace,
|
850
|
+
self.config.num_hidden_layers,
|
851
|
+
self.config.moe_layer_freq,
|
852
|
+
)
|
853
|
+
if not is_nextn
|
854
|
+
else [nextn_layer_id]
|
855
|
+
)
|
856
|
+
|
857
|
+
for moe_layer in moe_layers:
|
858
|
+
for suffix in suffix_list:
|
859
|
+
shared_expert_weight_name = (
|
860
|
+
f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
|
861
|
+
)
|
862
|
+
# online fp8 quantization does not load weight_scale
|
863
|
+
if shared_expert_weight_name not in weights_dict:
|
864
|
+
continue
|
865
|
+
weights_list.append(
|
866
|
+
(
|
867
|
+
f"model.layers.{moe_layer}."
|
868
|
+
f"mlp.experts."
|
869
|
+
f"{self.config.n_routed_experts + 0}"
|
870
|
+
f".{suffix}",
|
871
|
+
weights_dict[shared_expert_weight_name],
|
872
|
+
)
|
873
|
+
)
|
874
|
+
names_to_remove += [shared_expert_weight_name]
|
875
|
+
weights = [w for w in weights_list if w[0] not in names_to_remove]
|
876
|
+
|
877
|
+
# Params for weights, fp8 weight scales, fp8 activation scales
|
878
|
+
# (param_name, weight_name, expert_id, shard_id)
|
879
|
+
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
|
880
|
+
ckpt_gate_proj_name="gate_proj",
|
881
|
+
ckpt_down_proj_name="down_proj",
|
882
|
+
ckpt_up_proj_name="up_proj",
|
883
|
+
num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
|
884
|
+
)
|
885
|
+
|
886
|
+
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
|
887
|
+
fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
|
888
|
+
self.config.q_lora_rank is not None
|
889
|
+
)
|
890
|
+
cached_a_proj = {} if fuse_qkv_a_proj else None
|
891
|
+
|
892
|
+
if is_nextn:
|
893
|
+
nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
|
894
|
+
nextn_spec_weight_names = [
|
895
|
+
"shared_head.norm",
|
896
|
+
"eh_proj",
|
897
|
+
"enorm",
|
898
|
+
"hnorm",
|
899
|
+
]
|
900
|
+
|
901
|
+
params_dict = dict(self.named_parameters())
|
902
|
+
weight_names = []
|
903
|
+
for name, loaded_weight in weights:
|
904
|
+
weight_names.append(name)
|
905
|
+
|
906
|
+
if not is_nextn:
|
907
|
+
if hasattr(self.config, "num_nextn_predict_layers"):
|
908
|
+
num_nextn_layers = self.config.num_nextn_predict_layers
|
909
|
+
if num_nextn_layers > 0 and name.startswith("model.layers"):
|
910
|
+
name_list = name.split(".")
|
911
|
+
if (
|
912
|
+
len(name_list) >= 3
|
913
|
+
and int(name_list[2]) >= self.config.num_hidden_layers
|
914
|
+
):
|
915
|
+
continue
|
916
|
+
else:
|
917
|
+
if not name.startswith(nextn_layer_prefix):
|
918
|
+
continue
|
919
|
+
|
920
|
+
# Use shared head and embed weights from target model
|
921
|
+
if "shared_head.head" in name or "embed_tokens" in name:
|
922
|
+
continue
|
923
|
+
|
924
|
+
is_decoder = True
|
925
|
+
# For nextn specific weights
|
926
|
+
for weight_name in nextn_spec_weight_names:
|
927
|
+
if weight_name in name:
|
928
|
+
name = name.replace(nextn_layer_prefix, "model")
|
929
|
+
is_decoder = False
|
930
|
+
break
|
931
|
+
# For decoder layer weights
|
932
|
+
if is_decoder:
|
933
|
+
name = name.replace(nextn_layer_prefix, "model.decoder")
|
934
|
+
|
935
|
+
if "rotary_emb.inv_freq" in name:
|
936
|
+
continue
|
937
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
938
|
+
# Skip non-stacked layers and experts (experts handled below).
|
939
|
+
if weight_name not in name:
|
940
|
+
continue
|
941
|
+
# We have mlp.experts[0].gate_proj in the checkpoint.
|
942
|
+
# Since we handle the experts below in expert_params_mapping,
|
943
|
+
# we need to skip here BEFORE we update the name, otherwise
|
944
|
+
# name will be updated to mlp.experts[0].gate_up_proj, which
|
945
|
+
# will then be updated below in expert_params_mapping
|
946
|
+
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
947
|
+
if ("mlp.experts." in name) and name not in params_dict:
|
948
|
+
continue
|
949
|
+
name = name.replace(weight_name, param_name)
|
950
|
+
# Skip loading extra bias for GPTQ models.
|
951
|
+
if name.endswith(".bias") and name not in params_dict:
|
952
|
+
continue
|
953
|
+
param = params_dict[name]
|
954
|
+
weight_loader = param.weight_loader
|
955
|
+
weight_loader(param, loaded_weight, shard_id)
|
956
|
+
break
|
957
|
+
else:
|
958
|
+
for mapping in expert_params_mapping:
|
959
|
+
param_name, weight_name, expert_id, shard_id = mapping
|
960
|
+
if weight_name not in name:
|
961
|
+
continue
|
962
|
+
name = name.replace(weight_name, param_name)
|
963
|
+
param = params_dict[name]
|
964
|
+
weight_loader = param.weight_loader
|
965
|
+
weight_loader(
|
966
|
+
param,
|
967
|
+
loaded_weight,
|
968
|
+
name,
|
969
|
+
shard_id=shard_id,
|
970
|
+
expert_id=expert_id,
|
971
|
+
)
|
972
|
+
break
|
973
|
+
else:
|
974
|
+
# Skip loading extra bias for GPTQ models.
|
975
|
+
if name.endswith(".bias") and name not in params_dict:
|
976
|
+
continue
|
977
|
+
if fuse_qkv_a_proj and (
|
978
|
+
"q_a_proj" in name or "kv_a_proj_with_mqa" in name
|
979
|
+
):
|
980
|
+
cached_a_proj[name] = loaded_weight
|
981
|
+
q_a_proj_name = (
|
982
|
+
name
|
983
|
+
if "q_a_proj" in name
|
984
|
+
else name.replace("kv_a_proj_with_mqa", "q_a_proj")
|
985
|
+
)
|
986
|
+
kv_a_proj_name = (
|
987
|
+
name
|
988
|
+
if "kv_a_proj_with_mqa" in name
|
989
|
+
else name.replace("q_a_proj", "kv_a_proj_with_mqa")
|
990
|
+
)
|
991
|
+
|
992
|
+
# When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
|
993
|
+
if (
|
994
|
+
q_a_proj_name in cached_a_proj
|
995
|
+
and kv_a_proj_name in cached_a_proj
|
996
|
+
):
|
997
|
+
q_a_proj_weight = cached_a_proj[q_a_proj_name]
|
998
|
+
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
|
999
|
+
fused_weight = torch.cat(
|
1000
|
+
[q_a_proj_weight, kv_a_proj_weight], dim=0
|
1001
|
+
)
|
1002
|
+
param_name = (
|
1003
|
+
name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
|
1004
|
+
if "q_a_proj" in name
|
1005
|
+
else name.replace(
|
1006
|
+
"kv_a_proj_with_mqa", "fused_qkv_a_proj_with_mqa"
|
1007
|
+
)
|
1008
|
+
)
|
1009
|
+
param = params_dict[param_name]
|
1010
|
+
|
1011
|
+
weight_loader = getattr(
|
1012
|
+
param, "weight_loader", default_weight_loader
|
1013
|
+
)
|
1014
|
+
weight_loader(param, fused_weight)
|
1015
|
+
cached_a_proj.pop(q_a_proj_name)
|
1016
|
+
cached_a_proj.pop(kv_a_proj_name)
|
1017
|
+
else:
|
1018
|
+
if (
|
1019
|
+
"k_scale" in name or "v_scale" in name
|
1020
|
+
) and name not in params_dict:
|
1021
|
+
# modelopt attn kv scale is named differently
|
1022
|
+
if any(scale in name for scale in ["k_scale", "v_scale"]):
|
1023
|
+
name = name.replace("_proj", "attn_mqa")
|
1024
|
+
else:
|
1025
|
+
logger.warning(
|
1026
|
+
f"Unknown scale found in checkpoint: {name}"
|
1027
|
+
)
|
1028
|
+
param = params_dict[name]
|
1029
|
+
weight_loader = getattr(
|
1030
|
+
param, "weight_loader", default_weight_loader
|
1031
|
+
)
|
1032
|
+
weight_loader(param, loaded_weight)
|
1033
|
+
|
1034
|
+
|
1035
|
+
EntryClass = [Glm4MoeForCausalLM]
|