sglang 0.4.9.post3__py3-none-any.whl → 0.4.9.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (128) hide show
  1. sglang/lang/chat_template.py +21 -0
  2. sglang/srt/_custom_ops.py +29 -1
  3. sglang/srt/configs/internvl.py +3 -0
  4. sglang/srt/configs/model_config.py +5 -1
  5. sglang/srt/constrained/base_grammar_backend.py +10 -2
  6. sglang/srt/constrained/xgrammar_backend.py +7 -5
  7. sglang/srt/conversation.py +17 -2
  8. sglang/srt/debug_utils/__init__.py +0 -0
  9. sglang/srt/debug_utils/dump_comparator.py +131 -0
  10. sglang/srt/debug_utils/dumper.py +108 -0
  11. sglang/srt/debug_utils/text_comparator.py +172 -0
  12. sglang/srt/disaggregation/common/conn.py +34 -6
  13. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
  14. sglang/srt/disaggregation/mini_lb.py +3 -2
  15. sglang/srt/disaggregation/mooncake/conn.py +65 -20
  16. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  17. sglang/srt/disaggregation/nixl/conn.py +17 -13
  18. sglang/srt/disaggregation/prefill.py +13 -1
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  20. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  21. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  22. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  23. sglang/srt/distributed/parallel_state.py +70 -15
  24. sglang/srt/entrypoints/engine.py +5 -9
  25. sglang/srt/entrypoints/http_server.py +20 -32
  26. sglang/srt/entrypoints/openai/protocol.py +3 -3
  27. sglang/srt/entrypoints/openai/serving_chat.py +148 -72
  28. sglang/srt/function_call/base_format_detector.py +74 -12
  29. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  30. sglang/srt/function_call/ebnf_composer.py +105 -66
  31. sglang/srt/function_call/function_call_parser.py +6 -4
  32. sglang/srt/function_call/glm4_moe_detector.py +164 -0
  33. sglang/srt/function_call/kimik2_detector.py +41 -16
  34. sglang/srt/function_call/llama32_detector.py +6 -3
  35. sglang/srt/function_call/mistral_detector.py +11 -3
  36. sglang/srt/function_call/pythonic_detector.py +16 -14
  37. sglang/srt/function_call/qwen25_detector.py +12 -3
  38. sglang/srt/function_call/{qwen3_detector.py → qwen3_coder_detector.py} +11 -9
  39. sglang/srt/layers/activation.py +11 -3
  40. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  41. sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
  42. sglang/srt/layers/attention/vision.py +56 -8
  43. sglang/srt/layers/communicator.py +12 -12
  44. sglang/srt/layers/dp_attention.py +72 -24
  45. sglang/srt/layers/layernorm.py +26 -1
  46. sglang/srt/layers/logits_processor.py +46 -25
  47. sglang/srt/layers/moe/ep_moe/layer.py +172 -206
  48. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  49. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +25 -224
  51. sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
  52. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
  53. sglang/srt/layers/moe/topk.py +88 -34
  54. sglang/srt/layers/multimodal.py +11 -8
  55. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -9
  56. sglang/srt/layers/quantization/fp8.py +25 -247
  57. sglang/srt/layers/quantization/fp8_kernel.py +78 -48
  58. sglang/srt/layers/quantization/modelopt_quant.py +33 -14
  59. sglang/srt/layers/quantization/unquant.py +24 -76
  60. sglang/srt/layers/quantization/utils.py +0 -9
  61. sglang/srt/layers/quantization/w4afp8.py +68 -17
  62. sglang/srt/layers/radix_attention.py +5 -3
  63. sglang/srt/lora/lora_manager.py +133 -169
  64. sglang/srt/lora/lora_registry.py +188 -0
  65. sglang/srt/lora/mem_pool.py +2 -2
  66. sglang/srt/managers/cache_controller.py +62 -13
  67. sglang/srt/managers/io_struct.py +19 -1
  68. sglang/srt/managers/mm_utils.py +154 -35
  69. sglang/srt/managers/multimodal_processor.py +3 -14
  70. sglang/srt/managers/schedule_batch.py +27 -11
  71. sglang/srt/managers/scheduler.py +48 -26
  72. sglang/srt/managers/tokenizer_manager.py +62 -28
  73. sglang/srt/managers/tp_worker.py +5 -4
  74. sglang/srt/mem_cache/allocator.py +67 -7
  75. sglang/srt/mem_cache/hicache_storage.py +17 -1
  76. sglang/srt/mem_cache/hiradix_cache.py +35 -18
  77. sglang/srt/mem_cache/memory_pool_host.py +3 -0
  78. sglang/srt/model_executor/cuda_graph_runner.py +61 -25
  79. sglang/srt/model_executor/forward_batch_info.py +201 -29
  80. sglang/srt/model_executor/model_runner.py +109 -37
  81. sglang/srt/models/deepseek_v2.py +63 -30
  82. sglang/srt/models/glm4_moe.py +1035 -0
  83. sglang/srt/models/glm4_moe_nextn.py +167 -0
  84. sglang/srt/models/interns1.py +328 -0
  85. sglang/srt/models/internvl.py +143 -47
  86. sglang/srt/models/llava.py +9 -5
  87. sglang/srt/models/minicpmo.py +4 -1
  88. sglang/srt/models/mllama4.py +10 -3
  89. sglang/srt/models/qwen2_moe.py +2 -6
  90. sglang/srt/models/qwen3_moe.py +6 -8
  91. sglang/srt/multimodal/processors/base_processor.py +20 -6
  92. sglang/srt/multimodal/processors/clip.py +2 -2
  93. sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
  94. sglang/srt/multimodal/processors/gemma3.py +2 -2
  95. sglang/srt/multimodal/processors/gemma3n.py +2 -2
  96. sglang/srt/multimodal/processors/internvl.py +21 -8
  97. sglang/srt/multimodal/processors/janus_pro.py +2 -2
  98. sglang/srt/multimodal/processors/kimi_vl.py +2 -2
  99. sglang/srt/multimodal/processors/llava.py +4 -4
  100. sglang/srt/multimodal/processors/minicpm.py +2 -3
  101. sglang/srt/multimodal/processors/mlama.py +2 -2
  102. sglang/srt/multimodal/processors/mllama4.py +18 -111
  103. sglang/srt/multimodal/processors/phi4mm.py +2 -2
  104. sglang/srt/multimodal/processors/pixtral.py +2 -2
  105. sglang/srt/multimodal/processors/qwen_audio.py +2 -2
  106. sglang/srt/multimodal/processors/qwen_vl.py +2 -2
  107. sglang/srt/multimodal/processors/vila.py +3 -1
  108. sglang/srt/reasoning_parser.py +48 -5
  109. sglang/srt/sampling/sampling_batch_info.py +6 -5
  110. sglang/srt/server_args.py +132 -60
  111. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  112. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +37 -36
  113. sglang/srt/speculative/eagle_utils.py +51 -23
  114. sglang/srt/speculative/eagle_worker.py +59 -44
  115. sglang/srt/two_batch_overlap.py +9 -5
  116. sglang/srt/utils.py +113 -69
  117. sglang/srt/weight_sync/utils.py +119 -0
  118. sglang/test/runners.py +4 -0
  119. sglang/test/test_activation.py +50 -1
  120. sglang/test/test_utils.py +65 -5
  121. sglang/utils.py +19 -0
  122. sglang/version.py +1 -1
  123. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/METADATA +6 -6
  124. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/RECORD +127 -114
  125. sglang/srt/debug_utils.py +0 -74
  126. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/WHEEL +0 -0
  127. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/licenses/LICENSE +0 -0
  128. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1035 @@
