sglang 0.4.6.post1__py3-none-any.whl → 0.4.6.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -0
- sglang/check_env.py +3 -3
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/kimi_vl.py +38 -0
- sglang/srt/configs/kimi_vl_moonvit.py +32 -0
- sglang/srt/configs/model_config.py +15 -0
- sglang/srt/conversation.py +122 -1
- sglang/srt/entrypoints/engine.py +44 -22
- sglang/srt/function_call_parser.py +97 -0
- sglang/srt/hf_transformers_utils.py +2 -0
- sglang/srt/layers/attention/cutlass_mla_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +107 -82
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -16
- sglang/srt/layers/attention/flashmla_backend.py +3 -0
- sglang/srt/layers/dp_attention.py +5 -2
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +1 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +8 -6
- sglang/srt/layers/quantization/__init__.py +2 -2
- sglang/srt/layers/quantization/deep_gemm.py +1 -1
- sglang/srt/layers/utils.py +35 -0
- sglang/srt/lora/layers.py +35 -9
- sglang/srt/lora/lora_manager.py +84 -35
- sglang/srt/managers/data_parallel_controller.py +52 -34
- sglang/srt/managers/multimodal_processors/kimi_vl.py +73 -0
- sglang/srt/managers/schedule_batch.py +25 -15
- sglang/srt/managers/scheduler.py +263 -59
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -1
- sglang/srt/managers/tp_worker.py +51 -16
- sglang/srt/managers/tp_worker_overlap_thread.py +9 -3
- sglang/srt/mem_cache/memory_pool.py +70 -36
- sglang/srt/model_executor/cuda_graph_runner.py +82 -19
- sglang/srt/model_executor/forward_batch_info.py +31 -1
- sglang/srt/model_executor/model_runner.py +115 -57
- sglang/srt/models/deepseek_nextn.py +1 -257
- sglang/srt/models/deepseek_v2.py +78 -18
- sglang/srt/models/kimi_vl.py +308 -0
- sglang/srt/models/kimi_vl_moonvit.py +639 -0
- sglang/srt/models/llama.py +92 -30
- sglang/srt/models/llama4.py +2 -1
- sglang/srt/models/llama_eagle.py +4 -1
- sglang/srt/models/llama_eagle3.py +4 -1
- sglang/srt/models/qwen2_moe.py +8 -3
- sglang/srt/models/qwen2_vl.py +0 -12
- sglang/srt/models/qwen3_moe.py +8 -3
- sglang/srt/openai_api/adapter.py +34 -22
- sglang/srt/openai_api/protocol.py +11 -1
- sglang/srt/server_args.py +67 -22
- sglang/srt/speculative/eagle_worker.py +3 -2
- sglang/srt/utils.py +88 -9
- sglang/test/runners.py +4 -0
- sglang/test/test_utils.py +29 -0
- sglang/version.py +1 -1
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post2.dist-info}/METADATA +5 -4
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post2.dist-info}/RECORD +61 -51
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post2.dist-info}/WHEEL +1 -1
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post2.dist-info}/top_level.txt +0 -0
@@ -13,8 +13,10 @@
|
|
13
13
|
# ==============================================================================
|
14
14
|
"""ModelRunner runs the forward passes of the models."""
|
15
15
|
|
16
|
+
import collections
|
16
17
|
import datetime
|
17
18
|
import gc
|
19
|
+
import inspect
|
18
20
|
import json
|
19
21
|
import logging
|
20
22
|
import os
|
@@ -59,7 +61,7 @@ from sglang.srt.mem_cache.memory_pool import (
|
|
59
61
|
)
|
60
62
|
from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
|
61
63
|
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
62
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
64
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
63
65
|
from sglang.srt.model_loader import get_model
|
64
66
|
from sglang.srt.model_loader.loader import (
|
65
67
|
DefaultModelLoader,
|
@@ -110,6 +112,8 @@ class ModelRunner:
|
|
110
112
|
gpu_id: int,
|
111
113
|
tp_rank: int,
|
112
114
|
tp_size: int,
|
115
|
+
pp_rank: int,
|
116
|
+
pp_size: int,
|
113
117
|
nccl_port: int,
|
114
118
|
server_args: ServerArgs,
|
115
119
|
is_draft_worker: bool = False,
|
@@ -123,6 +127,8 @@ class ModelRunner:
|
|
123
127
|
self.gpu_id = gpu_id
|
124
128
|
self.tp_rank = tp_rank
|
125
129
|
self.tp_size = tp_size
|
130
|
+
self.pp_rank = pp_rank
|
131
|
+
self.pp_size = pp_size
|
126
132
|
self.dist_port = nccl_port
|
127
133
|
self.server_args = server_args
|
128
134
|
self.is_draft_worker = is_draft_worker
|
@@ -148,24 +154,24 @@ class ModelRunner:
|
|
148
154
|
global_server_args_dict.update(
|
149
155
|
{
|
150
156
|
"attention_backend": server_args.attention_backend,
|
151
|
-
"
|
152
|
-
"
|
153
|
-
"
|
157
|
+
"debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
|
158
|
+
"debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
|
159
|
+
"deepep_mode": server_args.deepep_mode,
|
160
|
+
"device": server_args.device,
|
161
|
+
"disable_chunked_prefix_cache": server_args.disable_chunked_prefix_cache,
|
162
|
+
"disable_radix_cache": server_args.disable_radix_cache,
|
154
163
|
"enable_nan_detection": server_args.enable_nan_detection,
|
155
164
|
"enable_dp_attention": server_args.enable_dp_attention,
|
156
165
|
"enable_ep_moe": server_args.enable_ep_moe,
|
157
166
|
"enable_deepep_moe": server_args.enable_deepep_moe,
|
158
|
-
"deepep_mode": server_args.deepep_mode,
|
159
|
-
"device": server_args.device,
|
160
|
-
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
|
161
|
-
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
|
162
|
-
"disable_radix_cache": server_args.disable_radix_cache,
|
163
167
|
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
|
164
168
|
"moe_dense_tp_size": server_args.moe_dense_tp_size,
|
165
|
-
"debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
|
166
|
-
"debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
|
167
169
|
"n_share_experts_fusion": server_args.n_share_experts_fusion,
|
168
|
-
"
|
170
|
+
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
|
171
|
+
"torchao_config": server_args.torchao_config,
|
172
|
+
"sampling_backend": server_args.sampling_backend,
|
173
|
+
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
|
174
|
+
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
|
169
175
|
"use_mla_backend": self.use_mla_backend,
|
170
176
|
}
|
171
177
|
)
|
@@ -183,6 +189,11 @@ class ModelRunner:
|
|
183
189
|
# If it is a draft model, tp_group can be different
|
184
190
|
self.initialize(min_per_gpu_memory)
|
185
191
|
|
192
|
+
# temporary cached values
|
193
|
+
self.support_pp = (
|
194
|
+
"pp_proxy_tensors" in inspect.signature(self.model.forward).parameters
|
195
|
+
)
|
196
|
+
|
186
197
|
def initialize(self, min_per_gpu_memory: float):
|
187
198
|
server_args = self.server_args
|
188
199
|
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
@@ -193,6 +204,12 @@ class ModelRunner:
|
|
193
204
|
self.sampler = Sampler()
|
194
205
|
self.load_model()
|
195
206
|
|
207
|
+
self.start_layer = getattr(self.model, "start_layer", 0)
|
208
|
+
self.end_layer = getattr(
|
209
|
+
self.model, "end_layer", self.model_config.num_hidden_layers
|
210
|
+
)
|
211
|
+
self.num_effective_layers = self.end_layer - self.start_layer
|
212
|
+
|
196
213
|
# Apply torchao quantization
|
197
214
|
torchao_applied = getattr(self.model, "torchao_applied", False)
|
198
215
|
# In layered loading, torchao may have been applied
|
@@ -359,18 +376,22 @@ class ModelRunner:
|
|
359
376
|
# Only initialize the distributed environment on the target model worker.
|
360
377
|
init_distributed_environment(
|
361
378
|
backend=backend,
|
362
|
-
world_size=self.tp_size,
|
363
|
-
rank=self.tp_rank,
|
379
|
+
world_size=self.tp_size * self.pp_size,
|
380
|
+
rank=self.tp_size * self.pp_rank + self.tp_rank,
|
364
381
|
local_rank=self.gpu_id,
|
365
382
|
distributed_init_method=dist_init_method,
|
366
383
|
timeout=self.server_args.dist_timeout,
|
367
384
|
)
|
368
|
-
initialize_model_parallel(
|
385
|
+
initialize_model_parallel(
|
386
|
+
tensor_model_parallel_size=self.tp_size,
|
387
|
+
pipeline_model_parallel_size=self.pp_size,
|
388
|
+
)
|
369
389
|
initialize_dp_attention(
|
370
390
|
enable_dp_attention=self.server_args.enable_dp_attention,
|
371
391
|
tp_rank=self.tp_rank,
|
372
392
|
tp_size=self.tp_size,
|
373
393
|
dp_size=self.server_args.dp_size,
|
394
|
+
pp_size=self.server_args.pp_size,
|
374
395
|
)
|
375
396
|
|
376
397
|
min_per_gpu_memory = get_available_gpu_memory(
|
@@ -692,16 +713,23 @@ class ModelRunner:
|
|
692
713
|
self.device, self.gpu_id, distributed=self.tp_size > 1
|
693
714
|
)
|
694
715
|
if self.use_mla_backend:
|
716
|
+
num_layers = (
|
717
|
+
self.model_config.num_hidden_layers
|
718
|
+
if not self.is_draft_worker
|
719
|
+
else self.model_config.hf_config.num_nextn_predict_layers
|
720
|
+
)
|
721
|
+
# FIXME: pipeline parallelism is not compatible with mla backend
|
722
|
+
assert self.pp_size == 1
|
695
723
|
cell_size = (
|
696
724
|
(self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
|
697
|
-
*
|
725
|
+
* num_layers
|
698
726
|
* torch._utils._element_size(self.kv_cache_dtype)
|
699
727
|
)
|
700
728
|
else:
|
701
729
|
cell_size = (
|
702
730
|
self.model_config.get_num_kv_heads(get_attention_tp_size())
|
703
731
|
* self.model_config.head_dim
|
704
|
-
* self.
|
732
|
+
* self.num_effective_layers
|
705
733
|
* 2
|
706
734
|
* torch._utils._element_size(self.kv_cache_dtype)
|
707
735
|
)
|
@@ -809,9 +837,15 @@ class ModelRunner:
|
|
809
837
|
dtype=self.kv_cache_dtype,
|
810
838
|
kv_lora_rank=self.model_config.kv_lora_rank,
|
811
839
|
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
812
|
-
layer_num=
|
840
|
+
layer_num=(
|
841
|
+
self.model_config.num_hidden_layers
|
842
|
+
if not self.is_draft_worker
|
843
|
+
else self.model_config.hf_config.num_nextn_predict_layers
|
844
|
+
), # PP is not compatible with mla backend
|
813
845
|
device=self.device,
|
814
846
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
847
|
+
start_layer=self.start_layer,
|
848
|
+
end_layer=self.end_layer,
|
815
849
|
)
|
816
850
|
elif self.server_args.enable_double_sparsity:
|
817
851
|
self.token_to_kv_pool = DoubleSparseTokenToKVPool(
|
@@ -820,10 +854,12 @@ class ModelRunner:
|
|
820
854
|
dtype=self.kv_cache_dtype,
|
821
855
|
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
|
822
856
|
head_dim=self.model_config.head_dim,
|
823
|
-
layer_num=self.
|
857
|
+
layer_num=self.num_effective_layers,
|
824
858
|
device=self.device,
|
825
859
|
heavy_channel_num=self.server_args.ds_heavy_channel_num,
|
826
860
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
861
|
+
start_layer=self.start_layer,
|
862
|
+
end_layer=self.end_layer,
|
827
863
|
)
|
828
864
|
else:
|
829
865
|
self.token_to_kv_pool = MHATokenToKVPool(
|
@@ -832,9 +868,11 @@ class ModelRunner:
|
|
832
868
|
dtype=self.kv_cache_dtype,
|
833
869
|
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
|
834
870
|
head_dim=self.model_config.head_dim,
|
835
|
-
layer_num=self.
|
871
|
+
layer_num=self.num_effective_layers,
|
836
872
|
device=self.device,
|
837
873
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
874
|
+
start_layer=self.start_layer,
|
875
|
+
end_layer=self.end_layer,
|
838
876
|
)
|
839
877
|
|
840
878
|
if self.token_to_kv_pool_allocator is None:
|
@@ -918,8 +956,10 @@ class ModelRunner:
|
|
918
956
|
|
919
957
|
self.attn_backend = FlashMLABackend(self)
|
920
958
|
elif self.server_args.attention_backend == "fa3":
|
921
|
-
assert
|
922
|
-
|
959
|
+
assert (
|
960
|
+
torch.cuda.get_device_capability()[0] == 8 and not self.use_mla_backend
|
961
|
+
) or torch.cuda.get_device_capability()[0] == 9, (
|
962
|
+
"FlashAttention v3 Backend requires SM>=80 and SM<=90. "
|
923
963
|
"Please use `--attention-backend flashinfer`."
|
924
964
|
)
|
925
965
|
from sglang.srt.layers.attention.flashattention_backend import (
|
@@ -945,7 +985,7 @@ class ModelRunner:
|
|
945
985
|
with open(self.server_args.ds_channel_config_path, "r") as f:
|
946
986
|
channel_config = json.load(f)
|
947
987
|
|
948
|
-
for i in range(self.
|
988
|
+
for i in range(self.start_layer, self.end_layer):
|
949
989
|
key = "model.layers." + str(i) + ".self_attn" + selected_channel
|
950
990
|
self.sorted_channels.append(
|
951
991
|
torch.tensor(channel_config[key])[
|
@@ -985,64 +1025,82 @@ class ModelRunner:
|
|
985
1025
|
device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
|
986
1026
|
tensor_parallel(self.model, device_mesh)
|
987
1027
|
|
988
|
-
def forward_decode(
|
1028
|
+
def forward_decode(
|
1029
|
+
self, forward_batch: ForwardBatch, pp_proxy_tensors=None
|
1030
|
+
) -> LogitsProcessorOutput:
|
989
1031
|
self.attn_backend.init_forward_metadata(forward_batch)
|
1032
|
+
# FIXME: add pp_proxy_tensors arg to all models
|
1033
|
+
kwargs = {}
|
1034
|
+
if self.support_pp:
|
1035
|
+
kwargs["pp_proxy_tensors"] = pp_proxy_tensors
|
990
1036
|
return self.model.forward(
|
991
|
-
forward_batch.input_ids, forward_batch.positions, forward_batch
|
1037
|
+
forward_batch.input_ids, forward_batch.positions, forward_batch, **kwargs
|
992
1038
|
)
|
993
1039
|
|
994
1040
|
def forward_extend(
|
995
|
-
self,
|
996
|
-
|
1041
|
+
self,
|
1042
|
+
forward_batch: ForwardBatch,
|
1043
|
+
skip_attn_backend_init: bool = False,
|
1044
|
+
pp_proxy_tensors=None,
|
1045
|
+
) -> LogitsProcessorOutput:
|
997
1046
|
if not skip_attn_backend_init:
|
998
1047
|
self.attn_backend.init_forward_metadata(forward_batch)
|
999
1048
|
|
1000
|
-
|
1001
|
-
|
1002
|
-
|
1003
|
-
|
1004
|
-
|
1005
|
-
|
1006
|
-
|
1007
|
-
|
1008
|
-
|
1009
|
-
|
1010
|
-
|
1011
|
-
|
1012
|
-
|
1013
|
-
# Only embedding models have get_embedding parameter
|
1014
|
-
return self.model.forward(
|
1015
|
-
forward_batch.input_ids,
|
1016
|
-
forward_batch.positions,
|
1017
|
-
forward_batch,
|
1018
|
-
get_embedding=True,
|
1019
|
-
)
|
1049
|
+
kwargs = {}
|
1050
|
+
if self.support_pp:
|
1051
|
+
kwargs["pp_proxy_tensors"] = pp_proxy_tensors
|
1052
|
+
if forward_batch.input_embeds is not None:
|
1053
|
+
kwargs["input_embeds"] = forward_batch.input_embeds.bfloat16()
|
1054
|
+
if not self.is_generation:
|
1055
|
+
kwargs["get_embedding"] = True
|
1056
|
+
return self.model.forward(
|
1057
|
+
forward_batch.input_ids,
|
1058
|
+
forward_batch.positions,
|
1059
|
+
forward_batch,
|
1060
|
+
**kwargs,
|
1061
|
+
)
|
1020
1062
|
|
1021
|
-
def forward_idle(
|
1063
|
+
def forward_idle(
|
1064
|
+
self, forward_batch: ForwardBatch, pp_proxy_tensors=None
|
1065
|
+
) -> LogitsProcessorOutput:
|
1066
|
+
kwargs = {}
|
1067
|
+
if self.support_pp:
|
1068
|
+
kwargs["pp_proxy_tensors"] = pp_proxy_tensors
|
1022
1069
|
return self.model.forward(
|
1023
|
-
forward_batch.input_ids,
|
1070
|
+
forward_batch.input_ids,
|
1071
|
+
forward_batch.positions,
|
1072
|
+
forward_batch,
|
1073
|
+
**kwargs,
|
1024
1074
|
)
|
1025
1075
|
|
1026
1076
|
def forward(
|
1027
|
-
self,
|
1028
|
-
|
1029
|
-
|
1077
|
+
self,
|
1078
|
+
forward_batch: ForwardBatch,
|
1079
|
+
skip_attn_backend_init: bool = False,
|
1080
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
1081
|
+
) -> Union[LogitsProcessorOutput, PPProxyTensors]:
|
1082
|
+
can_run_cuda_graph = bool(
|
1030
1083
|
forward_batch.forward_mode.is_cuda_graph()
|
1031
1084
|
and self.cuda_graph_runner
|
1032
1085
|
and self.cuda_graph_runner.can_run(forward_batch)
|
1033
|
-
)
|
1086
|
+
)
|
1087
|
+
if can_run_cuda_graph:
|
1034
1088
|
return self.cuda_graph_runner.replay(
|
1035
|
-
forward_batch,
|
1089
|
+
forward_batch,
|
1090
|
+
skip_attn_backend_init=skip_attn_backend_init,
|
1091
|
+
pp_proxy_tensors=pp_proxy_tensors,
|
1036
1092
|
)
|
1037
1093
|
|
1038
1094
|
if forward_batch.forward_mode.is_decode():
|
1039
|
-
return self.forward_decode(forward_batch)
|
1095
|
+
return self.forward_decode(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
|
1040
1096
|
elif forward_batch.forward_mode.is_extend():
|
1041
1097
|
return self.forward_extend(
|
1042
|
-
forward_batch,
|
1098
|
+
forward_batch,
|
1099
|
+
skip_attn_backend_init=skip_attn_backend_init,
|
1100
|
+
pp_proxy_tensors=pp_proxy_tensors,
|
1043
1101
|
)
|
1044
1102
|
elif forward_batch.forward_mode.is_idle():
|
1045
|
-
return self.forward_idle(forward_batch)
|
1103
|
+
return self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
|
1046
1104
|
else:
|
1047
1105
|
raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
|
1048
1106
|
|
@@ -177,263 +177,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
|
177
177
|
)
|
178
178
|
|
179
179
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
180
|
-
|
181
|
-
num_nextn_layers = self.config.num_nextn_predict_layers
|
182
|
-
assert num_nextn_layers == 1, "Only 1 nextn layer is supportted"
|
183
|
-
assert num_nextn_layers == self.config.num_hidden_layers
|
184
|
-
else:
|
185
|
-
raise ValueError("num_nextn_predict_layers is not in the config")
|
186
|
-
|
187
|
-
stacked_params_mapping = [
|
188
|
-
# (param_name, shard_name, shard_id)
|
189
|
-
("gate_up_proj", "gate_proj", 0),
|
190
|
-
("gate_up_proj", "up_proj", 1),
|
191
|
-
]
|
192
|
-
if self.n_share_experts_fusion > 0:
|
193
|
-
logger.info(
|
194
|
-
f"Cloning {self.n_share_experts_fusion} "
|
195
|
-
"replicas of the shared expert into MoE for DeepseekV3ForCausalLMNextN"
|
196
|
-
)
|
197
|
-
weights_list = list(weights)
|
198
|
-
weights_dict = dict(weights_list)
|
199
|
-
if self.quant_config is None or self.quant_config.get_name() == "w8a8_int8":
|
200
|
-
suffix_list = [
|
201
|
-
"down_proj.weight",
|
202
|
-
"down_proj.weight_scale",
|
203
|
-
"gate_proj.weight",
|
204
|
-
"gate_proj.weight_scale",
|
205
|
-
"up_proj.weight",
|
206
|
-
"up_proj.weight_scale",
|
207
|
-
]
|
208
|
-
else:
|
209
|
-
suffix_list = [
|
210
|
-
"down_proj.weight",
|
211
|
-
"down_proj.weight_scale_inv",
|
212
|
-
"gate_proj.weight",
|
213
|
-
"gate_proj.weight_scale_inv",
|
214
|
-
"up_proj.weight",
|
215
|
-
"up_proj.weight_scale_inv",
|
216
|
-
]
|
217
|
-
names_to_remove = []
|
218
|
-
for suffix in suffix_list:
|
219
|
-
shared_expert_weight_name = (
|
220
|
-
f"model.layers.0.mlp.shared_experts.{suffix}"
|
221
|
-
)
|
222
|
-
for num_repeat in range(self.n_share_experts_fusion):
|
223
|
-
weights_list.append(
|
224
|
-
(
|
225
|
-
f"model.layers.0."
|
226
|
-
f"mlp.experts."
|
227
|
-
f"{self.config.n_routed_experts + num_repeat}"
|
228
|
-
f".{suffix}",
|
229
|
-
weights_dict[shared_expert_weight_name],
|
230
|
-
)
|
231
|
-
)
|
232
|
-
names_to_remove += [shared_expert_weight_name]
|
233
|
-
weights = [w for w in weights_list if w[0] not in names_to_remove]
|
234
|
-
|
235
|
-
# Params for weights, fp8 weight scales, fp8 activation scales
|
236
|
-
# (param_name, weight_name, expert_id, shard_id)
|
237
|
-
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
|
238
|
-
expert_params_mapping = MoEImpl.make_expert_params_mapping(
|
239
|
-
ckpt_gate_proj_name="gate_proj",
|
240
|
-
ckpt_down_proj_name="down_proj",
|
241
|
-
ckpt_up_proj_name="up_proj",
|
242
|
-
num_experts=self.config.n_routed_experts + self.n_share_experts_fusion,
|
243
|
-
)
|
244
|
-
|
245
|
-
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
|
246
|
-
fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
|
247
|
-
self.config.q_lora_rank is not None
|
248
|
-
)
|
249
|
-
cached_a_proj = {} if fuse_qkv_a_proj else None
|
250
|
-
|
251
|
-
nextn_layer_prefix = "model.layers.0"
|
252
|
-
nextn_spec_weight_names = [
|
253
|
-
"shared_head.norm",
|
254
|
-
"eh_proj",
|
255
|
-
"enorm",
|
256
|
-
"hnorm",
|
257
|
-
]
|
258
|
-
|
259
|
-
params_dict = dict(self.named_parameters())
|
260
|
-
for name, loaded_weight in weights:
|
261
|
-
if not name.startswith(nextn_layer_prefix):
|
262
|
-
continue
|
263
|
-
|
264
|
-
# Use shared head and embed weights from target model
|
265
|
-
if "shared_head.head" in name or "embed_tokens" in name:
|
266
|
-
continue
|
267
|
-
|
268
|
-
is_decoder = True
|
269
|
-
# For nextn specific weights
|
270
|
-
for weight_name in nextn_spec_weight_names:
|
271
|
-
if weight_name in name:
|
272
|
-
name = name.replace(nextn_layer_prefix, "model")
|
273
|
-
is_decoder = False
|
274
|
-
break
|
275
|
-
# For decoder layer weights
|
276
|
-
if is_decoder:
|
277
|
-
name = name.replace(nextn_layer_prefix, "model.decoder")
|
278
|
-
|
279
|
-
if "rotary_emb.inv_freq" in name:
|
280
|
-
continue
|
281
|
-
for param_name, weight_name, shard_id in stacked_params_mapping:
|
282
|
-
# Skip non-stacked layers and experts (experts handled below).
|
283
|
-
if weight_name not in name:
|
284
|
-
continue
|
285
|
-
# We have mlp.experts[0].gate_proj in the checkpoint.
|
286
|
-
# Since we handle the experts below in expert_params_mapping,
|
287
|
-
# we need to skip here BEFORE we update the name, otherwise
|
288
|
-
# name will be updated to mlp.experts[0].gate_up_proj, which
|
289
|
-
# will then be updated below in expert_params_mapping
|
290
|
-
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
291
|
-
if ("mlp.experts." in name) and name not in params_dict:
|
292
|
-
continue
|
293
|
-
name = name.replace(weight_name, param_name)
|
294
|
-
# Skip loading extra bias for GPTQ models.
|
295
|
-
if name.endswith(".bias") and name not in params_dict:
|
296
|
-
continue
|
297
|
-
param = params_dict[name]
|
298
|
-
weight_loader = param.weight_loader
|
299
|
-
weight_loader(param, loaded_weight, shard_id)
|
300
|
-
break
|
301
|
-
else:
|
302
|
-
for mapping in expert_params_mapping:
|
303
|
-
param_name, weight_name, expert_id, shard_id = mapping
|
304
|
-
if weight_name not in name:
|
305
|
-
continue
|
306
|
-
name = name.replace(weight_name, param_name)
|
307
|
-
param = params_dict[name]
|
308
|
-
weight_loader = param.weight_loader
|
309
|
-
weight_loader(
|
310
|
-
param,
|
311
|
-
loaded_weight,
|
312
|
-
name,
|
313
|
-
shard_id=shard_id,
|
314
|
-
expert_id=expert_id,
|
315
|
-
)
|
316
|
-
break
|
317
|
-
else:
|
318
|
-
# Skip loading extra bias for GPTQ models.
|
319
|
-
if name.endswith(".bias") and name not in params_dict:
|
320
|
-
continue
|
321
|
-
|
322
|
-
# Handle fused_qkv_a_proj
|
323
|
-
if fuse_qkv_a_proj and (
|
324
|
-
"q_a_proj" in name or "kv_a_proj_with_mqa" in name
|
325
|
-
):
|
326
|
-
cached_a_proj[name] = loaded_weight
|
327
|
-
q_a_proj_name = (
|
328
|
-
name
|
329
|
-
if "q_a_proj" in name
|
330
|
-
else name.replace("kv_a_proj_with_mqa", "q_a_proj")
|
331
|
-
)
|
332
|
-
kv_a_proj_name = (
|
333
|
-
name
|
334
|
-
if "kv_a_proj_with_mqa" in name
|
335
|
-
else name.replace("q_a_proj", "kv_a_proj_with_mqa")
|
336
|
-
)
|
337
|
-
|
338
|
-
# When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
|
339
|
-
if (
|
340
|
-
q_a_proj_name in cached_a_proj
|
341
|
-
and kv_a_proj_name in cached_a_proj
|
342
|
-
):
|
343
|
-
|
344
|
-
q_a_proj_weight = cached_a_proj[q_a_proj_name]
|
345
|
-
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
|
346
|
-
fused_weight = torch.cat(
|
347
|
-
[q_a_proj_weight, kv_a_proj_weight], dim=0
|
348
|
-
)
|
349
|
-
|
350
|
-
param_name = name.replace(
|
351
|
-
"q_a_proj", "fused_qkv_a_proj_with_mqa"
|
352
|
-
)
|
353
|
-
param = params_dict[param_name]
|
354
|
-
|
355
|
-
weight_loader = getattr(
|
356
|
-
param, "weight_loader", default_weight_loader
|
357
|
-
)
|
358
|
-
weight_loader(param, fused_weight)
|
359
|
-
cached_a_proj.pop(q_a_proj_name)
|
360
|
-
cached_a_proj.pop(kv_a_proj_name)
|
361
|
-
else:
|
362
|
-
param = params_dict[name]
|
363
|
-
weight_loader = getattr(
|
364
|
-
param, "weight_loader", default_weight_loader
|
365
|
-
)
|
366
|
-
weight_loader(param, loaded_weight)
|
367
|
-
|
368
|
-
self_attn = self.model.decoder.self_attn
|
369
|
-
if hasattr(self_attn.kv_b_proj, "qweight"):
|
370
|
-
# AWQ compatible
|
371
|
-
if _is_cuda:
|
372
|
-
w = awq_dequantize(
|
373
|
-
self_attn.kv_b_proj.qweight,
|
374
|
-
self_attn.kv_b_proj.scales,
|
375
|
-
self_attn.kv_b_proj.qzeros,
|
376
|
-
).T
|
377
|
-
else:
|
378
|
-
w = awq_dequantize(
|
379
|
-
self_attn.kv_b_proj.qweight,
|
380
|
-
self_attn.kv_b_proj.scales,
|
381
|
-
self_attn.kv_b_proj.qzeros,
|
382
|
-
0,
|
383
|
-
0,
|
384
|
-
0,
|
385
|
-
).T
|
386
|
-
else:
|
387
|
-
w = self_attn.kv_b_proj.weight
|
388
|
-
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
|
389
|
-
# This may affect the accuracy of fp8 model.
|
390
|
-
if hasattr(self.quant_config, "weight_block_size") and w.dtype in (
|
391
|
-
torch.float8_e4m3fn,
|
392
|
-
torch.float8_e4m3fnuz,
|
393
|
-
):
|
394
|
-
weight_block_size = self.quant_config.weight_block_size
|
395
|
-
if weight_block_size is not None:
|
396
|
-
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
397
|
-
if _is_hip:
|
398
|
-
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
399
|
-
weight=w,
|
400
|
-
weight_scale=self_attn.kv_b_proj.weight_scale_inv,
|
401
|
-
input_scale=None,
|
402
|
-
)
|
403
|
-
else:
|
404
|
-
weight = w
|
405
|
-
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
406
|
-
|
407
|
-
w, scale = block_quant_to_tensor_quant(
|
408
|
-
weight, weight_scale, weight_block_size
|
409
|
-
)
|
410
|
-
self_attn.w_scale = scale
|
411
|
-
if w.dtype == torch.int8:
|
412
|
-
if hasattr(self.quant_config, "weight_block_size"):
|
413
|
-
# block-wise int8 need it
|
414
|
-
weight_block_size = self.quant_config.weight_block_size
|
415
|
-
if weight_block_size is not None:
|
416
|
-
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
417
|
-
weight = w
|
418
|
-
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
419
|
-
w = int8_block_dequant(weight, weight_scale, weight_block_size).to(
|
420
|
-
torch.bfloat16
|
421
|
-
)
|
422
|
-
else:
|
423
|
-
# channel-wise int8 need it
|
424
|
-
assert hasattr(self_attn.kv_b_proj, "weight_scale")
|
425
|
-
w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
|
426
|
-
torch.bfloat16
|
427
|
-
)
|
428
|
-
w_kc, w_vc = w.unflatten(
|
429
|
-
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
430
|
-
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
431
|
-
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
432
|
-
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
|
433
|
-
if hasattr(self_attn.kv_b_proj, "weight_scale") and self_attn.w_scale is None:
|
434
|
-
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
|
435
|
-
if _is_hip:
|
436
|
-
self_attn.w_scale *= 2.0
|
180
|
+
super().load_weights(weights, is_nextn=True)
|
437
181
|
|
438
182
|
|
439
183
|
EntryClass = [DeepseekV3ForCausalLMNextN]
|