sglang 0.4.9__py3-none-any.whl → 0.4.9.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/bench_serving.py +2 -2
  2. sglang/srt/configs/model_config.py +36 -2
  3. sglang/srt/conversation.py +56 -3
  4. sglang/srt/disaggregation/ascend/__init__.py +6 -0
  5. sglang/srt/disaggregation/ascend/conn.py +44 -0
  6. sglang/srt/disaggregation/ascend/transfer_engine.py +58 -0
  7. sglang/srt/disaggregation/mooncake/conn.py +50 -18
  8. sglang/srt/disaggregation/mooncake/transfer_engine.py +17 -8
  9. sglang/srt/disaggregation/utils.py +25 -3
  10. sglang/srt/entrypoints/engine.py +1 -1
  11. sglang/srt/entrypoints/http_server.py +1 -0
  12. sglang/srt/entrypoints/http_server_engine.py +1 -1
  13. sglang/srt/entrypoints/openai/protocol.py +11 -0
  14. sglang/srt/entrypoints/openai/serving_chat.py +7 -0
  15. sglang/srt/function_call/function_call_parser.py +2 -0
  16. sglang/srt/function_call/kimik2_detector.py +220 -0
  17. sglang/srt/hf_transformers_utils.py +18 -0
  18. sglang/srt/jinja_template_utils.py +8 -0
  19. sglang/srt/layers/communicator.py +20 -5
  20. sglang/srt/layers/flashinfer_comm_fusion.py +3 -3
  21. sglang/srt/layers/layernorm.py +2 -2
  22. sglang/srt/layers/linear.py +12 -2
  23. sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
  24. sglang/srt/layers/moe/ep_moe/kernels.py +60 -1
  25. sglang/srt/layers/moe/ep_moe/layer.py +141 -2
  26. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +2 -0
  27. sglang/srt/layers/moe/fused_moe_triton/layer.py +141 -59
  28. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
  29. sglang/srt/layers/moe/topk.py +8 -2
  30. sglang/srt/layers/parameter.py +19 -3
  31. sglang/srt/layers/quantization/__init__.py +2 -0
  32. sglang/srt/layers/quantization/fp8.py +28 -7
  33. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  34. sglang/srt/layers/quantization/modelopt_quant.py +244 -1
  35. sglang/srt/layers/quantization/moe_wna16.py +1 -2
  36. sglang/srt/layers/quantization/w4afp8.py +264 -0
  37. sglang/srt/layers/quantization/w8a8_int8.py +738 -14
  38. sglang/srt/layers/vocab_parallel_embedding.py +9 -3
  39. sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
  40. sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
  41. sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
  42. sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
  43. sglang/srt/managers/cache_controller.py +41 -195
  44. sglang/srt/managers/io_struct.py +35 -3
  45. sglang/srt/managers/mm_utils.py +59 -96
  46. sglang/srt/managers/schedule_batch.py +17 -6
  47. sglang/srt/managers/scheduler.py +38 -6
  48. sglang/srt/managers/tokenizer_manager.py +16 -0
  49. sglang/srt/mem_cache/hiradix_cache.py +2 -0
  50. sglang/srt/mem_cache/memory_pool.py +176 -101
  51. sglang/srt/mem_cache/memory_pool_host.py +6 -109
  52. sglang/srt/mem_cache/radix_cache.py +8 -4
  53. sglang/srt/model_executor/forward_batch_info.py +13 -1
  54. sglang/srt/model_loader/loader.py +23 -12
  55. sglang/srt/models/deepseek_janus_pro.py +1 -1
  56. sglang/srt/models/deepseek_v2.py +78 -19
  57. sglang/srt/models/deepseek_vl2.py +1 -1
  58. sglang/srt/models/gemma3_mm.py +1 -1
  59. sglang/srt/models/gemma3n_mm.py +6 -3
  60. sglang/srt/models/internvl.py +8 -2
  61. sglang/srt/models/kimi_vl.py +8 -2
  62. sglang/srt/models/llama.py +2 -0
  63. sglang/srt/models/llava.py +3 -1
  64. sglang/srt/models/llavavid.py +1 -1
  65. sglang/srt/models/minicpmo.py +1 -2
  66. sglang/srt/models/minicpmv.py +1 -1
  67. sglang/srt/models/mixtral_quant.py +4 -0
  68. sglang/srt/models/mllama4.py +372 -82
  69. sglang/srt/models/phi4mm.py +8 -2
  70. sglang/srt/models/phimoe.py +553 -0
  71. sglang/srt/models/qwen2.py +2 -0
  72. sglang/srt/models/qwen2_5_vl.py +10 -7
  73. sglang/srt/models/qwen2_vl.py +12 -1
  74. sglang/srt/models/vila.py +8 -2
  75. sglang/srt/multimodal/mm_utils.py +2 -2
  76. sglang/srt/multimodal/processors/base_processor.py +197 -137
  77. sglang/srt/multimodal/processors/deepseek_vl_v2.py +1 -1
  78. sglang/srt/multimodal/processors/gemma3.py +4 -2
  79. sglang/srt/multimodal/processors/gemma3n.py +1 -1
  80. sglang/srt/multimodal/processors/internvl.py +1 -1
  81. sglang/srt/multimodal/processors/janus_pro.py +1 -1
  82. sglang/srt/multimodal/processors/kimi_vl.py +1 -1
  83. sglang/srt/multimodal/processors/minicpm.py +4 -3
  84. sglang/srt/multimodal/processors/mllama4.py +63 -61
  85. sglang/srt/multimodal/processors/phi4mm.py +1 -1
  86. sglang/srt/multimodal/processors/pixtral.py +1 -1
  87. sglang/srt/multimodal/processors/qwen_vl.py +203 -80
  88. sglang/srt/multimodal/processors/vila.py +1 -1
  89. sglang/srt/server_args.py +26 -4
  90. sglang/srt/two_batch_overlap.py +3 -0
  91. sglang/srt/utils.py +191 -48
  92. sglang/test/test_cutlass_w4a8_moe.py +281 -0
  93. sglang/utils.py +5 -5
  94. sglang/version.py +1 -1
  95. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/METADATA +6 -4
  96. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/RECORD +99 -90
  97. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/WHEEL +0 -0
  98. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/licenses/LICENSE +0 -0
  99. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,16 @@
