sglang 0.4.9.post1__py3-none-any.whl → 0.4.9.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. sglang/srt/configs/model_config.py +24 -1
  2. sglang/srt/conversation.py +21 -2
  3. sglang/srt/disaggregation/ascend/__init__.py +6 -0
  4. sglang/srt/disaggregation/ascend/conn.py +44 -0
  5. sglang/srt/disaggregation/ascend/transfer_engine.py +58 -0
  6. sglang/srt/disaggregation/mooncake/conn.py +15 -14
  7. sglang/srt/disaggregation/mooncake/transfer_engine.py +17 -8
  8. sglang/srt/disaggregation/utils.py +25 -3
  9. sglang/srt/entrypoints/engine.py +1 -1
  10. sglang/srt/entrypoints/http_server.py +1 -0
  11. sglang/srt/entrypoints/openai/protocol.py +11 -0
  12. sglang/srt/entrypoints/openai/serving_chat.py +7 -0
  13. sglang/srt/function_call/function_call_parser.py +2 -0
  14. sglang/srt/function_call/kimik2_detector.py +220 -0
  15. sglang/srt/hf_transformers_utils.py +18 -0
  16. sglang/srt/jinja_template_utils.py +8 -0
  17. sglang/srt/layers/communicator.py +17 -4
  18. sglang/srt/layers/linear.py +12 -2
  19. sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
  20. sglang/srt/layers/moe/ep_moe/layer.py +2 -1
  21. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -2
  22. sglang/srt/layers/moe/topk.py +8 -2
  23. sglang/srt/layers/parameter.py +19 -3
  24. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  25. sglang/srt/layers/quantization/moe_wna16.py +1 -2
  26. sglang/srt/layers/quantization/w8a8_int8.py +738 -14
  27. sglang/srt/managers/io_struct.py +27 -2
  28. sglang/srt/managers/mm_utils.py +55 -94
  29. sglang/srt/managers/schedule_batch.py +16 -5
  30. sglang/srt/managers/scheduler.py +21 -1
  31. sglang/srt/managers/tokenizer_manager.py +16 -0
  32. sglang/srt/mem_cache/memory_pool.py +65 -40
  33. sglang/srt/model_executor/forward_batch_info.py +13 -1
  34. sglang/srt/model_loader/loader.py +23 -12
  35. sglang/srt/models/deepseek_janus_pro.py +1 -1
  36. sglang/srt/models/deepseek_v2.py +62 -17
  37. sglang/srt/models/deepseek_vl2.py +1 -1
  38. sglang/srt/models/gemma3_mm.py +1 -1
  39. sglang/srt/models/gemma3n_mm.py +6 -3
  40. sglang/srt/models/internvl.py +8 -2
  41. sglang/srt/models/kimi_vl.py +8 -2
  42. sglang/srt/models/llama.py +2 -0
  43. sglang/srt/models/llava.py +3 -1
  44. sglang/srt/models/llavavid.py +1 -1
  45. sglang/srt/models/minicpmo.py +1 -2
  46. sglang/srt/models/minicpmv.py +1 -1
  47. sglang/srt/models/mixtral_quant.py +4 -0
  48. sglang/srt/models/mllama4.py +13 -4
  49. sglang/srt/models/phi4mm.py +8 -2
  50. sglang/srt/models/phimoe.py +553 -0
  51. sglang/srt/models/qwen2.py +2 -0
  52. sglang/srt/models/qwen2_5_vl.py +10 -7
  53. sglang/srt/models/qwen2_vl.py +12 -1
  54. sglang/srt/models/vila.py +8 -2
  55. sglang/srt/multimodal/processors/base_processor.py +197 -137
  56. sglang/srt/multimodal/processors/deepseek_vl_v2.py +1 -1
  57. sglang/srt/multimodal/processors/gemma3.py +4 -2
  58. sglang/srt/multimodal/processors/gemma3n.py +1 -1
  59. sglang/srt/multimodal/processors/internvl.py +1 -1
  60. sglang/srt/multimodal/processors/janus_pro.py +1 -1
  61. sglang/srt/multimodal/processors/kimi_vl.py +1 -1
  62. sglang/srt/multimodal/processors/minicpm.py +4 -3
  63. sglang/srt/multimodal/processors/mllama4.py +1 -1
  64. sglang/srt/multimodal/processors/phi4mm.py +1 -1
  65. sglang/srt/multimodal/processors/pixtral.py +1 -1
  66. sglang/srt/multimodal/processors/qwen_vl.py +203 -80
  67. sglang/srt/multimodal/processors/vila.py +1 -1
  68. sglang/srt/server_args.py +11 -4
  69. sglang/srt/utils.py +154 -31
  70. sglang/version.py +1 -1
  71. {sglang-0.4.9.post1.dist-info → sglang-0.4.9.post2.dist-info}/METADATA +4 -3
  72. {sglang-0.4.9.post1.dist-info → sglang-0.4.9.post2.dist-info}/RECORD +75 -70
  73. {sglang-0.4.9.post1.dist-info → sglang-0.4.9.post2.dist-info}/WHEEL +0 -0
  74. {sglang-0.4.9.post1.dist-info → sglang-0.4.9.post2.dist-info}/licenses/LICENSE +0 -0
  75. {sglang-0.4.9.post1.dist-info → sglang-0.4.9.post2.dist-info}/top_level.txt +0 -0
