sglang 0.4.9__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +2 -2
- sglang/srt/configs/model_config.py +12 -1
- sglang/srt/conversation.py +35 -1
- sglang/srt/disaggregation/mooncake/conn.py +35 -4
- sglang/srt/entrypoints/http_server_engine.py +1 -1
- sglang/srt/layers/communicator.py +3 -1
- sglang/srt/layers/flashinfer_comm_fusion.py +3 -3
- sglang/srt/layers/layernorm.py +2 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +58 -0
- sglang/srt/layers/moe/ep_moe/layer.py +140 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +135 -58
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
- sglang/srt/layers/quantization/__init__.py +2 -0
- sglang/srt/layers/quantization/fp8.py +28 -7
- sglang/srt/layers/quantization/modelopt_quant.py +244 -1
- sglang/srt/layers/quantization/w4afp8.py +264 -0
- sglang/srt/layers/vocab_parallel_embedding.py +9 -3
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
- sglang/srt/managers/cache_controller.py +41 -195
- sglang/srt/managers/io_struct.py +8 -1
- sglang/srt/managers/mm_utils.py +4 -2
- sglang/srt/managers/schedule_batch.py +1 -1
- sglang/srt/managers/scheduler.py +17 -5
- sglang/srt/mem_cache/hiradix_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +113 -63
- sglang/srt/mem_cache/memory_pool_host.py +6 -109
- sglang/srt/mem_cache/radix_cache.py +8 -4
- sglang/srt/models/deepseek_v2.py +16 -2
- sglang/srt/models/mllama4.py +360 -79
- sglang/srt/multimodal/mm_utils.py +2 -2
- sglang/srt/multimodal/processors/mllama4.py +62 -60
- sglang/srt/server_args.py +15 -0
- sglang/srt/two_batch_overlap.py +3 -0
- sglang/srt/utils.py +37 -17
- sglang/test/test_cutlass_w4a8_moe.py +281 -0
- sglang/utils.py +5 -5
- sglang/version.py +1 -1
- {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +4 -3
- {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +47 -43
- {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
sglang/srt/models/mllama4.py
CHANGED
@@ -1,3 +1,6 @@
|
|
1
|
+
import json as json_lib
|
2
|
+
import logging
|
3
|
+
import os
|
1
4
|
from collections.abc import Iterable
|
2
5
|
from typing import List, Optional, Set, Tuple
|
3
6
|
|
@@ -19,6 +22,13 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
|
|
19
22
|
from sglang.srt.utils import add_prefix, is_cpu
|
20
23
|
|
21
24
|
_is_cpu = is_cpu()
|
25
|
+
from sglang.srt.model_loader.weight_utils import (
|
26
|
+
default_weight_loader,
|
27
|
+
maybe_remap_kv_scale_name,
|
28
|
+
)
|
29
|
+
from sglang.srt.utils import add_prefix
|
30
|
+
|
31
|
+
logger = logging.getLogger(__name__)
|
22
32
|
|
23
33
|
|
24
34
|
class Llama4ForConditionalGeneration(nn.Module):
|
@@ -37,19 +47,85 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
37
47
|
self.config = config
|
38
48
|
self.quant_config = quant_config
|
39
49
|
|
40
|
-
|
41
|
-
self.
|
50
|
+
# Check if this is a text-only model (modelopt fp8 llama4 has no vision components)
|
51
|
+
self.has_vision = self._has_vision_weights(config)
|
52
|
+
if not self.has_vision:
|
53
|
+
logger.warning(
|
54
|
+
"No vision weights found in checkpoint. Model will run in text-only mode. "
|
55
|
+
"Multimodal capabilities (image processing) will be unavailable."
|
56
|
+
)
|
57
|
+
|
58
|
+
if self.has_vision:
|
59
|
+
self.vision_model = Llama4VisionModel(config.vision_config)
|
60
|
+
self.multi_modal_projector = Llama4MultiModalProjector(config)
|
61
|
+
else:
|
62
|
+
self.vision_model = None
|
63
|
+
self.multi_modal_projector = None
|
42
64
|
|
43
65
|
# Initialize the language model
|
44
66
|
from sglang.srt.models.llama4 import Llama4ForCausalLM
|
45
67
|
|
46
68
|
self.language_model = Llama4ForCausalLM(
|
47
|
-
config.text_config,
|
69
|
+
config.text_config if hasattr(config, "text_config") else config,
|
48
70
|
quant_config=quant_config,
|
49
71
|
prefix=add_prefix("language_model", prefix),
|
50
72
|
)
|
51
73
|
|
52
|
-
self.logits_processor = LogitsProcessor(
|
74
|
+
self.logits_processor = LogitsProcessor(
|
75
|
+
config.text_config if hasattr(config, "text_config") else config
|
76
|
+
)
|
77
|
+
|
78
|
+
def _has_vision_weights(self, config) -> bool:
|
79
|
+
"""Check if the model has vision components by examining the checkpoint."""
|
80
|
+
model_path = getattr(config, "_name_or_path", None)
|
81
|
+
if not model_path:
|
82
|
+
return False
|
83
|
+
|
84
|
+
# Check if this is a local path first
|
85
|
+
if os.path.isdir(model_path):
|
86
|
+
index_file = os.path.join(model_path, "model.safetensors.index.json")
|
87
|
+
if os.path.exists(index_file):
|
88
|
+
return self._check_vision_weights_in_index(index_file)
|
89
|
+
|
90
|
+
# For HuggingFace models, we need to check the actual checkpoint
|
91
|
+
# The config might say it's multimodal, but the checkpoint might be text-only
|
92
|
+
try:
|
93
|
+
# Try to access the HuggingFace cache directory
|
94
|
+
from huggingface_hub import try_to_load_from_cache
|
95
|
+
|
96
|
+
# Check if index file exists in cache
|
97
|
+
index_file_path = try_to_load_from_cache(
|
98
|
+
repo_id=model_path,
|
99
|
+
filename="model.safetensors.index.json",
|
100
|
+
cache_dir=None,
|
101
|
+
)
|
102
|
+
|
103
|
+
if index_file_path and os.path.exists(index_file_path):
|
104
|
+
return self._check_vision_weights_in_index(index_file_path)
|
105
|
+
|
106
|
+
except Exception:
|
107
|
+
# If we can't access the cache, fall back to config-based detection
|
108
|
+
pass
|
109
|
+
|
110
|
+
# Fallback, assume text-only
|
111
|
+
return False
|
112
|
+
|
113
|
+
def _check_vision_weights_in_index(self, index_file: str) -> bool:
|
114
|
+
"""Check if the model.safetensors.index.json contains vision weights."""
|
115
|
+
try:
|
116
|
+
with open(index_file, "r") as f:
|
117
|
+
index_data = json_lib.load(f)
|
118
|
+
|
119
|
+
vision_patterns = ["vision_model", "vision_tower", "multi_modal_projector"]
|
120
|
+
weight_names = index_data.get("weight_map", {}).keys()
|
121
|
+
|
122
|
+
return any(
|
123
|
+
pattern in weight_name
|
124
|
+
for weight_name in weight_names
|
125
|
+
for pattern in vision_patterns
|
126
|
+
)
|
127
|
+
except (OSError, json_lib.JSONDecodeError, KeyError):
|
128
|
+
return False
|
53
129
|
|
54
130
|
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
55
131
|
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
@@ -59,6 +135,10 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
59
135
|
self,
|
60
136
|
items: List[MultimodalDataItem],
|
61
137
|
) -> torch.Tensor:
|
138
|
+
# For text-only models, return None or raise an error
|
139
|
+
if not self.has_vision or self.vision_model is None:
|
140
|
+
raise ValueError("Vision model not available for text-only checkpoint")
|
141
|
+
|
62
142
|
pixel_values = (
|
63
143
|
torch.concat([item.pixel_values for item in items])
|
64
144
|
.to(next(self.vision_model.parameters()).device)
|
@@ -79,11 +159,14 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
79
159
|
**kwargs: object,
|
80
160
|
) -> torch.Tensor:
|
81
161
|
|
162
|
+
# For text-only models, pass None for image_data_embedding_func
|
163
|
+
image_embedding_func = self.get_image_feature if self.has_vision else None
|
164
|
+
|
82
165
|
hs = general_mm_embed_routine(
|
83
166
|
input_ids=input_ids,
|
84
167
|
forward_batch=forward_batch,
|
85
168
|
language_model=self.language_model,
|
86
|
-
image_data_embedding_func=
|
169
|
+
image_data_embedding_func=image_embedding_func,
|
87
170
|
positions=positions,
|
88
171
|
)
|
89
172
|
|
@@ -124,7 +207,6 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
124
207
|
return name, loaded_weight
|
125
208
|
|
126
209
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
|
127
|
-
|
128
210
|
stacked_params_mapping = [
|
129
211
|
# (param_name, shard_name, shard_id)
|
130
212
|
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
|
@@ -137,11 +219,12 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
137
219
|
]
|
138
220
|
|
139
221
|
params_dict = dict(self.named_parameters())
|
222
|
+
num_experts = (
|
223
|
+
self.config.text_config.num_local_experts
|
224
|
+
if hasattr(self.config, "text_config")
|
225
|
+
else self.config.num_local_experts
|
226
|
+
)
|
140
227
|
|
141
|
-
num_experts = self.config.text_config.num_local_experts
|
142
|
-
|
143
|
-
# Params for weights, fp8 weight scales, fp8 activation scales
|
144
|
-
# (param_name, weight_name, expert_id, shard_id)
|
145
228
|
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
146
229
|
ckpt_gate_proj_name="gate_proj",
|
147
230
|
ckpt_down_proj_name="down_proj",
|
@@ -150,81 +233,279 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
150
233
|
)
|
151
234
|
|
152
235
|
for name, loaded_weight in weights:
|
153
|
-
if
|
236
|
+
if self._should_skip_weight(name):
|
237
|
+
continue
|
238
|
+
|
239
|
+
name = self._transform_weight_name(name)
|
240
|
+
|
241
|
+
if "vision" not in name:
|
154
242
|
name, loaded_weight = self.permute_qk_weight_for_rotary(
|
155
243
|
name, loaded_weight
|
156
244
|
)
|
157
245
|
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
246
|
+
if self._handle_scale_remapping(name, params_dict):
|
247
|
+
continue
|
248
|
+
|
249
|
+
if self._handle_stacked_params(
|
250
|
+
name, loaded_weight, stacked_params_mapping, params_dict
|
251
|
+
):
|
252
|
+
continue
|
253
|
+
|
254
|
+
if self._handle_expert_weights(
|
255
|
+
name, loaded_weight, expert_params_mapping, params_dict, num_experts
|
256
|
+
):
|
257
|
+
continue
|
258
|
+
|
259
|
+
self._handle_default_weight(name, loaded_weight, params_dict)
|
260
|
+
|
261
|
+
def _should_skip_weight(self, name: str) -> bool:
|
262
|
+
"""Check if we should skip loading this weight."""
|
263
|
+
return "vision" in name and not self.has_vision
|
264
|
+
|
265
|
+
def _transform_weight_name(self, name: str) -> str:
|
266
|
+
"""Transform weight name by adding language_model prefix if needed."""
|
267
|
+
if (
|
268
|
+
not name.startswith("language_model.")
|
269
|
+
and "vision" not in name
|
270
|
+
and "multi_modal_projector" not in name
|
271
|
+
):
|
272
|
+
return f"language_model.{name}"
|
273
|
+
return name
|
274
|
+
|
275
|
+
def _handle_scale_remapping(self, name: str, params_dict: dict) -> bool:
|
276
|
+
"""Handle scale parameter remapping. Returns True if handled."""
|
277
|
+
if "scale" in name and "expert" not in name:
|
278
|
+
remapped_name = maybe_remap_kv_scale_name(name, params_dict)
|
279
|
+
return remapped_name is None
|
280
|
+
return False
|
281
|
+
|
282
|
+
def _handle_stacked_params(
|
283
|
+
self,
|
284
|
+
name: str,
|
285
|
+
loaded_weight: torch.Tensor,
|
286
|
+
stacked_params_mapping: list,
|
287
|
+
params_dict: dict,
|
288
|
+
) -> bool:
|
289
|
+
"""Handle stacked parameter loading. Returns True if handled."""
|
290
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
291
|
+
if weight_name in name and "vision" not in name:
|
292
|
+
transformed_name = name.replace(weight_name, param_name)
|
293
|
+
param = params_dict[transformed_name]
|
294
|
+
param.weight_loader(param, loaded_weight, shard_id)
|
295
|
+
return True
|
296
|
+
return False
|
297
|
+
|
298
|
+
def _handle_expert_weights(
|
299
|
+
self,
|
300
|
+
name: str,
|
301
|
+
loaded_weight: torch.Tensor,
|
302
|
+
expert_params_mapping: list,
|
303
|
+
params_dict: dict,
|
304
|
+
num_experts: int,
|
305
|
+
) -> bool:
|
306
|
+
"""Handle expert weight loading for MoE (Mixture of Experts) layers.
|
307
|
+
|
308
|
+
Args:
|
309
|
+
name: Parameter name from the checkpoint
|
310
|
+
loaded_weight: The weight tensor to be loaded
|
311
|
+
expert_params_mapping: Mapping of parameter names to expert configurations
|
312
|
+
params_dict: Dictionary of model parameters
|
313
|
+
num_experts: Total number of experts in the MoE layer
|
314
|
+
|
315
|
+
Returns:
|
316
|
+
bool: True if the parameter was handled (is an expert parameter), False otherwise
|
317
|
+
"""
|
318
|
+
if ".experts" not in name:
|
319
|
+
return False
|
320
|
+
|
321
|
+
if "experts.gate_up_proj" not in name and "experts.down_proj" not in name:
|
322
|
+
return self._handle_other_expert_params(
|
323
|
+
name, loaded_weight, expert_params_mapping, params_dict
|
324
|
+
)
|
325
|
+
|
326
|
+
if "scale" in name:
|
327
|
+
return self._handle_expert_scale_params(
|
328
|
+
name, loaded_weight, params_dict, num_experts
|
329
|
+
)
|
330
|
+
else:
|
331
|
+
return self._handle_expert_weight_params(
|
332
|
+
name, loaded_weight, params_dict, num_experts
|
333
|
+
)
|
334
|
+
|
335
|
+
def _handle_other_expert_params(
|
336
|
+
self,
|
337
|
+
name: str,
|
338
|
+
loaded_weight: torch.Tensor,
|
339
|
+
expert_params_mapping: list,
|
340
|
+
params_dict: dict,
|
341
|
+
) -> bool:
|
342
|
+
"""Handle expert parameters that are not gate_up_proj or down_proj weights.
|
343
|
+
|
344
|
+
Args:
|
345
|
+
name: Parameter name from the checkpoint
|
346
|
+
loaded_weight: The weight tensor to be loaded
|
347
|
+
expert_params_mapping: List of tuples mapping checkpoint names to model parameters
|
348
|
+
params_dict: Dictionary of model parameters
|
349
|
+
|
350
|
+
Returns:
|
351
|
+
bool: True if parameter was found and handled, False otherwise
|
352
|
+
"""
|
353
|
+
for param_name, weight_name, expert_id, shard_id in expert_params_mapping:
|
354
|
+
if weight_name in name:
|
355
|
+
transformed_name = name.replace(weight_name, param_name)
|
356
|
+
param = params_dict[transformed_name]
|
357
|
+
param.weight_loader(
|
358
|
+
param, loaded_weight, name, shard_id=shard_id, expert_id=expert_id
|
359
|
+
)
|
360
|
+
return True
|
361
|
+
return False
|
362
|
+
|
363
|
+
def _transform_expert_name(
|
364
|
+
self, name: str, is_weight: bool = False
|
365
|
+
) -> Tuple[str, str, List[str]]:
|
366
|
+
"""Transform expert parameter name and get shard information.
|
367
|
+
|
368
|
+
Args:
|
369
|
+
name: The original parameter name
|
370
|
+
is_weight: Whether this is a weight parameter (adds _weight suffix)
|
371
|
+
|
372
|
+
Returns:
|
373
|
+
Tuple of (transformed_name, shard_id, shard_id_list)
|
374
|
+
"""
|
375
|
+
suffix = "_weight" if is_weight else ""
|
376
|
+
|
377
|
+
if ".gate_up_proj" in name:
|
378
|
+
transformed_name = name.replace(
|
379
|
+
".experts.gate_up_proj", f".experts.w13{suffix}"
|
380
|
+
)
|
381
|
+
shard_id = "w13"
|
382
|
+
shard_id_list = ["w1", "w3"]
|
383
|
+
else: # down_proj
|
384
|
+
transformed_name = name.replace(
|
385
|
+
".experts.down_proj", f".experts.w2{suffix}"
|
386
|
+
)
|
387
|
+
shard_id = "w2"
|
388
|
+
shard_id_list = ["w2"]
|
389
|
+
|
390
|
+
return transformed_name, shard_id, shard_id_list
|
391
|
+
|
392
|
+
def _handle_expert_scale_params(
|
393
|
+
self,
|
394
|
+
name: str,
|
395
|
+
loaded_weight: torch.Tensor,
|
396
|
+
params_dict: dict,
|
397
|
+
num_experts: int,
|
398
|
+
) -> bool:
|
399
|
+
"""Handle quantization scale parameters for expert weights.
|
400
|
+
|
401
|
+
Args:
|
402
|
+
name: Parameter name containing scale information
|
403
|
+
loaded_weight: Scale tensor to be loaded
|
404
|
+
params_dict: Dictionary of model parameters
|
405
|
+
num_experts: Total number of experts for broadcast operations
|
406
|
+
|
407
|
+
Returns:
|
408
|
+
bool: True (always handles scale parameters)
|
409
|
+
"""
|
410
|
+
import re
|
411
|
+
|
412
|
+
# Check if this matches the expert parameter pattern: experts.{expert_id}.{param_name}
|
413
|
+
expert_match = re.search(r"experts\.(\d+)\.", name)
|
414
|
+
|
415
|
+
# Transform name
|
416
|
+
transformed_name, _, _ = self._transform_expert_name(name)
|
417
|
+
|
418
|
+
if transformed_name not in params_dict:
|
419
|
+
return True
|
420
|
+
|
421
|
+
param = params_dict[transformed_name]
|
422
|
+
|
423
|
+
# Handle scale parameters
|
424
|
+
if expert_match:
|
425
|
+
# If we have a specific expert ID, only load for that expert
|
426
|
+
expert_id = int(expert_match.group(1))
|
427
|
+
# For scale parameters, we can directly set the value
|
428
|
+
param.data[expert_id] = loaded_weight
|
429
|
+
else:
|
430
|
+
# No expert ID found - this is a single scale for all experts
|
431
|
+
# Load the same scale for all experts
|
432
|
+
for expert_id in range(num_experts):
|
433
|
+
param.data[expert_id] = loaded_weight
|
434
|
+
|
435
|
+
return True
|
436
|
+
|
437
|
+
def _handle_expert_weight_params(
|
438
|
+
self,
|
439
|
+
name: str,
|
440
|
+
loaded_weight: torch.Tensor,
|
441
|
+
params_dict: dict,
|
442
|
+
num_experts: int,
|
443
|
+
) -> bool:
|
444
|
+
"""Handle actual weight tensors for expert layers (gate_up_proj and down_proj).
|
445
|
+
|
446
|
+
Args:
|
447
|
+
name: Parameter name (should contain gate_up_proj or down_proj)
|
448
|
+
loaded_weight: Weight tensor(s) to be loaded
|
449
|
+
params_dict: Dictionary of model parameters
|
450
|
+
num_experts: Total number of experts for tensor distribution
|
451
|
+
|
452
|
+
Returns:
|
453
|
+
bool: True (always handles weight parameters)
|
454
|
+
"""
|
455
|
+
# Transform name and get shard info
|
456
|
+
transformed_name, _, shard_id_list = self._transform_expert_name(
|
457
|
+
name, is_weight=True
|
458
|
+
)
|
459
|
+
|
460
|
+
if ".gate_up_proj" in name:
|
461
|
+
loaded_weight_list = loaded_weight.chunk(2, dim=-1)
|
462
|
+
else: # down_proj
|
463
|
+
loaded_weight_list = [loaded_weight]
|
464
|
+
|
465
|
+
for param_name, weight_chunk, shard_id in zip(
|
466
|
+
[transformed_name] * len(shard_id_list), loaded_weight_list, shard_id_list
|
467
|
+
):
|
468
|
+
if param_name not in params_dict:
|
469
|
+
continue
|
470
|
+
|
471
|
+
param = params_dict[param_name]
|
472
|
+
weight_loader = param.weight_loader
|
473
|
+
|
474
|
+
# Handle the case where loaded_weight might be a single tensor for all experts
|
475
|
+
if weight_chunk.dim() == 2:
|
476
|
+
# Single tensor case - load for all experts
|
477
|
+
for expert_id in range(num_experts):
|
478
|
+
weight_loader(
|
479
|
+
param,
|
480
|
+
weight_chunk.T,
|
481
|
+
param_name,
|
482
|
+
shard_id=shard_id,
|
483
|
+
expert_id=expert_id,
|
484
|
+
)
|
169
485
|
else:
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
if weight_name not in name:
|
179
|
-
continue
|
180
|
-
name = name.replace(weight_name, param_name)
|
181
|
-
param = params_dict[name]
|
182
|
-
weight_loader = param.weight_loader
|
183
|
-
weight_loader(
|
184
|
-
param,
|
185
|
-
loaded_weight,
|
186
|
-
name,
|
187
|
-
shard_id=shard_id,
|
188
|
-
expert_id=expert_id,
|
189
|
-
)
|
190
|
-
break
|
191
|
-
else:
|
192
|
-
if ".gate_up_proj" in name:
|
193
|
-
name_list = [
|
194
|
-
name.replace(
|
195
|
-
".experts.gate_up_proj", ".experts.w13_weight"
|
196
|
-
)
|
197
|
-
] * 2
|
198
|
-
loaded_weight_list = loaded_weight.chunk(2, dim=-1)
|
199
|
-
shard_id_list = ["w1", "w3"]
|
200
|
-
else:
|
201
|
-
name_list = [
|
202
|
-
name.replace(".experts.down_proj", ".experts.w2_weight")
|
203
|
-
]
|
204
|
-
shard_id_list = ["w2"]
|
205
|
-
loaded_weight_list = [loaded_weight]
|
206
|
-
for name, loaded_weight, shard_id in zip(
|
207
|
-
name_list, loaded_weight_list, shard_id_list
|
208
|
-
):
|
209
|
-
param = params_dict[name]
|
210
|
-
weight_loader = param.weight_loader
|
211
|
-
for expert_id in range(num_experts):
|
212
|
-
weight_loader(
|
213
|
-
param,
|
214
|
-
loaded_weight[expert_id].T,
|
215
|
-
name,
|
216
|
-
shard_id=shard_id,
|
217
|
-
expert_id=expert_id,
|
218
|
-
)
|
219
|
-
else:
|
220
|
-
# Skip loading extra bias for GPTQ models.
|
221
|
-
if name.endswith(".bias") and name not in params_dict:
|
222
|
-
continue
|
223
|
-
param = params_dict[name]
|
224
|
-
weight_loader = getattr(
|
225
|
-
param, "weight_loader", default_weight_loader
|
486
|
+
# Multiple experts case - load each expert's weights
|
487
|
+
for expert_id in range(num_experts):
|
488
|
+
weight_loader(
|
489
|
+
param,
|
490
|
+
weight_chunk[expert_id].T,
|
491
|
+
param_name,
|
492
|
+
shard_id=shard_id,
|
493
|
+
expert_id=expert_id,
|
226
494
|
)
|
227
|
-
|
495
|
+
|
496
|
+
return True
|
497
|
+
|
498
|
+
def _handle_default_weight(
|
499
|
+
self, name: str, loaded_weight: torch.Tensor, params_dict: dict
|
500
|
+
):
|
501
|
+
"""Handle default weight loading."""
|
502
|
+
# Skip loading extra bias for GPTQ models
|
503
|
+
if name.endswith(".bias") and name not in params_dict:
|
504
|
+
return
|
505
|
+
|
506
|
+
param = params_dict[name]
|
507
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
508
|
+
weight_loader(param, loaded_weight)
|
228
509
|
|
229
510
|
def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
|
230
511
|
if hasattr(self.language_model, "set_eagle3_layers_to_capture"):
|
@@ -28,12 +28,12 @@ LLaVA-Onevision : https://arxiv.org/pdf/2408.03326
|
|
28
28
|
|
29
29
|
"""
|
30
30
|
import ast
|
31
|
-
import base64
|
32
31
|
import math
|
33
32
|
import re
|
34
33
|
from io import BytesIO
|
35
34
|
|
36
35
|
import numpy as np
|
36
|
+
import pybase64
|
37
37
|
from PIL import Image
|
38
38
|
|
39
39
|
from sglang.srt.utils import flatten_nested_list
|
@@ -252,7 +252,7 @@ def process_anyres_image(image, processor, grid_pinpoints):
|
|
252
252
|
|
253
253
|
|
254
254
|
def load_image_from_base64(image):
|
255
|
-
return Image.open(BytesIO(
|
255
|
+
return Image.open(BytesIO(pybase64.b64decode(image, validate=True)))
|
256
256
|
|
257
257
|
|
258
258
|
def expand2square(pil_img, background_color):
|
@@ -60,70 +60,72 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
|
|
60
60
|
)
|
61
61
|
|
62
62
|
# Handle image resolutions and aspect ratios
|
63
|
-
if "pixel_values" in processor_output:
|
64
|
-
|
65
|
-
tokenizer = self._processor.tokenizer
|
63
|
+
if "pixel_values" not in processor_output: # no image processed
|
64
|
+
return None
|
66
65
|
|
67
|
-
|
68
|
-
|
69
|
-
max_num_tiles = getattr(self.vision_config, "max_patches", 1)
|
66
|
+
image_processor = processor.image_processor
|
67
|
+
tokenizer = self._processor.tokenizer
|
70
68
|
|
71
|
-
|
72
|
-
|
73
|
-
|
69
|
+
# Calculate tile size and find supported resolutions
|
70
|
+
tile_size = self.vision_config.image_size
|
71
|
+
max_num_tiles = getattr(self.vision_config, "max_patches", 1)
|
72
|
+
|
73
|
+
possible_resolutions = find_supported_resolutions(
|
74
|
+
max_num_chunks=max_num_tiles,
|
75
|
+
patch_size=SizeDict(height=tile_size, width=tile_size),
|
76
|
+
)
|
77
|
+
|
78
|
+
# Find best fit for each image
|
79
|
+
best_fit_sizes = [
|
80
|
+
get_best_fit(
|
81
|
+
(image.size[1], image.size[0]), # (height, width)
|
82
|
+
torch.tensor(possible_resolutions),
|
83
|
+
resize_to_max_canvas=image_processor.resize_to_max_canvas,
|
74
84
|
)
|
85
|
+
for image in processed_data.images
|
86
|
+
]
|
87
|
+
|
88
|
+
# Calculate aspect ratios and patches per image
|
89
|
+
aspect_ratios = [
|
90
|
+
(image_size[0] // tile_size, image_size[1] // tile_size)
|
91
|
+
for image_size in best_fit_sizes
|
92
|
+
]
|
93
|
+
|
94
|
+
patches_per_image = [
|
95
|
+
1 if r_h * r_w == 1 else 1 + r_h * r_w for (r_h, r_w) in aspect_ratios
|
96
|
+
]
|
97
|
+
|
98
|
+
# Add to image_inputs
|
99
|
+
processor_output["aspect_ratios"] = aspect_ratios
|
100
|
+
processor_output["patches_per_image"] = torch.tensor(patches_per_image)
|
101
|
+
|
102
|
+
# Process embed_is_patch
|
103
|
+
vocab = tokenizer.get_vocab()
|
104
|
+
patch_id = vocab.get(processor.img_patch_token, -1)
|
105
|
+
image_end_id = vocab.get(processor.end_of_img_token, -1)
|
106
|
+
|
107
|
+
if patch_id != -1 and image_end_id != -1:
|
108
|
+
input_ids = processor_output["input_ids"].view(-1)
|
109
|
+
|
110
|
+
# Remove BOS token if present
|
111
|
+
if input_ids.size(0) > 0 and input_ids[0] == tokenizer.bos_token_id:
|
112
|
+
input_ids = input_ids[1:]
|
113
|
+
|
114
|
+
# Find image end indices and split input_ids
|
115
|
+
image_end_indices = (input_ids == image_end_id).nonzero().view(-1)
|
116
|
+
|
117
|
+
if image_end_indices.size(0) > 0:
|
118
|
+
# Split at image boundaries
|
119
|
+
split_indices = (image_end_indices + 1)[:-1]
|
120
|
+
split_input_ids = torch.tensor_split(input_ids, split_indices)
|
121
|
+
split_input_ids = [x for x in split_input_ids if x.numel() > 0]
|
122
|
+
|
123
|
+
# Create embed_is_patch for each image
|
124
|
+
embed_is_patch = []
|
125
|
+
for per_image_input_ids in split_input_ids:
|
126
|
+
embed_is_patch.append(per_image_input_ids == patch_id)
|
75
127
|
|
76
|
-
|
77
|
-
best_fit_sizes = [
|
78
|
-
get_best_fit(
|
79
|
-
(image.size[1], image.size[0]), # (height, width)
|
80
|
-
torch.tensor(possible_resolutions),
|
81
|
-
resize_to_max_canvas=image_processor.resize_to_max_canvas,
|
82
|
-
)
|
83
|
-
for image in processed_data.images
|
84
|
-
]
|
85
|
-
|
86
|
-
# Calculate aspect ratios and patches per image
|
87
|
-
aspect_ratios = [
|
88
|
-
(image_size[0] // tile_size, image_size[1] // tile_size)
|
89
|
-
for image_size in best_fit_sizes
|
90
|
-
]
|
91
|
-
|
92
|
-
patches_per_image = [
|
93
|
-
1 if r_h * r_w == 1 else 1 + r_h * r_w for (r_h, r_w) in aspect_ratios
|
94
|
-
]
|
95
|
-
|
96
|
-
# Add to image_inputs
|
97
|
-
processor_output["aspect_ratios"] = aspect_ratios
|
98
|
-
processor_output["patches_per_image"] = torch.tensor(patches_per_image)
|
99
|
-
|
100
|
-
# Process embed_is_patch
|
101
|
-
vocab = tokenizer.get_vocab()
|
102
|
-
patch_id = vocab.get(processor.img_patch_token, -1)
|
103
|
-
image_end_id = vocab.get(processor.end_of_img_token, -1)
|
104
|
-
|
105
|
-
if patch_id != -1 and image_end_id != -1:
|
106
|
-
input_ids = processor_output["input_ids"].view(-1)
|
107
|
-
|
108
|
-
# Remove BOS token if present
|
109
|
-
if input_ids.size(0) > 0 and input_ids[0] == tokenizer.bos_token_id:
|
110
|
-
input_ids = input_ids[1:]
|
111
|
-
|
112
|
-
# Find image end indices and split input_ids
|
113
|
-
image_end_indices = (input_ids == image_end_id).nonzero().view(-1)
|
114
|
-
|
115
|
-
if image_end_indices.size(0) > 0:
|
116
|
-
# Split at image boundaries
|
117
|
-
split_indices = (image_end_indices + 1)[:-1]
|
118
|
-
split_input_ids = torch.tensor_split(input_ids, split_indices)
|
119
|
-
split_input_ids = [x for x in split_input_ids if x.numel() > 0]
|
120
|
-
|
121
|
-
# Create embed_is_patch for each image
|
122
|
-
embed_is_patch = []
|
123
|
-
for per_image_input_ids in split_input_ids:
|
124
|
-
embed_is_patch.append(per_image_input_ids == patch_id)
|
125
|
-
|
126
|
-
processor_output["embed_is_patch"] = embed_is_patch
|
128
|
+
processor_output["embed_is_patch"] = embed_is_patch
|
127
129
|
|
128
130
|
# Convert to the format expected by SGLang
|
129
131
|
processor_output["input_ids"] = processor_output["input_ids"].tolist()[0]
|