1
+ import json as json_lib
2
+ import logging
3
+ import os
1
4
  from collections.abc import Iterable
2
5
  from typing import List, Optional, Set, Tuple
3
6
 
4
7
  import torch
5
8
  from torch import nn
6
- from transformers import Llama4Config, Llama4VisionModel
7
- from transformers.models.llama4.modeling_llama4 import Llama4MultiModalProjector
9
+ from transformers import Llama4Config
10
+ from transformers.models.llama4.modeling_llama4 import (
11
+ Llama4MultiModalProjector,
12
+ Llama4VisionModel,
13
+ )
8
14
 
9
15
  from sglang.srt.layers.logits_processor import LogitsProcessor
10
16
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
@@ -13,12 +19,23 @@ from sglang.srt.managers.mm_utils import (
13
19
  MultiModalityDataPaddingPatternMultimodalTokens,
14
20
  general_mm_embed_routine,
15
21
  )
16
- from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
22
+ from sglang.srt.managers.schedule_batch import (
23
+ Modality,
24
+ MultimodalDataItem,
25
+ MultimodalInputs,
26
+ )
17
27
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
18
28
  from sglang.srt.model_loader.weight_utils import default_weight_loader
19
29
  from sglang.srt.utils import add_prefix, is_cpu
20
30
 
21
31
  _is_cpu = is_cpu()
32
+ from sglang.srt.model_loader.weight_utils import (
33
+ default_weight_loader,
34
+ maybe_remap_kv_scale_name,
35
+ )
36
+ from sglang.srt.utils import add_prefix
37
+
38
+ logger = logging.getLogger(__name__)
22
39
 
23
40
 
24
41
  class Llama4ForConditionalGeneration(nn.Module):