@@ -64,10 +64,13 @@ from sglang.srt.model_loader.weight_utils import (
64
64
  from sglang.srt.utils import (
65
65
  get_bool_env_var,
66
66
  get_device_capability,
67
+ is_npu,
67
68
  is_pin_memory_available,
68
69
  set_weight_attrs,
69
70
  )
70
71
 
72
+ _is_npu = is_npu()
73
+
71
74
 
72
75
  @contextmanager
73
76
  def device_loading_context(module: torch.nn.Module, target_device: torch.device):
@@ -127,18 +130,19 @@ def _get_quantization_config(
127
130
  # (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3
128
131
  if quant_config is None:
129
132
  return None
130
- major, minor = get_device_capability()
131
-
132
- if major is not None and minor is not None:
133
- assert 0 <= minor < 10
134
- capability = major * 10 + minor
135
- if capability < quant_config.get_min_capability():
136
- raise ValueError(
137
- f"The quantization method {model_config.quantization} "
138
- "is not supported for the current GPU. "
139
- f"Minimum capability: {quant_config.get_min_capability()}. "
140
- f"Current capability: {capability}."
141
- )
133
+ if not _is_npu:
134
+ major, minor = get_device_capability()
135
+
136
+ if major is not None and minor is not None:
137
+ assert 0 <= minor < 10
138
+ capability = major * 10 + minor
139
+ if capability < quant_config.get_min_capability():
140
+ raise ValueError(
141
+ f"The quantization method {model_config.quantization} "
142
+ "is not supported for the current GPU. "
143
+ f"Minimum capability: {quant_config.get_min_capability()}. "
144
+ f"Current capability: {capability}."
145
+ )
142
146
  supported_dtypes = quant_config.get_supported_act_dtypes()
143
147
  if model_config.dtype not in supported_dtypes:
144
148
  raise ValueError(
@@ -157,6 +161,13 @@ def _initialize_model(
157
161
  """Initialize a model with the given configurations."""
158
162
  model_class, _ = get_model_architecture(model_config)
159
163
  packed_modules_mapping = getattr(model_class, "packed_modules_mapping", {})
164
+ if _is_npu:
165
+ packed_modules_mapping["fused_qkv_a_proj_with_mqa"] = [
166
+ "q_a_proj",
167
+ "kv_a_proj_with_mqa",
168
+ ]
169
+ packed_modules_mapping["qkv_proj"] = ["q_proj", "k_proj", "v_proj"]
170
+ packed_modules_mapping["gate_up_proj"] = ["gate_proj", "up_proj"]
160
171
  quant_config = _get_quantization_config(
161
172
  model_config, load_config, packed_modules_mapping
162
173
  )
@@ -1989,7 +1989,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
1989
1989
  hidden_states = general_mm_embed_routine(
1990
1990
  input_ids=input_ids,
1991
1991
  forward_batch=forward_batch,
1992
- image_data_embedding_func=self.get_image_feature,
1992
+ multimodal_model=self,
1993
1993
  language_model=self.language_model,
1994
1994
  positions=positions,
1995
1995
  )
@@ -77,6 +77,7 @@ from sglang.srt.layers.quantization.int8_utils import (
77
77
  )
78
78
  from sglang.srt.layers.radix_attention import RadixAttention
79
79
  from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
80
+ from sglang.srt.layers.utils import is_sm100_supported
80
81
  from sglang.srt.layers.vocab_parallel_embedding import (
81
82
  ParallelLMHead,
82
83
  VocabParallelEmbedding,
@@ -100,6 +101,7 @@ from sglang.srt.utils import (
100
101
  get_int_env_var,
101
102
  is_cpu,
102
103
  is_cuda,
104
+ is_flashinfer_available,
103
105
  is_hip,
104
106
  is_non_idle_and_non_empty,
105
107
  log_info_on_rank0,
@@ -132,6 +134,9 @@ if _is_hip:
132
134
  decode_attention_fwd_grouped_rope,
133
135
  )
134
136
 
137
+ _is_flashinfer_available = is_flashinfer_available()
138
+ _is_sm100_supported = is_cuda() and is_sm100_supported()
139
+
135
140
 
136
141
  logger = logging.getLogger(__name__)
137
142
 
@@ -195,13 +200,13 @@ class DeepseekV2MLP(nn.Module):
195
200
  )
196
201
  self.act_fn = SiluAndMul()
197
202
 
198
- def forward(self, x, forward_batch=None):
203
+ def forward(self, x, forward_batch=None, can_fuse_mlp_allreduce=False):
199
204
  if (self.tp_size == 1) and x.shape[0] == 0:
200
205
  return x
201
206
 
202
207
  gate_up, _ = self.gate_up_proj(x)
203
208
  x = self.act_fn(gate_up)
204
- x, _ = self.down_proj(x)
209
+ x, _ = self.down_proj(x, can_fuse_mlp_allreduce=can_fuse_mlp_allreduce)
205
210
  return x
206
211
 
207
212
 
@@ -409,7 +414,10 @@ class DeepseekV2MoE(nn.Module):
409
414
  ]
410
415
 
411
416
  def forward(
412
- self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
417
+ self,
418
+ hidden_states: torch.Tensor,
419
+ forward_batch: Optional[ForwardBatch] = None,
420
+ can_fuse_mlp_allreduce: bool = False,
413
421
  ) -> torch.Tensor:
414
422
  if not self._enable_deepep_moe:
415
423
  DUAL_STREAM_TOKEN_THRESHOLD = 1024
@@ -418,13 +426,17 @@ class DeepseekV2MoE(nn.Module):
418
426
  and self.num_fused_shared_experts == 0
419
427
  and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
420
428
  ):
421
- return self.forward_normal_dual_stream(hidden_states)
429
+ return self.forward_normal_dual_stream(
430
+ hidden_states, can_fuse_mlp_allreduce
431
+ )
422
432
  else:
423
- return self.forward_normal(hidden_states)
433
+ return self.forward_normal(hidden_states, can_fuse_mlp_allreduce)
424
434
  else:
425
435
  return self.forward_deepep(hidden_states, forward_batch)
426
436
 
427
- def forward_normal_dual_stream(self, hidden_states: torch.Tensor) -> torch.Tensor:
437
+ def forward_normal_dual_stream(
438
+ self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
439
+ ) -> torch.Tensor:
428
440
  # router_logits: (num_tokens, n_experts)
429
441
  router_logits = self.gate(hidden_states)
430
442
 
@@ -440,11 +452,13 @@ class DeepseekV2MoE(nn.Module):
440
452
  final_hidden_states *= self.routed_scaling_factor
441
453
  current_stream.wait_stream(self.alt_stream)
442
454
  final_hidden_states = final_hidden_states + shared_output
443
- if self.tp_size > 1:
455
+ if self.tp_size > 1 and not can_fuse_mlp_allreduce:
444
456
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
445
457
  return final_hidden_states
446
458
 
447
- def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
459
+ def forward_normal(
460
+ self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
461
+ ) -> torch.Tensor:
448
462
  if hasattr(self, "shared_experts") and use_intel_amx_backend(
449
463
  self.shared_experts.gate_up_proj
450
464
  ):
@@ -461,7 +475,7 @@ class DeepseekV2MoE(nn.Module):
461
475
  final_hidden_states *= self.routed_scaling_factor
462
476
  if shared_output is not None:
463
477
  final_hidden_states = final_hidden_states + shared_output
464
- if self.tp_size > 1:
478
+ if self.tp_size > 1 and not can_fuse_mlp_allreduce:
465
479
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
466
480
  return final_hidden_states
467
481
 
@@ -514,7 +528,7 @@ class DeepseekV2MoE(nn.Module):
514
528
  None, # a2_scale
515
529
  True, # is_vnni
516
530
  )
517
- if self.tp_size > 1:
531
+ if self.tp_size > 1 and not self.can_fuse_mlp_allreduce:
518
532
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
519
533
  return final_hidden_states
520
534
 
@@ -1818,6 +1832,29 @@ class DeepseekV2DecoderLayer(nn.Module):
1818
1832
  and layer_id % self.config.moe_layer_freq == 0
1819
1833
  )
