sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +119 -17
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +42 -7
  6. sglang/srt/conversation.py +9 -5
  7. sglang/srt/disaggregation/base/conn.py +5 -2
  8. sglang/srt/disaggregation/decode.py +14 -4
  9. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
  10. sglang/srt/disaggregation/mooncake/conn.py +286 -160
  11. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  12. sglang/srt/disaggregation/prefill.py +2 -0
  13. sglang/srt/distributed/parallel_state.py +15 -11
  14. sglang/srt/entrypoints/context.py +227 -0
  15. sglang/srt/entrypoints/engine.py +15 -9
  16. sglang/srt/entrypoints/harmony_utils.py +372 -0
  17. sglang/srt/entrypoints/http_server.py +74 -4
  18. sglang/srt/entrypoints/openai/protocol.py +218 -1
  19. sglang/srt/entrypoints/openai/serving_chat.py +41 -11
  20. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  21. sglang/srt/entrypoints/openai/tool_server.py +175 -0
  22. sglang/srt/entrypoints/tool.py +87 -0
  23. sglang/srt/eplb/expert_location.py +5 -1
  24. sglang/srt/function_call/ebnf_composer.py +1 -0
  25. sglang/srt/function_call/function_call_parser.py +2 -0
  26. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  27. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  28. sglang/srt/function_call/kimik2_detector.py +3 -3
  29. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  30. sglang/srt/hf_transformers_utils.py +30 -3
  31. sglang/srt/jinja_template_utils.py +14 -1
  32. sglang/srt/layers/attention/aiter_backend.py +375 -115
  33. sglang/srt/layers/attention/ascend_backend.py +3 -0
  34. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  35. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  36. sglang/srt/layers/attention/flashinfer_backend.py +52 -13
  37. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  38. sglang/srt/layers/attention/triton_backend.py +85 -14
  39. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  41. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  42. sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
  43. sglang/srt/layers/attention/vision.py +22 -6
  44. sglang/srt/layers/attention/wave_backend.py +627 -0
  45. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  46. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  47. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  48. sglang/srt/layers/communicator.py +29 -14
  49. sglang/srt/layers/dp_attention.py +12 -0
  50. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  51. sglang/srt/layers/linear.py +3 -7
  52. sglang/srt/layers/moe/cutlass_moe.py +12 -3
  53. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  54. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  55. sglang/srt/layers/moe/ep_moe/layer.py +135 -73
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  59. sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
  60. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
  61. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  62. sglang/srt/layers/moe/topk.py +16 -4
  63. sglang/srt/layers/moe/utils.py +16 -0
  64. sglang/srt/layers/quantization/__init__.py +27 -3
  65. sglang/srt/layers/quantization/fp4.py +557 -0
  66. sglang/srt/layers/quantization/fp8.py +3 -6
  67. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  68. sglang/srt/layers/quantization/fp8_utils.py +51 -10
  69. sglang/srt/layers/quantization/modelopt_quant.py +258 -68
  70. sglang/srt/layers/quantization/mxfp4.py +654 -0
  71. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  72. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  73. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  74. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  75. sglang/srt/layers/quantization/quark/utils.py +107 -0
  76. sglang/srt/layers/quantization/unquant.py +60 -6
  77. sglang/srt/layers/quantization/w4afp8.py +21 -12
  78. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  79. sglang/srt/layers/rotary_embedding.py +506 -3
  80. sglang/srt/layers/utils.py +9 -0
  81. sglang/srt/layers/vocab_parallel_embedding.py +8 -3
  82. sglang/srt/lora/backend/base_backend.py +3 -23
  83. sglang/srt/lora/layers.py +60 -114
  84. sglang/srt/lora/lora.py +17 -62
  85. sglang/srt/lora/lora_manager.py +82 -62
  86. sglang/srt/lora/lora_registry.py +23 -11
  87. sglang/srt/lora/mem_pool.py +63 -68
  88. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  89. sglang/srt/lora/utils.py +25 -58
  90. sglang/srt/managers/cache_controller.py +75 -58
  91. sglang/srt/managers/detokenizer_manager.py +1 -1
  92. sglang/srt/managers/io_struct.py +20 -8
  93. sglang/srt/managers/mm_utils.py +6 -13
  94. sglang/srt/managers/multimodal_processor.py +1 -1
  95. sglang/srt/managers/schedule_batch.py +61 -25
  96. sglang/srt/managers/schedule_policy.py +6 -6
  97. sglang/srt/managers/scheduler.py +41 -19
  98. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  99. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  100. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  101. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  102. sglang/srt/managers/template_manager.py +35 -1
  103. sglang/srt/managers/tokenizer_manager.py +47 -30
  104. sglang/srt/managers/tp_worker.py +3 -0
  105. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  106. sglang/srt/mem_cache/allocator.py +61 -87
  107. sglang/srt/mem_cache/hicache_storage.py +1 -1
  108. sglang/srt/mem_cache/hiradix_cache.py +80 -22
  109. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  110. sglang/srt/mem_cache/memory_pool_host.py +34 -36
  111. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  112. sglang/srt/mem_cache/radix_cache.py +2 -5
  113. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  114. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  115. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  116. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  117. sglang/srt/model_executor/cuda_graph_runner.py +29 -9
  118. sglang/srt/model_executor/forward_batch_info.py +61 -19
  119. sglang/srt/model_executor/model_runner.py +148 -37
  120. sglang/srt/model_loader/loader.py +18 -6
  121. sglang/srt/model_loader/weight_utils.py +10 -0
  122. sglang/srt/models/bailing_moe.py +425 -0
  123. sglang/srt/models/deepseek_v2.py +137 -59
  124. sglang/srt/models/ernie4.py +426 -0
  125. sglang/srt/models/ernie4_eagle.py +203 -0
  126. sglang/srt/models/gemma2.py +0 -34
  127. sglang/srt/models/gemma3n_mm.py +38 -0
  128. sglang/srt/models/glm4.py +6 -0
  129. sglang/srt/models/glm4_moe.py +28 -16
  130. sglang/srt/models/glm4v.py +589 -0
  131. sglang/srt/models/glm4v_moe.py +400 -0
  132. sglang/srt/models/gpt_oss.py +1251 -0
  133. sglang/srt/models/granite.py +0 -25
  134. sglang/srt/models/llama.py +0 -25
  135. sglang/srt/models/llama4.py +1 -1
  136. sglang/srt/models/qwen2.py +6 -0
  137. sglang/srt/models/qwen2_5_vl.py +7 -3
  138. sglang/srt/models/qwen2_audio.py +10 -9
  139. sglang/srt/models/qwen2_moe.py +6 -0
  140. sglang/srt/models/qwen3.py +0 -24
  141. sglang/srt/models/qwen3_moe.py +32 -6
  142. sglang/srt/models/registry.py +1 -1
  143. sglang/srt/models/step3_vl.py +9 -0
  144. sglang/srt/models/torch_native_llama.py +0 -24
  145. sglang/srt/models/transformers.py +2 -5
  146. sglang/srt/multimodal/processors/base_processor.py +23 -13
  147. sglang/srt/multimodal/processors/glm4v.py +132 -0
  148. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  149. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  150. sglang/srt/reasoning_parser.py +332 -37
  151. sglang/srt/server_args.py +186 -75
  152. sglang/srt/speculative/eagle_worker.py +16 -0
  153. sglang/srt/two_batch_overlap.py +169 -9
  154. sglang/srt/utils.py +41 -5
  155. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  156. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  157. sglang/test/doc_patch.py +59 -0
  158. sglang/test/few_shot_gsm8k.py +1 -1
  159. sglang/test/few_shot_gsm8k_engine.py +1 -1
  160. sglang/test/run_eval.py +4 -1
  161. sglang/test/runners.py +2 -2
  162. sglang/test/simple_eval_common.py +6 -0
  163. sglang/test/simple_eval_gpqa.py +2 -0
  164. sglang/test/test_fp4_moe.py +118 -36
  165. sglang/test/test_utils.py +1 -1
  166. sglang/utils.py +1 -1
  167. sglang/version.py +1 -1
  168. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
  169. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
  170. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  171. /sglang/{api.py → lang/api.py} +0 -0
  172. /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
  173. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
  174. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  175. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,400 @@