1
+ # Copyright 2025-2026 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
15
+ """Inference-only GLM-4.5 model compatible with HuggingFace weights"""
16
+
17
+ import logging
18
+ from typing import Any, Dict, Iterable, Optional, Tuple
19
+
20
+ import torch
21
+ import torch.nn.functional as F
22
+ from torch import nn
23
+ from transformers import PretrainedConfig
24
+
25
+ from sglang.srt.distributed import (
26
+ get_tensor_model_parallel_rank,
27
+ get_tensor_model_parallel_world_size,
28
+ parallel_state,
29
+ tensor_model_parallel_all_reduce,
30
+ )
31
+ from sglang.srt.layers.activation import SiluAndMul
32
+ from sglang.srt.layers.amx_utils import PackWeightMethod
33
+ from sglang.srt.layers.communicator import (
34
+ LayerCommunicator,
35
+ LayerScatterModes,
36
+ enable_moe_dense_fully_dp,
37
+ )
38
+ from sglang.srt.layers.dp_attention import (
39
+ get_attention_tp_rank,
40
+ get_attention_tp_size,
41
+ get_local_attention_dp_size,
42
+ )
43
+ from sglang.srt.layers.layernorm import RMSNorm
44
+ from sglang.srt.layers.linear import (
45
+ ColumnParallelLinear,
46
+ MergedColumnParallelLinear,
47
+ QKVParallelLinear,
48
+ ReplicatedLinear,
49
+ RowParallelLinear,
50
+ )
51
+ from sglang.srt.layers.logits_processor import LogitsProcessor
52
+ from sglang.srt.layers.moe.ep_moe.layer import (
53
+ DeepEPMoE,
54
+ get_moe_impl_class,
55
+ use_flashinfer_trtllm_moe,
56
+ )
57
+ from sglang.srt.layers.moe.topk import TopK
58
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
59
+ from sglang.srt.layers.quantization.fp8_kernel import (
60
+ is_fp8_fnuz,
61
+ per_tensor_quant_mla_fp8,
62
+ per_token_group_quant_mla_deep_gemm_masked_fp8,
63
+ )
64
+ from sglang.srt.layers.radix_attention import RadixAttention
65
+ from sglang.srt.layers.rotary_embedding import get_rope
66
+ from sglang.srt.layers.vocab_parallel_embedding import (
67
+ ParallelLMHead,
68
+ VocabParallelEmbedding,
69
+ )
70
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
71
+ from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
72
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
73
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
74
+ from sglang.srt.models.deepseek_v2 import (
75
+ DeepseekV2DecoderLayer,
76
+ DeepseekV2ForCausalLM,
77
+ DeepseekV2Model,
78
+ DeepseekV2MoE,
79
+ )
80
+ from sglang.srt.two_batch_overlap import (
81
+ MaybeTboDeepEPDispatcher,
82
+ model_forward_maybe_tbo,
83
+ )
84
+ from sglang.srt.utils import (
85
+ BumpAllocator,
86
+ DeepEPMode,
87
+ LazyValue,
88
+ add_prefix,
89
+ bind_or_assign,
90
+ cpu_has_amx_support,
91
+ get_bool_env_var,
92
+ get_device_sm,
93
+ get_int_env_var,
94
+ is_cpu,
95
+ is_cuda,
96
+ is_flashinfer_available,
97
+ is_hip,
98
+ is_non_idle_and_non_empty,
99
+ log_info_on_rank0,
100
+ use_intel_amx_backend,
101
+ )
102
+
103
+ _is_hip = is_hip()
104
+ _is_cuda = is_cuda()
105
+ _is_fp8_fnuz = is_fp8_fnuz()
106
+ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
107
+ _is_cpu_amx_available = cpu_has_amx_support()
108
+ _is_cpu = is_cpu()
109
+ _device_sm = get_device_sm()
110
+
111
+ if _is_cuda:
112
+ from sgl_kernel import dsv3_router_gemm
113
+ elif _is_cpu and _is_cpu_amx_available:
114
+ pass
115
+
116
+ logger = logging.getLogger(__name__)
117
+
118
+
119
+ class Glm4MoeMLP(nn.Module):
120
+ def __init__(
121
+ self,
122
+ hidden_size: int,
123
+ intermediate_size: int,
124
+ hidden_act: str,
125
+ quant_config: Optional[QuantizationConfig] = None,
126
+ reduce_results: bool = True,
127
+ prefix: str = "",
128
+ tp_rank: Optional[int] = None,
129
+ tp_size: Optional[int] = None,
130
+ ) -> None:
131
+ super().__init__()
132
+ self.tp_size = tp_size
133
+
134
+ self.gate_up_proj = MergedColumnParallelLinear(
135
+ hidden_size,
136
+ [intermediate_size] * 2,
137
+ bias=False,
138
+ quant_config=quant_config,
139
+ prefix=add_prefix("gate_up_proj", prefix),
140
+ tp_rank=tp_rank,
141
+ tp_size=tp_size,
142
+ )
143
+ self.down_proj = RowParallelLinear(
144
+ intermediate_size,
145
+ hidden_size,
146
+ bias=False,
147
+ quant_config=quant_config,
148
+ reduce_results=reduce_results,
149
+ prefix=add_prefix("down_proj", prefix),
150
+ tp_rank=tp_rank,
151
+ tp_size=tp_size,
152
+ )
153
+ if hidden_act != "silu":
154
+ raise ValueError(
155
+ f"Unsupported activation: {hidden_act}. "
156
+ "Only silu is supported for now."
157
+ )
158
+ self.act_fn = SiluAndMul()
159
+
160
+ def forward(self, x, forward_batch=None, can_fuse_mlp_allreduce=False):
161
+ if (self.tp_size == 1) and x.shape[0] == 0:
162
+ return x
163
+
164
+ gate_up, _ = self.gate_up_proj(x)
165
+ x = self.act_fn(gate_up)
166
+ x, _ = self.down_proj(x, can_fuse_mlp_allreduce=can_fuse_mlp_allreduce)
167
+ return x
168
+
169
+
170
+ class Glm4MoeAttention(nn.Module):
171
+ def __init__(
172
+ self,
173
+ hidden_size: int,
174
+ num_heads: int,
175
+ num_kv_heads: int,
176
+ layer_id: int = 0,
177
+ rope_theta: float = 10000,
178
+ partial_rotary_factor: float = 0.5,
179
+ rope_scaling: Optional[Dict[str, Any]] = None,
180
+ max_position_embeddings: int = 8192,
181
+ head_dim: Optional[int] = None,
182
+ rms_norm_eps: float = 1e-05,
183
+ attention_bias: bool = True,
184
+ quant_config: Optional[QuantizationConfig] = None,
185
+ use_qk_norm: bool = False,
186
+ prefix: str = "",
187
+ alt_stream: Optional[torch.cuda.Stream] = None,
188
+ ) -> None:
189
+ super().__init__()
190
+ self.hidden_size = hidden_size
191
+
192
+ attn_tp_rank = get_attention_tp_rank()
193
+ attn_tp_size = get_attention_tp_size()
194
+
195
+ self.total_num_heads = num_heads
196
+ assert self.total_num_heads % attn_tp_size == 0
197
+ self.num_heads = self.total_num_heads // attn_tp_size
198
+ self.total_num_kv_heads = num_kv_heads
199
+ if self.total_num_kv_heads >= attn_tp_size:
200
+ # Number of KV heads is greater than TP size, so we partition
201
+ # the KV heads across multiple tensor parallel GPUs.
202
+ assert self.total_num_kv_heads % attn_tp_size == 0
203
+ else:
204
+ # Number of KV heads is less than TP size, so we replicate
205
+ # the KV heads across multiple tensor parallel GPUs.
206
+ assert attn_tp_size % self.total_num_kv_heads == 0
207
+ self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size)
208
+ self.head_dim = head_dim or hidden_size // self.total_num_heads
209
+ self.q_size = self.num_heads * self.head_dim
210
+ self.kv_size = self.num_kv_heads * self.head_dim
211
+ self.scaling = self.head_dim**-0.5
212
+ self.rope_theta = rope_theta
213
+ self.use_qk_norm = use_qk_norm
214
+ self.max_position_embeddings = max_position_embeddings
215
+ self.tp_rank = get_tensor_model_parallel_rank()
216
+
217
+ self.qkv_proj = QKVParallelLinear(
218
+ hidden_size,
219
+ self.head_dim,
220
+ self.total_num_heads,
221
+ self.total_num_kv_heads,
222
+ bias=attention_bias,
223
+ quant_config=quant_config,
224
+ tp_rank=attn_tp_rank,
225
+ tp_size=attn_tp_size,
226
+ prefix=add_prefix("qkv_proj", prefix),
227
+ )
228
+
229
+ self.o_proj = RowParallelLinear(
230
+ self.total_num_heads * self.head_dim,
231
+ hidden_size,
232
+ bias=False,
233
+ quant_config=quant_config,
234
+ tp_rank=attn_tp_rank,
235
+ tp_size=attn_tp_size,
236
+ reduce_results=False,
237
+ prefix=add_prefix("o_proj", prefix),
238
+ )
239
+
240
+ self.rotary_emb = get_rope(
241
+ self.head_dim,
242
+ rotary_dim=self.head_dim,
243
+ max_position=max_position_embeddings,
244
+ partial_rotary_factor=partial_rotary_factor,
245
+ base=rope_theta,
246
+ rope_scaling=rope_scaling,
247
+ )
248
+ self.attn = RadixAttention(
249
+ self.num_heads,
250
+ self.head_dim,
251
+ self.scaling,
252
+ num_kv_heads=self.num_kv_heads,
253
+ layer_id=layer_id,
254
+ prefix=add_prefix("attn", prefix),
255
+ )
256
+
257
+ if self.use_qk_norm:
258
+ self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
259
+ self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
260
+ self.alt_stream = alt_stream
261
+
262
+ def _apply_qk_norm(
263
+ self, q: torch.Tensor, k: torch.Tensor
264
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
265
+ # overlap qk norm
266
+ if self.alt_stream is not None and get_is_capture_mode():
267
+ current_stream = torch.cuda.current_stream()
268
+ self.alt_stream.wait_stream(current_stream)
269
+ q_by_head = q.reshape(-1, self.head_dim)
270
+ q_by_head = self.q_norm(q_by_head)
271
+ with torch.cuda.stream(self.alt_stream):
272
+ k_by_head = k.reshape(-1, self.head_dim)
273
+ k_by_head = self.k_norm(k_by_head)
274
+ current_stream.wait_stream(self.alt_stream)
275
+ else:
276
+ q_by_head = q.reshape(-1, self.head_dim)
277
+ q_by_head = self.q_norm(q_by_head)
278
+ k_by_head = k.reshape(-1, self.head_dim)
279
+ k_by_head = self.k_norm(k_by_head)
280
+ q = q_by_head.view(q.shape)
281
+ k = k_by_head.view(k.shape)
282
+ return q, k
283
+
284
+ def op_prepare(self, state):
285
+ state.attn_intermediate_state = self.forward_prepare(
286
+ positions=state.positions,
287
+ hidden_states=state.pop("hidden_states_after_comm_pre_attn"),
288
+ forward_batch=state.forward_batch,
289
+ )
290
+
291
+ def op_core(self, state):
292
+ state.hidden_states_after_attn = self.forward_core(
293
+ state.pop("attn_intermediate_state")
294
+ )
295
+
296
+ def forward_prepare(
297
+ self,
298
+ positions: torch.Tensor,
299
+ hidden_states: torch.Tensor,
300
+ forward_batch: ForwardBatch,
301
+ ):
302
+ if hidden_states.shape[0] == 0:
303
+ return hidden_states, forward_batch, None
304
+ qkv, _ = self.qkv_proj(hidden_states)
305
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
306
+ if self.use_qk_norm:
307
+ q, k = self._apply_qk_norm(q, k)
308
+ q, k = self.rotary_emb(positions, q, k)
309
+ inner_state = q, k, v, forward_batch
310
+ return None, forward_batch, inner_state
311
+
312
+ def forward_core(self, intermediate_state):
313
+ hidden_states, forward_batch, inner_state = intermediate_state
314
+ if inner_state is None:
315
+ return hidden_states
316
+ attn_output = self.attn(*inner_state)
317
+ output, _ = self.o_proj(attn_output)
318
+ return output
319
+
320
+ def forward(
321
+ self,
322
+ positions: torch.Tensor,
323
+ hidden_states: torch.Tensor,
324
+ forward_batch: ForwardBatch,
325
+ ) -> torch.Tensor:
326
+ s = self.forward_prepare(
327
+ positions=positions,
328
+ hidden_states=hidden_states,
329
+ forward_batch=forward_batch,
330
+ )
331
+ return self.forward_core(s)
332
+
333
+
334
+ class Glm4MoeGate(nn.Module):
335
+ def __init__(
336
+ self,
337
+ config,
338
+ prefix: str = "",
339
+ is_nextn: bool = False,
340
+ ):
341
+ super().__init__()
342
+ self.is_nextn = is_nextn
343
+ self.weight = nn.Parameter(
344
+ torch.empty((config.n_routed_experts, config.hidden_size))
345
+ )
346
+ self.e_score_correction_bias = nn.Parameter(
347
+ torch.empty((config.n_routed_experts))
348
+ )
349
+ if _is_cpu and _is_cpu_amx_available:
350
+ self.quant_method = PackWeightMethod(weight_names=["weight"])
351
+
352
+ def forward(self, hidden_states):
353
+ if use_intel_amx_backend(self):
354
+ return torch.ops.sgl_kernel.weight_packed_linear(
355
+ hidden_states,
356
+ self.weight,
357
+ None, # bias
358
+ True, # is_vnni
359
+ )
360
+
361
+ # NOTE: For some unknown reason, router_gemm seems degrade accept length.
362
+ if (
363
+ _is_cuda
364
+ and not self.is_nextn
365
+ and hidden_states.shape[0] < 4
366
+ and hidden_states.shape[1] == 7168
367
+ and self.weight.shape[0] == 256
368
+ and _device_sm >= 90
369
+ ):
370
+ logits = dsv3_router_gemm(hidden_states, self.weight).to(
371
+ hidden_states.dtype
372
+ )
373
+ else:
374
+ logits = F.linear(hidden_states, self.weight, None)
375
+
376
+ return logits
377
+
378
+
379
+ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
380
+ def __init__(
381
+ self,
382
+ config: PretrainedConfig,
383
+ layer_id: int,
384
+ quant_config: Optional[QuantizationConfig] = None,
385
+ prefix: str = "",
386
+ alt_stream: Optional[torch.cuda.Stream] = None,
387
+ is_nextn: bool = False,
388
+ ):
389
+ nn.Module.__init__(self)
390
+ self.tp_size = get_tensor_model_parallel_world_size()
391
+ self.routed_scaling_factor = config.routed_scaling_factor
392
+ self.n_shared_experts = config.n_shared_experts
393
+ self.num_fused_shared_experts = (
394
+ 0
395
+ if global_server_args_dict["disable_shared_experts_fusion"]
396
+ else config.n_shared_experts
397
+ )
398
+ self.config = config
399
+ self.layer_id = layer_id
400
+ self.alt_stream = alt_stream
401
+
402
+ if self.tp_size > config.n_routed_experts:
403
+ raise ValueError(
404
+ f"Tensor parallel size {self.tp_size} is greater than "
405
+ f"the number of experts {config.n_routed_experts}."
406
+ )
407
+
408
+ if config.hidden_act != "silu":
409
+ raise ValueError(
410
+ f"Unsupported activation: {config.hidden_act}. "
411
+ "Only silu is supported for now."
412
+ )
413
+
414
+ self.gate = Glm4MoeGate(
415
+ config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
416
+ )
417
+
418
+ self.topk = (
419
+ TopK(
420
+ top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
421
+ renormalize=config.norm_topk_prob,
422
+ use_grouped_topk=True,
423
+ num_expert_group=config.n_group,
424
+ num_fused_shared_experts=self.num_fused_shared_experts,
425
+ topk_group=config.topk_group,
426
+ correction_bias=self.gate.e_score_correction_bias,
427
+ routed_scaling_factor=self.routed_scaling_factor,
428
+ )
429
+ if not use_flashinfer_trtllm_moe
430
+ else None
431
+ )
432
+
433
+ self.experts = get_moe_impl_class()(
434
+ num_experts=config.n_routed_experts
435
+ + self.num_fused_shared_experts
436
+ + global_server_args_dict["ep_num_redundant_experts"],
437
+ top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
438
+ hidden_size=config.hidden_size,
439
+ intermediate_size=config.moe_intermediate_size,
440
+ layer_id=self.layer_id,
441
+ quant_config=quant_config,
442
+ routed_scaling_factor=self.routed_scaling_factor,
443
+ prefix=add_prefix("experts", prefix),
444
+ **(
445
+ dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
446
+ if global_server_args_dict["enable_deepep_moe"]
447
+ else {}
448
+ ),
449
+ # Additional args for FusedMoE
450
+ **(
451
+ dict(
452
+ enable_flashinfer_cutlass_moe=True,
453
+ enable_ep_moe=global_server_args_dict["enable_ep_moe"],
454
+ )
455
+ if global_server_args_dict["enable_flashinfer_cutlass_moe"]
456
+ else {}
457
+ ),
458
+ **(
459
+ dict(
460
+ renormalize=config.norm_topk_prob,
461
+ use_grouped_topk=True,
462
+ num_expert_group=config.n_group,
463
+ num_fused_shared_experts=self.num_fused_shared_experts,
464
+ topk_group=config.topk_group,
465
+ correction_bias=self.gate.e_score_correction_bias,
466
+ )
467
+ if use_flashinfer_trtllm_moe
468
+ else {}
469
+ ),
470
+ )
471
+
472
+ self.shared_experts_is_int8 = False
473
+ self.shared_experts_is_fp8 = False
474
+ # self.shared_experts_weight_block_size = None
475
+ if config.n_shared_experts is not None and self.num_fused_shared_experts == 0:
476
+ intermediate_size = config.moe_intermediate_size * config.n_shared_experts
477
+ self.shared_experts = Glm4MoeMLP(
478
+ hidden_size=config.hidden_size,
479
+ intermediate_size=intermediate_size,
480
+ hidden_act=config.hidden_act,
481
+ quant_config=quant_config,
482
+ reduce_results=False,
483
+ prefix=add_prefix("shared_experts", prefix),
484
+ **(
485
+ dict(tp_rank=0, tp_size=1)
486
+ if global_server_args_dict["enable_deepep_moe"]
487
+ else {}
488
+ ),
489
+ )
490
+ is_packed_weight = hasattr(
491
+ self.shared_experts.gate_up_proj.quant_method, "quant_config"
492
+ )
493
+ self.shared_experts_is_int8 = (
494
+ not is_packed_weight
495
+ and self.shared_experts.gate_up_proj.weight.dtype == torch.int8
496
+ )
497
+ self.shared_experts_is_fp8 = (
498
+ not is_packed_weight
499
+ and self.shared_experts.gate_up_proj.weight.dtype == torch.float8_e4m3fn
500
+ )
501
+
502
+ self.top_k = config.num_experts_per_tok
503
+
504
+ if global_server_args_dict["enable_deepep_moe"]:
505
+ # TODO: we will support tp < ep in the future
506
+ self.ep_size = get_tensor_model_parallel_world_size()
507
+ self.num_experts = (
508
+ config.n_routed_experts
509
+ + global_server_args_dict["ep_num_redundant_experts"]
510
+ )
511
+ self.renormalize = config.norm_topk_prob
512
+ self.topk_group = config.topk_group
513
+ self.num_expert_group = config.n_group
514
+ self.correction_bias = (
515
+ self.gate.e_score_correction_bias.data
516
+ if self.gate.e_score_correction_bias is not None
517
+ else None
518
+ )
519
+
520
+ self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
521
+ group=parallel_state.get_tp_group().device_group,
522
+ router_topk=self.top_k,
523
+ permute_fusion=True,
524
+ num_experts=self.num_experts,
525
+ num_local_experts=config.n_routed_experts // self.tp_size,
526
+ hidden_size=config.hidden_size,
527
+ params_dtype=config.torch_dtype,
528
+ deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
529
+ async_finish=True,
530
+ return_recv_hook=True,
531
+ )
532
+
533
+ self._enable_deepep_moe = global_server_args_dict["enable_deepep_moe"]
534
+
535
+
536
+ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
537
+ def __init__(
538
+ self,
539
+ config: PretrainedConfig,
540
+ layer_id: int,
541
+ quant_config: Optional[QuantizationConfig] = None,
542
+ is_nextn: bool = False,
543
+ prefix: str = "",
544
+ alt_stream: Optional[torch.cuda.Stream] = None,
545
+ ) -> None:
546
+ nn.Module.__init__(self)
547
+ self.hidden_size = config.hidden_size
548
+ self.config = config
549
+ rope_theta = getattr(config, "rope_theta", 10000)
550
+ rope_scaling = getattr(config, "rope_scaling", None)
551
+ partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5)
552
+ max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
553
+ head_dim = getattr(
554
+ config, "head_dim", config.hidden_size // config.num_attention_heads
555
+ )
556
+ rms_norm_eps = config.rms_norm_eps
557
+ attention_bias = config.attention_bias
558
+ self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
559
+ self.layer_id = layer_id
560
+ self.self_attn = Glm4MoeAttention(
561
+ hidden_size=self.hidden_size,
562
+ num_heads=config.num_attention_heads,
563
+ num_kv_heads=config.num_key_value_heads,
564
+ layer_id=layer_id,
565
+ rope_theta=rope_theta,
566
+ rope_scaling=rope_scaling,
567
+ partial_rotary_factor=partial_rotary_factor,
568
+ max_position_embeddings=max_position_embeddings,
569
+ head_dim=head_dim,
570
+ rms_norm_eps=rms_norm_eps,
571
+ attention_bias=attention_bias,
572
+ quant_config=quant_config,
573
+ prefix=add_prefix("self_attn", prefix),
574
+ use_qk_norm=config.use_qk_norm,
575
+ )
576
+
577
+ self.is_layer_sparse = self._is_layer_sparse(layer_id, is_nextn=is_nextn)
578
+ is_previous_layer_sparse = self._is_layer_sparse(layer_id - 1, is_nextn=False)
579
+
580
+ num_layers = 1 if is_nextn else config.num_hidden_layers
581
+ self.layer_scatter_modes = LayerScatterModes.init_new(
582
+ layer_id=layer_id,
583
+ num_layers=num_layers,
584
+ is_layer_sparse=self.is_layer_sparse,
585
+ is_previous_layer_sparse=is_previous_layer_sparse,
586
+ )
587
+
588
+ if self.is_layer_sparse:
589
+ self.mlp = Glm4MoeSparseMoeBlock(
590
+ config=config,
591
+ quant_config=quant_config,
592
+ prefix=add_prefix("mlp", prefix),
593
+ layer_id=self.layer_id,
594
+ )
595
+ else:
596
+ if enable_moe_dense_fully_dp():
597
+ mlp_tp_rank, mlp_tp_size = 0, 1
598
+ else:
599
+ mlp_tp_rank, mlp_tp_size = None, None
600
+ self.mlp = Glm4MoeMLP(
601
+ hidden_size=config.hidden_size,
602
+ intermediate_size=config.intermediate_size,
603
+ hidden_act=config.hidden_act,
604
+ quant_config=quant_config,
605
+ prefix=add_prefix("mlp", prefix),
606
+ tp_rank=mlp_tp_rank,
607
+ tp_size=mlp_tp_size,
608
+ )
609
+
610
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
611
+ self.post_attention_layernorm = RMSNorm(
612
+ config.hidden_size, eps=config.rms_norm_eps
613
+ )
614
+
615
+ self.layer_communicator = LayerCommunicator(
616
+ layer_scatter_modes=self.layer_scatter_modes,
617
+ input_layernorm=self.input_layernorm,
618
+ post_attention_layernorm=self.post_attention_layernorm,
619
+ )
620
+
621
+ def forward(
622
+ self,
623
+ positions: torch.Tensor,
624
+ hidden_states: torch.Tensor,
625
+ forward_batch: ForwardBatch,
626
+ residual: Optional[torch.Tensor],
627
+ zero_allocator: BumpAllocator,
628
+ ) -> torch.Tensor:
629
+ hidden_states, residual = self.layer_communicator.prepare_attn(
630
+ hidden_states, residual, forward_batch
631
+ )
632
+
633
+ hidden_states = self.self_attn(
634
+ positions=positions,
635
+ hidden_states=hidden_states,
636
+ forward_batch=forward_batch,
637
+ )
638
+
639
+ hidden_states, residual = self.layer_communicator.prepare_mlp(
640
+ hidden_states, residual, forward_batch
641
+ )
642
+
643
+ hidden_states = self.mlp(hidden_states, forward_batch)
644
+
645
+ hidden_states, residual = self.layer_communicator.postprocess_layer(
646
+ hidden_states, residual, forward_batch
647
+ )
648
+
649
+ return hidden_states, residual
650
+
651
+
652
+ class Glm4MoeModel(DeepseekV2Model):
653
+ def __init__(
654
+ self,
655
+ config: PretrainedConfig,
656
+ quant_config: Optional[QuantizationConfig] = None,
657
+ prefix: str = "",
658
+ ) -> None:
659
+ nn.Module.__init__(self)
660
+ self.padding_id = config.pad_token_id
661
+ self.vocab_size = config.vocab_size
662
+ self.first_k_dense_replace = config.first_k_dense_replace
663
+
664
+ self.embed_tokens = VocabParallelEmbedding(
665
+ config.vocab_size,
666
+ config.hidden_size,
667
+ enable_tp=not global_server_args_dict["enable_dp_attention"],
668
+ )
669
+ self.alt_stream = torch.cuda.Stream() if _is_cuda else None
670
+ self.layers = nn.ModuleList(
671
+ [
672
+ Glm4MoeDecoderLayer(
673
+ config,
674
+ layer_id,
675
+ quant_config=quant_config,
676
+ prefix=add_prefix(f"layers.{layer_id}", prefix),
677
+ alt_stream=self.alt_stream,
678
+ )
679
+ for layer_id in range(config.num_hidden_layers)
680
+ ]
681
+ )
682
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
683
+
684
+ self.dp_size = get_local_attention_dp_size()
685
+
686
+
687
+ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
688
+
689
+ def __init__(
690
+ self,
691
+ config: PretrainedConfig,
692
+ quant_config: Optional[QuantizationConfig] = None,
693
+ prefix: str = "",
694
+ ) -> None:
695
+ nn.Module.__init__(self)
696
+ config.moe_layer_freq = 1
697
+ self.config = config
698
+ self.tp_size = get_tensor_model_parallel_world_size()
699
+ self.quant_config = quant_config
700
+ self.determine_num_fused_shared_experts("Glm4MoeForCausalLM")
701
+ self.model = Glm4MoeModel(
702
+ config, quant_config, prefix=add_prefix("model", prefix)
703
+ )
704
+ self.lm_head = ParallelLMHead(
705
+ config.vocab_size,
706
+ config.hidden_size,
707
+ quant_config=quant_config,
708
+ prefix=add_prefix("lm_head", prefix),
709
+ use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
710
+ )
711
+ self.logits_processor = LogitsProcessor(config)
712
+ self.dp_size = get_local_attention_dp_size()
713
+
714
+ self._routed_experts_weights_of_layer = LazyValue(
715
+ lambda: {
716
+ layer_id: layer.mlp.get_moe_weights()
717
+ for layer_id, layer in enumerate(self.model.layers)
718
+ if isinstance(layer.mlp, DeepseekV2MoE)
719
+ }
720
+ )
721
+
722
+ def determine_num_fused_shared_experts(
723
+ self, architecture: str = "DeepseekV3ForCausalLM"
724
+ ):
725
+ self.num_fused_shared_experts = 0
726
+ if global_server_args_dict["disable_shared_experts_fusion"]:
727
+ return
728
+
729
+ # Only Deepseek V3/R1 can use shared experts fusion optimization now.
730
+ disable_reason = None
731
+ if (
732
+ not _is_cuda
733
+ or torch.cuda.get_device_capability("cuda") < (8, 0)
734
+ or self.config.architectures[0] != architecture
735
+ or self.config.n_routed_experts != 128
736
+ or self.config.n_shared_experts != 1
737
+ ):
738
+ disable_reason = "Only GLM-4.5 on NV-platform with capability >= 80 can use shared experts fusion optimization."
739
+ elif (
740
+ global_server_args_dict["enable_deepep_moe"]
741
+ or global_server_args_dict["enable_ep_moe"]
742
+ ):
743
+ disable_reason = "Deepseek GLM-4.5 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
744
+
745
+ if disable_reason is not None:
746
+ global_server_args_dict["disable_shared_experts_fusion"] = True
747
+ log_info_on_rank0(
748
+ logger,
749
+ f"{disable_reason} Shared experts fusion optimization is disabled.",
750
+ )
751
+ return
752
+
753
+ self.num_fused_shared_experts = self.config.n_shared_experts
754
+
755
+ def get_input_embeddings(self) -> nn.Embedding:
756
+ return self.model.embed_tokens
757
+
758
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
759
+
760
+ if is_nextn:
761
+ if hasattr(self.config, "num_nextn_predict_layers"):
762
+ num_nextn_layers = self.config.num_nextn_predict_layers
763
+ assert num_nextn_layers == 1, "Only 1 nextn layer is supported"
764
+ # compatible with old design
765
+ nextn_layer_id = (
766
+ 0
767
+ if self.config.num_hidden_layers == 1
768
+ else self.config.num_hidden_layers
769
+ )
770
+ else:
771
+ raise ValueError("num_nextn_predict_layers is not in the config")
772
+
773
+ stacked_params_mapping = [
774
+ # (param_name, shard_name, shard_id)
775
+ ("qkv_proj", "q_proj", "q"),
776
+ ("qkv_proj", "k_proj", "k"),
777
+ ("qkv_proj", "v_proj", "v"),
778
+ ("gate_up_proj", "gate_proj", 0),
779
+ ("gate_up_proj", "up_proj", 1),
780
+ ]
781
+ if self.num_fused_shared_experts > 0:
782
+ assert self.num_fused_shared_experts == 1
783
+ weights_list = list(weights)
784
+ weights_dict = dict(weights_list)
785
+ if self.quant_config is not None:
786
+ if self.quant_config.get_name() == "w8a8_int8":
787
+ suffix_list = [
788
+ "down_proj.weight",
789
+ "down_proj.weight_scale",
790
+ "gate_proj.weight",
791
+ "gate_proj.weight_scale",
792
+ "up_proj.weight",
793
+ "up_proj.weight_scale",
794
+ ]
795
+ elif (
796
+ self.quant_config.get_name() == "fp8"
797
+ or self.quant_config.get_name() == "blockwise_int8"
798
+ or self.quant_config.get_name() == "compressed_tensors"
799
+ ):
800
+ suffix_list = [
801
+ "down_proj.weight",
802
+ "down_proj.weight_scale",
803
+ "gate_proj.weight",
804
+ "gate_proj.weight_scale",
805
+ "up_proj.weight",
806
+ "up_proj.weight_scale",
807
+ ]
808
+ elif self.quant_config.get_name() == "awq":
809
+ suffix_list = [
810
+ "down_proj.qweight",
811
+ "down_proj.qzeros",
812
+ "down_proj.scales",
813
+ "gate_proj.qweight",
814
+ "gate_proj.qzeros",
815
+ "gate_proj.scales",
816
+ "up_proj.qweight",
817
+ "up_proj.qzeros",
818
+ "up_proj.scales",
819
+ ]
820
+ elif self.quant_config.get_name() == "modelopt_fp4":
821
+ suffix_list = [
822
+ "down_proj.weight",
823
+ "down_proj.weight_scale",
824
+ "down_proj.weight_scale_2",
825
+ "down_proj.input_scale",
826
+ "gate_proj.weight",
827
+ "gate_proj.weight_scale",
828
+ "gate_proj.weight_scale_2",
829
+ "gate_proj.input_scale",
830
+ "up_proj.weight",
831
+ "up_proj.weight_scale",
832
+ "up_proj.weight_scale_2",
833
+ "up_proj.input_scale",
834
+ ]
835
+ else:
836
+ raise ValueError(
837
+ f"Unsupported shared expert fusion for quantization: {self.quant_config.get_name()}."
838
+ )
839
+ else:
840
+ suffix_list = [
841
+ "down_proj.weight",
842
+ "gate_proj.weight",
843
+ "up_proj.weight",
844
+ ]
845
+ names_to_remove = []
846
+
847
+ moe_layers = (
848
+ range(
849
+ self.config.first_k_dense_replace,
850
+ self.config.num_hidden_layers,
851
+ self.config.moe_layer_freq,
852
+ )
853
+ if not is_nextn
854
+ else [nextn_layer_id]
855
+ )
856
+
857
+ for moe_layer in moe_layers:
858
+ for suffix in suffix_list:
859
+ shared_expert_weight_name = (
860
+ f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
861
+ )
862
+ # online fp8 quantization does not load weight_scale
863
+ if shared_expert_weight_name not in weights_dict:
864
+ continue
865
+ weights_list.append(
866
+ (
867
+ f"model.layers.{moe_layer}."
868
+ f"mlp.experts."
869
+ f"{self.config.n_routed_experts + 0}"
870
+ f".{suffix}",
871
+ weights_dict[shared_expert_weight_name],
872
+ )
873
+ )
874
+ names_to_remove += [shared_expert_weight_name]
875
+ weights = [w for w in weights_list if w[0] not in names_to_remove]
876
+
877
+ # Params for weights, fp8 weight scales, fp8 activation scales
878
+ # (param_name, weight_name, expert_id, shard_id)
879
+ expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
880
+ ckpt_gate_proj_name="gate_proj",
881
+ ckpt_down_proj_name="down_proj",
882
+ ckpt_up_proj_name="up_proj",
883
+ num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
884
+ )
885
+
886
+ # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
887
+ fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
888
+ self.config.q_lora_rank is not None
889
+ )
890
+ cached_a_proj = {} if fuse_qkv_a_proj else None
891
+
892
+ if is_nextn:
893
+ nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
894
+ nextn_spec_weight_names = [
895
+ "shared_head.norm",
896
+ "eh_proj",
897
+ "enorm",
898
+ "hnorm",
899
+ ]
900
+
901
+ params_dict = dict(self.named_parameters())
902
+ weight_names = []
903
+ for name, loaded_weight in weights:
904
+ weight_names.append(name)
905
+
906
+ if not is_nextn:
907
+ if hasattr(self.config, "num_nextn_predict_layers"):
908
+ num_nextn_layers = self.config.num_nextn_predict_layers
909
+ if num_nextn_layers > 0 and name.startswith("model.layers"):
910
+ name_list = name.split(".")
911
+ if (
912
+ len(name_list) >= 3
913
+ and int(name_list[2]) >= self.config.num_hidden_layers
914
+ ):
915
+ continue
916
+ else:
917
+ if not name.startswith(nextn_layer_prefix):
918
+ continue
919
+
920
+ # Use shared head and embed weights from target model
921
+ if "shared_head.head" in name or "embed_tokens" in name:
922
+ continue
923
+
924
+ is_decoder = True
925
+ # For nextn specific weights
926
+ for weight_name in nextn_spec_weight_names:
927
+ if weight_name in name:
928
+ name = name.replace(nextn_layer_prefix, "model")
929
+ is_decoder = False
930
+ break
931
+ # For decoder layer weights
932
+ if is_decoder:
933
+ name = name.replace(nextn_layer_prefix, "model.decoder")
934
+
935
+ if "rotary_emb.inv_freq" in name:
936
+ continue
937
+ for param_name, weight_name, shard_id in stacked_params_mapping:
938
+ # Skip non-stacked layers and experts (experts handled below).
939
+ if weight_name not in name:
940
+ continue
941
+ # We have mlp.experts[0].gate_proj in the checkpoint.
942
+ # Since we handle the experts below in expert_params_mapping,
943
+ # we need to skip here BEFORE we update the name, otherwise
944
+ # name will be updated to mlp.experts[0].gate_up_proj, which
945
+ # will then be updated below in expert_params_mapping
946
+ # for mlp.experts[0].gate_gate_up_proj, which breaks load.
947
+ if ("mlp.experts." in name) and name not in params_dict:
948
+ continue
949
+ name = name.replace(weight_name, param_name)
950
+ # Skip loading extra bias for GPTQ models.
951
+ if name.endswith(".bias") and name not in params_dict:
952
+ continue
953
+ param = params_dict[name]
954
+ weight_loader = param.weight_loader
955
+ weight_loader(param, loaded_weight, shard_id)
956
+ break
957
+ else:
958
+ for mapping in expert_params_mapping:
959
+ param_name, weight_name, expert_id, shard_id = mapping
960
+ if weight_name not in name:
961
+ continue
962
+ name = name.replace(weight_name, param_name)
963
+ param = params_dict[name]
964
+ weight_loader = param.weight_loader
965
+ weight_loader(
966
+ param,
967
+ loaded_weight,
968
+ name,
969
+ shard_id=shard_id,
970
+ expert_id=expert_id,
971
+ )
972
+ break
973
+ else:
974
+ # Skip loading extra bias for GPTQ models.
975
+ if name.endswith(".bias") and name not in params_dict:
976
+ continue
977
+ if fuse_qkv_a_proj and (
978
+ "q_a_proj" in name or "kv_a_proj_with_mqa" in name
979
+ ):
980
+ cached_a_proj[name] = loaded_weight
981
+ q_a_proj_name = (
982
+ name
983
+ if "q_a_proj" in name
984
+ else name.replace("kv_a_proj_with_mqa", "q_a_proj")
985
+ )
986
+ kv_a_proj_name = (
987
+ name
988
+ if "kv_a_proj_with_mqa" in name
989
+ else name.replace("q_a_proj", "kv_a_proj_with_mqa")
990
+ )
991
+
992
+ # When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
993
+ if (
994
+ q_a_proj_name in cached_a_proj
995
+ and kv_a_proj_name in cached_a_proj
996
+ ):
997
+ q_a_proj_weight = cached_a_proj[q_a_proj_name]
998
+ kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
999
+ fused_weight = torch.cat(
1000
+ [q_a_proj_weight, kv_a_proj_weight], dim=0
1001
+ )
1002
+ param_name = (
1003
+ name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
1004
+ if "q_a_proj" in name
1005
+ else name.replace(
1006
+ "kv_a_proj_with_mqa", "fused_qkv_a_proj_with_mqa"
1007
+ )
1008
+ )
1009
+ param = params_dict[param_name]
1010
+
1011
+ weight_loader = getattr(
1012
+ param, "weight_loader", default_weight_loader
1013
+ )
1014
+ weight_loader(param, fused_weight)
1015
+ cached_a_proj.pop(q_a_proj_name)
1016
+ cached_a_proj.pop(kv_a_proj_name)
1017
+ else:
1018
+ if (
1019
+ "k_scale" in name or "v_scale" in name
1020
+ ) and name not in params_dict:
1021
+ # modelopt attn kv scale is named differently
1022
+ if any(scale in name for scale in ["k_scale", "v_scale"]):
1023
+ name = name.replace("_proj", "attn_mqa")
1024
+ else:
1025
+ logger.warning(
1026
+ f"Unknown scale found in checkpoint: {name}"
1027
+ )
1028
+ param = params_dict[name]
1029
+ weight_loader = getattr(
1030
+ param, "weight_loader", default_weight_loader
1031
+ )
1032
+ weight_loader(param, loaded_weight)
1033
+
1034
+
1035
+ EntryClass = [Glm4MoeForCausalLM]