sglang 0.4.6.post1__py3-none-any.whl → 0.4.6.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sglang/bench_one_batch.py +2 -0
  2. sglang/check_env.py +3 -3
  3. sglang/srt/configs/__init__.py +4 -0
  4. sglang/srt/configs/kimi_vl.py +38 -0
  5. sglang/srt/configs/kimi_vl_moonvit.py +32 -0
  6. sglang/srt/configs/model_config.py +15 -0
  7. sglang/srt/conversation.py +122 -1
  8. sglang/srt/entrypoints/engine.py +44 -22
  9. sglang/srt/function_call_parser.py +97 -0
  10. sglang/srt/hf_transformers_utils.py +2 -0
  11. sglang/srt/layers/attention/cutlass_mla_backend.py +1 -1
  12. sglang/srt/layers/attention/flashinfer_backend.py +107 -82
  13. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -16
  14. sglang/srt/layers/attention/flashmla_backend.py +3 -0
  15. sglang/srt/layers/dp_attention.py +5 -2
  16. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +1 -3
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  21. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +8 -6
  22. sglang/srt/layers/quantization/__init__.py +2 -2
  23. sglang/srt/layers/quantization/deep_gemm.py +1 -1
  24. sglang/srt/layers/utils.py +35 -0
  25. sglang/srt/lora/layers.py +35 -9
  26. sglang/srt/lora/lora_manager.py +84 -35
  27. sglang/srt/managers/data_parallel_controller.py +52 -34
  28. sglang/srt/managers/multimodal_processors/kimi_vl.py +73 -0
  29. sglang/srt/managers/schedule_batch.py +25 -15
  30. sglang/srt/managers/scheduler.py +263 -59
  31. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -1
  32. sglang/srt/managers/tp_worker.py +51 -16
  33. sglang/srt/managers/tp_worker_overlap_thread.py +9 -3
  34. sglang/srt/mem_cache/memory_pool.py +70 -36
  35. sglang/srt/model_executor/cuda_graph_runner.py +82 -19
  36. sglang/srt/model_executor/forward_batch_info.py +31 -1
  37. sglang/srt/model_executor/model_runner.py +115 -57
  38. sglang/srt/models/deepseek_nextn.py +1 -257
  39. sglang/srt/models/deepseek_v2.py +78 -18
  40. sglang/srt/models/kimi_vl.py +308 -0
  41. sglang/srt/models/kimi_vl_moonvit.py +639 -0
  42. sglang/srt/models/llama.py +92 -30
  43. sglang/srt/models/llama4.py +2 -1
  44. sglang/srt/models/llama_eagle.py +4 -1
  45. sglang/srt/models/llama_eagle3.py +4 -1
  46. sglang/srt/models/qwen2_moe.py +8 -3
  47. sglang/srt/models/qwen2_vl.py +0 -12
  48. sglang/srt/models/qwen3_moe.py +8 -3
  49. sglang/srt/openai_api/adapter.py +34 -22
  50. sglang/srt/openai_api/protocol.py +11 -1
  51. sglang/srt/server_args.py +67 -22
  52. sglang/srt/speculative/eagle_worker.py +3 -2
  53. sglang/srt/utils.py +88 -9
  54. sglang/test/runners.py +4 -0
  55. sglang/test/test_utils.py +29 -0
  56. sglang/version.py +1 -1
  57. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post2.dist-info}/METADATA +5 -4
  58. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post2.dist-info}/RECORD +61 -51
  59. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post2.dist-info}/WHEEL +1 -1
  60. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post2.dist-info}/licenses/LICENSE +0 -0
  61. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post2.dist-info}/top_level.txt +0 -0
@@ -13,8 +13,10 @@
13
13
  # ==============================================================================
14
14
  """ModelRunner runs the forward passes of the models."""
15
15
 
16
+ import collections
16
17
  import datetime
17
18
  import gc
19
+ import inspect
18
20
  import json
19
21
  import logging
20
22
  import os
@@ -59,7 +61,7 @@ from sglang.srt.mem_cache.memory_pool import (
59
61
  )
60
62
  from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
61
63
  from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
62
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
64
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
63
65
  from sglang.srt.model_loader import get_model