1820
1834
 
1835
+ def _should_fuse_mlp_allreduce_with_next_layer(self, forward_batch) -> bool:
1836
+ """Check if MLP allreduce can be fused with next layer's add_rmsnorm"""
1837
+
1838
+ if (
1839
+ self.layer_id == self.config.num_hidden_layers - 1
1840
+ or get_tensor_model_parallel_world_size() <= 1
1841
+ ):
1842
+ return False
1843
+
1844
+ if not global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False):
1845
+ return False
1846
+
1847
+ if not _is_sm100_supported or not _is_flashinfer_available:
1848
+ return False
1849
+
1850
+ if hasattr(forward_batch, "input_ids") and (
1851
+ forward_batch.input_ids.shape[0] == 0
1852
+ or forward_batch.input_ids.shape[0] > 128
1853
+ ):
1854
+ return False
1855
+
1856
+ return True
1857
+
1821
1858
  def forward(
1822
1859
  self,
1823
1860
  positions: torch.Tensor,
@@ -1842,12 +1879,22 @@ class DeepseekV2DecoderLayer(nn.Module):
1842
1879
  hidden_states, residual, forward_batch
1843
1880
  )
1844
1881
 
1845
- hidden_states = self.mlp(hidden_states, forward_batch)
1846
-
1847
- hidden_states, residual = self.layer_communicator.postprocess_layer(
1848
- hidden_states, residual, forward_batch
1882
+ can_fuse_mlp_allreduce = (
1883
+ self._should_fuse_mlp_allreduce_with_next_layer(forward_batch)
1884
+ and not (self.enable_dp_attention and self.speculative_algorithm.is_eagle())
1885
+ and not self.is_nextn
1849
1886
  )
1850
1887
 
1888
+ hidden_states = self.mlp(hidden_states, forward_batch, can_fuse_mlp_allreduce)
1889
+
1890
+ if can_fuse_mlp_allreduce:
1891
+ hidden_states._sglang_needs_allreduce_fusion = True
1892
+
1893
+ if not can_fuse_mlp_allreduce:
1894
+ hidden_states, residual = self.layer_communicator.postprocess_layer(
1895
+ hidden_states, residual, forward_batch
1896
+ )
1897
+
1851
1898
  return hidden_states, residual
1852
1899
 
1853
1900
  def op_comm_prepare_attn(
@@ -2146,7 +2193,6 @@ class DeepseekV2ForCausalLM(nn.Module):
2146
2193
  # This may affect the accuracy of fp8 model.
2147
2194
  # Fix deepseek v3 blockwise bmm by using deep_gemm
2148
2195
  use_deep_gemm_bmm = False
2149
- model_dtype = torch.get_default_dtype()
2150
2196
 
2151
2197
  if w.dtype in (
2152
2198
  torch.float8_e4m3fn,
@@ -2172,7 +2218,6 @@ class DeepseekV2ForCausalLM(nn.Module):
2172
2218
  _is_cuda
2173
2219
  and weight_block_size[0] == 128
2174
2220
  and weight_block_size[1] == 128
2175
- and model_dtype == torch.bfloat16
2176
2221
  ):
2177
2222
  if (
2178
2223
  deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
@@ -2186,7 +2231,7 @@ class DeepseekV2ForCausalLM(nn.Module):
2186
2231
  weight,
2187
2232
  weight_scale,
2188
2233
  weight_block_size,
2189
- model_dtype,
2234
+ torch.bfloat16,
2190
2235
  )
2191
2236
  else:
2192
2237
  w, scale = block_quant_to_tensor_quant(
@@ -227,7 +227,7 @@ class DeepseekVL2ForCausalLM(nn.Module):
227
227
  input_ids=input_ids,
228
228
  positions=positions,
229
229
  forward_batch=forward_batch,
230
- image_data_embedding_func=self.get_image_feature,
230
+ multimodal_model=self,
231
231
  language_model=self.language_model,
232
232
  )
233
233
 
@@ -374,7 +374,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
374
374
  input_ids=llm_input_ids,
375
375
  forward_batch=forward_batch,
376
376
  language_model=self.language_model,
377
- image_data_embedding_func=self.get_image_feature,
377
+ multimodal_model=self,
378
378
  positions=positions,
379
379
  )
380
380
 
@@ -1,7 +1,7 @@
1
1
  import logging
2
2
  import re
3
3
  from functools import lru_cache
4
- from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict, Union
4
+ from typing import Iterable, List, Optional, Set, Tuple, TypedDict, Union
5
5
 
6
6
  import torch
7
7
  from torch import nn
@@ -25,6 +25,7 @@ from sglang.srt.managers.mm_utils import (
25
25
  general_mm_embed_routine,
26
26
  )
27
27
  from sglang.srt.managers.schedule_batch import (
28
+ Modality,
28
29
  MultimodalDataItem,
29
30
  MultimodalInputs,
30
31
  flatten_nested_list,
@@ -434,8 +435,10 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
434
435
  input_ids=input_ids,
435
436
  forward_batch=forward_batch,
436
437
  language_model=self.language_model,
437
- image_data_embedding_func=self.get_image_feature,
438
- audio_data_embedding_func=self.get_audio_feature,
438
+ data_embedding_funcs={
439
+ Modality.IMAGE: self.get_image_feature,
440
+ Modality.AUDIO: self.get_audio_feature,
441
+ },
439
442
  positions=positions,
440
443
  per_layer_inputs=per_layer_inputs,
441
444
  )
@@ -29,7 +29,11 @@ from sglang.srt.managers.mm_utils import (
29
29
  MultiModalityDataPaddingPatternTokenPairs,
30
30
  general_mm_embed_routine,
31
31
  )
32
- from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
32
+ from sglang.srt.managers.schedule_batch import (
33
+ Modality,
34
+ MultimodalDataItem,
35
+ MultimodalInputs,
36
+ )
33
37
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
34
38
  from sglang.srt.model_loader.weight_utils import default_weight_loader
35
39
  from sglang.srt.models.deepseek_janus_pro import DropPath
@@ -523,7 +527,9 @@ class InternVLChatModel(nn.Module):
523
527
  input_ids=input_ids,
524
528
  forward_batch=forward_batch,
525
529
  language_model=self.language_model,
526
- image_data_embedding_func=self.get_image_feature,
530
+ data_embedding_funcs={
531
+ Modality.IMAGE: self.get_image_feature,
532
+ },
527
533
  positions=positions,
528
534
  )
