sglang 0.4.8.post1__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +17 -2
- sglang/bench_serving.py +170 -24
- sglang/srt/configs/internvl.py +4 -2
- sglang/srt/configs/janus_pro.py +1 -1
- sglang/srt/configs/model_config.py +60 -1
- sglang/srt/configs/update_config.py +119 -0
- sglang/srt/conversation.py +69 -1
- sglang/srt/disaggregation/decode.py +21 -5
- sglang/srt/disaggregation/mooncake/conn.py +35 -4
- sglang/srt/disaggregation/nixl/conn.py +6 -6
- sglang/srt/disaggregation/prefill.py +2 -2
- sglang/srt/disaggregation/utils.py +1 -1
- sglang/srt/distributed/parallel_state.py +44 -17
- sglang/srt/entrypoints/EngineBase.py +8 -0
- sglang/srt/entrypoints/engine.py +40 -6
- sglang/srt/entrypoints/http_server.py +111 -24
- sglang/srt/entrypoints/http_server_engine.py +1 -1
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/eplb/__init__.py +0 -0
- sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
- sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
- sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
- sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
- sglang/srt/{managers → eplb}/expert_location.py +1 -1
- sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
- sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/activation.py +2 -2
- sglang/srt/layers/amx_utils.py +86 -0
- sglang/srt/layers/attention/ascend_backend.py +219 -0
- sglang/srt/layers/attention/flashattention_backend.py +32 -9
- sglang/srt/layers/attention/tbo_backend.py +37 -9
- sglang/srt/layers/communicator.py +20 -2
- sglang/srt/layers/dp_attention.py +9 -3
- sglang/srt/layers/elementwise.py +76 -12
- sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
- sglang/srt/layers/layernorm.py +26 -0
- sglang/srt/layers/linear.py +84 -14
- sglang/srt/layers/logits_processor.py +4 -4
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +81 -8
- sglang/srt/layers/moe/ep_moe/layer.py +176 -15
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +211 -74
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
- sglang/srt/layers/moe/router.py +60 -22
- sglang/srt/layers/moe/topk.py +10 -28
- sglang/srt/layers/parameter.py +67 -7
- sglang/srt/layers/quantization/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
- sglang/srt/layers/quantization/fp8.py +72 -7
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +1 -2
- sglang/srt/layers/quantization/gptq.py +5 -1
- sglang/srt/layers/quantization/modelopt_quant.py +244 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -1
- sglang/srt/layers/quantization/quant_utils.py +166 -0
- sglang/srt/layers/quantization/w4afp8.py +264 -0
- sglang/srt/layers/quantization/w8a8_int8.py +52 -1
- sglang/srt/layers/rotary_embedding.py +2 -2
- sglang/srt/layers/vocab_parallel_embedding.py +20 -10
- sglang/srt/lora/lora.py +4 -5
- sglang/srt/lora/lora_manager.py +73 -20
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
- sglang/srt/managers/cache_controller.py +41 -195
- sglang/srt/managers/configure_logging.py +1 -1
- sglang/srt/managers/io_struct.py +58 -14
- sglang/srt/managers/mm_utils.py +77 -61
- sglang/srt/managers/multimodal_processor.py +2 -6
- sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
- sglang/srt/managers/schedule_batch.py +78 -85
- sglang/srt/managers/scheduler.py +130 -64
- sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
- sglang/srt/managers/session_controller.py +12 -3
- sglang/srt/managers/tokenizer_manager.py +314 -103
- sglang/srt/managers/tp_worker.py +13 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
- sglang/srt/mem_cache/allocator.py +290 -0
- sglang/srt/mem_cache/chunk_cache.py +34 -2
- sglang/srt/mem_cache/hiradix_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +402 -66
- sglang/srt/mem_cache/memory_pool_host.py +6 -109
- sglang/srt/mem_cache/multimodal_cache.py +3 -0
- sglang/srt/mem_cache/radix_cache.py +8 -4
- sglang/srt/model_executor/cuda_graph_runner.py +2 -1
- sglang/srt/model_executor/forward_batch_info.py +17 -4
- sglang/srt/model_executor/model_runner.py +297 -56
- sglang/srt/model_loader/loader.py +41 -0
- sglang/srt/model_loader/weight_utils.py +72 -4
- sglang/srt/models/deepseek_nextn.py +1 -3
- sglang/srt/models/deepseek_v2.py +195 -45
- sglang/srt/models/deepseek_vl2.py +3 -5
- sglang/srt/models/gemma3_causal.py +1 -2
- sglang/srt/models/gemma3n_causal.py +4 -3
- sglang/srt/models/gemma3n_mm.py +4 -20
- sglang/srt/models/hunyuan.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -2
- sglang/srt/models/llama.py +10 -4
- sglang/srt/models/llama4.py +32 -45
- sglang/srt/models/llama_eagle3.py +61 -11
- sglang/srt/models/llava.py +5 -5
- sglang/srt/models/minicpmo.py +2 -2
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mllama4.py +402 -89
- sglang/srt/models/phi4mm.py +1 -3
- sglang/srt/models/pixtral.py +3 -7
- sglang/srt/models/qwen2.py +31 -3
- sglang/srt/models/qwen2_5_vl.py +1 -3
- sglang/srt/models/qwen2_audio.py +200 -0
- sglang/srt/models/qwen2_moe.py +32 -6
- sglang/srt/models/qwen2_vl.py +1 -4
- sglang/srt/models/qwen3.py +94 -25
- sglang/srt/models/qwen3_moe.py +68 -21
- sglang/srt/models/vila.py +3 -8
- sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +2 -2
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +65 -66
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
- sglang/srt/operations_strategy.py +6 -2
- sglang/srt/reasoning_parser.py +26 -0
- sglang/srt/sampling/sampling_batch_info.py +39 -1
- sglang/srt/server_args.py +84 -22
- sglang/srt/speculative/build_eagle_tree.py +57 -18
- sglang/srt/speculative/eagle_worker.py +6 -4
- sglang/srt/two_batch_overlap.py +203 -27
- sglang/srt/utils.py +343 -163
- sglang/srt/warmup.py +12 -3
- sglang/test/runners.py +10 -1
- sglang/test/test_cutlass_w4a8_moe.py +281 -0
- sglang/test/test_utils.py +15 -3
- sglang/utils.py +5 -5
- sglang/version.py +1 -1
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +12 -8
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +157 -146
- sglang/math_utils.py +0 -8
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
- /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
sglang/srt/models/mllama4.py
CHANGED
@@ -1,3 +1,6 @@
|
|
1
|
+
import json as json_lib
|
2
|
+
import logging
|
3
|
+
import os
|
1
4
|
from collections.abc import Iterable
|
2
5
|
from typing import List, Optional, Set, Tuple
|
3
6
|
|
@@ -16,8 +19,17 @@ from sglang.srt.managers.mm_utils import (
|
|
16
19
|
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
|
17
20
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
18
21
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
22
|
+
from sglang.srt.utils import add_prefix, is_cpu
|
23
|
+
|
24
|
+
_is_cpu = is_cpu()
|
25
|
+
from sglang.srt.model_loader.weight_utils import (
|
26
|
+
default_weight_loader,
|
27
|
+
maybe_remap_kv_scale_name,
|
28
|
+
)
|
19
29
|
from sglang.srt.utils import add_prefix
|
20
30
|
|
31
|
+
logger = logging.getLogger(__name__)
|
32
|
+
|
21
33
|
|
22
34
|
class Llama4ForConditionalGeneration(nn.Module):
|
23
35
|
packed_modules_mapping = {
|
@@ -35,31 +47,98 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
35
47
|
self.config = config
|
36
48
|
self.quant_config = quant_config
|
37
49
|
|
38
|
-
|
39
|
-
self.
|
50
|
+
# Check if this is a text-only model (modelopt fp8 llama4 has no vision components)
|
51
|
+
self.has_vision = self._has_vision_weights(config)
|
52
|
+
if not self.has_vision:
|
53
|
+
logger.warning(
|
54
|
+
"No vision weights found in checkpoint. Model will run in text-only mode. "
|
55
|
+
"Multimodal capabilities (image processing) will be unavailable."
|
56
|
+
)
|
57
|
+
|
58
|
+
if self.has_vision:
|
59
|
+
self.vision_model = Llama4VisionModel(config.vision_config)
|
60
|
+
self.multi_modal_projector = Llama4MultiModalProjector(config)
|
61
|
+
else:
|
62
|
+
self.vision_model = None
|
63
|
+
self.multi_modal_projector = None
|
40
64
|
|
41
65
|
# Initialize the language model
|
42
66
|
from sglang.srt.models.llama4 import Llama4ForCausalLM
|
43
67
|
|
44
68
|
self.language_model = Llama4ForCausalLM(
|
45
|
-
config.text_config,
|
69
|
+
config.text_config if hasattr(config, "text_config") else config,
|
46
70
|
quant_config=quant_config,
|
47
71
|
prefix=add_prefix("language_model", prefix),
|
48
72
|
)
|
49
73
|
|
50
|
-
self.logits_processor = LogitsProcessor(
|
74
|
+
self.logits_processor = LogitsProcessor(
|
75
|
+
config.text_config if hasattr(config, "text_config") else config
|
76
|
+
)
|
51
77
|
|
52
|
-
def
|
53
|
-
|
54
|
-
|
78
|
+
def _has_vision_weights(self, config) -> bool:
|
79
|
+
"""Check if the model has vision components by examining the checkpoint."""
|
80
|
+
model_path = getattr(config, "_name_or_path", None)
|
81
|
+
if not model_path:
|
82
|
+
return False
|
83
|
+
|
84
|
+
# Check if this is a local path first
|
85
|
+
if os.path.isdir(model_path):
|
86
|
+
index_file = os.path.join(model_path, "model.safetensors.index.json")
|
87
|
+
if os.path.exists(index_file):
|
88
|
+
return self._check_vision_weights_in_index(index_file)
|
89
|
+
|
90
|
+
# For HuggingFace models, we need to check the actual checkpoint
|
91
|
+
# The config might say it's multimodal, but the checkpoint might be text-only
|
92
|
+
try:
|
93
|
+
# Try to access the HuggingFace cache directory
|
94
|
+
from huggingface_hub import try_to_load_from_cache
|
95
|
+
|
96
|
+
# Check if index file exists in cache
|
97
|
+
index_file_path = try_to_load_from_cache(
|
98
|
+
repo_id=model_path,
|
99
|
+
filename="model.safetensors.index.json",
|
100
|
+
cache_dir=None,
|
101
|
+
)
|
102
|
+
|
103
|
+
if index_file_path and os.path.exists(index_file_path):
|
104
|
+
return self._check_vision_weights_in_index(index_file_path)
|
105
|
+
|
106
|
+
except Exception:
|
107
|
+
# If we can't access the cache, fall back to config-based detection
|
108
|
+
pass
|
109
|
+
|
110
|
+
# Fallback, assume text-only
|
111
|
+
return False
|
112
|
+
|
113
|
+
def _check_vision_weights_in_index(self, index_file: str) -> bool:
|
114
|
+
"""Check if the model.safetensors.index.json contains vision weights."""
|
115
|
+
try:
|
116
|
+
with open(index_file, "r") as f:
|
117
|
+
index_data = json_lib.load(f)
|
118
|
+
|
119
|
+
vision_patterns = ["vision_model", "vision_tower", "multi_modal_projector"]
|
120
|
+
weight_names = index_data.get("weight_map", {}).keys()
|
121
|
+
|
122
|
+
return any(
|
123
|
+
pattern in weight_name
|
124
|
+
for weight_name in weight_names
|
125
|
+
for pattern in vision_patterns
|
126
|
+
)
|
127
|
+
except (OSError, json_lib.JSONDecodeError, KeyError):
|
128
|
+
return False
|
55
129
|
|
56
|
-
|
130
|
+
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
131
|
+
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
57
132
|
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
58
133
|
|
59
134
|
def get_image_feature(
|
60
135
|
self,
|
61
136
|
items: List[MultimodalDataItem],
|
62
137
|
) -> torch.Tensor:
|
138
|
+
# For text-only models, return None or raise an error
|
139
|
+
if not self.has_vision or self.vision_model is None:
|
140
|
+
raise ValueError("Vision model not available for text-only checkpoint")
|
141
|
+
|
63
142
|
pixel_values = (
|
64
143
|
torch.concat([item.pixel_values for item in items])
|
65
144
|
.to(next(self.vision_model.parameters()).device)
|
@@ -80,11 +159,14 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
80
159
|
**kwargs: object,
|
81
160
|
) -> torch.Tensor:
|
82
161
|
|
162
|
+
# For text-only models, pass None for image_data_embedding_func
|
163
|
+
image_embedding_func = self.get_image_feature if self.has_vision else None
|
164
|
+
|
83
165
|
hs = general_mm_embed_routine(
|
84
166
|
input_ids=input_ids,
|
85
167
|
forward_batch=forward_batch,
|
86
168
|
language_model=self.language_model,
|
87
|
-
image_data_embedding_func=
|
169
|
+
image_data_embedding_func=image_embedding_func,
|
88
170
|
positions=positions,
|
89
171
|
)
|
90
172
|
|
@@ -110,18 +192,21 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
110
192
|
|
111
193
|
# rotary embeds should be sliced
|
112
194
|
if ("wk" in modules or "k_proj" in modules) and modules[-1] == "weight":
|
113
|
-
|
114
|
-
|
115
|
-
|
195
|
+
if _is_cpu:
|
196
|
+
dim = self.language_model.config.original_total_num_kv_heads
|
197
|
+
else:
|
198
|
+
dim = self.language_model.config.num_key_value_heads
|
199
|
+
loaded_weight = permute(loaded_weight, dim)
|
116
200
|
elif ("wq" in modules or "q_proj" in modules) and modules[-1] == "weight":
|
117
|
-
|
118
|
-
|
119
|
-
|
201
|
+
if _is_cpu:
|
202
|
+
dim = self.language_model.config.original_num_attention_heads
|
203
|
+
else:
|
204
|
+
dim = self.language_model.config.num_attention_heads
|
205
|
+
loaded_weight = permute(loaded_weight, dim)
|
120
206
|
|
121
207
|
return name, loaded_weight
|
122
208
|
|
123
209
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
|
124
|
-
|
125
210
|
stacked_params_mapping = [
|
126
211
|
# (param_name, shard_name, shard_id)
|
127
212
|
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
|
@@ -134,11 +219,12 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
134
219
|
]
|
135
220
|
|
136
221
|
params_dict = dict(self.named_parameters())
|
222
|
+
num_experts = (
|
223
|
+
self.config.text_config.num_local_experts
|
224
|
+
if hasattr(self.config, "text_config")
|
225
|
+
else self.config.num_local_experts
|
226
|
+
)
|
137
227
|
|
138
|
-
num_experts = self.config.text_config.num_local_experts
|
139
|
-
|
140
|
-
# Params for weights, fp8 weight scales, fp8 activation scales
|
141
|
-
# (param_name, weight_name, expert_id, shard_id)
|
142
228
|
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
143
229
|
ckpt_gate_proj_name="gate_proj",
|
144
230
|
ckpt_down_proj_name="down_proj",
|
@@ -147,81 +233,308 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
147
233
|
)
|
148
234
|
|
149
235
|
for name, loaded_weight in weights:
|
150
|
-
if
|
236
|
+
if self._should_skip_weight(name):
|
237
|
+
continue
|
238
|
+
|
239
|
+
name = self._transform_weight_name(name)
|
240
|
+
|
241
|
+
if "vision" not in name:
|
151
242
|
name, loaded_weight = self.permute_qk_weight_for_rotary(
|
152
243
|
name, loaded_weight
|
153
244
|
)
|
154
245
|
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
246
|
+
if self._handle_scale_remapping(name, params_dict):
|
247
|
+
continue
|
248
|
+
|
249
|
+
if self._handle_stacked_params(
|
250
|
+
name, loaded_weight, stacked_params_mapping, params_dict
|
251
|
+
):
|
252
|
+
continue
|
253
|
+
|
254
|
+
if self._handle_expert_weights(
|
255
|
+
name, loaded_weight, expert_params_mapping, params_dict, num_experts
|
256
|
+
):
|
257
|
+
continue
|
258
|
+
|
259
|
+
self._handle_default_weight(name, loaded_weight, params_dict)
|
260
|
+
|
261
|
+
def _should_skip_weight(self, name: str) -> bool:
|
262
|
+
"""Check if we should skip loading this weight."""
|
263
|
+
return "vision" in name and not self.has_vision
|
264
|
+
|
265
|
+
def _transform_weight_name(self, name: str) -> str:
|
266
|
+
"""Transform weight name by adding language_model prefix if needed."""
|
267
|
+
if (
|
268
|
+
not name.startswith("language_model.")
|
269
|
+
and "vision" not in name
|
270
|
+
and "multi_modal_projector" not in name
|
271
|
+
):
|
272
|
+
return f"language_model.{name}"
|
273
|
+
return name
|
274
|
+
|
275
|
+
def _handle_scale_remapping(self, name: str, params_dict: dict) -> bool:
|
276
|
+
"""Handle scale parameter remapping. Returns True if handled."""
|
277
|
+
if "scale" in name and "expert" not in name:
|
278
|
+
remapped_name = maybe_remap_kv_scale_name(name, params_dict)
|
279
|
+
return remapped_name is None
|
280
|
+
return False
|
281
|
+
|
282
|
+
def _handle_stacked_params(
|
283
|
+
self,
|
284
|
+
name: str,
|
285
|
+
loaded_weight: torch.Tensor,
|
286
|
+
stacked_params_mapping: list,
|
287
|
+
params_dict: dict,
|
288
|
+
) -> bool:
|
289
|
+
"""Handle stacked parameter loading. Returns True if handled."""
|
290
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
291
|
+
if weight_name in name and "vision" not in name:
|
292
|
+
transformed_name = name.replace(weight_name, param_name)
|
293
|
+
param = params_dict[transformed_name]
|
294
|
+
param.weight_loader(param, loaded_weight, shard_id)
|
295
|
+
return True
|
296
|
+
return False
|
297
|
+
|
298
|
+
def _handle_expert_weights(
|
299
|
+
self,
|
300
|
+
name: str,
|
301
|
+
loaded_weight: torch.Tensor,
|
302
|
+
expert_params_mapping: list,
|
303
|
+
params_dict: dict,
|
304
|
+
num_experts: int,
|
305
|
+
) -> bool:
|
306
|
+
"""Handle expert weight loading for MoE (Mixture of Experts) layers.
|
307
|
+
|
308
|
+
Args:
|
309
|
+
name: Parameter name from the checkpoint
|
310
|
+
loaded_weight: The weight tensor to be loaded
|
311
|
+
expert_params_mapping: Mapping of parameter names to expert configurations
|
312
|
+
params_dict: Dictionary of model parameters
|
313
|
+
num_experts: Total number of experts in the MoE layer
|
314
|
+
|
315
|
+
Returns:
|
316
|
+
bool: True if the parameter was handled (is an expert parameter), False otherwise
|
317
|
+
"""
|
318
|
+
if ".experts" not in name:
|
319
|
+
return False
|
320
|
+
|
321
|
+
if "experts.gate_up_proj" not in name and "experts.down_proj" not in name:
|
322
|
+
return self._handle_other_expert_params(
|
323
|
+
name, loaded_weight, expert_params_mapping, params_dict
|
324
|
+
)
|
325
|
+
|
326
|
+
if "scale" in name:
|
327
|
+
return self._handle_expert_scale_params(
|
328
|
+
name, loaded_weight, params_dict, num_experts
|
329
|
+
)
|
330
|
+
else:
|
331
|
+
return self._handle_expert_weight_params(
|
332
|
+
name, loaded_weight, params_dict, num_experts
|
333
|
+
)
|
334
|
+
|
335
|
+
def _handle_other_expert_params(
|
336
|
+
self,
|
337
|
+
name: str,
|
338
|
+
loaded_weight: torch.Tensor,
|
339
|
+
expert_params_mapping: list,
|
340
|
+
params_dict: dict,
|
341
|
+
) -> bool:
|
342
|
+
"""Handle expert parameters that are not gate_up_proj or down_proj weights.
|
343
|
+
|
344
|
+
Args:
|
345
|
+
name: Parameter name from the checkpoint
|
346
|
+
loaded_weight: The weight tensor to be loaded
|
347
|
+
expert_params_mapping: List of tuples mapping checkpoint names to model parameters
|
348
|
+
params_dict: Dictionary of model parameters
|
349
|
+
|
350
|
+
Returns:
|
351
|
+
bool: True if parameter was found and handled, False otherwise
|
352
|
+
"""
|
353
|
+
for param_name, weight_name, expert_id, shard_id in expert_params_mapping:
|
354
|
+
if weight_name in name:
|
355
|
+
transformed_name = name.replace(weight_name, param_name)
|
356
|
+
param = params_dict[transformed_name]
|
357
|
+
param.weight_loader(
|
358
|
+
param, loaded_weight, name, shard_id=shard_id, expert_id=expert_id
|
359
|
+
)
|
360
|
+
return True
|
361
|
+
return False
|
362
|
+
|
363
|
+
def _transform_expert_name(
|
364
|
+
self, name: str, is_weight: bool = False
|
365
|
+
) -> Tuple[str, str, List[str]]:
|
366
|
+
"""Transform expert parameter name and get shard information.
|
367
|
+
|
368
|
+
Args:
|
369
|
+
name: The original parameter name
|
370
|
+
is_weight: Whether this is a weight parameter (adds _weight suffix)
|
371
|
+
|
372
|
+
Returns:
|
373
|
+
Tuple of (transformed_name, shard_id, shard_id_list)
|
374
|
+
"""
|
375
|
+
suffix = "_weight" if is_weight else ""
|
376
|
+
|
377
|
+
if ".gate_up_proj" in name:
|
378
|
+
transformed_name = name.replace(
|
379
|
+
".experts.gate_up_proj", f".experts.w13{suffix}"
|
380
|
+
)
|
381
|
+
shard_id = "w13"
|
382
|
+
shard_id_list = ["w1", "w3"]
|
383
|
+
else: # down_proj
|
384
|
+
transformed_name = name.replace(
|
385
|
+
".experts.down_proj", f".experts.w2{suffix}"
|
386
|
+
)
|
387
|
+
shard_id = "w2"
|
388
|
+
shard_id_list = ["w2"]
|
389
|
+
|
390
|
+
return transformed_name, shard_id, shard_id_list
|
391
|
+
|
392
|
+
def _handle_expert_scale_params(
|
393
|
+
self,
|
394
|
+
name: str,
|
395
|
+
loaded_weight: torch.Tensor,
|
396
|
+
params_dict: dict,
|
397
|
+
num_experts: int,
|
398
|
+
) -> bool:
|
399
|
+
"""Handle quantization scale parameters for expert weights.
|
400
|
+
|
401
|
+
Args:
|
402
|
+
name: Parameter name containing scale information
|
403
|
+
loaded_weight: Scale tensor to be loaded
|
404
|
+
params_dict: Dictionary of model parameters
|
405
|
+
num_experts: Total number of experts for broadcast operations
|
406
|
+
|
407
|
+
Returns:
|
408
|
+
bool: True (always handles scale parameters)
|
409
|
+
"""
|
410
|
+
import re
|
411
|
+
|
412
|
+
# Check if this matches the expert parameter pattern: experts.{expert_id}.{param_name}
|
413
|
+
expert_match = re.search(r"experts\.(\d+)\.", name)
|
414
|
+
|
415
|
+
# Transform name
|
416
|
+
transformed_name, _, _ = self._transform_expert_name(name)
|
417
|
+
|
418
|
+
if transformed_name not in params_dict:
|
419
|
+
return True
|
420
|
+
|
421
|
+
param = params_dict[transformed_name]
|
422
|
+
|
423
|
+
# Handle scale parameters
|
424
|
+
if expert_match:
|
425
|
+
# If we have a specific expert ID, only load for that expert
|
426
|
+
expert_id = int(expert_match.group(1))
|
427
|
+
# For scale parameters, we can directly set the value
|
428
|
+
param.data[expert_id] = loaded_weight
|
429
|
+
else:
|
430
|
+
# No expert ID found - this is a single scale for all experts
|
431
|
+
# Load the same scale for all experts
|
432
|
+
for expert_id in range(num_experts):
|
433
|
+
param.data[expert_id] = loaded_weight
|
434
|
+
|
435
|
+
return True
|
436
|
+
|
437
|
+
def _handle_expert_weight_params(
|
438
|
+
self,
|
439
|
+
name: str,
|
440
|
+
loaded_weight: torch.Tensor,
|
441
|
+
params_dict: dict,
|
442
|
+
num_experts: int,
|
443
|
+
) -> bool:
|
444
|
+
"""Handle actual weight tensors for expert layers (gate_up_proj and down_proj).
|
445
|
+
|
446
|
+
Args:
|
447
|
+
name: Parameter name (should contain gate_up_proj or down_proj)
|
448
|
+
loaded_weight: Weight tensor(s) to be loaded
|
449
|
+
params_dict: Dictionary of model parameters
|
450
|
+
num_experts: Total number of experts for tensor distribution
|
451
|
+
|
452
|
+
Returns:
|
453
|
+
bool: True (always handles weight parameters)
|
454
|
+
"""
|
455
|
+
# Transform name and get shard info
|
456
|
+
transformed_name, _, shard_id_list = self._transform_expert_name(
|
457
|
+
name, is_weight=True
|
458
|
+
)
|
459
|
+
|
460
|
+
if ".gate_up_proj" in name:
|
461
|
+
loaded_weight_list = loaded_weight.chunk(2, dim=-1)
|
462
|
+
else: # down_proj
|
463
|
+
loaded_weight_list = [loaded_weight]
|
464
|
+
|
465
|
+
for param_name, weight_chunk, shard_id in zip(
|
466
|
+
[transformed_name] * len(shard_id_list), loaded_weight_list, shard_id_list
|
467
|
+
):
|
468
|
+
if param_name not in params_dict:
|
469
|
+
continue
|
470
|
+
|
471
|
+
param = params_dict[param_name]
|
472
|
+
weight_loader = param.weight_loader
|
473
|
+
|
474
|
+
# Handle the case where loaded_weight might be a single tensor for all experts
|
475
|
+
if weight_chunk.dim() == 2:
|
476
|
+
# Single tensor case - load for all experts
|
477
|
+
for expert_id in range(num_experts):
|
478
|
+
weight_loader(
|
479
|
+
param,
|
480
|
+
weight_chunk.T,
|
481
|
+
param_name,
|
482
|
+
shard_id=shard_id,
|
483
|
+
expert_id=expert_id,
|
484
|
+
)
|
166
485
|
else:
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
if weight_name not in name:
|
176
|
-
continue
|
177
|
-
name = name.replace(weight_name, param_name)
|
178
|
-
param = params_dict[name]
|
179
|
-
weight_loader = param.weight_loader
|
180
|
-
weight_loader(
|
181
|
-
param,
|
182
|
-
loaded_weight,
|
183
|
-
name,
|
184
|
-
shard_id=shard_id,
|
185
|
-
expert_id=expert_id,
|
186
|
-
)
|
187
|
-
break
|
188
|
-
else:
|
189
|
-
if ".gate_up_proj" in name:
|
190
|
-
name_list = [
|
191
|
-
name.replace(
|
192
|
-
".experts.gate_up_proj", ".experts.w13_weight"
|
193
|
-
)
|
194
|
-
] * 2
|
195
|
-
loaded_weight_list = loaded_weight.chunk(2, dim=-1)
|
196
|
-
shard_id_list = ["w1", "w3"]
|
197
|
-
else:
|
198
|
-
name_list = [
|
199
|
-
name.replace(".experts.down_proj", ".experts.w2_weight")
|
200
|
-
]
|
201
|
-
shard_id_list = ["w2"]
|
202
|
-
loaded_weight_list = [loaded_weight]
|
203
|
-
for name, loaded_weight, shard_id in zip(
|
204
|
-
name_list, loaded_weight_list, shard_id_list
|
205
|
-
):
|
206
|
-
param = params_dict[name]
|
207
|
-
weight_loader = param.weight_loader
|
208
|
-
for expert_id in range(num_experts):
|
209
|
-
weight_loader(
|
210
|
-
param,
|
211
|
-
loaded_weight[expert_id].T,
|
212
|
-
name,
|
213
|
-
shard_id=shard_id,
|
214
|
-
expert_id=expert_id,
|
215
|
-
)
|
216
|
-
else:
|
217
|
-
# Skip loading extra bias for GPTQ models.
|
218
|
-
if name.endswith(".bias") and name not in params_dict:
|
219
|
-
continue
|
220
|
-
param = params_dict[name]
|
221
|
-
weight_loader = getattr(
|
222
|
-
param, "weight_loader", default_weight_loader
|
486
|
+
# Multiple experts case - load each expert's weights
|
487
|
+
for expert_id in range(num_experts):
|
488
|
+
weight_loader(
|
489
|
+
param,
|
490
|
+
weight_chunk[expert_id].T,
|
491
|
+
param_name,
|
492
|
+
shard_id=shard_id,
|
493
|
+
expert_id=expert_id,
|
223
494
|
)
|
224
|
-
|
495
|
+
|
496
|
+
return True
|
497
|
+
|
498
|
+
def _handle_default_weight(
|
499
|
+
self, name: str, loaded_weight: torch.Tensor, params_dict: dict
|
500
|
+
):
|
501
|
+
"""Handle default weight loading."""
|
502
|
+
# Skip loading extra bias for GPTQ models
|
503
|
+
if name.endswith(".bias") and name not in params_dict:
|
504
|
+
return
|
505
|
+
|
506
|
+
param = params_dict[name]
|
507
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
508
|
+
weight_loader(param, loaded_weight)
|
509
|
+
|
510
|
+
def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
|
511
|
+
if hasattr(self.language_model, "set_eagle3_layers_to_capture"):
|
512
|
+
self.language_model.set_eagle3_layers_to_capture(layer_ids)
|
513
|
+
|
514
|
+
def get_embed_and_head(self):
|
515
|
+
# For EAGLE3, we delegate to the language model which should have this method
|
516
|
+
# If the language model doesn't have lm_head (like EAGLE3), we return None for head
|
517
|
+
embed = self.language_model.get_embed()
|
518
|
+
if hasattr(self.language_model, "get_embed_and_head"):
|
519
|
+
return self.language_model.get_embed_and_head()
|
520
|
+
elif hasattr(self.language_model, "lm_head"):
|
521
|
+
return embed, self.language_model.lm_head.weight
|
522
|
+
else:
|
523
|
+
# For EAGLE3, head might not be needed
|
524
|
+
return embed, None
|
525
|
+
|
526
|
+
def set_embed_and_head(self, embed, head):
|
527
|
+
if hasattr(self.language_model, "set_embed_and_head"):
|
528
|
+
return self.language_model.set_embed_and_head(embed, head)
|
529
|
+
else:
|
530
|
+
# For EAGLE3, only set embed
|
531
|
+
return self.language_model.set_embed(embed)
|
532
|
+
|
533
|
+
def get_embed(self):
|
534
|
+
return self.language_model.get_embed()
|
535
|
+
|
536
|
+
def set_embed(self, embed):
|
537
|
+
return self.language_model.set_embed(embed)
|
225
538
|
|
226
539
|
|
227
540
|
EntryClass = Llama4ForConditionalGeneration
|
sglang/srt/models/phi4mm.py
CHANGED
@@ -446,9 +446,7 @@ class Phi4MMForCausalLM(nn.Module):
|
|
446
446
|
return hidden_states
|
447
447
|
|
448
448
|
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
449
|
-
|
450
|
-
im_token_id: int = mm_inputs.im_token_id
|
451
|
-
pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id])
|
449
|
+
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
452
450
|
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
453
451
|
|
454
452
|
def should_apply_lora(self, module_name: str) -> bool:
|
sglang/srt/models/pixtral.py
CHANGED
@@ -268,15 +268,14 @@ class PixtralHFVisionModel(nn.Module):
|
|
268
268
|
|
269
269
|
DEFAULT_IMAGE_TOKEN_ID = 10
|
270
270
|
|
271
|
-
def pad_input_ids(self, input_ids: List[int],
|
272
|
-
return self.input_padder.pad_input_tokens(input_ids,
|
271
|
+
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
272
|
+
return self.input_padder.pad_input_tokens(input_ids, mm_inputs)
|
273
273
|
|
274
274
|
def __init__(
|
275
275
|
self,
|
276
276
|
config: PixtralVisionConfig,
|
277
277
|
quant_config: Optional[QuantizationConfig] = None,
|
278
278
|
*,
|
279
|
-
image_token_id: int = DEFAULT_IMAGE_TOKEN_ID,
|
280
279
|
num_hidden_layers_override: Optional[int] = None,
|
281
280
|
prefix: str = "",
|
282
281
|
) -> None:
|
@@ -314,11 +313,8 @@ class PixtralHFVisionModel(nn.Module):
|
|
314
313
|
)
|
315
314
|
|
316
315
|
# Initialize patch position embedding
|
317
|
-
self.image_token_id = image_token_id
|
318
316
|
self.patch_positional_embedding = PixtralRotaryEmbedding(config)
|
319
|
-
self.input_padder = MultiModalityDataPaddingPatternMultimodalTokens(
|
320
|
-
[self.image_token_id]
|
321
|
-
)
|
317
|
+
self.input_padder = MultiModalityDataPaddingPatternMultimodalTokens()
|
322
318
|
|
323
319
|
@property
|
324
320
|
def dtype(self):
|