sglang 0.4.9.post4__py3-none-any.whl → 0.4.9.post6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. sglang/lang/chat_template.py +21 -0
  2. sglang/srt/configs/internvl.py +3 -0
  3. sglang/srt/configs/model_config.py +7 -0
  4. sglang/srt/constrained/base_grammar_backend.py +10 -2
  5. sglang/srt/constrained/xgrammar_backend.py +7 -5
  6. sglang/srt/conversation.py +16 -1
  7. sglang/srt/debug_utils/__init__.py +0 -0
  8. sglang/srt/debug_utils/dump_comparator.py +131 -0
  9. sglang/srt/debug_utils/dumper.py +108 -0
  10. sglang/srt/debug_utils/text_comparator.py +172 -0
  11. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
  12. sglang/srt/disaggregation/mooncake/conn.py +16 -0
  13. sglang/srt/disaggregation/prefill.py +13 -1
  14. sglang/srt/entrypoints/engine.py +4 -2
  15. sglang/srt/entrypoints/http_server.py +13 -1
  16. sglang/srt/entrypoints/openai/protocol.py +3 -1
  17. sglang/srt/entrypoints/openai/serving_base.py +5 -2
  18. sglang/srt/entrypoints/openai/serving_chat.py +132 -79
  19. sglang/srt/function_call/ebnf_composer.py +10 -3
  20. sglang/srt/function_call/function_call_parser.py +2 -0
  21. sglang/srt/function_call/glm4_moe_detector.py +164 -0
  22. sglang/srt/function_call/qwen3_coder_detector.py +1 -0
  23. sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
  24. sglang/srt/layers/attention/vision.py +56 -8
  25. sglang/srt/layers/layernorm.py +26 -1
  26. sglang/srt/layers/logits_processor.py +14 -3
  27. sglang/srt/layers/moe/ep_moe/layer.py +323 -242
  28. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +83 -118
  29. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
  33. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
  34. sglang/srt/layers/moe/token_dispatcher/__init__.py +0 -0
  35. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +48 -0
  36. sglang/srt/layers/moe/token_dispatcher/standard.py +19 -0
  37. sglang/srt/layers/moe/topk.py +90 -24
  38. sglang/srt/layers/multimodal.py +11 -8
  39. sglang/srt/layers/quantization/fp8.py +25 -247
  40. sglang/srt/layers/quantization/fp8_kernel.py +78 -48
  41. sglang/srt/layers/quantization/modelopt_quant.py +27 -10
  42. sglang/srt/layers/quantization/unquant.py +24 -76
  43. sglang/srt/layers/quantization/w4afp8.py +68 -17
  44. sglang/srt/lora/lora_registry.py +93 -29
  45. sglang/srt/managers/cache_controller.py +9 -7
  46. sglang/srt/managers/data_parallel_controller.py +4 -0
  47. sglang/srt/managers/io_struct.py +12 -0
  48. sglang/srt/managers/mm_utils.py +154 -35
  49. sglang/srt/managers/multimodal_processor.py +3 -14
  50. sglang/srt/managers/schedule_batch.py +14 -8
  51. sglang/srt/managers/scheduler.py +64 -1
  52. sglang/srt/managers/scheduler_input_blocker.py +106 -0
  53. sglang/srt/managers/tokenizer_manager.py +80 -15
  54. sglang/srt/managers/tp_worker.py +8 -0
  55. sglang/srt/mem_cache/hiradix_cache.py +5 -2
  56. sglang/srt/model_executor/model_runner.py +83 -27
  57. sglang/srt/models/deepseek_v2.py +75 -84
  58. sglang/srt/models/glm4_moe.py +1035 -0
  59. sglang/srt/models/glm4_moe_nextn.py +167 -0
  60. sglang/srt/models/interns1.py +328 -0
  61. sglang/srt/models/internvl.py +143 -47
  62. sglang/srt/models/llava.py +9 -5
  63. sglang/srt/models/minicpmo.py +4 -1
  64. sglang/srt/models/qwen2_moe.py +2 -2
  65. sglang/srt/models/qwen3_moe.py +17 -71
  66. sglang/srt/multimodal/processors/base_processor.py +20 -6
  67. sglang/srt/multimodal/processors/clip.py +2 -2
  68. sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
  69. sglang/srt/multimodal/processors/gemma3.py +2 -2
  70. sglang/srt/multimodal/processors/gemma3n.py +2 -2
  71. sglang/srt/multimodal/processors/internvl.py +21 -8
  72. sglang/srt/multimodal/processors/janus_pro.py +2 -2
  73. sglang/srt/multimodal/processors/kimi_vl.py +2 -2
  74. sglang/srt/multimodal/processors/llava.py +4 -4
  75. sglang/srt/multimodal/processors/minicpm.py +2 -3
  76. sglang/srt/multimodal/processors/mlama.py +2 -2
  77. sglang/srt/multimodal/processors/mllama4.py +18 -111
  78. sglang/srt/multimodal/processors/phi4mm.py +2 -2
  79. sglang/srt/multimodal/processors/pixtral.py +2 -2
  80. sglang/srt/multimodal/processors/qwen_audio.py +2 -2
  81. sglang/srt/multimodal/processors/qwen_vl.py +2 -2
  82. sglang/srt/multimodal/processors/vila.py +3 -1
  83. sglang/srt/poll_based_barrier.py +31 -0
  84. sglang/srt/reasoning_parser.py +2 -1
  85. sglang/srt/server_args.py +65 -6
  86. sglang/srt/two_batch_overlap.py +8 -3
  87. sglang/srt/utils.py +96 -1
  88. sglang/srt/weight_sync/utils.py +119 -0
  89. sglang/test/runners.py +4 -0
  90. sglang/test/test_utils.py +118 -5
  91. sglang/utils.py +19 -0
  92. sglang/version.py +1 -1
  93. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/METADATA +5 -4
  94. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/RECORD +97 -80
  95. sglang/srt/debug_utils.py +0 -74
  96. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/WHEEL +0 -0
  97. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/licenses/LICENSE +0 -0
  98. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,167 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