529
535
 
@@ -67,7 +67,11 @@ from sglang.srt.managers.mm_utils import (
67
67
  MultiModalityDataPaddingPatternMultimodalTokens,
68
68
  general_mm_embed_routine,
69
69
  )
70
- from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
70
+ from sglang.srt.managers.schedule_batch import (
71
+ Modality,
72
+ MultimodalDataItem,
73
+ MultimodalInputs,
74
+ )
71
75
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
72
76
  from sglang.srt.model_loader.weight_utils import (
73
77
  default_weight_loader,
@@ -168,7 +172,9 @@ class KimiVLForConditionalGeneration(nn.Module):
168
172
  input_ids=input_ids,
169
173
  forward_batch=forward_batch,
170
174
  language_model=self.language_model,
171
- image_data_embedding_func=self.get_image_feature,
175
+ data_embedding_funcs={
176
+ Modality.IMAGE: self.get_image_feature,
177
+ },
172
178
  positions=positions,
173
179
  )
174
180
 
@@ -575,6 +575,8 @@ class LlamaForCausalLM(nn.Module):
575
575
  # Skip loading extra bias for GPTQ models.
576
576
  if name.endswith(".bias") and name not in params_dict:
577
577
  continue
578
+ if name not in params_dict:
579
+ continue
578
580
  param = params_dict[name]