1
+ import logging
2
+ from functools import lru_cache
3
+ from typing import Iterable, Optional, Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from transformers.models.glm4v_moe.configuration_glm4v_moe import Glm4vMoeConfig
8
+
9
+ from sglang.srt.distributed import (
10
+ get_moe_expert_parallel_world_size,
11
+ get_tensor_model_parallel_rank,
12
+ get_tensor_model_parallel_world_size,
13
+ parallel_state,
14
+ tensor_model_parallel_all_reduce,
15
+ )
16
+ from sglang.srt.hf_transformers_utils import get_processor
17
+ from sglang.srt.layers.dp_attention import (
18
+ get_attention_tp_rank,
19
+ get_attention_tp_size,
20
+ get_local_attention_dp_size,
21
+ )
22
+ from sglang.srt.layers.logits_processor import LogitsProcessor
23
+ from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
24
+ from sglang.srt.layers.pooler import Pooler, PoolingType
25
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
26
+ from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
27
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
28
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
29
+ from sglang.srt.models.glm4_moe import Glm4MoeModel
30
+ from sglang.srt.models.glm4v import Glm4vForConditionalGeneration, Glm4vVisionModel
31
+ from sglang.srt.utils import add_prefix, is_cuda, log_info_on_rank0
32
+
33
+ _is_cuda = is_cuda()
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+ cached_get_processor = lru_cache(get_processor)
38
+
39
+
40
+ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
41
+ def __init__(
42
+ self,
43
+ config: Glm4vMoeConfig,
44
+ quant_config: Optional[QuantizationConfig] = None,
45
+ prefix: str = "",
46
+ ) -> None:
47
+ nn.Module.__init__(self)
48
+
49
+ config.moe_layer_freq = 1
50
+ self.config = config
51
+ self.tp_size = get_tensor_model_parallel_world_size()
52
+ self.dp_size = get_local_attention_dp_size()
53
+ self.quant_config = quant_config
54
+ self.determine_num_fused_shared_experts("Glm4MoeForCausalLM")
55
+ self.num_fused_shared_experts = (
56
+ 0
57
+ if global_server_args_dict["disable_shared_experts_fusion"]
58
+ else config.n_shared_experts
59
+ )
60
+
61
+ self.model = Glm4MoeModel(
62
+ config,
63
+ quant_config,
64
+ prefix=add_prefix("language_model", prefix),
65
+ )
66
+ self.visual = Glm4vVisionModel(
67
+ config.vision_config,
68
+ norm_eps=getattr(config, "rms_norm_eps", 1e-5),
69
+ quant_config=quant_config,
70
+ prefix=add_prefix("visual", prefix),
71
+ )
72
+
73
+ self.lm_head = ParallelLMHead(
74
+ config.vocab_size,
75
+ config.hidden_size,
76
+ quant_config=quant_config,
77
+ prefix=add_prefix("lm_head", prefix),
78
+ use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
79
+ )
80
+ self.logits_processor = LogitsProcessor(config)
81
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
82
+ self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
83
+
84
+ def determine_num_fused_shared_experts(
85
+ self, architecture: str = "Glm4MoeForCausalLM"
86
+ ):
87
+ self.num_fused_shared_experts = 0
88
+ if global_server_args_dict["disable_shared_experts_fusion"]:
89
+ return
90
+
91
+ # Only Deepseek V3/R1 can use shared experts fusion optimization now.
92
+ disable_reason = None
93
+ if (
94
+ not _is_cuda
95
+ or torch.cuda.get_device_capability("cuda") < (8, 0)
96
+ or self.config.architectures[0] != architecture
97
+ or self.config.n_shared_experts != 1
98
+ ):
99
+ disable_reason = "Only GLM-4.5 on NV-platform with capability >= 80 can use shared experts fusion optimization."
100
+ elif get_moe_expert_parallel_world_size() > 1:
101
+ disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization under expert parallelism."
102
+
103
+ if disable_reason is not None:
104
+ global_server_args_dict["disable_shared_experts_fusion"] = True
105
+ self.num_fused_shared_experts = 0
106
+ log_info_on_rank0(
107
+ logger,
108
+ f"{disable_reason} Shared experts fusion optimization is disabled.",
109
+ )
110
+ return
111
+
112
+ self.num_fused_shared_experts = self.config.n_shared_experts
113
+
114
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
115
+
116
+ if is_nextn:
117
+ if hasattr(self.config, "num_nextn_predict_layers"):
118
+ num_nextn_layers = self.config.num_nextn_predict_layers
119
+ assert num_nextn_layers == 1, "Only 1 nextn layer is supported"
120
+ # compatible with old design
121
+ nextn_layer_id = (
122
+ 0
123
+ if self.config.num_hidden_layers == 1
124
+ else self.config.num_hidden_layers
125
+ )
126
+ else:
127
+ raise ValueError("num_nextn_predict_layers is not in the config")
128
+
129
+ stacked_params_mapping = [
130
+ # (param_name, shard_name, shard_id)
131
+ ("qkv_proj", "q_proj", "q"),
132
+ ("qkv_proj", "k_proj", "k"),
133
+ ("qkv_proj", "v_proj", "v"),
134
+ ("gate_up_proj", "gate_proj", 0),
135
+ ("gate_up_proj", "up_proj", 1),
136
+ ]
137
+ if self.num_fused_shared_experts > 0:
138
+ assert self.num_fused_shared_experts == 1
139
+ weights_list = list(weights)
140
+ weights_dict = dict(weights_list)
141
+ if self.quant_config is not None:
142
+ if self.quant_config.get_name() == "w8a8_int8":
143
+ suffix_list = [
144
+ "down_proj.weight",
145
+ "down_proj.weight_scale",
146
+ "gate_proj.weight",
147
+ "gate_proj.weight_scale",
148
+ "up_proj.weight",
149
+ "up_proj.weight_scale",
150
+ ]
151
+ elif (
152
+ self.quant_config.get_name() == "fp8"
153
+ or self.quant_config.get_name() == "blockwise_int8"
154
+ or self.quant_config.get_name() == "compressed_tensors"
155
+ ):
156
+ suffix_list = [
157
+ "down_proj.weight",
158
+ "down_proj.weight_scale",
159
+ "gate_proj.weight",
160
+ "gate_proj.weight_scale",
161
+ "up_proj.weight",
162
+ "up_proj.weight_scale",
163
+ ]
164
+ elif self.quant_config.get_name() == "awq":
165
+ suffix_list = [
166
+ "down_proj.qweight",
167
+ "down_proj.qzeros",
168
+ "down_proj.scales",
169
+ "gate_proj.qweight",
170
+ "gate_proj.qzeros",
171
+ "gate_proj.scales",
172
+ "up_proj.qweight",
173
+ "up_proj.qzeros",
174
+ "up_proj.scales",
175
+ ]
176
+ elif self.quant_config.get_name() == "modelopt_fp4":
177
+ suffix_list = [
178
+ "down_proj.weight",
179
+ "down_proj.weight_scale",
180
+ "down_proj.weight_scale_2",
181
+ "down_proj.input_scale",
182
+ "gate_proj.weight",
183
+ "gate_proj.weight_scale",
184
+ "gate_proj.weight_scale_2",
185
+ "gate_proj.input_scale",
186
+ "up_proj.weight",
187
+ "up_proj.weight_scale",
188
+ "up_proj.weight_scale_2",
189
+ "up_proj.input_scale",
190
+ ]
191
+ else:
192
+ raise ValueError(
193
+ f"Unsupported shared expert fusion for quantization: {self.quant_config.get_name()}."
194
+ )
195
+ else:
196
+ suffix_list = [
197
+ "down_proj.weight",
198
+ "gate_proj.weight",
199
+ "up_proj.weight",
200
+ ]
201
+ names_to_remove = []
202
+
203
+ moe_layers = (
204
+ range(
205
+ self.config.first_k_dense_replace,
206
+ self.config.num_hidden_layers,
207
+ self.config.moe_layer_freq,
208
+ )
209
+ if not is_nextn
210
+ else [nextn_layer_id]
211
+ )
212
+
213
+ for moe_layer in moe_layers:
214
+ for suffix in suffix_list:
215
+ shared_expert_weight_name = (
216
+ f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
217
+ )
218
+ # online fp8 quantization does not load weight_scale
219
+ if shared_expert_weight_name not in weights_dict:
220
+ continue
221
+ weights_list.append(
222
+ (
223
+ f"model.layers.{moe_layer}."
224
+ f"mlp.experts."
225
+ f"{self.config.n_routed_experts + 0}"
226
+ f".{suffix}",
227
+ weights_dict[shared_expert_weight_name],
228
+ )
229
+ )
230
+ names_to_remove += [shared_expert_weight_name]
231
+ weights = [w for w in weights_list if w[0] not in names_to_remove]
232
+
233
+ # Params for weights, fp8 weight scales, fp8 activation scales
234
+ # (param_name, weight_name, expert_id, shard_id)
235
+ expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
236
+ ckpt_gate_proj_name="gate_proj",
237
+ ckpt_down_proj_name="down_proj",
238
+ ckpt_up_proj_name="up_proj",
239
+ num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
240
+ )
241
+
242
+ # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
243
+ fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
244
+ self.config.q_lora_rank is not None
245
+ )
246
+ cached_a_proj = {} if fuse_qkv_a_proj else None
247
+
248
+ if is_nextn:
249
+ nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
250
+ nextn_spec_weight_names = [
251
+ "shared_head.norm",
252
+ "eh_proj",
253
+ "enorm",
254
+ "hnorm",
255
+ ]
256
+
257
+ params_dict = dict(self.named_parameters())
258
+ weight_names = []
259
+ for name, loaded_weight in weights:
260
+ weight_names.append(name)
261
+
262
+ if not is_nextn:
263
+ if hasattr(self.config, "num_nextn_predict_layers"):
264
+ num_nextn_layers = self.config.num_nextn_predict_layers
265
+ if num_nextn_layers > 0 and name.startswith("model.layers"):
266
+ name_list = name.split(".")
267
+ if (
268
+ len(name_list) >= 3
269
+ and int(name_list[2]) >= self.config.num_hidden_layers
270
+ ):
271
+ continue
272
+ else:
273
+ if not name.startswith(nextn_layer_prefix):
274
+ continue
275
+
276
+ # Use shared head and embed weights from target model
277
+ if "shared_head.head" in name or "embed_tokens" in name:
278
+ continue
279
+
280
+ is_decoder = True
281
+ # For nextn specific weights
282
+ for weight_name in nextn_spec_weight_names:
283
+ if weight_name in name:
284
+ name = name.replace(nextn_layer_prefix, "model")
285
+ is_decoder = False
286
+ break
287
+ # For decoder layer weights
288
+ if is_decoder:
289
+ name = name.replace(nextn_layer_prefix, "model.decoder")
290
+
291
+ if "language_model." in name:
292
+ name = name.replace("language_model.", "")
293
+ if "model.visual." in name:
294
+ name = name.replace("model.visual.", "visual.")
295
+ if "rotary_emb.inv_freq" in name:
296
+ continue
297
+ for param_name, weight_name, shard_id in stacked_params_mapping:
298
+ # Skip non-stacked layers and experts (experts handled below).
299
+ if weight_name not in name:
300
+ continue
301
+ # We have mlp.experts[0].gate_proj in the checkpoint.
302
+ # Since we handle the experts below in expert_params_mapping,
303
+ # we need to skip here BEFORE we update the name, otherwise
304
+ # name will be updated to mlp.experts[0].gate_up_proj, which
305
+ # will then be updated below in expert_params_mapping
306
+ # for mlp.experts[0].gate_gate_up_proj, which breaks load.
307
+ if ("mlp.experts." in name) and name not in params_dict:
308
+ continue
309
+ name = name.replace(weight_name, param_name)
310
+ # Skip loading extra bias for GPTQ models.
311
+ if name.endswith(".bias") and name not in params_dict:
312
+ continue
313
+ param = params_dict[name]
314
+
315
+ weight_loader = param.weight_loader
316
+ weight_loader(param, loaded_weight, shard_id)
317
+ break
318
+ else:
319
+ for mapping in expert_params_mapping:
320
+ param_name, weight_name, expert_id, shard_id = mapping
321
+ if weight_name not in name:
322
+ continue
323
+ name = name.replace(weight_name, param_name)
324
+ param = params_dict[name]
325
+ weight_loader = param.weight_loader
326
+ weight_loader(
327
+ param,
328
+ loaded_weight,
329
+ name,
330
+ shard_id=shard_id,
331
+ expert_id=expert_id,
332
+ )
333
+ break
334
+ else:
335
+ if "visual" in name:
336
+ # adapt to VisionAttention
337
+ name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
338
+
339
+ # Skip loading extra bias for GPTQ models.
340
+ if name.endswith(".bias") and name not in params_dict:
341
+ continue
342
+ if fuse_qkv_a_proj and (
343
+ "q_a_proj" in name or "kv_a_proj_with_mqa" in name
344
+ ):
345
+ cached_a_proj[name] = loaded_weight
346
+ q_a_proj_name = (
347
+ name
348
+ if "q_a_proj" in name
349
+ else name.replace("kv_a_proj_with_mqa", "q_a_proj")
350
+ )
351
+ kv_a_proj_name = (
352
+ name
353
+ if "kv_a_proj_with_mqa" in name
354
+ else name.replace("q_a_proj", "kv_a_proj_with_mqa")
355
+ )
356
+
357
+ # When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
358
+ if (
359
+ q_a_proj_name in cached_a_proj
360
+ and kv_a_proj_name in cached_a_proj
361
+ ):
362
+ q_a_proj_weight = cached_a_proj[q_a_proj_name]
363
+ kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
364
+ fused_weight = torch.cat(
365
+ [q_a_proj_weight, kv_a_proj_weight], dim=0
366
+ )
367
+ param_name = (
368
+ name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
369
+ if "q_a_proj" in name
370
+ else name.replace(
371
+ "kv_a_proj_with_mqa", "fused_qkv_a_proj_with_mqa"
372
+ )
373
+ )
374
+ param = params_dict[param_name]
375
+
376
+ weight_loader = getattr(
377
+ param, "weight_loader", default_weight_loader
378
+ )
379
+ weight_loader(param, fused_weight)
380
+ cached_a_proj.pop(q_a_proj_name)
381
+ cached_a_proj.pop(kv_a_proj_name)
382
+ else:
383
+ if (
384
+ "k_scale" in name or "v_scale" in name
385
+ ) and name not in params_dict:
386
+ # modelopt attn kv scale is named differently
387
+ if any(scale in name for scale in ["k_scale", "v_scale"]):
388
+ name = name.replace("_proj", "attn_mqa")
389
+ else:
390
+ logger.warning(
391
+ f"Unknown scale found in checkpoint: {name}"
392
+ )
393
+ param = params_dict[name]
394
+ weight_loader = getattr(
395
+ param, "weight_loader", default_weight_loader
396
+ )
397
+ weight_loader(param, loaded_weight)
398
+
399
+
400
+ EntryClass = [Glm4vMoeForConditionalGeneration]