15
+ """Inference-only GLM-4.5 NextN Speculative Decoding."""
16
+ import logging
17
+ from typing import Iterable, Optional, Tuple
18
+
19
+ import torch
20
+ from torch import nn
21
+ from transformers import PretrainedConfig
22
+
23
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
24
+ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
25
+ from sglang.srt.layers.layernorm import RMSNorm
26
+ from sglang.srt.layers.logits_processor import LogitsProcessor
27
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
28
+ from sglang.srt.layers.vocab_parallel_embedding import (
29
+ ParallelLMHead,
30
+ VocabParallelEmbedding,
31
+ )
32
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
33
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
34
+ from sglang.srt.models.glm4_moe import Glm4MoeDecoderLayer, Glm4MoeForCausalLM
35
+ from sglang.srt.utils import BumpAllocator, add_prefix
36
+
37
+ logger = logging.getLogger(__name__)
38
+
39
+
40
+ class Glm4MoeModelNextN(nn.Module):
41
+ def __init__(
42
+ self,
43
+ config: PretrainedConfig,
44
+ quant_config: Optional[QuantizationConfig] = None,
45
+ prefix: str = "",
46
+ ) -> None:
47
+ super().__init__()
48
+ if quant_config is not None and quant_config.get_name() == "modelopt_fp4":
49
+ logger.warning(
50
+ "Overriding Glm4MoeForCausalLMNextN quant config for modelopt_fp4 GLM-4.5 model."
51
+ )
52
+ quant_config = None
53
+
54
+ self.vocab_size = config.vocab_size
55
+
56
+ self.embed_tokens = VocabParallelEmbedding(
57
+ config.vocab_size,
58
+ config.hidden_size,
59
+ enable_tp=not global_server_args_dict["enable_dp_attention"],
60
+ prefix=add_prefix("embed_tokens", prefix),
61
+ )
62
+
63
+ self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
64
+ self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
65
+
66
+ self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False)
67
+
68
+ self.decoder = Glm4MoeDecoderLayer(
69
+ config,
70
+ 0,
71
+ quant_config=quant_config,
72
+ is_nextn=True,
73
+ prefix=add_prefix("decoder", prefix),
74
+ )
75
+
76
+ self.shared_head = nn.Module()
77
+ self.shared_head.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
78
+
79
+ def forward(
80
+ self,
81
+ input_ids: torch.Tensor,
82
+ positions: torch.Tensor,
83
+ forward_batch: ForwardBatch,
84
+ input_embeds: torch.Tensor = None,
85
+ ) -> torch.Tensor:
86
+ zero_allocator = BumpAllocator(
87
+ buffer_size=2,
88
+ dtype=torch.float32,
89
+ device=(
90
+ input_embeds.device if input_embeds is not None else input_ids.device
91
+ ),
92
+ )
93
+
94
+ if input_embeds is None:
95
+ hidden_states = self.embed_tokens(input_ids)
96
+ else:
97
+ hidden_states = input_embeds
98
+
99
+ if hidden_states.shape[0] > 0:
100
+ hidden_states = self.eh_proj(
101
+ torch.cat(
102
+ (
103
+ self.enorm(hidden_states),
104
+ self.hnorm(forward_batch.spec_info.hidden_states),
105
+ ),
106
+ dim=-1,
107
+ )
108
+ )
109
+
110
+ residual = None
111
+ with get_global_expert_distribution_recorder().disable_this_region():
112
+ hidden_states, residual = self.decoder(
113
+ positions, hidden_states, forward_batch, residual, zero_allocator
114
+ )
115
+
116
+ if not forward_batch.forward_mode.is_idle():
117
+ if residual is not None:
118
+ hidden_states, _ = self.shared_head.norm(hidden_states, residual)
119
+ else:
120
+ hidden_states = self.shared_head.norm(hidden_states)
121
+
122
+ return hidden_states
123
+
124
+
125
+ class Glm4MoeForCausalLMNextN(Glm4MoeForCausalLM):
126
+
127
+ def __init__(
128
+ self,
129
+ config: PretrainedConfig,
130
+ quant_config: Optional[QuantizationConfig] = None,
131
+ prefix: str = "",
132
+ ) -> None:
133
+ nn.Module.__init__(self)
134
+ self.config = config
135
+ self.tp_size = get_tensor_model_parallel_world_size()
136
+ self.quant_config = quant_config
137
+ self.determine_num_fused_shared_experts("Glm4MoeForCausalLMNextN")
138
+
139
+ self.model = Glm4MoeModelNextN(
140
+ config, quant_config, prefix=add_prefix("model", prefix)
141
+ )
142
+ self.lm_head = ParallelLMHead(
143
+ config.vocab_size,
144
+ config.hidden_size,
145
+ quant_config=quant_config,
146
+ prefix=add_prefix("model.shared_head.head", prefix),
147
+ use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
148
+ )
149
+ self.logits_processor = LogitsProcessor(config)
150
+
151
+ @torch.no_grad()
152
+ def forward(
153
+ self,
154
+ input_ids: torch.Tensor,
155
+ positions: torch.Tensor,
156
+ forward_batch: ForwardBatch,
157
+ ) -> torch.Tensor:
158
+ hidden_states = self.model(input_ids, positions, forward_batch)
159
+ return self.logits_processor(
160
+ input_ids, hidden_states, self.lm_head, forward_batch
161
+ )
162
+
163
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
164
+ super().load_weights(weights, is_nextn=True)
165
+
166
+
167
+ EntryClass = [Glm4MoeForCausalLMNextN]
@@ -0,0 +1,328 @@
1
+ from typing import Iterable, List, Optional, Set, Tuple
2
+
3
+ import torch
4
+ from torch import nn
5
+ from transformers import PretrainedConfig
6
+
7
+ from sglang.srt.distributed import parallel_state
8
+ from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
9
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
10
+ from sglang.srt.managers.mm_utils import (
11
+ MultiModalityDataPaddingPatternTokenPairs,
12
+ general_mm_embed_routine,
13
+ )
14
+ from sglang.srt.managers.schedule_batch import (
15
+ Modality,
16
+ MultimodalDataItem,
17
+ MultimodalInputs,
18
+ )
19
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
20
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
21
+ from sglang.srt.models.internvl import InternVisionModel
22
+ from sglang.srt.models.qwen2 import Qwen2ForCausalLM
23
+ from sglang.srt.models.qwen3_moe import Qwen3MoeForCausalLM
24
+ from sglang.utils import logger
25
+
26
+
27
+ class InternS1ForConditionalGeneration(nn.Module):
28
+ def __init__(
29
+ self,
30
+ config: PretrainedConfig,
31
+ quant_config: Optional[QuantizationConfig] = None,
32
+ use_flash_attn=True,
33
+ ) -> None:
34
+ super().__init__()
35
+ self.config = config
36
+ self.quant_config = quant_config
37
+ self._update_hf_config()
38
+ image_size = (
39
+ getattr(config, "force_image_size", None) or config.vision_config.image_size
40
+ )
41
+ patch_size = config.vision_config.patch_size
42
+ if isinstance(image_size, list):
43
+ image_size = image_size[0]
44
+ if isinstance(patch_size, list):
45
+ patch_size = patch_size[0]
46
+ self.patch_size = patch_size
47
+ self.select_layer = config.vision_feature_layer
48
+ self.num_image_token = int(
49
+ (image_size // patch_size) ** 2 * (config.downsample_ratio**2)
50
+ )
51
+ self.downsample_ratio = config.downsample_ratio
52
+ self.ps_version = getattr(config, "ps_version", "v1")
53
+ # self.template = getattr(config, 'template', 'internvl2_5')
54
+
55
+ config.vision_config.use_flash_attn = True if use_flash_attn else False
56
+ config.text_config._attn_implementation = (
57
+ "flash_attention_2" if use_flash_attn else "eager"
58
+ )
59
+
60
+ logger.info(f"num_image_token: {self.num_image_token}")
61
+ logger.info(f"ps_version: {self.ps_version}")
62
+
63
+ self.vision_model = InternVisionModel(config.vision_config)
64
+ if config.text_config.architectures[0] == "Qwen2ForCausalLM":
65
+ self.language_model = Qwen2ForCausalLM(
66
+ config=config.text_config, quant_config=quant_config
67
+ )
68
+ elif config.text_config.architectures[0] == "Qwen3MoeForCausalLM":
69
+ self.language_model = Qwen3MoeForCausalLM(
70
+ config=config.text_config, quant_config=quant_config
71
+ )
72
+ else:
73
+ raise NotImplementedError(
74
+ f"{config.text_config.architectures[0]} is not implemented."
75
+ )
76
+
77
+ vit_hidden_size = config.vision_config.hidden_size
78
+ llm_hidden_size = config.text_config.hidden_size
79
+
80
+ self.mlp1 = nn.Sequential(
81
+ nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
82
+ nn.Linear(
83
+ vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size
84
+ ),
85
+ nn.GELU(),
86
+ nn.Linear(llm_hidden_size, llm_hidden_size),
87
+ )
88
+
89
+ def _update_hf_config(self):
90
+ """update hf config to support tp"""
91
+ world_size = parallel_state.get_tensor_model_parallel_world_size()
92
+ num_heads = self.config.vision_config.num_attention_heads
93
+ head_dim = self.config.vision_config.hidden_size // num_heads
94
+ num_dummy_heads = 0
95
+
96
+ if num_heads % world_size != 0:
97
+ num_dummy_heads = (
98
+ (num_heads + world_size) // world_size
99
+ ) * world_size - num_heads
100
+
101
+ setattr(self.config.vision_config, "head_dim", head_dim)
102
+ setattr(self.config.vision_config, "num_dummy_heads", num_dummy_heads)
103
+
104
+ def pixel_shuffle(self, x, scale_factor=0.5):
105
+ n, w, h, c = x.size()
106
+ # N, W, H, C --> N, W, H * scale, C // scale
107
+ x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
108
+ # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
109
+ x = x.permute(0, 2, 1, 3).contiguous()
110
+ # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
111
+ x = x.view(
112
+ n,
113
+ int(h * scale_factor),
114
+ int(w * scale_factor),
115
+ int(c / (scale_factor * scale_factor)),
116
+ )
117
+ if self.ps_version == "v1":
118
+ logger.warn(
119
+ "In ps_version 'v1', the height and width have not been swapped back, "
120
+ "which results in a transposed image."
121
+ )
122
+ else:
123
+ x = x.permute(0, 2, 1, 3).contiguous()
124
+ return x
125
+
126
+ def extract_feature(self, pixel_values):
127
+ if self.select_layer == -1:
128
+ vit_embeds = self.vision_model(
129
+ pixel_values=pixel_values, output_hidden_states=False, return_dict=True
130
+ ).last_hidden_state
131
+ else:
132
+ vit_embeds = self.vision_model(
133
+ pixel_values=pixel_values, output_hidden_states=True, return_dict=True
134
+ ).hidden_states[self.select_layer]
135
+ vit_embeds = vit_embeds[:, 1:, :]
136
+
137
+ h = w = int(vit_embeds.shape[1] ** 0.5)
138
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
139
+ vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
140
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
141
+ vit_embeds = self.mlp1(vit_embeds)
142
+ return vit_embeds
143
+
144
+ def get_image_feature(self, items: List[MultimodalDataItem]):
145
+ """
146
+ Projects the last hidden state from the vision model into language model space.
147
+
148
+ Returns:
149
+ image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
150
+ """
151
+ pixel_values = torch.cat([item.feature for item in items])
152
+ image_features = self.extract_feature(pixel_values)
153
+ return image_features
154
+
155
+ @torch.no_grad()
156
+ def forward(
157
+ self,
158
+ input_ids: torch.Tensor,
159
+ positions: torch.Tensor,
160
+ forward_batch: ForwardBatch,
161
+ input_embeds: torch.Tensor = None,
162
+ ) -> torch.Tensor:
163
+
164
+ hs = general_mm_embed_routine(
165
+ input_ids=input_ids,
166
+ forward_batch=forward_batch,
167
+ language_model=self.language_model,
168
+ data_embedding_funcs={
169
+ Modality.IMAGE: self.get_image_feature,
170
+ },
171
+ positions=positions,
172
+ )
173
+
174
+ return hs
175
+
176
+ def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
177
+ # Get all special token IDs
178
+ im_start_id: int = mm_inputs.im_start_id
179
+ im_end_id: int = mm_inputs.im_end_id
180
+
181
+ media_token_pairs = [(im_start_id, im_end_id)]
182
+ helper = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
183
+
184
+ return helper.pad_input_tokens(input_ids, mm_inputs)
185
+
186
+ def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor):
187
+ """pad attn qkv weights for dummy heads"""
188
+ num_dummy_heads = self.config.vision_config.num_dummy_heads
189
+ if num_dummy_heads == 0:
190
+ return loaded_weight
191
+ head_dim = self.config.vision_config.head_dim
192
+
193
+ if any([_ in name for _ in ["attn.q_proj", "attn.k_proj", "attn.v_proj"]]):
194
+ if name.endswith(".weight"):
195
+ dummy_shape = [num_dummy_heads, head_dim, loaded_weight.shape[-1]]
196
+ elif name.endswith(".bias"):
197
+ dummy_shape = [num_dummy_heads, head_dim]
198
+ else:
199
+ raise RuntimeError(f"Unsupported weight with name={name}")
200
+ padded_weight = loaded_weight.new_zeros(dummy_shape)
201
+ loaded_weight = torch.cat(
202
+ [loaded_weight.unflatten(0, (-1, head_dim)), padded_weight], dim=0
203
+ ).flatten(0, 1)
204
+ if "attn.proj.weight" in name:
205
+ padded_weight = loaded_weight.new_zeros(
206
+ loaded_weight.shape[0], head_dim * num_dummy_heads
207
+ )
208
+ loaded_weight = torch.cat([loaded_weight, padded_weight], dim=-1)
209
+ if "attn.q_norm.weight" in name or "attn.k_norm.weight" in name:
210
+ padded_weight = loaded_weight.new_zeros(head_dim * num_dummy_heads)
211
+ loaded_weight = torch.cat([loaded_weight, padded_weight], dim=0)
212
+ return loaded_weight
213
+
214
+ def _mapping_interns1_name(self, name):
215
+ names_map = {
216
+ "lm_head.weight": "language_model.lm_head.weight",
217
+ "model.multi_modal_projector.layer_norm.bias": "mlp1.0.bias",
218
+ "model.multi_modal_projector.layer_norm.weight": "mlp1.0.weight",
219
+ "model.multi_modal_projector.linear_1.bias": "mlp1.1.bias",
220
+ "model.multi_modal_projector.linear_1.weight": "mlp1.1.weight",
221
+ "model.multi_modal_projector.linear_2.bias": "mlp1.3.bias",
222
+ "model.multi_modal_projector.linear_2.weight": "mlp1.3.weight",
223
+ "model.vision_tower.embeddings.cls_token": "vision_model.embeddings.class_embedding",
224
+ "model.vision_tower.embeddings.patch_embeddings.projection.bias": "vision_model.embeddings.patch_embedding.bias",
225
+ "model.vision_tower.embeddings.patch_embeddings.projection.weight": "vision_model.embeddings.patch_embedding.weight",
226
+ "model.vision_tower.embeddings.position_embeddings": "vision_model.embeddings.position_embedding",
227
+ }
228
+ if name in names_map:
229
+ name = names_map[name]
230
+ elif name.startswith("model.language_model."):
231
+ name = "language_model.model." + name[len("model.language_model.") :]
232
+ elif name.startswith("model.vision_tower."):
233
+ name = "vision_model." + name[len("model.vision_tower.") :]
234
+
235
+ if name.startswith("vision_model.encoder.layer"):
236
+
237
+ name = name.replace(r".layer.", r".layers.")
238
+ name = name.replace(r".attention.", r".attn.attn.")
239
+ name = name.replace(r".projection_layer.", r".proj.")
240
+ name = name.replace(r".lambda_1", r".ls1")
241
+ name = name.replace(r".lambda_2", r".ls2")
242
+ name = name.replace(r".layernorm_before.", r".norm1.")
243
+ name = name.replace(r".layernorm_after.", r".norm2.")
244
+ return name
245
+
246
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
247
+ stacked_params_mapping = [
248
+ # (param_name, shard_name, shard_id)
249
+ ("qkv_proj", "q_proj", "q"),
250
+ ("qkv_proj", "k_proj", "k"),
251
+ ("qkv_proj", "v_proj", "v"),
252
+ ("gate_up_proj", "gate_proj", 0),
253
+ ("gate_up_proj", "up_proj", 1),
254
+ ]
255
+ expert_params_mapping = []
256
+ if "Qwen3MoeForCausalLM" in self.config.text_config.architectures:
257
+ expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
258
+ ckpt_gate_proj_name="gate_proj",
259
+ ckpt_down_proj_name="down_proj",
260
+ ckpt_up_proj_name="up_proj",
261
+ num_experts=self.config.num_experts,
262
+ )
263
+
264
+ params_dict = dict(self.named_parameters())
265
+ loaded_params: Set[str] = set()
266
+
267
+ for name, loaded_weight in weights:
268
+ if "rotary_emb.inv_freq" in name:
269
+ continue
270
+ name = self._mapping_interns1_name(name)
271
+ if "vision_model" in name:
272
+ loaded_weight = self._pad_vit_attn_dummy_heads(name, loaded_weight)
273
+
274
+ for param_name, weight_name, shard_id in stacked_params_mapping:
275
+ if weight_name not in name:
276
+ continue
277
+ # We have mlp.experts[0].gate_proj in the checkpoint.
278
+ # Since we handle the experts below in expert_params_mapping,
279
+ # we need to skip here BEFORE we update the name, otherwise
280
+ # name will be updated to mlp.experts[0].gate_up_proj, which
281
+ # will then be updated below in expert_params_mapping
282
+ # for mlp.experts[0].gate_gate_up_proj, which breaks load.
283
+ if "mlp.experts" in name:
284
+ continue
285
+ name = name.replace(weight_name, param_name)
286
+ # Skip loading extra bias for GPTQ models.
287
+ if name.endswith(".bias") and name not in params_dict:
288
+ continue
289
+ param = params_dict[name]
290
+ weight_loader = param.weight_loader
291
+ weight_loader(param, loaded_weight, shard_id)
292
+ break
293
+ else:
294
+ for mapping in expert_params_mapping:
295
+ param_name, weight_name, expert_id, shard_id = mapping
296
+ if weight_name not in name:
297
+ continue
298
+ name = name.replace(weight_name, param_name)
299
+ param = params_dict[name]
300
+ weight_loader = param.weight_loader
301
+ weight_loader(
302
+ param,
303
+ loaded_weight,
304
+ name,
305
+ shard_id=shard_id,
306
+ expert_id=expert_id,
307
+ )
308
+ break
309
+ else:
310
+ # Skip loading extra bias for GPTQ models.
311
+ if name.endswith(".bias") and name not in params_dict:
312
+ continue
313
+ param = params_dict[name]
314
+ weight_loader = getattr(
315
+ param, "weight_loader", default_weight_loader
316
+ )
317
+ weight_loader(param, loaded_weight)
318
+
319
+ loaded_params.add(name)
320
+ unloaded_params = params_dict.keys() - loaded_params
321
+ if unloaded_params:
322
+ raise RuntimeError(
323
+ f"Some weights are not initialized from checkpoints: {unloaded_params}"
324
+ )
325
+ return loaded_params
326
+
327
+
328
+ EntryClass = [InternS1ForConditionalGeneration]