579
581
  weight_loader = param.weight_loader
580
582
  weight_loader(param, loaded_weight, shard_id)
@@ -787,7 +787,9 @@ class LlavaForConditionalGeneration(LlavaBaseForCausalLM):
787
787
  forward_batch=forward_batch,
788
788
  get_embedding=get_embedding,
789
789
  language_model=self.language_model,
790
- image_data_embedding_func=self.get_image_feature,
790
+ data_embedding_funcs={
791
+ Modality.IMAGE: self.get_image_feature,
792
+ },
791
793
  placeholder_tokens=None, # using mm_item.pad_value
792
794
  positions=positions,
793
795
  )
@@ -142,7 +142,7 @@ class LlavaVidForCausalLM(nn.Module):
142
142
  )
143
143
  image_offsets = [
144
144
  flatten_nested_list(
145
- [item.image_offsets for item in image_inputs[i].mm_items]
145
+ [item.offsets for item in image_inputs[i].mm_items]
146
146
  )
147
147
  for i in range(bs)
148
148
  if need_vision[i]
@@ -1827,8 +1827,7 @@ class MiniCPMO(MiniCPMBaseModel):
1827
1827
  input_ids=input_ids,
1828
1828
  forward_batch=forward_batch,
1829
1829
  language_model=self.llm,
1830
- image_data_embedding_func=self.get_image_feature,
1831
- audio_data_embedding_func=self.get_audio_feature,
1830
+ multimodal_model=self,
1832
1831
  positions=positions,
1833
1832
  )
