sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +119 -17
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +42 -7
- sglang/srt/conversation.py +9 -5
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +14 -4
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
- sglang/srt/disaggregation/mooncake/conn.py +286 -160
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/disaggregation/prefill.py +2 -0
- sglang/srt/distributed/parallel_state.py +15 -11
- sglang/srt/entrypoints/context.py +227 -0
- sglang/srt/entrypoints/engine.py +15 -9
- sglang/srt/entrypoints/harmony_utils.py +372 -0
- sglang/srt/entrypoints/http_server.py +74 -4
- sglang/srt/entrypoints/openai/protocol.py +218 -1
- sglang/srt/entrypoints/openai/serving_chat.py +41 -11
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +175 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/hf_transformers_utils.py +30 -3
- sglang/srt/jinja_template_utils.py +14 -1
- sglang/srt/layers/attention/aiter_backend.py +375 -115
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +52 -13
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
- sglang/srt/layers/attention/vision.py +22 -6
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +29 -14
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +3 -7
- sglang/srt/layers/moe/cutlass_moe.py +12 -3
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +135 -73
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +16 -4
- sglang/srt/layers/moe/utils.py +16 -0
- sglang/srt/layers/quantization/__init__.py +27 -3
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +3 -6
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +51 -10
- sglang/srt/layers/quantization/modelopt_quant.py +258 -68
- sglang/srt/layers/quantization/mxfp4.py +654 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +21 -12
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +506 -3
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +8 -3
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +60 -114
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +82 -62
- sglang/srt/lora/lora_registry.py +23 -11
- sglang/srt/lora/mem_pool.py +63 -68
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +75 -58
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +20 -8
- sglang/srt/managers/mm_utils.py +6 -13
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +61 -25
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +41 -19
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +35 -1
- sglang/srt/managers/tokenizer_manager.py +47 -30
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/mem_cache/allocator.py +61 -87
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +80 -22
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +34 -36
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +29 -9
- sglang/srt/model_executor/forward_batch_info.py +61 -19
- sglang/srt/model_executor/model_runner.py +148 -37
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +137 -59
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +38 -0
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +28 -16
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +1251 -0
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +0 -25
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen2_moe.py +6 -0
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/qwen3_moe.py +32 -6
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/step3_vl.py +9 -0
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/reasoning_parser.py +332 -37
- sglang/srt/server_args.py +186 -75
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +169 -9
- sglang/srt/utils.py +41 -5
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/runners.py +2 -2
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/test/test_utils.py +1 -1
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,400 @@
|
|
1
|
+
import logging
|
2
|
+
from functools import lru_cache
|
3
|
+
from typing import Iterable, Optional, Tuple
|
4
|
+
|
5
|
+
import torch
|
6
|
+
import torch.nn as nn
|
7
|
+
from transformers.models.glm4v_moe.configuration_glm4v_moe import Glm4vMoeConfig
|
8
|
+
|
9
|
+
from sglang.srt.distributed import (
|
10
|
+
get_moe_expert_parallel_world_size,
|
11
|
+
get_tensor_model_parallel_rank,
|
12
|
+
get_tensor_model_parallel_world_size,
|
13
|
+
parallel_state,
|
14
|
+
tensor_model_parallel_all_reduce,
|
15
|
+
)
|
16
|
+
from sglang.srt.hf_transformers_utils import get_processor
|
17
|
+
from sglang.srt.layers.dp_attention import (
|
18
|
+
get_attention_tp_rank,
|
19
|
+
get_attention_tp_size,
|
20
|
+
get_local_attention_dp_size,
|
21
|
+
)
|
22
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
23
|
+
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
24
|
+
from sglang.srt.layers.pooler import Pooler, PoolingType
|
25
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
26
|
+
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
27
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
28
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
29
|
+
from sglang.srt.models.glm4_moe import Glm4MoeModel
|
30
|
+
from sglang.srt.models.glm4v import Glm4vForConditionalGeneration, Glm4vVisionModel
|
31
|
+
from sglang.srt.utils import add_prefix, is_cuda, log_info_on_rank0
|
32
|
+
|
33
|
+
_is_cuda = is_cuda()
|
34
|
+
|
35
|
+
logger = logging.getLogger(__name__)
|
36
|
+
|
37
|
+
cached_get_processor = lru_cache(get_processor)
|
38
|
+
|
39
|
+
|
40
|
+
class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
|
41
|
+
def __init__(
|
42
|
+
self,
|
43
|
+
config: Glm4vMoeConfig,
|
44
|
+
quant_config: Optional[QuantizationConfig] = None,
|
45
|
+
prefix: str = "",
|
46
|
+
) -> None:
|
47
|
+
nn.Module.__init__(self)
|
48
|
+
|
49
|
+
config.moe_layer_freq = 1
|
50
|
+
self.config = config
|
51
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
52
|
+
self.dp_size = get_local_attention_dp_size()
|
53
|
+
self.quant_config = quant_config
|
54
|
+
self.determine_num_fused_shared_experts("Glm4MoeForCausalLM")
|
55
|
+
self.num_fused_shared_experts = (
|
56
|
+
0
|
57
|
+
if global_server_args_dict["disable_shared_experts_fusion"]
|
58
|
+
else config.n_shared_experts
|
59
|
+
)
|
60
|
+
|
61
|
+
self.model = Glm4MoeModel(
|
62
|
+
config,
|
63
|
+
quant_config,
|
64
|
+
prefix=add_prefix("language_model", prefix),
|
65
|
+
)
|
66
|
+
self.visual = Glm4vVisionModel(
|
67
|
+
config.vision_config,
|
68
|
+
norm_eps=getattr(config, "rms_norm_eps", 1e-5),
|
69
|
+
quant_config=quant_config,
|
70
|
+
prefix=add_prefix("visual", prefix),
|
71
|
+
)
|
72
|
+
|
73
|
+
self.lm_head = ParallelLMHead(
|
74
|
+
config.vocab_size,
|
75
|
+
config.hidden_size,
|
76
|
+
quant_config=quant_config,
|
77
|
+
prefix=add_prefix("lm_head", prefix),
|
78
|
+
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
79
|
+
)
|
80
|
+
self.logits_processor = LogitsProcessor(config)
|
81
|
+
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
82
|
+
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
|
83
|
+
|
84
|
+
def determine_num_fused_shared_experts(
|
85
|
+
self, architecture: str = "Glm4MoeForCausalLM"
|
86
|
+
):
|
87
|
+
self.num_fused_shared_experts = 0
|
88
|
+
if global_server_args_dict["disable_shared_experts_fusion"]:
|
89
|
+
return
|
90
|
+
|
91
|
+
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
|
92
|
+
disable_reason = None
|
93
|
+
if (
|
94
|
+
not _is_cuda
|
95
|
+
or torch.cuda.get_device_capability("cuda") < (8, 0)
|
96
|
+
or self.config.architectures[0] != architecture
|
97
|
+
or self.config.n_shared_experts != 1
|
98
|
+
):
|
99
|
+
disable_reason = "Only GLM-4.5 on NV-platform with capability >= 80 can use shared experts fusion optimization."
|
100
|
+
elif get_moe_expert_parallel_world_size() > 1:
|
101
|
+
disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization under expert parallelism."
|
102
|
+
|
103
|
+
if disable_reason is not None:
|
104
|
+
global_server_args_dict["disable_shared_experts_fusion"] = True
|
105
|
+
self.num_fused_shared_experts = 0
|
106
|
+
log_info_on_rank0(
|
107
|
+
logger,
|
108
|
+
f"{disable_reason} Shared experts fusion optimization is disabled.",
|
109
|
+
)
|
110
|
+
return
|
111
|
+
|
112
|
+
self.num_fused_shared_experts = self.config.n_shared_experts
|
113
|
+
|
114
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
|
115
|
+
|
116
|
+
if is_nextn:
|
117
|
+
if hasattr(self.config, "num_nextn_predict_layers"):
|
118
|
+
num_nextn_layers = self.config.num_nextn_predict_layers
|
119
|
+
assert num_nextn_layers == 1, "Only 1 nextn layer is supported"
|
120
|
+
# compatible with old design
|
121
|
+
nextn_layer_id = (
|
122
|
+
0
|
123
|
+
if self.config.num_hidden_layers == 1
|
124
|
+
else self.config.num_hidden_layers
|
125
|
+
)
|
126
|
+
else:
|
127
|
+
raise ValueError("num_nextn_predict_layers is not in the config")
|
128
|
+
|
129
|
+
stacked_params_mapping = [
|
130
|
+
# (param_name, shard_name, shard_id)
|
131
|
+
("qkv_proj", "q_proj", "q"),
|
132
|
+
("qkv_proj", "k_proj", "k"),
|
133
|
+
("qkv_proj", "v_proj", "v"),
|
134
|
+
("gate_up_proj", "gate_proj", 0),
|
135
|
+
("gate_up_proj", "up_proj", 1),
|
136
|
+
]
|
137
|
+
if self.num_fused_shared_experts > 0:
|
138
|
+
assert self.num_fused_shared_experts == 1
|
139
|
+
weights_list = list(weights)
|
140
|
+
weights_dict = dict(weights_list)
|
141
|
+
if self.quant_config is not None:
|
142
|
+
if self.quant_config.get_name() == "w8a8_int8":
|
143
|
+
suffix_list = [
|
144
|
+
"down_proj.weight",
|
145
|
+
"down_proj.weight_scale",
|
146
|
+
"gate_proj.weight",
|
147
|
+
"gate_proj.weight_scale",
|
148
|
+
"up_proj.weight",
|
149
|
+
"up_proj.weight_scale",
|
150
|
+
]
|
151
|
+
elif (
|
152
|
+
self.quant_config.get_name() == "fp8"
|
153
|
+
or self.quant_config.get_name() == "blockwise_int8"
|
154
|
+
or self.quant_config.get_name() == "compressed_tensors"
|
155
|
+
):
|
156
|
+
suffix_list = [
|
157
|
+
"down_proj.weight",
|
158
|
+
"down_proj.weight_scale",
|
159
|
+
"gate_proj.weight",
|
160
|
+
"gate_proj.weight_scale",
|
161
|
+
"up_proj.weight",
|
162
|
+
"up_proj.weight_scale",
|
163
|
+
]
|
164
|
+
elif self.quant_config.get_name() == "awq":
|
165
|
+
suffix_list = [
|
166
|
+
"down_proj.qweight",
|
167
|
+
"down_proj.qzeros",
|
168
|
+
"down_proj.scales",
|
169
|
+
"gate_proj.qweight",
|
170
|
+
"gate_proj.qzeros",
|
171
|
+
"gate_proj.scales",
|
172
|
+
"up_proj.qweight",
|
173
|
+
"up_proj.qzeros",
|
174
|
+
"up_proj.scales",
|
175
|
+
]
|
176
|
+
elif self.quant_config.get_name() == "modelopt_fp4":
|
177
|
+
suffix_list = [
|
178
|
+
"down_proj.weight",
|
179
|
+
"down_proj.weight_scale",
|
180
|
+
"down_proj.weight_scale_2",
|
181
|
+
"down_proj.input_scale",
|
182
|
+
"gate_proj.weight",
|
183
|
+
"gate_proj.weight_scale",
|
184
|
+
"gate_proj.weight_scale_2",
|
185
|
+
"gate_proj.input_scale",
|
186
|
+
"up_proj.weight",
|
187
|
+
"up_proj.weight_scale",
|
188
|
+
"up_proj.weight_scale_2",
|
189
|
+
"up_proj.input_scale",
|
190
|
+
]
|
191
|
+
else:
|
192
|
+
raise ValueError(
|
193
|
+
f"Unsupported shared expert fusion for quantization: {self.quant_config.get_name()}."
|
194
|
+
)
|
195
|
+
else:
|
196
|
+
suffix_list = [
|
197
|
+
"down_proj.weight",
|
198
|
+
"gate_proj.weight",
|
199
|
+
"up_proj.weight",
|
200
|
+
]
|
201
|
+
names_to_remove = []
|
202
|
+
|
203
|
+
moe_layers = (
|
204
|
+
range(
|
205
|
+
self.config.first_k_dense_replace,
|
206
|
+
self.config.num_hidden_layers,
|
207
|
+
self.config.moe_layer_freq,
|
208
|
+
)
|
209
|
+
if not is_nextn
|
210
|
+
else [nextn_layer_id]
|
211
|
+
)
|
212
|
+
|
213
|
+
for moe_layer in moe_layers:
|
214
|
+
for suffix in suffix_list:
|
215
|
+
shared_expert_weight_name = (
|
216
|
+
f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
|
217
|
+
)
|
218
|
+
# online fp8 quantization does not load weight_scale
|
219
|
+
if shared_expert_weight_name not in weights_dict:
|
220
|
+
continue
|
221
|
+
weights_list.append(
|
222
|
+
(
|
223
|
+
f"model.layers.{moe_layer}."
|
224
|
+
f"mlp.experts."
|
225
|
+
f"{self.config.n_routed_experts + 0}"
|
226
|
+
f".{suffix}",
|
227
|
+
weights_dict[shared_expert_weight_name],
|
228
|
+
)
|
229
|
+
)
|
230
|
+
names_to_remove += [shared_expert_weight_name]
|
231
|
+
weights = [w for w in weights_list if w[0] not in names_to_remove]
|
232
|
+
|
233
|
+
# Params for weights, fp8 weight scales, fp8 activation scales
|
234
|
+
# (param_name, weight_name, expert_id, shard_id)
|
235
|
+
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
|
236
|
+
ckpt_gate_proj_name="gate_proj",
|
237
|
+
ckpt_down_proj_name="down_proj",
|
238
|
+
ckpt_up_proj_name="up_proj",
|
239
|
+
num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
|
240
|
+
)
|
241
|
+
|
242
|
+
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
|
243
|
+
fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
|
244
|
+
self.config.q_lora_rank is not None
|
245
|
+
)
|
246
|
+
cached_a_proj = {} if fuse_qkv_a_proj else None
|
247
|
+
|
248
|
+
if is_nextn:
|
249
|
+
nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
|
250
|
+
nextn_spec_weight_names = [
|
251
|
+
"shared_head.norm",
|
252
|
+
"eh_proj",
|
253
|
+
"enorm",
|
254
|
+
"hnorm",
|
255
|
+
]
|
256
|
+
|
257
|
+
params_dict = dict(self.named_parameters())
|
258
|
+
weight_names = []
|
259
|
+
for name, loaded_weight in weights:
|
260
|
+
weight_names.append(name)
|
261
|
+
|
262
|
+
if not is_nextn:
|
263
|
+
if hasattr(self.config, "num_nextn_predict_layers"):
|
264
|
+
num_nextn_layers = self.config.num_nextn_predict_layers
|
265
|
+
if num_nextn_layers > 0 and name.startswith("model.layers"):
|
266
|
+
name_list = name.split(".")
|
267
|
+
if (
|
268
|
+
len(name_list) >= 3
|
269
|
+
and int(name_list[2]) >= self.config.num_hidden_layers
|
270
|
+
):
|
271
|
+
continue
|
272
|
+
else:
|
273
|
+
if not name.startswith(nextn_layer_prefix):
|
274
|
+
continue
|
275
|
+
|
276
|
+
# Use shared head and embed weights from target model
|
277
|
+
if "shared_head.head" in name or "embed_tokens" in name:
|
278
|
+
continue
|
279
|
+
|
280
|
+
is_decoder = True
|
281
|
+
# For nextn specific weights
|
282
|
+
for weight_name in nextn_spec_weight_names:
|
283
|
+
if weight_name in name:
|
284
|
+
name = name.replace(nextn_layer_prefix, "model")
|
285
|
+
is_decoder = False
|
286
|
+
break
|
287
|
+
# For decoder layer weights
|
288
|
+
if is_decoder:
|
289
|
+
name = name.replace(nextn_layer_prefix, "model.decoder")
|
290
|
+
|
291
|
+
if "language_model." in name:
|
292
|
+
name = name.replace("language_model.", "")
|
293
|
+
if "model.visual." in name:
|
294
|
+
name = name.replace("model.visual.", "visual.")
|
295
|
+
if "rotary_emb.inv_freq" in name:
|
296
|
+
continue
|
297
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
298
|
+
# Skip non-stacked layers and experts (experts handled below).
|
299
|
+
if weight_name not in name:
|
300
|
+
continue
|
301
|
+
# We have mlp.experts[0].gate_proj in the checkpoint.
|
302
|
+
# Since we handle the experts below in expert_params_mapping,
|
303
|
+
# we need to skip here BEFORE we update the name, otherwise
|
304
|
+
# name will be updated to mlp.experts[0].gate_up_proj, which
|
305
|
+
# will then be updated below in expert_params_mapping
|
306
|
+
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
307
|
+
if ("mlp.experts." in name) and name not in params_dict:
|
308
|
+
continue
|
309
|
+
name = name.replace(weight_name, param_name)
|
310
|
+
# Skip loading extra bias for GPTQ models.
|
311
|
+
if name.endswith(".bias") and name not in params_dict:
|
312
|
+
continue
|
313
|
+
param = params_dict[name]
|
314
|
+
|
315
|
+
weight_loader = param.weight_loader
|
316
|
+
weight_loader(param, loaded_weight, shard_id)
|
317
|
+
break
|
318
|
+
else:
|
319
|
+
for mapping in expert_params_mapping:
|
320
|
+
param_name, weight_name, expert_id, shard_id = mapping
|
321
|
+
if weight_name not in name:
|
322
|
+
continue
|
323
|
+
name = name.replace(weight_name, param_name)
|
324
|
+
param = params_dict[name]
|
325
|
+
weight_loader = param.weight_loader
|
326
|
+
weight_loader(
|
327
|
+
param,
|
328
|
+
loaded_weight,
|
329
|
+
name,
|
330
|
+
shard_id=shard_id,
|
331
|
+
expert_id=expert_id,
|
332
|
+
)
|
333
|
+
break
|
334
|
+
else:
|
335
|
+
if "visual" in name:
|
336
|
+
# adapt to VisionAttention
|
337
|
+
name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
|
338
|
+
|
339
|
+
# Skip loading extra bias for GPTQ models.
|
340
|
+
if name.endswith(".bias") and name not in params_dict:
|
341
|
+
continue
|
342
|
+
if fuse_qkv_a_proj and (
|
343
|
+
"q_a_proj" in name or "kv_a_proj_with_mqa" in name
|
344
|
+
):
|
345
|
+
cached_a_proj[name] = loaded_weight
|
346
|
+
q_a_proj_name = (
|
347
|
+
name
|
348
|
+
if "q_a_proj" in name
|
349
|
+
else name.replace("kv_a_proj_with_mqa", "q_a_proj")
|
350
|
+
)
|
351
|
+
kv_a_proj_name = (
|
352
|
+
name
|
353
|
+
if "kv_a_proj_with_mqa" in name
|
354
|
+
else name.replace("q_a_proj", "kv_a_proj_with_mqa")
|
355
|
+
)
|
356
|
+
|
357
|
+
# When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
|
358
|
+
if (
|
359
|
+
q_a_proj_name in cached_a_proj
|
360
|
+
and kv_a_proj_name in cached_a_proj
|
361
|
+
):
|
362
|
+
q_a_proj_weight = cached_a_proj[q_a_proj_name]
|
363
|
+
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
|
364
|
+
fused_weight = torch.cat(
|
365
|
+
[q_a_proj_weight, kv_a_proj_weight], dim=0
|
366
|
+
)
|
367
|
+
param_name = (
|
368
|
+
name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
|
369
|
+
if "q_a_proj" in name
|
370
|
+
else name.replace(
|
371
|
+
"kv_a_proj_with_mqa", "fused_qkv_a_proj_with_mqa"
|
372
|
+
)
|
373
|
+
)
|
374
|
+
param = params_dict[param_name]
|
375
|
+
|
376
|
+
weight_loader = getattr(
|
377
|
+
param, "weight_loader", default_weight_loader
|
378
|
+
)
|
379
|
+
weight_loader(param, fused_weight)
|
380
|
+
cached_a_proj.pop(q_a_proj_name)
|
381
|
+
cached_a_proj.pop(kv_a_proj_name)
|
382
|
+
else:
|
383
|
+
if (
|
384
|
+
"k_scale" in name or "v_scale" in name
|
385
|
+
) and name not in params_dict:
|
386
|
+
# modelopt attn kv scale is named differently
|
387
|
+
if any(scale in name for scale in ["k_scale", "v_scale"]):
|
388
|
+
name = name.replace("_proj", "attn_mqa")
|
389
|
+
else:
|
390
|
+
logger.warning(
|
391
|
+
f"Unknown scale found in checkpoint: {name}"
|
392
|
+
)
|
393
|
+
param = params_dict[name]
|
394
|
+
weight_loader = getattr(
|
395
|
+
param, "weight_loader", default_weight_loader
|
396
|
+
)
|
397
|
+
weight_loader(param, loaded_weight)
|
398
|
+
|
399
|
+
|
400
|
+
EntryClass = [Glm4vMoeForConditionalGeneration]
|