64
66
  from sglang.srt.model_loader.loader import (
65
67
  DefaultModelLoader,
@@ -110,6 +112,8 @@ class ModelRunner:
110
112
  gpu_id: int,
111
113
  tp_rank: int,
112
114
  tp_size: int,
115
+ pp_rank: int,
116
+ pp_size: int,
113
117
  nccl_port: int,
114
118
  server_args: ServerArgs,
115
119
  is_draft_worker: bool = False,
@@ -123,6 +127,8 @@ class ModelRunner:
123
127
  self.gpu_id = gpu_id
124
128
  self.tp_rank = tp_rank
125
129
  self.tp_size = tp_size
130
+ self.pp_rank = pp_rank
131
+ self.pp_size = pp_size
126
132
  self.dist_port = nccl_port
127
133
  self.server_args = server_args
128
134
  self.is_draft_worker = is_draft_worker
@@ -148,24 +154,24 @@ class ModelRunner:
148
154
  global_server_args_dict.update(
149
155
  {
150
156
  "attention_backend": server_args.attention_backend,
151
- "sampling_backend": server_args.sampling_backend,
152
- "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
153
- "torchao_config": server_args.torchao_config,
157
+ "debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
158
+ "debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
159
+ "deepep_mode": server_args.deepep_mode,
160
+ "device": server_args.device,
161
+ "disable_chunked_prefix_cache": server_args.disable_chunked_prefix_cache,
162
+ "disable_radix_cache": server_args.disable_radix_cache,
154
163
  "enable_nan_detection": server_args.enable_nan_detection,
155
164
  "enable_dp_attention": server_args.enable_dp_attention,
156
165
  "enable_ep_moe": server_args.enable_ep_moe,
157
166
  "enable_deepep_moe": server_args.enable_deepep_moe,
158
- "deepep_mode": server_args.deepep_mode,
159
- "device": server_args.device,
160
- "speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
161
- "speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
162
- "disable_radix_cache": server_args.disable_radix_cache,
163
167
  "flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
164
168
  "moe_dense_tp_size": server_args.moe_dense_tp_size,
165
- "debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
166
- "debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
167
169
  "n_share_experts_fusion": server_args.n_share_experts_fusion,
168
- "disable_chunked_prefix_cache": server_args.disable_chunked_prefix_cache,
170
+ "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
171
+ "torchao_config": server_args.torchao_config,
172
+ "sampling_backend": server_args.sampling_backend,
173
+ "speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
174
+ "speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
169
175
  "use_mla_backend": self.use_mla_backend,
170
176
  }
171
177
  )
@@ -183,6 +189,11 @@ class ModelRunner:
183
189
  # If it is a draft model, tp_group can be different
184
190
  self.initialize(min_per_gpu_memory)
185
191
 
192
+ # temporary cached values
193
+ self.support_pp = (
194
+ "pp_proxy_tensors" in inspect.signature(self.model.forward).parameters
195
+ )
196
+
186
197
  def initialize(self, min_per_gpu_memory: float):
187
198
  server_args = self.server_args
188
199
  self.memory_saver_adapter = TorchMemorySaverAdapter.create(
@@ -193,6 +204,12 @@ class ModelRunner:
193
204
  self.sampler = Sampler()
194
205
  self.load_model()
195
206
 
207
+ self.start_layer = getattr(self.model, "start_layer", 0)
208
+ self.end_layer = getattr(
209
+ self.model, "end_layer", self.model_config.num_hidden_layers
210
+ )
211
+ self.num_effective_layers = self.end_layer - self.start_layer
212
+
196
213
  # Apply torchao quantization
197
214
  torchao_applied = getattr(self.model, "torchao_applied", False)
198
215
  # In layered loading, torchao may have been applied
@@ -359,18 +376,22 @@ class ModelRunner:
359
376
  # Only initialize the distributed environment on the target model worker.
360
377
  init_distributed_environment(
361
378
  backend=backend,
362
- world_size=self.tp_size,
363
- rank=self.tp_rank,
379
+ world_size=self.tp_size * self.pp_size,
380
+ rank=self.tp_size * self.pp_rank + self.tp_rank,
364
381
  local_rank=self.gpu_id,
365
382
  distributed_init_method=dist_init_method,
366
383
  timeout=self.server_args.dist_timeout,
367
384
  )
368
- initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
385
+ initialize_model_parallel(
386
+ tensor_model_parallel_size=self.tp_size,
387
+ pipeline_model_parallel_size=self.pp_size,
388
+ )
369
389
  initialize_dp_attention(
370
390
  enable_dp_attention=self.server_args.enable_dp_attention,
371
391
  tp_rank=self.tp_rank,
372
392
  tp_size=self.tp_size,
373
393
  dp_size=self.server_args.dp_size,
394
+ pp_size=self.server_args.pp_size,
374
395
  )
375
396
 
376
397
  min_per_gpu_memory = get_available_gpu_memory(
@@ -692,16 +713,23 @@ class ModelRunner:
692
713
  self.device, self.gpu_id, distributed=self.tp_size > 1
693
714
  )
694
715
  if self.use_mla_backend:
716
+ num_layers = (
717
+ self.model_config.num_hidden_layers
718
+ if not self.is_draft_worker
719
+ else self.model_config.hf_config.num_nextn_predict_layers
720
+ )
721
+ # FIXME: pipeline parallelism is not compatible with mla backend
722
+ assert self.pp_size == 1
695
723
  cell_size = (
696
724
  (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
697
- * self.model_config.num_hidden_layers
725
+ * num_layers
698
726
  * torch._utils._element_size(self.kv_cache_dtype)
699
727
  )
700
728
  else:
701
729
  cell_size = (
702
730
  self.model_config.get_num_kv_heads(get_attention_tp_size())
703
731
  * self.model_config.head_dim
704
- * self.model_config.num_hidden_layers
732
+ * self.num_effective_layers
705
733
  * 2
706
734
  * torch._utils._element_size(self.kv_cache_dtype)
707
735
  )
@@ -809,9 +837,15 @@ class ModelRunner:
809
837
  dtype=self.kv_cache_dtype,
810
838
  kv_lora_rank=self.model_config.kv_lora_rank,
811
839
  qk_rope_head_dim=self.model_config.qk_rope_head_dim,
812
- layer_num=self.model_config.num_hidden_layers,
840
+ layer_num=(
841
+ self.model_config.num_hidden_layers
842
+ if not self.is_draft_worker
843
+ else self.model_config.hf_config.num_nextn_predict_layers
844
+ ), # PP is not compatible with mla backend
813
845
  device=self.device,
814
846
  enable_memory_saver=self.server_args.enable_memory_saver,
847
+ start_layer=self.start_layer,
848
+ end_layer=self.end_layer,
815
849
  )
816
850
  elif self.server_args.enable_double_sparsity:
817
851
  self.token_to_kv_pool = DoubleSparseTokenToKVPool(
@@ -820,10 +854,12 @@ class ModelRunner:
820
854
  dtype=self.kv_cache_dtype,
821
855
  head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
822
856
  head_dim=self.model_config.head_dim,
823
- layer_num=self.model_config.num_hidden_layers,
857
+ layer_num=self.num_effective_layers,
824
858
  device=self.device,
825
859
  heavy_channel_num=self.server_args.ds_heavy_channel_num,
826
860
  enable_memory_saver=self.server_args.enable_memory_saver,
861
+ start_layer=self.start_layer,
862
+ end_layer=self.end_layer,
827
863
  )
828
864
  else:
829
865
  self.token_to_kv_pool = MHATokenToKVPool(
@@ -832,9 +868,11 @@ class ModelRunner:
832
868
  dtype=self.kv_cache_dtype,
833
869
  head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
834
870
  head_dim=self.model_config.head_dim,
835
- layer_num=self.model_config.num_hidden_layers,
871
+ layer_num=self.num_effective_layers,
836
872
  device=self.device,
837
873
  enable_memory_saver=self.server_args.enable_memory_saver,
874
+ start_layer=self.start_layer,
875
+ end_layer=self.end_layer,
838
876
  )
839
877
 
840
878
  if self.token_to_kv_pool_allocator is None:
@@ -918,8 +956,10 @@ class ModelRunner:
918
956
 
919
957
  self.attn_backend = FlashMLABackend(self)
920
958
  elif self.server_args.attention_backend == "fa3":
921
- assert torch.cuda.get_device_capability()[0] >= 9, (
922
- "FlashAttention v3 Backend requires SM>=90. "
959
+ assert (
960
+ torch.cuda.get_device_capability()[0] == 8 and not self.use_mla_backend
961
+ ) or torch.cuda.get_device_capability()[0] == 9, (
962
+ "FlashAttention v3 Backend requires SM>=80 and SM<=90. "
923
963
  "Please use `--attention-backend flashinfer`."
924
964
  )
925
965
  from sglang.srt.layers.attention.flashattention_backend import (
@@ -945,7 +985,7 @@ class ModelRunner:
945
985
  with open(self.server_args.ds_channel_config_path, "r") as f:
946
986
  channel_config = json.load(f)
947
987
 
948
- for i in range(self.model_config.num_hidden_layers):
988
+ for i in range(self.start_layer, self.end_layer):
949
989
  key = "model.layers." + str(i) + ".self_attn" + selected_channel
950
990
  self.sorted_channels.append(
951
991
  torch.tensor(channel_config[key])[
@@ -985,64 +1025,82 @@ class ModelRunner:
985
1025
  device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
986
1026
  tensor_parallel(self.model, device_mesh)
987
1027
 
988
- def forward_decode(self, forward_batch: ForwardBatch):
1028
+ def forward_decode(
1029
+ self, forward_batch: ForwardBatch, pp_proxy_tensors=None
1030
+ ) -> LogitsProcessorOutput:
989
1031
  self.attn_backend.init_forward_metadata(forward_batch)
1032
+ # FIXME: add pp_proxy_tensors arg to all models
1033
+ kwargs = {}
1034
+ if self.support_pp:
1035
+ kwargs["pp_proxy_tensors"] = pp_proxy_tensors
990
1036
  return self.model.forward(
991
- forward_batch.input_ids, forward_batch.positions, forward_batch
1037
+ forward_batch.input_ids, forward_batch.positions, forward_batch, **kwargs
992
1038
  )
993
1039
 
994
1040
  def forward_extend(
995
- self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
996
- ):
1041
+ self,
1042
+ forward_batch: ForwardBatch,
1043
+ skip_attn_backend_init: bool = False,
1044
+ pp_proxy_tensors=None,
1045
+ ) -> LogitsProcessorOutput:
997
1046
  if not skip_attn_backend_init:
998
1047
  self.attn_backend.init_forward_metadata(forward_batch)
999
1048
 
1000
- if self.is_generation:
1001
- if forward_batch.input_embeds is None:
1002
- return self.model.forward(
1003
- forward_batch.input_ids, forward_batch.positions, forward_batch
1004
- )
1005
- else:
1006
- return self.model.forward(
1007
- forward_batch.input_ids,
1008
- forward_batch.positions,
1009
- forward_batch,
1010
- input_embeds=forward_batch.input_embeds.bfloat16(),
1011
- )
1012
- else:
1013
- # Only embedding models have get_embedding parameter
1014
- return self.model.forward(
1015
- forward_batch.input_ids,
1016
- forward_batch.positions,
1017
- forward_batch,
1018
- get_embedding=True,
1019
- )
1049
+ kwargs = {}
1050
+ if self.support_pp:
1051
+ kwargs["pp_proxy_tensors"] = pp_proxy_tensors
1052
+ if forward_batch.input_embeds is not None:
1053
+ kwargs["input_embeds"] = forward_batch.input_embeds.bfloat16()
1054
+ if not self.is_generation:
1055
+ kwargs["get_embedding"] = True
1056
+ return self.model.forward(
1057
+ forward_batch.input_ids,
1058
+ forward_batch.positions,
1059
+ forward_batch,
1060
+ **kwargs,
1061
+ )
1020
1062
 
1021
- def forward_idle(self, forward_batch: ForwardBatch):
1063
+ def forward_idle(
1064
+ self, forward_batch: ForwardBatch, pp_proxy_tensors=None
1065
+ ) -> LogitsProcessorOutput:
1066
+ kwargs = {}
1067
+ if self.support_pp:
1068
+ kwargs["pp_proxy_tensors"] = pp_proxy_tensors
1022
1069
  return self.model.forward(
1023
- forward_batch.input_ids, forward_batch.positions, forward_batch
1070
+ forward_batch.input_ids,
1071
+ forward_batch.positions,
1072
+ forward_batch,
1073
+ **kwargs,
1024
1074
  )
1025
1075
 
1026
1076
  def forward(
1027
- self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
1028
- ) -> LogitsProcessorOutput:
1029
- if (
1077
+ self,
1078
+ forward_batch: ForwardBatch,
1079
+ skip_attn_backend_init: bool = False,
1080
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
1081
+ ) -> Union[LogitsProcessorOutput, PPProxyTensors]:
1082
+ can_run_cuda_graph = bool(
1030
1083
  forward_batch.forward_mode.is_cuda_graph()
1031
1084
  and self.cuda_graph_runner
1032
1085
  and self.cuda_graph_runner.can_run(forward_batch)
1033
- ):
1086
+ )
1087
+ if can_run_cuda_graph:
1034
1088
  return self.cuda_graph_runner.replay(
1035
- forward_batch, skip_attn_backend_init=skip_attn_backend_init
1089
+ forward_batch,
1090
+ skip_attn_backend_init=skip_attn_backend_init,
1091
+ pp_proxy_tensors=pp_proxy_tensors,
1036
1092
  )
1037
1093
 
1038
1094
  if forward_batch.forward_mode.is_decode():
1039
- return self.forward_decode(forward_batch)
1095
+ return self.forward_decode(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
1040
1096
  elif forward_batch.forward_mode.is_extend():
1041
1097
  return self.forward_extend(
1042
- forward_batch, skip_attn_backend_init=skip_attn_backend_init
1098
+ forward_batch,
1099
+ skip_attn_backend_init=skip_attn_backend_init,
1100
+ pp_proxy_tensors=pp_proxy_tensors,
1043
1101
  )
1044
1102
  elif forward_batch.forward_mode.is_idle():
1045
- return self.forward_idle(forward_batch)
1103
+ return self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
1046
1104
  else:
1047
1105
  raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
1048
1106
 
@@ -177,263 +177,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
177
177
  )
178
178
 
179
179
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
180
- if hasattr(self.config, "num_nextn_predict_layers"):
181
- num_nextn_layers = self.config.num_nextn_predict_layers
182
- assert num_nextn_layers == 1, "Only 1 nextn layer is supportted"
183
- assert num_nextn_layers == self.config.num_hidden_layers
184
- else:
185
- raise ValueError("num_nextn_predict_layers is not in the config")
186
-
187
- stacked_params_mapping = [
188
- # (param_name, shard_name, shard_id)
189
- ("gate_up_proj", "gate_proj", 0),
190
- ("gate_up_proj", "up_proj", 1),
191
- ]
192
- if self.n_share_experts_fusion > 0:
193
- logger.info(
194
- f"Cloning {self.n_share_experts_fusion} "
195
- "replicas of the shared expert into MoE for DeepseekV3ForCausalLMNextN"
196
- )
197
- weights_list = list(weights)
198
- weights_dict = dict(weights_list)
199
- if self.quant_config is None or self.quant_config.get_name() == "w8a8_int8":
200
- suffix_list = [
201
- "down_proj.weight",
202
- "down_proj.weight_scale",
203
- "gate_proj.weight",
204
- "gate_proj.weight_scale",
205
- "up_proj.weight",
206
- "up_proj.weight_scale",
207
- ]
208
- else:
209
- suffix_list = [
210
- "down_proj.weight",
211
- "down_proj.weight_scale_inv",
212
- "gate_proj.weight",
213
- "gate_proj.weight_scale_inv",
214
- "up_proj.weight",
215
- "up_proj.weight_scale_inv",
216
- ]
217
- names_to_remove = []
218
- for suffix in suffix_list:
219
- shared_expert_weight_name = (
220
- f"model.layers.0.mlp.shared_experts.{suffix}"
221
- )
222
- for num_repeat in range(self.n_share_experts_fusion):
223
- weights_list.append(
224
- (
225
- f"model.layers.0."
226
- f"mlp.experts."
227
- f"{self.config.n_routed_experts + num_repeat}"
228
- f".{suffix}",
229
- weights_dict[shared_expert_weight_name],
230
- )
231
- )
232
- names_to_remove += [shared_expert_weight_name]
233
- weights = [w for w in weights_list if w[0] not in names_to_remove]
234
-
235
- # Params for weights, fp8 weight scales, fp8 activation scales
236
- # (param_name, weight_name, expert_id, shard_id)
237
- MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
238
- expert_params_mapping = MoEImpl.make_expert_params_mapping(
239
- ckpt_gate_proj_name="gate_proj",
240
- ckpt_down_proj_name="down_proj",
241
- ckpt_up_proj_name="up_proj",
242
- num_experts=self.config.n_routed_experts + self.n_share_experts_fusion,
243
- )
244
-
245
- # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
246
- fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
247
- self.config.q_lora_rank is not None
248
- )
249
- cached_a_proj = {} if fuse_qkv_a_proj else None
250
-
251
- nextn_layer_prefix = "model.layers.0"
252
- nextn_spec_weight_names = [
253
- "shared_head.norm",
254
- "eh_proj",
255
- "enorm",
256
- "hnorm",
257
- ]
258
-
259
- params_dict = dict(self.named_parameters())
260
- for name, loaded_weight in weights:
261
- if not name.startswith(nextn_layer_prefix):
262
- continue
263
-
264
- # Use shared head and embed weights from target model
265
- if "shared_head.head" in name or "embed_tokens" in name:
266
- continue
267
-
268
- is_decoder = True
269
- # For nextn specific weights
270
- for weight_name in nextn_spec_weight_names:
271
- if weight_name in name:
272
- name = name.replace(nextn_layer_prefix, "model")
273
- is_decoder = False
274
- break
275
- # For decoder layer weights
276
- if is_decoder:
277
- name = name.replace(nextn_layer_prefix, "model.decoder")
278
-
279
- if "rotary_emb.inv_freq" in name:
280
- continue
281
- for param_name, weight_name, shard_id in stacked_params_mapping:
282
- # Skip non-stacked layers and experts (experts handled below).
283
- if weight_name not in name:
284
- continue
285
- # We have mlp.experts[0].gate_proj in the checkpoint.
286
- # Since we handle the experts below in expert_params_mapping,
287
- # we need to skip here BEFORE we update the name, otherwise
288
- # name will be updated to mlp.experts[0].gate_up_proj, which
289
- # will then be updated below in expert_params_mapping
290
- # for mlp.experts[0].gate_gate_up_proj, which breaks load.
291
- if ("mlp.experts." in name) and name not in params_dict:
292
- continue
293
- name = name.replace(weight_name, param_name)
294
- # Skip loading extra bias for GPTQ models.
295
- if name.endswith(".bias") and name not in params_dict:
296
- continue
297
- param = params_dict[name]
298
- weight_loader = param.weight_loader
299
- weight_loader(param, loaded_weight, shard_id)
300
- break
301
- else:
302
- for mapping in expert_params_mapping:
303
- param_name, weight_name, expert_id, shard_id = mapping
304
- if weight_name not in name:
305
- continue
306
- name = name.replace(weight_name, param_name)
307
- param = params_dict[name]
308
- weight_loader = param.weight_loader
309
- weight_loader(
310
- param,
311
- loaded_weight,
312
- name,
313
- shard_id=shard_id,
314
- expert_id=expert_id,
315
- )
316
- break
317
- else:
318
- # Skip loading extra bias for GPTQ models.
319
- if name.endswith(".bias") and name not in params_dict:
320
- continue
321
-
322
- # Handle fused_qkv_a_proj
323
- if fuse_qkv_a_proj and (
324
- "q_a_proj" in name or "kv_a_proj_with_mqa" in name
325
- ):
326
- cached_a_proj[name] = loaded_weight
327
- q_a_proj_name = (
328
- name
329
- if "q_a_proj" in name
330
- else name.replace("kv_a_proj_with_mqa", "q_a_proj")
331
- )
332
- kv_a_proj_name = (
333
- name
334
- if "kv_a_proj_with_mqa" in name
335
- else name.replace("q_a_proj", "kv_a_proj_with_mqa")
336
- )
337
-
338
- # When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
339
- if (
340
- q_a_proj_name in cached_a_proj
341
- and kv_a_proj_name in cached_a_proj
342
- ):
343
-
344
- q_a_proj_weight = cached_a_proj[q_a_proj_name]
345
- kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
346
- fused_weight = torch.cat(
347
- [q_a_proj_weight, kv_a_proj_weight], dim=0
348
- )
349
-
350
- param_name = name.replace(
351
- "q_a_proj", "fused_qkv_a_proj_with_mqa"
352
- )
353
- param = params_dict[param_name]
354
-
355
- weight_loader = getattr(
356
- param, "weight_loader", default_weight_loader
357
- )
358
- weight_loader(param, fused_weight)
359
- cached_a_proj.pop(q_a_proj_name)
360
- cached_a_proj.pop(kv_a_proj_name)
361
- else:
362
- param = params_dict[name]
363
- weight_loader = getattr(
364
- param, "weight_loader", default_weight_loader
365
- )
366
- weight_loader(param, loaded_weight)
367
-
368
- self_attn = self.model.decoder.self_attn
369
- if hasattr(self_attn.kv_b_proj, "qweight"):
370
- # AWQ compatible
371
- if _is_cuda:
372
- w = awq_dequantize(
373
- self_attn.kv_b_proj.qweight,
374
- self_attn.kv_b_proj.scales,
375
- self_attn.kv_b_proj.qzeros,
376
- ).T
377
- else:
378
- w = awq_dequantize(
379
- self_attn.kv_b_proj.qweight,
380
- self_attn.kv_b_proj.scales,
381
- self_attn.kv_b_proj.qzeros,
382
- 0,
383
- 0,
384
- 0,
385
- ).T
386
- else:
387
- w = self_attn.kv_b_proj.weight
388
- # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
389
- # This may affect the accuracy of fp8 model.
390
- if hasattr(self.quant_config, "weight_block_size") and w.dtype in (
391
- torch.float8_e4m3fn,
392
- torch.float8_e4m3fnuz,
393
- ):
394
- weight_block_size = self.quant_config.weight_block_size
395
- if weight_block_size is not None:
396
- assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
397
- if _is_hip:
398
- weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
399
- weight=w,
400
- weight_scale=self_attn.kv_b_proj.weight_scale_inv,
401
- input_scale=None,
402
- )
403
- else:
404
- weight = w
405
- weight_scale = self_attn.kv_b_proj.weight_scale_inv
406
-
407
- w, scale = block_quant_to_tensor_quant(
408
- weight, weight_scale, weight_block_size
409
- )
410
- self_attn.w_scale = scale
411
- if w.dtype == torch.int8:
412
- if hasattr(self.quant_config, "weight_block_size"):
413
- # block-wise int8 need it
414
- weight_block_size = self.quant_config.weight_block_size
415
- if weight_block_size is not None:
416
- assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
417
- weight = w
418
- weight_scale = self_attn.kv_b_proj.weight_scale_inv
419
- w = int8_block_dequant(weight, weight_scale, weight_block_size).to(
420
- torch.bfloat16
421
- )
422
- else:
423
- # channel-wise int8 need it
424
- assert hasattr(self_attn.kv_b_proj, "weight_scale")
425
- w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
426
- torch.bfloat16
427
- )
428
- w_kc, w_vc = w.unflatten(
429
- 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
430
- ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
431
- self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
432
- self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
433
- if hasattr(self_attn.kv_b_proj, "weight_scale") and self_attn.w_scale is None:
434
- self_attn.w_scale = self_attn.kv_b_proj.weight_scale
435
- if _is_hip:
436
- self_attn.w_scale *= 2.0
180
+ super().load_weights(weights, is_nextn=True)
437
181
 
438
182
 
439
183
  EntryClass = [DeepseekV3ForCausalLMNextN]