1834
1833
  return hidden_states
@@ -573,7 +573,7 @@ class MiniCPMBaseModel(nn.Module):
573
573
  hidden_states = general_mm_embed_routine(
574
574
  input_ids=input_ids,
575
575
  forward_batch=forward_batch,
576
- image_data_embedding_func=self.get_image_feature,
576
+ multimodal_model=self,
577
577
  language_model=self.llm,
578
578
  positions=positions,
579
579
  )
@@ -407,6 +407,8 @@ class QuantMixtralForCausalLM(nn.Module):
407
407
  # Skip loading extra bias for GPTQ models.
408
408
  if name.endswith(".bias") and name not in params_dict:
409
409
  continue
410
+ if name not in params_dict:
411
+ continue
410
412
  param = params_dict[name]
411
413
  weight_loader = param.weight_loader
412
414
  weight_loader(param, loaded_weight, shard_id)
@@ -418,6 +420,8 @@ class QuantMixtralForCausalLM(nn.Module):
418
420
  # Skip experts that are not assigned to this worker.
419
421
  if "block_sparse_moe.experts." in name and name not in params_dict:
420
422
  continue
423
+ if name not in params_dict:
424
+ continue
421
425
  param = params_dict[name]
422
426
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
423
427
  weight_loader(param, loaded_weight)
@@ -6,8 +6,11 @@ from typing import List, Optional, Set, Tuple
6
6
 
7
7
  import torch
8
8
  from torch import nn
9
- from transformers import Llama4Config, Llama4VisionModel
10
- from transformers.models.llama4.modeling_llama4 import Llama4MultiModalProjector
9
+ from transformers import Llama4Config
10
+ from transformers.models.llama4.modeling_llama4 import (
11
+ Llama4MultiModalProjector,
12
+ Llama4VisionModel,
13
+ )
11
14
 
12
15
  from sglang.srt.layers.logits_processor import LogitsProcessor
13
16
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
@@ -16,7 +19,11 @@ from sglang.srt.managers.mm_utils import (
16
19
  MultiModalityDataPaddingPatternMultimodalTokens,
17
20
  general_mm_embed_routine,
18
21
  )
19
- from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
22
+ from sglang.srt.managers.schedule_batch import (
23
+ Modality,
24
+ MultimodalDataItem,
25
+ MultimodalInputs,
26
+ )
20
27
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
21
28
  from sglang.srt.model_loader.weight_utils import default_weight_loader
22
29
  from sglang.srt.utils import add_prefix, is_cpu
@@ -166,7 +173,9 @@ class Llama4ForConditionalGeneration(nn.Module):
166
173
  input_ids=input_ids,
167
174
  forward_batch=forward_batch,
168
175
  language_model=self.language_model,
169
- image_data_embedding_func=image_embedding_func,
176
+ data_embedding_funcs={
177
+ Modality.IMAGE: self.get_image_feature,
178
+ },
170
179
  positions=positions,
171
180
  )
172
181
 
@@ -31,7 +31,11 @@ from sglang.srt.managers.mm_utils import (
31
31
  MultiModalityDataPaddingPatternMultimodalTokens,
32
32
  general_mm_embed_routine,
33
33
  )
34
- from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
34
+ from sglang.srt.managers.schedule_batch import (
35
+ Modality,
36
+ MultimodalDataItem,
37
+ MultimodalInputs,
38
+ )
35
39
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
36
40
  from sglang.srt.model_loader.weight_utils import default_weight_loader
37
41
  from sglang.srt.models.idefics2 import Idefics2VisionTransformer
@@ -439,7 +443,9 @@ class Phi4MMForCausalLM(nn.Module):
439
443
  input_ids=input_ids,
440
444
  forward_batch=forward_batch,
441
445
  language_model=self.language_model,
442
- image_data_embedding_func=self.get_image_feature,
446
+ data_embedding_funcs={
447
+ Modality.IMAGE: self.get_image_feature,
448
+ },
443
449
  positions=positions,
444
450
  )
445
451