sglang 0.4.9.post4__py3-none-any.whl → 0.4.9.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/lang/chat_template.py +21 -0
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/model_config.py +4 -0
- sglang/srt/constrained/base_grammar_backend.py +10 -2
- sglang/srt/constrained/xgrammar_backend.py +7 -5
- sglang/srt/conversation.py +16 -1
- sglang/srt/debug_utils/__init__.py +0 -0
- sglang/srt/debug_utils/dump_comparator.py +131 -0
- sglang/srt/debug_utils/dumper.py +108 -0
- sglang/srt/debug_utils/text_comparator.py +172 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
- sglang/srt/disaggregation/mooncake/conn.py +16 -0
- sglang/srt/disaggregation/prefill.py +13 -1
- sglang/srt/entrypoints/engine.py +4 -2
- sglang/srt/entrypoints/openai/serving_chat.py +132 -79
- sglang/srt/function_call/ebnf_composer.py +10 -3
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +164 -0
- sglang/srt/function_call/qwen3_coder_detector.py +1 -0
- sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
- sglang/srt/layers/attention/vision.py +56 -8
- sglang/srt/layers/layernorm.py +26 -1
- sglang/srt/layers/logits_processor.py +14 -3
- sglang/srt/layers/moe/ep_moe/layer.py +172 -206
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
- sglang/srt/layers/moe/topk.py +84 -22
- sglang/srt/layers/multimodal.py +11 -8
- sglang/srt/layers/quantization/fp8.py +25 -247
- sglang/srt/layers/quantization/fp8_kernel.py +78 -48
- sglang/srt/layers/quantization/modelopt_quant.py +25 -10
- sglang/srt/layers/quantization/unquant.py +24 -76
- sglang/srt/layers/quantization/w4afp8.py +68 -17
- sglang/srt/lora/lora_registry.py +93 -29
- sglang/srt/managers/cache_controller.py +9 -7
- sglang/srt/managers/mm_utils.py +154 -35
- sglang/srt/managers/multimodal_processor.py +3 -14
- sglang/srt/managers/schedule_batch.py +14 -8
- sglang/srt/managers/scheduler.py +35 -1
- sglang/srt/managers/tokenizer_manager.py +37 -6
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/mem_cache/hiradix_cache.py +5 -2
- sglang/srt/model_executor/model_runner.py +68 -14
- sglang/srt/models/deepseek_v2.py +62 -28
- sglang/srt/models/glm4_moe.py +1035 -0
- sglang/srt/models/glm4_moe_nextn.py +167 -0
- sglang/srt/models/interns1.py +328 -0
- sglang/srt/models/internvl.py +143 -47
- sglang/srt/models/llava.py +9 -5
- sglang/srt/models/minicpmo.py +4 -1
- sglang/srt/models/qwen2_moe.py +2 -2
- sglang/srt/models/qwen3_moe.py +5 -2
- sglang/srt/multimodal/processors/base_processor.py +20 -6
- sglang/srt/multimodal/processors/clip.py +2 -2
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
- sglang/srt/multimodal/processors/gemma3.py +2 -2
- sglang/srt/multimodal/processors/gemma3n.py +2 -2
- sglang/srt/multimodal/processors/internvl.py +21 -8
- sglang/srt/multimodal/processors/janus_pro.py +2 -2
- sglang/srt/multimodal/processors/kimi_vl.py +2 -2
- sglang/srt/multimodal/processors/llava.py +4 -4
- sglang/srt/multimodal/processors/minicpm.py +2 -3
- sglang/srt/multimodal/processors/mlama.py +2 -2
- sglang/srt/multimodal/processors/mllama4.py +18 -111
- sglang/srt/multimodal/processors/phi4mm.py +2 -2
- sglang/srt/multimodal/processors/pixtral.py +2 -2
- sglang/srt/multimodal/processors/qwen_audio.py +2 -2
- sglang/srt/multimodal/processors/qwen_vl.py +2 -2
- sglang/srt/multimodal/processors/vila.py +3 -1
- sglang/srt/reasoning_parser.py +2 -1
- sglang/srt/server_args.py +57 -6
- sglang/srt/utils.py +96 -1
- sglang/srt/weight_sync/utils.py +119 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_utils.py +65 -5
- sglang/utils.py +19 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post5.dist-info}/METADATA +4 -4
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post5.dist-info}/RECORD +83 -73
- sglang/srt/debug_utils.py +0 -74
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post5.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,167 @@
|
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
|
15
|
+
"""Inference-only GLM-4.5 NextN Speculative Decoding."""
|
16
|
+
import logging
|
17
|
+
from typing import Iterable, Optional, Tuple
|
18
|
+
|
19
|
+
import torch
|
20
|
+
from torch import nn
|
21
|
+
from transformers import PretrainedConfig
|
22
|
+
|
23
|
+
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
24
|
+
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
25
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
26
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
27
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
28
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
29
|
+
ParallelLMHead,
|
30
|
+
VocabParallelEmbedding,
|
31
|
+
)
|
32
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
33
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
34
|
+
from sglang.srt.models.glm4_moe import Glm4MoeDecoderLayer, Glm4MoeForCausalLM
|
35
|
+
from sglang.srt.utils import BumpAllocator, add_prefix
|
36
|
+
|
37
|
+
logger = logging.getLogger(__name__)
|
38
|
+
|
39
|
+
|
40
|
+
class Glm4MoeModelNextN(nn.Module):
|
41
|
+
def __init__(
|
42
|
+
self,
|
43
|
+
config: PretrainedConfig,
|
44
|
+
quant_config: Optional[QuantizationConfig] = None,
|
45
|
+
prefix: str = "",
|
46
|
+
) -> None:
|
47
|
+
super().__init__()
|
48
|
+
if quant_config is not None and quant_config.get_name() == "modelopt_fp4":
|
49
|
+
logger.warning(
|
50
|
+
"Overriding Glm4MoeForCausalLMNextN quant config for modelopt_fp4 GLM-4.5 model."
|
51
|
+
)
|
52
|
+
quant_config = None
|
53
|
+
|
54
|
+
self.vocab_size = config.vocab_size
|
55
|
+
|
56
|
+
self.embed_tokens = VocabParallelEmbedding(
|
57
|
+
config.vocab_size,
|
58
|
+
config.hidden_size,
|
59
|
+
enable_tp=not global_server_args_dict["enable_dp_attention"],
|
60
|
+
prefix=add_prefix("embed_tokens", prefix),
|
61
|
+
)
|
62
|
+
|
63
|
+
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
64
|
+
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
65
|
+
|
66
|
+
self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False)
|
67
|
+
|
68
|
+
self.decoder = Glm4MoeDecoderLayer(
|
69
|
+
config,
|
70
|
+
0,
|
71
|
+
quant_config=quant_config,
|
72
|
+
is_nextn=True,
|
73
|
+
prefix=add_prefix("decoder", prefix),
|
74
|
+
)
|
75
|
+
|
76
|
+
self.shared_head = nn.Module()
|
77
|
+
self.shared_head.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
78
|
+
|
79
|
+
def forward(
|
80
|
+
self,
|
81
|
+
input_ids: torch.Tensor,
|
82
|
+
positions: torch.Tensor,
|
83
|
+
forward_batch: ForwardBatch,
|
84
|
+
input_embeds: torch.Tensor = None,
|
85
|
+
) -> torch.Tensor:
|
86
|
+
zero_allocator = BumpAllocator(
|
87
|
+
buffer_size=2,
|
88
|
+
dtype=torch.float32,
|
89
|
+
device=(
|
90
|
+
input_embeds.device if input_embeds is not None else input_ids.device
|
91
|
+
),
|
92
|
+
)
|
93
|
+
|
94
|
+
if input_embeds is None:
|
95
|
+
hidden_states = self.embed_tokens(input_ids)
|
96
|
+
else:
|
97
|
+
hidden_states = input_embeds
|
98
|
+
|
99
|
+
if hidden_states.shape[0] > 0:
|
100
|
+
hidden_states = self.eh_proj(
|
101
|
+
torch.cat(
|
102
|
+
(
|
103
|
+
self.enorm(hidden_states),
|
104
|
+
self.hnorm(forward_batch.spec_info.hidden_states),
|
105
|
+
),
|
106
|
+
dim=-1,
|
107
|
+
)
|
108
|
+
)
|
109
|
+
|
110
|
+
residual = None
|
111
|
+
with get_global_expert_distribution_recorder().disable_this_region():
|
112
|
+
hidden_states, residual = self.decoder(
|
113
|
+
positions, hidden_states, forward_batch, residual, zero_allocator
|
114
|
+
)
|
115
|
+
|
116
|
+
if not forward_batch.forward_mode.is_idle():
|
117
|
+
if residual is not None:
|
118
|
+
hidden_states, _ = self.shared_head.norm(hidden_states, residual)
|
119
|
+
else:
|
120
|
+
hidden_states = self.shared_head.norm(hidden_states)
|
121
|
+
|
122
|
+
return hidden_states
|
123
|
+
|
124
|
+
|
125
|
+
class Glm4MoeForCausalLMNextN(Glm4MoeForCausalLM):
|
126
|
+
|
127
|
+
def __init__(
|
128
|
+
self,
|
129
|
+
config: PretrainedConfig,
|
130
|
+
quant_config: Optional[QuantizationConfig] = None,
|
131
|
+
prefix: str = "",
|
132
|
+
) -> None:
|
133
|
+
nn.Module.__init__(self)
|
134
|
+
self.config = config
|
135
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
136
|
+
self.quant_config = quant_config
|
137
|
+
self.determine_num_fused_shared_experts("Glm4MoeForCausalLMNextN")
|
138
|
+
|
139
|
+
self.model = Glm4MoeModelNextN(
|
140
|
+
config, quant_config, prefix=add_prefix("model", prefix)
|
141
|
+
)
|
142
|
+
self.lm_head = ParallelLMHead(
|
143
|
+
config.vocab_size,
|
144
|
+
config.hidden_size,
|
145
|
+
quant_config=quant_config,
|
146
|
+
prefix=add_prefix("model.shared_head.head", prefix),
|
147
|
+
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
148
|
+
)
|
149
|
+
self.logits_processor = LogitsProcessor(config)
|
150
|
+
|
151
|
+
@torch.no_grad()
|
152
|
+
def forward(
|
153
|
+
self,
|
154
|
+
input_ids: torch.Tensor,
|
155
|
+
positions: torch.Tensor,
|
156
|
+
forward_batch: ForwardBatch,
|
157
|
+
) -> torch.Tensor:
|
158
|
+
hidden_states = self.model(input_ids, positions, forward_batch)
|
159
|
+
return self.logits_processor(
|
160
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
161
|
+
)
|
162
|
+
|
163
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
164
|
+
super().load_weights(weights, is_nextn=True)
|
165
|
+
|
166
|
+
|
167
|
+
EntryClass = [Glm4MoeForCausalLMNextN]
|
@@ -0,0 +1,328 @@
|
|
1
|
+
from typing import Iterable, List, Optional, Set, Tuple
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from torch import nn
|
5
|
+
from transformers import PretrainedConfig
|
6
|
+
|
7
|
+
from sglang.srt.distributed import parallel_state
|
8
|
+
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
9
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
10
|
+
from sglang.srt.managers.mm_utils import (
|
11
|
+
MultiModalityDataPaddingPatternTokenPairs,
|
12
|
+
general_mm_embed_routine,
|
13
|
+
)
|
14
|
+
from sglang.srt.managers.schedule_batch import (
|
15
|
+
Modality,
|
16
|
+
MultimodalDataItem,
|
17
|
+
MultimodalInputs,
|
18
|
+
)
|
19
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
20
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
21
|
+
from sglang.srt.models.internvl import InternVisionModel
|
22
|
+
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
23
|
+
from sglang.srt.models.qwen3_moe import Qwen3MoeForCausalLM
|
24
|
+
from sglang.utils import logger
|
25
|
+
|
26
|
+
|
27
|
+
class InternS1ForConditionalGeneration(nn.Module):
|
28
|
+
def __init__(
|
29
|
+
self,
|
30
|
+
config: PretrainedConfig,
|
31
|
+
quant_config: Optional[QuantizationConfig] = None,
|
32
|
+
use_flash_attn=True,
|
33
|
+
) -> None:
|
34
|
+
super().__init__()
|
35
|
+
self.config = config
|
36
|
+
self.quant_config = quant_config
|
37
|
+
self._update_hf_config()
|
38
|
+
image_size = (
|
39
|
+
getattr(config, "force_image_size", None) or config.vision_config.image_size
|
40
|
+
)
|
41
|
+
patch_size = config.vision_config.patch_size
|
42
|
+
if isinstance(image_size, list):
|
43
|
+
image_size = image_size[0]
|
44
|
+
if isinstance(patch_size, list):
|
45
|
+
patch_size = patch_size[0]
|
46
|
+
self.patch_size = patch_size
|
47
|
+
self.select_layer = config.vision_feature_layer
|
48
|
+
self.num_image_token = int(
|
49
|
+
(image_size // patch_size) ** 2 * (config.downsample_ratio**2)
|
50
|
+
)
|
51
|
+
self.downsample_ratio = config.downsample_ratio
|
52
|
+
self.ps_version = getattr(config, "ps_version", "v1")
|
53
|
+
# self.template = getattr(config, 'template', 'internvl2_5')
|
54
|
+
|
55
|
+
config.vision_config.use_flash_attn = True if use_flash_attn else False
|
56
|
+
config.text_config._attn_implementation = (
|
57
|
+
"flash_attention_2" if use_flash_attn else "eager"
|
58
|
+
)
|
59
|
+
|
60
|
+
logger.info(f"num_image_token: {self.num_image_token}")
|
61
|
+
logger.info(f"ps_version: {self.ps_version}")
|
62
|
+
|
63
|
+
self.vision_model = InternVisionModel(config.vision_config)
|
64
|
+
if config.text_config.architectures[0] == "Qwen2ForCausalLM":
|
65
|
+
self.language_model = Qwen2ForCausalLM(
|
66
|
+
config=config.text_config, quant_config=quant_config
|
67
|
+
)
|
68
|
+
elif config.text_config.architectures[0] == "Qwen3MoeForCausalLM":
|
69
|
+
self.language_model = Qwen3MoeForCausalLM(
|
70
|
+
config=config.text_config, quant_config=quant_config
|
71
|
+
)
|
72
|
+
else:
|
73
|
+
raise NotImplementedError(
|
74
|
+
f"{config.text_config.architectures[0]} is not implemented."
|
75
|
+
)
|
76
|
+
|
77
|
+
vit_hidden_size = config.vision_config.hidden_size
|
78
|
+
llm_hidden_size = config.text_config.hidden_size
|
79
|
+
|
80
|
+
self.mlp1 = nn.Sequential(
|
81
|
+
nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
|
82
|
+
nn.Linear(
|
83
|
+
vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size
|
84
|
+
),
|
85
|
+
nn.GELU(),
|
86
|
+
nn.Linear(llm_hidden_size, llm_hidden_size),
|
87
|
+
)
|
88
|
+
|
89
|
+
def _update_hf_config(self):
|
90
|
+
"""update hf config to support tp"""
|
91
|
+
world_size = parallel_state.get_tensor_model_parallel_world_size()
|
92
|
+
num_heads = self.config.vision_config.num_attention_heads
|
93
|
+
head_dim = self.config.vision_config.hidden_size // num_heads
|
94
|
+
num_dummy_heads = 0
|
95
|
+
|
96
|
+
if num_heads % world_size != 0:
|
97
|
+
num_dummy_heads = (
|
98
|
+
(num_heads + world_size) // world_size
|
99
|
+
) * world_size - num_heads
|
100
|
+
|
101
|
+
setattr(self.config.vision_config, "head_dim", head_dim)
|
102
|
+
setattr(self.config.vision_config, "num_dummy_heads", num_dummy_heads)
|
103
|
+
|
104
|
+
def pixel_shuffle(self, x, scale_factor=0.5):
|
105
|
+
n, w, h, c = x.size()
|
106
|
+
# N, W, H, C --> N, W, H * scale, C // scale
|
107
|
+
x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
|
108
|
+
# N, W, H * scale, C // scale --> N, H * scale, W, C // scale
|
109
|
+
x = x.permute(0, 2, 1, 3).contiguous()
|
110
|
+
# N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
|
111
|
+
x = x.view(
|
112
|
+
n,
|
113
|
+
int(h * scale_factor),
|
114
|
+
int(w * scale_factor),
|
115
|
+
int(c / (scale_factor * scale_factor)),
|
116
|
+
)
|
117
|
+
if self.ps_version == "v1":
|
118
|
+
logger.warn(
|
119
|
+
"In ps_version 'v1', the height and width have not been swapped back, "
|
120
|
+
"which results in a transposed image."
|
121
|
+
)
|
122
|
+
else:
|
123
|
+
x = x.permute(0, 2, 1, 3).contiguous()
|
124
|
+
return x
|
125
|
+
|
126
|
+
def extract_feature(self, pixel_values):
|
127
|
+
if self.select_layer == -1:
|
128
|
+
vit_embeds = self.vision_model(
|
129
|
+
pixel_values=pixel_values, output_hidden_states=False, return_dict=True
|
130
|
+
).last_hidden_state
|
131
|
+
else:
|
132
|
+
vit_embeds = self.vision_model(
|
133
|
+
pixel_values=pixel_values, output_hidden_states=True, return_dict=True
|
134
|
+
).hidden_states[self.select_layer]
|
135
|
+
vit_embeds = vit_embeds[:, 1:, :]
|
136
|
+
|
137
|
+
h = w = int(vit_embeds.shape[1] ** 0.5)
|
138
|
+
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
|
139
|
+
vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
|
140
|
+
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
|
141
|
+
vit_embeds = self.mlp1(vit_embeds)
|
142
|
+
return vit_embeds
|
143
|
+
|
144
|
+
def get_image_feature(self, items: List[MultimodalDataItem]):
|
145
|
+
"""
|
146
|
+
Projects the last hidden state from the vision model into language model space.
|
147
|
+
|
148
|
+
Returns:
|
149
|
+
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
150
|
+
"""
|
151
|
+
pixel_values = torch.cat([item.feature for item in items])
|
152
|
+
image_features = self.extract_feature(pixel_values)
|
153
|
+
return image_features
|
154
|
+
|
155
|
+
@torch.no_grad()
|
156
|
+
def forward(
|
157
|
+
self,
|
158
|
+
input_ids: torch.Tensor,
|
159
|
+
positions: torch.Tensor,
|
160
|
+
forward_batch: ForwardBatch,
|
161
|
+
input_embeds: torch.Tensor = None,
|
162
|
+
) -> torch.Tensor:
|
163
|
+
|
164
|
+
hs = general_mm_embed_routine(
|
165
|
+
input_ids=input_ids,
|
166
|
+
forward_batch=forward_batch,
|
167
|
+
language_model=self.language_model,
|
168
|
+
data_embedding_funcs={
|
169
|
+
Modality.IMAGE: self.get_image_feature,
|
170
|
+
},
|
171
|
+
positions=positions,
|
172
|
+
)
|
173
|
+
|
174
|
+
return hs
|
175
|
+
|
176
|
+
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
177
|
+
# Get all special token IDs
|
178
|
+
im_start_id: int = mm_inputs.im_start_id
|
179
|
+
im_end_id: int = mm_inputs.im_end_id
|
180
|
+
|
181
|
+
media_token_pairs = [(im_start_id, im_end_id)]
|
182
|
+
helper = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
|
183
|
+
|
184
|
+
return helper.pad_input_tokens(input_ids, mm_inputs)
|
185
|
+
|
186
|
+
def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor):
|
187
|
+
"""pad attn qkv weights for dummy heads"""
|
188
|
+
num_dummy_heads = self.config.vision_config.num_dummy_heads
|
189
|
+
if num_dummy_heads == 0:
|
190
|
+
return loaded_weight
|
191
|
+
head_dim = self.config.vision_config.head_dim
|
192
|
+
|
193
|
+
if any([_ in name for _ in ["attn.q_proj", "attn.k_proj", "attn.v_proj"]]):
|
194
|
+
if name.endswith(".weight"):
|
195
|
+
dummy_shape = [num_dummy_heads, head_dim, loaded_weight.shape[-1]]
|
196
|
+
elif name.endswith(".bias"):
|
197
|
+
dummy_shape = [num_dummy_heads, head_dim]
|
198
|
+
else:
|
199
|
+
raise RuntimeError(f"Unsupported weight with name={name}")
|
200
|
+
padded_weight = loaded_weight.new_zeros(dummy_shape)
|
201
|
+
loaded_weight = torch.cat(
|
202
|
+
[loaded_weight.unflatten(0, (-1, head_dim)), padded_weight], dim=0
|
203
|
+
).flatten(0, 1)
|
204
|
+
if "attn.proj.weight" in name:
|
205
|
+
padded_weight = loaded_weight.new_zeros(
|
206
|
+
loaded_weight.shape[0], head_dim * num_dummy_heads
|
207
|
+
)
|
208
|
+
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=-1)
|
209
|
+
if "attn.q_norm.weight" in name or "attn.k_norm.weight" in name:
|
210
|
+
padded_weight = loaded_weight.new_zeros(head_dim * num_dummy_heads)
|
211
|
+
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=0)
|
212
|
+
return loaded_weight
|
213
|
+
|
214
|
+
def _mapping_interns1_name(self, name):
|
215
|
+
names_map = {
|
216
|
+
"lm_head.weight": "language_model.lm_head.weight",
|
217
|
+
"model.multi_modal_projector.layer_norm.bias": "mlp1.0.bias",
|
218
|
+
"model.multi_modal_projector.layer_norm.weight": "mlp1.0.weight",
|
219
|
+
"model.multi_modal_projector.linear_1.bias": "mlp1.1.bias",
|
220
|
+
"model.multi_modal_projector.linear_1.weight": "mlp1.1.weight",
|
221
|
+
"model.multi_modal_projector.linear_2.bias": "mlp1.3.bias",
|
222
|
+
"model.multi_modal_projector.linear_2.weight": "mlp1.3.weight",
|
223
|
+
"model.vision_tower.embeddings.cls_token": "vision_model.embeddings.class_embedding",
|
224
|
+
"model.vision_tower.embeddings.patch_embeddings.projection.bias": "vision_model.embeddings.patch_embedding.bias",
|
225
|
+
"model.vision_tower.embeddings.patch_embeddings.projection.weight": "vision_model.embeddings.patch_embedding.weight",
|
226
|
+
"model.vision_tower.embeddings.position_embeddings": "vision_model.embeddings.position_embedding",
|
227
|
+
}
|
228
|
+
if name in names_map:
|
229
|
+
name = names_map[name]
|
230
|
+
elif name.startswith("model.language_model."):
|
231
|
+
name = "language_model.model." + name[len("model.language_model.") :]
|
232
|
+
elif name.startswith("model.vision_tower."):
|
233
|
+
name = "vision_model." + name[len("model.vision_tower.") :]
|
234
|
+
|
235
|
+
if name.startswith("vision_model.encoder.layer"):
|
236
|
+
|
237
|
+
name = name.replace(r".layer.", r".layers.")
|
238
|
+
name = name.replace(r".attention.", r".attn.attn.")
|
239
|
+
name = name.replace(r".projection_layer.", r".proj.")
|
240
|
+
name = name.replace(r".lambda_1", r".ls1")
|
241
|
+
name = name.replace(r".lambda_2", r".ls2")
|
242
|
+
name = name.replace(r".layernorm_before.", r".norm1.")
|
243
|
+
name = name.replace(r".layernorm_after.", r".norm2.")
|
244
|
+
return name
|
245
|
+
|
246
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
247
|
+
stacked_params_mapping = [
|
248
|
+
# (param_name, shard_name, shard_id)
|
249
|
+
("qkv_proj", "q_proj", "q"),
|
250
|
+
("qkv_proj", "k_proj", "k"),
|
251
|
+
("qkv_proj", "v_proj", "v"),
|
252
|
+
("gate_up_proj", "gate_proj", 0),
|
253
|
+
("gate_up_proj", "up_proj", 1),
|
254
|
+
]
|
255
|
+
expert_params_mapping = []
|
256
|
+
if "Qwen3MoeForCausalLM" in self.config.text_config.architectures:
|
257
|
+
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
|
258
|
+
ckpt_gate_proj_name="gate_proj",
|
259
|
+
ckpt_down_proj_name="down_proj",
|
260
|
+
ckpt_up_proj_name="up_proj",
|
261
|
+
num_experts=self.config.num_experts,
|
262
|
+
)
|
263
|
+
|
264
|
+
params_dict = dict(self.named_parameters())
|
265
|
+
loaded_params: Set[str] = set()
|
266
|
+
|
267
|
+
for name, loaded_weight in weights:
|
268
|
+
if "rotary_emb.inv_freq" in name:
|
269
|
+
continue
|
270
|
+
name = self._mapping_interns1_name(name)
|
271
|
+
if "vision_model" in name:
|
272
|
+
loaded_weight = self._pad_vit_attn_dummy_heads(name, loaded_weight)
|
273
|
+
|
274
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
275
|
+
if weight_name not in name:
|
276
|
+
continue
|
277
|
+
# We have mlp.experts[0].gate_proj in the checkpoint.
|
278
|
+
# Since we handle the experts below in expert_params_mapping,
|
279
|
+
# we need to skip here BEFORE we update the name, otherwise
|
280
|
+
# name will be updated to mlp.experts[0].gate_up_proj, which
|
281
|
+
# will then be updated below in expert_params_mapping
|
282
|
+
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
283
|
+
if "mlp.experts" in name:
|
284
|
+
continue
|
285
|
+
name = name.replace(weight_name, param_name)
|
286
|
+
# Skip loading extra bias for GPTQ models.
|
287
|
+
if name.endswith(".bias") and name not in params_dict:
|
288
|
+
continue
|
289
|
+
param = params_dict[name]
|
290
|
+
weight_loader = param.weight_loader
|
291
|
+
weight_loader(param, loaded_weight, shard_id)
|
292
|
+
break
|
293
|
+
else:
|
294
|
+
for mapping in expert_params_mapping:
|
295
|
+
param_name, weight_name, expert_id, shard_id = mapping
|
296
|
+
if weight_name not in name:
|
297
|
+
continue
|
298
|
+
name = name.replace(weight_name, param_name)
|
299
|
+
param = params_dict[name]
|
300
|
+
weight_loader = param.weight_loader
|
301
|
+
weight_loader(
|
302
|
+
param,
|
303
|
+
loaded_weight,
|
304
|
+
name,
|
305
|
+
shard_id=shard_id,
|
306
|
+
expert_id=expert_id,
|
307
|
+
)
|
308
|
+
break
|
309
|
+
else:
|
310
|
+
# Skip loading extra bias for GPTQ models.
|
311
|
+
if name.endswith(".bias") and name not in params_dict:
|
312
|
+
continue
|
313
|
+
param = params_dict[name]
|
314
|
+
weight_loader = getattr(
|
315
|
+
param, "weight_loader", default_weight_loader
|
316
|
+
)
|
317
|
+
weight_loader(param, loaded_weight)
|
318
|
+
|
319
|
+
loaded_params.add(name)
|
320
|
+
unloaded_params = params_dict.keys() - loaded_params
|
321
|
+
if unloaded_params:
|
322
|
+
raise RuntimeError(
|
323
|
+
f"Some weights are not initialized from checkpoints: {unloaded_params}"
|
324
|
+
)
|
325
|
+
return loaded_params
|
326
|
+
|
327
|
+
|
328
|
+
EntryClass = [InternS1ForConditionalGeneration]
|