@@ -37,19 +54,85 @@ class Llama4ForConditionalGeneration(nn.Module):
37
54
  self.config = config
38
55
  self.quant_config = quant_config
39
56
 
40
- self.vision_model = Llama4VisionModel(config.vision_config)
41
- self.multi_modal_projector = Llama4MultiModalProjector(config)
57
+ # Check if this is a text-only model (modelopt fp8 llama4 has no vision components)
58
+ self.has_vision = self._has_vision_weights(config)
59
+ if not self.has_vision:
60
+ logger.warning(
61
+ "No vision weights found in checkpoint. Model will run in text-only mode. "
62
+ "Multimodal capabilities (image processing) will be unavailable."
63
+ )
64
+
65
+ if self.has_vision:
66
+ self.vision_model = Llama4VisionModel(config.vision_config)
67
+ self.multi_modal_projector = Llama4MultiModalProjector(config)
68
+ else:
69
+ self.vision_model = None
70
+ self.multi_modal_projector = None
42
71
 
43
72
  # Initialize the language model
44
73
  from sglang.srt.models.llama4 import Llama4ForCausalLM
45
74
 
46
75
  self.language_model = Llama4ForCausalLM(
47
- config.text_config,
76
+ config.text_config if hasattr(config, "text_config") else config,
48
77
  quant_config=quant_config,
49
78
  prefix=add_prefix("language_model", prefix),
50
79
  )
51
80
 
52
- self.logits_processor = LogitsProcessor(config.text_config)
81
+ self.logits_processor = LogitsProcessor(
82
+ config.text_config if hasattr(config, "text_config") else config
83
+ )
84
+
85
+ def _has_vision_weights(self, config) -> bool:
86
+ """Check if the model has vision components by examining the checkpoint."""
87
+ model_path = getattr(config, "_name_or_path", None)
88
+ if not model_path:
89
+ return False
90
+
91
+ # Check if this is a local path first
92
+ if os.path.isdir(model_path):
93
+ index_file = os.path.join(model_path, "model.safetensors.index.json")
94
+ if os.path.exists(index_file):
95
+ return self._check_vision_weights_in_index(index_file)
96
+
97
+ # For HuggingFace models, we need to check the actual checkpoint
98
+ # The config might say it's multimodal, but the checkpoint might be text-only
99
+ try:
100
+ # Try to access the HuggingFace cache directory
101
+ from huggingface_hub import try_to_load_from_cache
102
+
103
+ # Check if index file exists in cache
104
+ index_file_path = try_to_load_from_cache(
105
+ repo_id=model_path,
106
+ filename="model.safetensors.index.json",
107
+ cache_dir=None,
108
+ )
109
+
110
+ if index_file_path and os.path.exists(index_file_path):
111
+ return self._check_vision_weights_in_index(index_file_path)
112
+
113
+ except Exception:
114
+ # If we can't access the cache, fall back to config-based detection
115
+ pass
116
+
117
+ # Fallback, assume text-only
118
+ return False
119
+
120
+ def _check_vision_weights_in_index(self, index_file: str) -> bool:
121
+ """Check if the model.safetensors.index.json contains vision weights."""
122
+ try:
123
+ with open(index_file, "r") as f:
124
+ index_data = json_lib.load(f)
125
+
126
+ vision_patterns = ["vision_model", "vision_tower", "multi_modal_projector"]
127
+ weight_names = index_data.get("weight_map", {}).keys()
128
+
129
+ return any(
130
+ pattern in weight_name
131
+ for weight_name in weight_names
132
+ for pattern in vision_patterns
133
+ )
134
+ except (OSError, json_lib.JSONDecodeError, KeyError):
135
+ return False
53
136
 
54
137
  def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
55
138
  pattern = MultiModalityDataPaddingPatternMultimodalTokens()
@@ -59,6 +142,10 @@ class Llama4ForConditionalGeneration(nn.Module):
59
142
  self,
60
143
  items: List[MultimodalDataItem],
61
144
  ) -> torch.Tensor:
145
+ # For text-only models, return None or raise an error
146
+ if not self.has_vision or self.vision_model is None:
147
+ raise ValueError("Vision model not available for text-only checkpoint")
148
+
62
149
  pixel_values = (
63
150
  torch.concat([item.pixel_values for item in items])
64
151
  .to(next(self.vision_model.parameters()).device)
@@ -79,11 +166,16 @@ class Llama4ForConditionalGeneration(nn.Module):
79
166
  **kwargs: object,
80
167
  ) -> torch.Tensor:
81
168
 
169
+ # For text-only models, pass None for image_data_embedding_func
170
+ image_embedding_func = self.get_image_feature if self.has_vision else None
171
+
82
172
  hs = general_mm_embed_routine(
83
173
  input_ids=input_ids,
84
174
  forward_batch=forward_batch,
85
175
  language_model=self.language_model,
86
- image_data_embedding_func=self.get_image_feature,
176
+ data_embedding_funcs={
177
+ Modality.IMAGE: self.get_image_feature,
178
+ },
87
179
  positions=positions,
88
180
  )
89
181
 
@@ -124,7 +216,6 @@ class Llama4ForConditionalGeneration(nn.Module):
124
216
  return name, loaded_weight
125
217
 
126
218
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
127
-
128
219
  stacked_params_mapping = [
129
220
  # (param_name, shard_name, shard_id)
130
221
  (".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
@@ -137,11 +228,12 @@ class Llama4ForConditionalGeneration(nn.Module):
137
228
  ]
138
229
 
139
230
  params_dict = dict(self.named_parameters())
231
+ num_experts = (
232
+ self.config.text_config.num_local_experts
233
+ if hasattr(self.config, "text_config")
234
+ else self.config.num_local_experts
235
+ )
140
236
 
141
- num_experts = self.config.text_config.num_local_experts
142
-
143
- # Params for weights, fp8 weight scales, fp8 activation scales
144
- # (param_name, weight_name, expert_id, shard_id)
145
237
  expert_params_mapping = FusedMoE.make_expert_params_mapping(
146
238
  ckpt_gate_proj_name="gate_proj",
147
239
  ckpt_down_proj_name="down_proj",
@@ -150,81 +242,279 @@ class Llama4ForConditionalGeneration(nn.Module):
150
242
  )
151
243
 
152
244
  for name, loaded_weight in weights:
153
- if not "vision" in name:
245
+ if self._should_skip_weight(name):
246
+ continue
247
+
248
+ name = self._transform_weight_name(name)
249
+
250
+ if "vision" not in name:
154
251
  name, loaded_weight = self.permute_qk_weight_for_rotary(
155
252
  name, loaded_weight
156
253
  )
157
254
 
158
- for param_name, weight_name, shard_id in stacked_params_mapping:
159
- if weight_name not in name:
160
- continue
161
-
162
- if "vision" in name:
163
- continue
164
- name = name.replace(weight_name, param_name)
165
- param = params_dict[name]
166
- weight_loader = param.weight_loader
167
- weight_loader(param, loaded_weight, shard_id)
168
- break
255
+ if self._handle_scale_remapping(name, params_dict):
256
+ continue
257
+
258
+ if self._handle_stacked_params(
259
+ name, loaded_weight, stacked_params_mapping, params_dict
260
+ ):
261
+ continue
262
+
263
+ if self._handle_expert_weights(
264
+ name, loaded_weight, expert_params_mapping, params_dict, num_experts
265
+ ):
266
+ continue
267
+
268
+ self._handle_default_weight(name, loaded_weight, params_dict)
269
+
270
+ def _should_skip_weight(self, name: str) -> bool:
271
+ """Check if we should skip loading this weight."""
272
+ return "vision" in name and not self.has_vision
273
+
274
+ def _transform_weight_name(self, name: str) -> str:
275
+ """Transform weight name by adding language_model prefix if needed."""
276
+ if (
277
+ not name.startswith("language_model.")
278
+ and "vision" not in name
279
+ and "multi_modal_projector" not in name
280
+ ):
281
+ return f"language_model.{name}"
282
+ return name
283
+
284
+ def _handle_scale_remapping(self, name: str, params_dict: dict) -> bool:
285
+ """Handle scale parameter remapping. Returns True if handled."""
286
+ if "scale" in name and "expert" not in name:
287
+ remapped_name = maybe_remap_kv_scale_name(name, params_dict)
288
+ return remapped_name is None
289
+ return False
290
+
291
+ def _handle_stacked_params(
292
+ self,
293
+ name: str,
294
+ loaded_weight: torch.Tensor,
295
+ stacked_params_mapping: list,
296
+ params_dict: dict,
297
+ ) -> bool:
298
+ """Handle stacked parameter loading. Returns True if handled."""
299
+ for param_name, weight_name, shard_id in stacked_params_mapping:
300
+ if weight_name in name and "vision" not in name:
301
+ transformed_name = name.replace(weight_name, param_name)
302
+ param = params_dict[transformed_name]
303
+ param.weight_loader(param, loaded_weight, shard_id)
304
+ return True
305
+ return False
306
+
307
+ def _handle_expert_weights(
308
+ self,
309
+ name: str,
310
+ loaded_weight: torch.Tensor,
311
+ expert_params_mapping: list,
312
+ params_dict: dict,
313
+ num_experts: int,
314
+ ) -> bool:
315
+ """Handle expert weight loading for MoE (Mixture of Experts) layers.
316
+
317
+ Args:
318
+ name: Parameter name from the checkpoint
319
+ loaded_weight: The weight tensor to be loaded
320
+ expert_params_mapping: Mapping of parameter names to expert configurations
321
+ params_dict: Dictionary of model parameters
322
+ num_experts: Total number of experts in the MoE layer
323
+
324
+ Returns:
325
+ bool: True if the parameter was handled (is an expert parameter), False otherwise
326
+ """
327
+ if ".experts" not in name:
328
+ return False
329
+
330
+ if "experts.gate_up_proj" not in name and "experts.down_proj" not in name:
331
+ return self._handle_other_expert_params(
332
+ name, loaded_weight, expert_params_mapping, params_dict
333
+ )
334
+
335
+ if "scale" in name:
336
+ return self._handle_expert_scale_params(
337
+ name, loaded_weight, params_dict, num_experts
338
+ )
339
+ else:
340
+ return self._handle_expert_weight_params(
341
+ name, loaded_weight, params_dict, num_experts
342
+ )
343
+
344
+ def _handle_other_expert_params(
345
+ self,
346
+ name: str,
347
+ loaded_weight: torch.Tensor,
348
+ expert_params_mapping: list,
349
+ params_dict: dict,
350
+ ) -> bool:
351
+ """Handle expert parameters that are not gate_up_proj or down_proj weights.
352
+
353
+ Args:
354
+ name: Parameter name from the checkpoint
355
+ loaded_weight: The weight tensor to be loaded
356
+ expert_params_mapping: List of tuples mapping checkpoint names to model parameters
357
+ params_dict: Dictionary of model parameters
358
+
359
+ Returns:
360
+ bool: True if parameter was found and handled, False otherwise
361
+ """
362
+ for param_name, weight_name, expert_id, shard_id in expert_params_mapping:
363
+ if weight_name in name:
364
+ transformed_name = name.replace(weight_name, param_name)
365
+ param = params_dict[transformed_name]
366
+ param.weight_loader(
367
+ param, loaded_weight, name, shard_id=shard_id, expert_id=expert_id
368
+ )
369
+ return True
370
+ return False
371
+
372
+ def _transform_expert_name(
373
+ self, name: str, is_weight: bool = False
374
+ ) -> Tuple[str, str, List[str]]:
375
+ """Transform expert parameter name and get shard information.
376
+
377
+ Args:
378
+ name: The original parameter name
379
+ is_weight: Whether this is a weight parameter (adds _weight suffix)
380
+
381
+ Returns:
382
+ Tuple of (transformed_name, shard_id, shard_id_list)
383
+ """
384
+ suffix = "_weight" if is_weight else ""
385
+
386
+ if ".gate_up_proj" in name:
387
+ transformed_name = name.replace(
388
+ ".experts.gate_up_proj", f".experts.w13{suffix}"
389
+ )
390
+ shard_id = "w13"
391
+ shard_id_list = ["w1", "w3"]
392
+ else: # down_proj
393
+ transformed_name = name.replace(
394
+ ".experts.down_proj", f".experts.w2{suffix}"
395
+ )
396
+ shard_id = "w2"
397
+ shard_id_list = ["w2"]
398
+
399
+ return transformed_name, shard_id, shard_id_list
400
+
401
+ def _handle_expert_scale_params(
402
+ self,
403
+ name: str,
404
+ loaded_weight: torch.Tensor,
405
+ params_dict: dict,
406
+ num_experts: int,
407
+ ) -> bool:
408
+ """Handle quantization scale parameters for expert weights.
409
+
410
+ Args:
411
+ name: Parameter name containing scale information
412
+ loaded_weight: Scale tensor to be loaded
413
+ params_dict: Dictionary of model parameters
414
+ num_experts: Total number of experts for broadcast operations
415
+
416
+ Returns:
417
+ bool: True (always handles scale parameters)
418
+ """
419
+ import re
420
+
421
+ # Check if this matches the expert parameter pattern: experts.{expert_id}.{param_name}
422
+ expert_match = re.search(r"experts\.(\d+)\.", name)
423
+
424
+ # Transform name
425
+ transformed_name, _, _ = self._transform_expert_name(name)
426
+
427
+ if transformed_name not in params_dict:
428
+ return True
429
+
430
+ param = params_dict[transformed_name]
431
+
432
+ # Handle scale parameters
433
+ if expert_match:
434
+ # If we have a specific expert ID, only load for that expert
435
+ expert_id = int(expert_match.group(1))
436
+ # For scale parameters, we can directly set the value
437
+ param.data[expert_id] = loaded_weight
438
+ else:
439
+ # No expert ID found - this is a single scale for all experts
440
+ # Load the same scale for all experts
441
+ for expert_id in range(num_experts):
442
+ param.data[expert_id] = loaded_weight
443
+
444
+ return True
445
+
446
+ def _handle_expert_weight_params(
447
+ self,
448
+ name: str,
449
+ loaded_weight: torch.Tensor,
450
+ params_dict: dict,
451
+ num_experts: int,
452
+ ) -> bool:
453
+ """Handle actual weight tensors for expert layers (gate_up_proj and down_proj).
454
+
455
+ Args:
456
+ name: Parameter name (should contain gate_up_proj or down_proj)
457
+ loaded_weight: Weight tensor(s) to be loaded
458
+ params_dict: Dictionary of model parameters
459
+ num_experts: Total number of experts for tensor distribution
460
+
461
+ Returns:
462
+ bool: True (always handles weight parameters)
463
+ """
464
+ # Transform name and get shard info
465
+ transformed_name, _, shard_id_list = self._transform_expert_name(
466
+ name, is_weight=True
467
+ )
468
+
469
+ if ".gate_up_proj" in name:
470
+ loaded_weight_list = loaded_weight.chunk(2, dim=-1)
471
+ else: # down_proj
472
+ loaded_weight_list = [loaded_weight]
473
+
474
+ for param_name, weight_chunk, shard_id in zip(
475
+ [transformed_name] * len(shard_id_list), loaded_weight_list, shard_id_list
476
+ ):
477
+ if param_name not in params_dict:
478
+ continue
479
+
480
+ param = params_dict[param_name]
481
+ weight_loader = param.weight_loader
482
+
483
+ # Handle the case where loaded_weight might be a single tensor for all experts
484
+ if weight_chunk.dim() == 2:
485
+ # Single tensor case - load for all experts
486
+ for expert_id in range(num_experts):
487
+ weight_loader(
488
+ param,
489
+ weight_chunk.T,
490
+ param_name,
491
+ shard_id=shard_id,
492
+ expert_id=expert_id,
493
+ )
169
494
  else:
170
- if ".experts" in name:
171
- # NOTE: llama4 fp8 has different weight format for experts
172
- if (
173
- "experts.gate_up_proj" not in name
174
- and "experts.down_proj" not in name
175
- ):
176
- for mapping in expert_params_mapping:
177
- param_name, weight_name, expert_id, shard_id = mapping
178
- if weight_name not in name:
179
- continue
180
- name = name.replace(weight_name, param_name)
181
- param = params_dict[name]
182
- weight_loader = param.weight_loader
183
- weight_loader(
184
- param,
185
- loaded_weight,
186
- name,
187
- shard_id=shard_id,
188
- expert_id=expert_id,
189
- )
190
- break
191
- else:
192
- if ".gate_up_proj" in name:
193
- name_list = [
194
- name.replace(
195
- ".experts.gate_up_proj", ".experts.w13_weight"
196
- )
197
- ] * 2
198
- loaded_weight_list = loaded_weight.chunk(2, dim=-1)
199
- shard_id_list = ["w1", "w3"]
200
- else:
201
- name_list = [
202
- name.replace(".experts.down_proj", ".experts.w2_weight")
203
- ]
204
- shard_id_list = ["w2"]
205
- loaded_weight_list = [loaded_weight]
206
- for name, loaded_weight, shard_id in zip(
207
- name_list, loaded_weight_list, shard_id_list
208
- ):
209
- param = params_dict[name]
210
- weight_loader = param.weight_loader
211
- for expert_id in range(num_experts):
212
- weight_loader(
213
- param,
214
- loaded_weight[expert_id].T,
215
- name,
216
- shard_id=shard_id,
217
- expert_id=expert_id,
218
- )
219
- else:
220
- # Skip loading extra bias for GPTQ models.
221
- if name.endswith(".bias") and name not in params_dict:
222
- continue
223
- param = params_dict[name]
224
- weight_loader = getattr(
225
- param, "weight_loader", default_weight_loader
495
+ # Multiple experts case - load each expert's weights
496
+ for expert_id in range(num_experts):
497
+ weight_loader(
498
+ param,
499
+ weight_chunk[expert_id].T,
500
+ param_name,
501
+ shard_id=shard_id,
502
+ expert_id=expert_id,
226
503
  )
227
- weight_loader(param, loaded_weight)
504
+
505
+ return True
506
+
507
+ def _handle_default_weight(
508
+ self, name: str, loaded_weight: torch.Tensor, params_dict: dict
509
+ ):
510
+ """Handle default weight loading."""
511
+ # Skip loading extra bias for GPTQ models
512
+ if name.endswith(".bias") and name not in params_dict:
513
+ return
514
+
515
+ param = params_dict[name]
516
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
517
+ weight_loader(param, loaded_weight)
228
518
 
229
519
  def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
230
520
  if hasattr(self.language_model, "set_eagle3_layers_to_capture"):
@@ -31,7 +31,11 @@ from sglang.srt.managers.mm_utils import (
31
31
  MultiModalityDataPaddingPatternMultimodalTokens,
32
32
  general_mm_embed_routine,
33
33
  )
34
- from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
34
+ from sglang.srt.managers.schedule_batch import (
35
+ Modality,
36
+ MultimodalDataItem,
37
+ MultimodalInputs,
38
+ )
35
39
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
36
40
  from sglang.srt.model_loader.weight_utils import default_weight_loader
37
41
  from sglang.srt.models.idefics2 import Idefics2VisionTransformer
@@ -439,7 +443,9 @@ class Phi4MMForCausalLM(nn.Module):
439
443
  input_ids=input_ids,
440
444
  forward_batch=forward_batch,
441
445
  language_model=self.language_model,
442
- image_data_embedding_func=self.get_image_feature,
446
+ data_embedding_funcs={
447
+ Modality.IMAGE: self.get_image_feature,
448
+ },
443
449
  positions=positions,
444
450
  )
445
451