sglang 0.4.9__py3-none-any.whl → 0.4.9.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/bench_serving.py +2 -2
  2. sglang/srt/configs/model_config.py +36 -2
  3. sglang/srt/conversation.py +56 -3
  4. sglang/srt/disaggregation/ascend/__init__.py +6 -0
  5. sglang/srt/disaggregation/ascend/conn.py +44 -0
  6. sglang/srt/disaggregation/ascend/transfer_engine.py +58 -0
  7. sglang/srt/disaggregation/mooncake/conn.py +50 -18
  8. sglang/srt/disaggregation/mooncake/transfer_engine.py +17 -8
  9. sglang/srt/disaggregation/utils.py +25 -3
  10. sglang/srt/entrypoints/engine.py +1 -1
  11. sglang/srt/entrypoints/http_server.py +1 -0
  12. sglang/srt/entrypoints/http_server_engine.py +1 -1
  13. sglang/srt/entrypoints/openai/protocol.py +11 -0
  14. sglang/srt/entrypoints/openai/serving_chat.py +7 -0
  15. sglang/srt/function_call/function_call_parser.py +2 -0
  16. sglang/srt/function_call/kimik2_detector.py +220 -0
  17. sglang/srt/hf_transformers_utils.py +18 -0
  18. sglang/srt/jinja_template_utils.py +8 -0
  19. sglang/srt/layers/communicator.py +20 -5
  20. sglang/srt/layers/flashinfer_comm_fusion.py +3 -3
  21. sglang/srt/layers/layernorm.py +2 -2
  22. sglang/srt/layers/linear.py +12 -2
  23. sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
  24. sglang/srt/layers/moe/ep_moe/kernels.py +60 -1
  25. sglang/srt/layers/moe/ep_moe/layer.py +141 -2
  26. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +2 -0
  27. sglang/srt/layers/moe/fused_moe_triton/layer.py +141 -59
  28. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
  29. sglang/srt/layers/moe/topk.py +8 -2
  30. sglang/srt/layers/parameter.py +19 -3
  31. sglang/srt/layers/quantization/__init__.py +2 -0
  32. sglang/srt/layers/quantization/fp8.py +28 -7
  33. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  34. sglang/srt/layers/quantization/modelopt_quant.py +244 -1
  35. sglang/srt/layers/quantization/moe_wna16.py +1 -2
  36. sglang/srt/layers/quantization/w4afp8.py +264 -0
  37. sglang/srt/layers/quantization/w8a8_int8.py +738 -14
  38. sglang/srt/layers/vocab_parallel_embedding.py +9 -3
  39. sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
  40. sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
  41. sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
  42. sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
  43. sglang/srt/managers/cache_controller.py +41 -195
  44. sglang/srt/managers/io_struct.py +35 -3
  45. sglang/srt/managers/mm_utils.py +59 -96
  46. sglang/srt/managers/schedule_batch.py +17 -6
  47. sglang/srt/managers/scheduler.py +38 -6
  48. sglang/srt/managers/tokenizer_manager.py +16 -0
  49. sglang/srt/mem_cache/hiradix_cache.py +2 -0
  50. sglang/srt/mem_cache/memory_pool.py +176 -101
  51. sglang/srt/mem_cache/memory_pool_host.py +6 -109
  52. sglang/srt/mem_cache/radix_cache.py +8 -4
  53. sglang/srt/model_executor/forward_batch_info.py +13 -1
  54. sglang/srt/model_loader/loader.py +23 -12
  55. sglang/srt/models/deepseek_janus_pro.py +1 -1
  56. sglang/srt/models/deepseek_v2.py +78 -19
  57. sglang/srt/models/deepseek_vl2.py +1 -1
  58. sglang/srt/models/gemma3_mm.py +1 -1
  59. sglang/srt/models/gemma3n_mm.py +6 -3
  60. sglang/srt/models/internvl.py +8 -2
  61. sglang/srt/models/kimi_vl.py +8 -2
  62. sglang/srt/models/llama.py +2 -0
  63. sglang/srt/models/llava.py +3 -1
  64. sglang/srt/models/llavavid.py +1 -1
  65. sglang/srt/models/minicpmo.py +1 -2
  66. sglang/srt/models/minicpmv.py +1 -1
  67. sglang/srt/models/mixtral_quant.py +4 -0
  68. sglang/srt/models/mllama4.py +372 -82
  69. sglang/srt/models/phi4mm.py +8 -2
  70. sglang/srt/models/phimoe.py +553 -0
  71. sglang/srt/models/qwen2.py +2 -0
  72. sglang/srt/models/qwen2_5_vl.py +10 -7
  73. sglang/srt/models/qwen2_vl.py +12 -1
  74. sglang/srt/models/vila.py +8 -2
  75. sglang/srt/multimodal/mm_utils.py +2 -2
  76. sglang/srt/multimodal/processors/base_processor.py +197 -137
  77. sglang/srt/multimodal/processors/deepseek_vl_v2.py +1 -1
  78. sglang/srt/multimodal/processors/gemma3.py +4 -2
  79. sglang/srt/multimodal/processors/gemma3n.py +1 -1
  80. sglang/srt/multimodal/processors/internvl.py +1 -1
  81. sglang/srt/multimodal/processors/janus_pro.py +1 -1
  82. sglang/srt/multimodal/processors/kimi_vl.py +1 -1
  83. sglang/srt/multimodal/processors/minicpm.py +4 -3
  84. sglang/srt/multimodal/processors/mllama4.py +63 -61
  85. sglang/srt/multimodal/processors/phi4mm.py +1 -1
  86. sglang/srt/multimodal/processors/pixtral.py +1 -1
  87. sglang/srt/multimodal/processors/qwen_vl.py +203 -80
  88. sglang/srt/multimodal/processors/vila.py +1 -1
  89. sglang/srt/server_args.py +26 -4
  90. sglang/srt/two_batch_overlap.py +3 -0
  91. sglang/srt/utils.py +191 -48
  92. sglang/test/test_cutlass_w4a8_moe.py +281 -0
  93. sglang/utils.py +5 -5
  94. sglang/version.py +1 -1
  95. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/METADATA +6 -4
  96. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/RECORD +99 -90
  97. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/WHEEL +0 -0
  98. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/licenses/LICENSE +0 -0
  99. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/top_level.txt +0 -0
@@ -77,6 +77,7 @@ from sglang.srt.layers.quantization.int8_utils import (
77
77
  )
78
78
  from sglang.srt.layers.radix_attention import RadixAttention
79
79
  from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
80
+ from sglang.srt.layers.utils import is_sm100_supported
80
81
  from sglang.srt.layers.vocab_parallel_embedding import (
81
82
  ParallelLMHead,
82
83
  VocabParallelEmbedding,
@@ -100,6 +101,7 @@ from sglang.srt.utils import (
100
101
  get_int_env_var,
101
102
  is_cpu,
102
103
  is_cuda,
104
+ is_flashinfer_available,
103
105
  is_hip,
104
106
  is_non_idle_and_non_empty,
105
107
  log_info_on_rank0,
@@ -132,6 +134,9 @@ if _is_hip:
132
134
  decode_attention_fwd_grouped_rope,
133
135
  )
134
136
 
137
+ _is_flashinfer_available = is_flashinfer_available()
138
+ _is_sm100_supported = is_cuda() and is_sm100_supported()
139
+
135
140
 
136
141
  logger = logging.getLogger(__name__)
137
142
 
@@ -195,13 +200,13 @@ class DeepseekV2MLP(nn.Module):
195
200
  )
196
201
  self.act_fn = SiluAndMul()
197
202
 
198
- def forward(self, x, forward_batch=None):
203
+ def forward(self, x, forward_batch=None, can_fuse_mlp_allreduce=False):
199
204
  if (self.tp_size == 1) and x.shape[0] == 0:
200
205
  return x
201
206
 
202
207
  gate_up, _ = self.gate_up_proj(x)
203
208
  x = self.act_fn(gate_up)
204
- x, _ = self.down_proj(x)
209
+ x, _ = self.down_proj(x, can_fuse_mlp_allreduce=can_fuse_mlp_allreduce)
205
210
  return x
206
211
 
207
212
 
@@ -210,8 +215,10 @@ class MoEGate(nn.Module):
210
215
  self,
211
216
  config,
212
217
  prefix: str = "",
218
+ is_nextn: bool = False,
213
219
  ):
214
220
  super().__init__()
221
+ self.is_nextn = is_nextn
215
222
  self.weight = nn.Parameter(
216
223
  torch.empty((config.n_routed_experts, config.hidden_size))
217
224
  )
@@ -233,8 +240,10 @@ class MoEGate(nn.Module):
233
240
  True, # is_vnni
234
241
  )
235
242
 
243
+ # NOTE: For some unknown reason, router_gemm seems degrade accept length.
236
244
  if (
237
245
  _is_cuda
246
+ and not self.is_nextn
238
247
  and hidden_states.shape[0] < 4
239
248
  and hidden_states.shape[1] == 7168
240
249
  and self.weight.shape[0] == 256
@@ -258,6 +267,7 @@ class DeepseekV2MoE(nn.Module):
258
267
  quant_config: Optional[QuantizationConfig] = None,
259
268
  prefix: str = "",
260
269
  alt_stream: Optional[torch.cuda.Stream] = None,
270
+ is_nextn: bool = False,
261
271
  ):
262
272
  super().__init__()
263
273
  self.tp_size = get_tensor_model_parallel_world_size()
@@ -284,7 +294,9 @@ class DeepseekV2MoE(nn.Module):
284
294
  "Only silu is supported for now."
285
295
  )
286
296
 
287
- self.gate = MoEGate(config=config, prefix=add_prefix("gate", prefix))
297
+ self.gate = MoEGate(
298
+ config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
299
+ )
288
300
 
289
301
  self.experts = get_moe_impl_class()(
290
302
  num_experts=config.n_routed_experts
@@ -402,7 +414,10 @@ class DeepseekV2MoE(nn.Module):
402
414
  ]
403
415
 
404
416
  def forward(
405
- self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
417
+ self,
418
+ hidden_states: torch.Tensor,
419
+ forward_batch: Optional[ForwardBatch] = None,
420
+ can_fuse_mlp_allreduce: bool = False,
406
421
  ) -> torch.Tensor:
407
422
  if not self._enable_deepep_moe:
408
423
  DUAL_STREAM_TOKEN_THRESHOLD = 1024
@@ -411,13 +426,17 @@ class DeepseekV2MoE(nn.Module):
411
426
  and self.num_fused_shared_experts == 0
412
427
  and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
413
428
  ):
414
- return self.forward_normal_dual_stream(hidden_states)
429
+ return self.forward_normal_dual_stream(
430
+ hidden_states, can_fuse_mlp_allreduce
431
+ )
415
432
  else:
416
- return self.forward_normal(hidden_states)
433
+ return self.forward_normal(hidden_states, can_fuse_mlp_allreduce)
417
434
  else:
418
435
  return self.forward_deepep(hidden_states, forward_batch)
419
436
 
420
- def forward_normal_dual_stream(self, hidden_states: torch.Tensor) -> torch.Tensor:
437
+ def forward_normal_dual_stream(
438
+ self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
439
+ ) -> torch.Tensor:
421
440
  # router_logits: (num_tokens, n_experts)
422
441
  router_logits = self.gate(hidden_states)
423
442
 
@@ -433,11 +452,13 @@ class DeepseekV2MoE(nn.Module):
433
452
  final_hidden_states *= self.routed_scaling_factor
434
453
  current_stream.wait_stream(self.alt_stream)
435
454
  final_hidden_states = final_hidden_states + shared_output
436
- if self.tp_size > 1:
455
+ if self.tp_size > 1 and not can_fuse_mlp_allreduce:
437
456
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
438
457
  return final_hidden_states
439
458
 
440
- def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
459
+ def forward_normal(
460
+ self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
461
+ ) -> torch.Tensor:
441
462
  if hasattr(self, "shared_experts") and use_intel_amx_backend(
442
463
  self.shared_experts.gate_up_proj
443
464
  ):
@@ -454,7 +475,7 @@ class DeepseekV2MoE(nn.Module):
454
475
  final_hidden_states *= self.routed_scaling_factor
455
476
  if shared_output is not None:
456
477
  final_hidden_states = final_hidden_states + shared_output
457
- if self.tp_size > 1:
478
+ if self.tp_size > 1 and not can_fuse_mlp_allreduce:
458
479
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
459
480
  return final_hidden_states
460
481
 
@@ -507,7 +528,7 @@ class DeepseekV2MoE(nn.Module):
507
528
  None, # a2_scale
508
529
  True, # is_vnni
509
530
  )
510
- if self.tp_size > 1:
531
+ if self.tp_size > 1 and not self.can_fuse_mlp_allreduce:
511
532
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
512
533
  return final_hidden_states
513
534
 
@@ -1776,6 +1797,7 @@ class DeepseekV2DecoderLayer(nn.Module):
1776
1797
  prefix=add_prefix("mlp", prefix),
1777
1798
  layer_id=self.layer_id,
1778
1799
  alt_stream=alt_stream,
1800
+ is_nextn=is_nextn,
1779
1801
  )
1780
1802
  else:
1781
1803
  if enable_moe_dense_fully_dp():
@@ -1810,6 +1832,29 @@ class DeepseekV2DecoderLayer(nn.Module):
1810
1832
  and layer_id % self.config.moe_layer_freq == 0
1811
1833
  )
1812
1834
 
1835
+ def _should_fuse_mlp_allreduce_with_next_layer(self, forward_batch) -> bool:
1836
+ """Check if MLP allreduce can be fused with next layer's add_rmsnorm"""
1837
+
1838
+ if (
1839
+ self.layer_id == self.config.num_hidden_layers - 1
1840
+ or get_tensor_model_parallel_world_size() <= 1
1841
+ ):
1842
+ return False
1843
+
1844
+ if not global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False):
1845
+ return False
1846
+
1847
+ if not _is_sm100_supported or not _is_flashinfer_available:
1848
+ return False
1849
+
1850
+ if hasattr(forward_batch, "input_ids") and (
1851
+ forward_batch.input_ids.shape[0] == 0
1852
+ or forward_batch.input_ids.shape[0] > 128
1853
+ ):
1854
+ return False
1855
+
1856
+ return True
1857
+
1813
1858
  def forward(
1814
1859
  self,
1815
1860
  positions: torch.Tensor,
@@ -1834,12 +1879,22 @@ class DeepseekV2DecoderLayer(nn.Module):
1834
1879
  hidden_states, residual, forward_batch
1835
1880
  )
1836
1881
 
1837
- hidden_states = self.mlp(hidden_states, forward_batch)
1838
-
1839
- hidden_states, residual = self.layer_communicator.postprocess_layer(
1840
- hidden_states, residual, forward_batch
1882
+ can_fuse_mlp_allreduce = (
1883
+ self._should_fuse_mlp_allreduce_with_next_layer(forward_batch)
1884
+ and not (self.enable_dp_attention and self.speculative_algorithm.is_eagle())
1885
+ and not self.is_nextn
1841
1886
  )
1842
1887
 
1888
+ hidden_states = self.mlp(hidden_states, forward_batch, can_fuse_mlp_allreduce)
1889
+
1890
+ if can_fuse_mlp_allreduce:
1891
+ hidden_states._sglang_needs_allreduce_fusion = True
1892
+
1893
+ if not can_fuse_mlp_allreduce:
1894
+ hidden_states, residual = self.layer_communicator.postprocess_layer(
1895
+ hidden_states, residual, forward_batch
1896
+ )
1897
+
1843
1898
  return hidden_states, residual
1844
1899
 
1845
1900
  def op_comm_prepare_attn(
@@ -1930,7 +1985,7 @@ class DeepseekV2Model(nn.Module):
1930
1985
  self.embed_tokens = VocabParallelEmbedding(
1931
1986
  config.vocab_size,
1932
1987
  config.hidden_size,
1933
- use_attn_tp_group=True,
1988
+ enable_tp=not global_server_args_dict["enable_dp_attention"],
1934
1989
  )
1935
1990
  self.alt_stream = torch.cuda.Stream() if _is_cuda else None
1936
1991
  self.layers = nn.ModuleList(
@@ -2138,7 +2193,6 @@ class DeepseekV2ForCausalLM(nn.Module):
2138
2193
  # This may affect the accuracy of fp8 model.
2139
2194
  # Fix deepseek v3 blockwise bmm by using deep_gemm
2140
2195
  use_deep_gemm_bmm = False
2141
- model_dtype = torch.get_default_dtype()
2142
2196
 
2143
2197
  if w.dtype in (
2144
2198
  torch.float8_e4m3fn,
@@ -2164,7 +2218,6 @@ class DeepseekV2ForCausalLM(nn.Module):
2164
2218
  _is_cuda
2165
2219
  and weight_block_size[0] == 128
2166
2220
  and weight_block_size[1] == 128
2167
- and model_dtype == torch.bfloat16
2168
2221
  ):
2169
2222
  if (
2170
2223
  deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
@@ -2178,7 +2231,7 @@ class DeepseekV2ForCausalLM(nn.Module):
2178
2231
  weight,
2179
2232
  weight_scale,
2180
2233
  weight_block_size,
2181
- model_dtype,
2234
+ torch.bfloat16,
2182
2235
  )
2183
2236
  else:
2184
2237
  w, scale = block_quant_to_tensor_quant(
@@ -2355,6 +2408,12 @@ class DeepseekV2ForCausalLM(nn.Module):
2355
2408
  ckpt_up_proj_name="up_proj",
2356
2409
  num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
2357
2410
  )
2411
+ if self.quant_config and self.quant_config.get_name() == "w4afp8":
2412
+ expert_params_mapping += (
2413
+ get_moe_impl_class().make_expert_input_scale_params_mapping(
2414
+ num_experts=self.config.n_routed_experts
2415
+ )
2416
+ )
2358
2417
 
2359
2418
  # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
2360
2419
  fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
@@ -227,7 +227,7 @@ class DeepseekVL2ForCausalLM(nn.Module):
227
227
  input_ids=input_ids,
228
228
  positions=positions,
229
229
  forward_batch=forward_batch,
230
- image_data_embedding_func=self.get_image_feature,
230
+ multimodal_model=self,
231
231
  language_model=self.language_model,
232
232
  )
233
233
 
@@ -374,7 +374,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
374
374
  input_ids=llm_input_ids,
375
375
  forward_batch=forward_batch,
376
376
  language_model=self.language_model,
377
- image_data_embedding_func=self.get_image_feature,
377
+ multimodal_model=self,
378
378
  positions=positions,
379
379
  )
380
380
 
@@ -1,7 +1,7 @@
1
1
  import logging
2
2
  import re
3
3
  from functools import lru_cache
4
- from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict, Union
4
+ from typing import Iterable, List, Optional, Set, Tuple, TypedDict, Union
5
5
 
6
6
  import torch
7
7
  from torch import nn
@@ -25,6 +25,7 @@ from sglang.srt.managers.mm_utils import (
25
25
  general_mm_embed_routine,
26
26
  )
27
27
  from sglang.srt.managers.schedule_batch import (
28
+ Modality,
28
29
  MultimodalDataItem,
29
30
  MultimodalInputs,
30
31
  flatten_nested_list,
@@ -434,8 +435,10 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
434
435
  input_ids=input_ids,
435
436
  forward_batch=forward_batch,
436
437
  language_model=self.language_model,
437
- image_data_embedding_func=self.get_image_feature,
438
- audio_data_embedding_func=self.get_audio_feature,
438
+ data_embedding_funcs={
439
+ Modality.IMAGE: self.get_image_feature,
440
+ Modality.AUDIO: self.get_audio_feature,
441
+ },
439
442
  positions=positions,
440
443
  per_layer_inputs=per_layer_inputs,
441
444
  )
@@ -29,7 +29,11 @@ from sglang.srt.managers.mm_utils import (
29
29
  MultiModalityDataPaddingPatternTokenPairs,
30
30
  general_mm_embed_routine,
31
31
  )
32
- from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
32
+ from sglang.srt.managers.schedule_batch import (
33
+ Modality,
34
+ MultimodalDataItem,
35
+ MultimodalInputs,
36
+ )
33
37
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
34
38
  from sglang.srt.model_loader.weight_utils import default_weight_loader
35
39
  from sglang.srt.models.deepseek_janus_pro import DropPath
@@ -523,7 +527,9 @@ class InternVLChatModel(nn.Module):
523
527
  input_ids=input_ids,
524
528
  forward_batch=forward_batch,
525
529
  language_model=self.language_model,
526
- image_data_embedding_func=self.get_image_feature,
530
+ data_embedding_funcs={
531
+ Modality.IMAGE: self.get_image_feature,
532
+ },
527
533
  positions=positions,
528
534
  )
529
535
 
@@ -67,7 +67,11 @@ from sglang.srt.managers.mm_utils import (
67
67
  MultiModalityDataPaddingPatternMultimodalTokens,
68
68
  general_mm_embed_routine,
69
69
  )
70
- from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
70
+ from sglang.srt.managers.schedule_batch import (
71
+ Modality,
72
+ MultimodalDataItem,
73
+ MultimodalInputs,
74
+ )
71
75
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
72
76
  from sglang.srt.model_loader.weight_utils import (
73
77
  default_weight_loader,
@@ -168,7 +172,9 @@ class KimiVLForConditionalGeneration(nn.Module):
168
172
  input_ids=input_ids,
169
173
  forward_batch=forward_batch,
170
174
  language_model=self.language_model,
171
- image_data_embedding_func=self.get_image_feature,
175
+ data_embedding_funcs={
176
+ Modality.IMAGE: self.get_image_feature,
177
+ },
172
178
  positions=positions,
173
179
  )
174
180
 
@@ -575,6 +575,8 @@ class LlamaForCausalLM(nn.Module):
575
575
  # Skip loading extra bias for GPTQ models.
576
576
  if name.endswith(".bias") and name not in params_dict:
577
577
  continue
578
+ if name not in params_dict:
579
+ continue
578
580
  param = params_dict[name]
579
581
  weight_loader = param.weight_loader
580
582
  weight_loader(param, loaded_weight, shard_id)
@@ -787,7 +787,9 @@ class LlavaForConditionalGeneration(LlavaBaseForCausalLM):
787
787
  forward_batch=forward_batch,
788
788
  get_embedding=get_embedding,
789
789
  language_model=self.language_model,
790
- image_data_embedding_func=self.get_image_feature,
790
+ data_embedding_funcs={
791
+ Modality.IMAGE: self.get_image_feature,
792
+ },
791
793
  placeholder_tokens=None, # using mm_item.pad_value
792
794
  positions=positions,
793
795
  )
@@ -142,7 +142,7 @@ class LlavaVidForCausalLM(nn.Module):
142
142
  )
143
143
  image_offsets = [
144
144
  flatten_nested_list(
145
- [item.image_offsets for item in image_inputs[i].mm_items]
145
+ [item.offsets for item in image_inputs[i].mm_items]
146
146
  )
147
147
  for i in range(bs)
148
148
  if need_vision[i]
@@ -1827,8 +1827,7 @@ class MiniCPMO(MiniCPMBaseModel):
1827
1827
  input_ids=input_ids,
1828
1828
  forward_batch=forward_batch,
1829
1829
  language_model=self.llm,
1830
- image_data_embedding_func=self.get_image_feature,
1831
- audio_data_embedding_func=self.get_audio_feature,
1830
+ multimodal_model=self,
1832
1831
  positions=positions,
1833
1832
  )
1834
1833
  return hidden_states
@@ -573,7 +573,7 @@ class MiniCPMBaseModel(nn.Module):
573
573
  hidden_states = general_mm_embed_routine(
574
574
  input_ids=input_ids,
575
575
  forward_batch=forward_batch,
576
- image_data_embedding_func=self.get_image_feature,
576
+ multimodal_model=self,
577
577
  language_model=self.llm,
578
578
  positions=positions,
579
579
  )
@@ -407,6 +407,8 @@ class QuantMixtralForCausalLM(nn.Module):
407
407
  # Skip loading extra bias for GPTQ models.
408
408
  if name.endswith(".bias") and name not in params_dict:
409
409
  continue
410
+ if name not in params_dict:
411
+ continue
410
412
  param = params_dict[name]
411
413
  weight_loader = param.weight_loader
412
414
  weight_loader(param, loaded_weight, shard_id)
@@ -418,6 +420,8 @@ class QuantMixtralForCausalLM(nn.Module):
418
420
  # Skip experts that are not assigned to this worker.
419
421
  if "block_sparse_moe.experts." in name and name not in params_dict:
420
422
  continue
423
+ if name not in params_dict:
424
+ continue
421
425
  param = params_dict[name]
422
426
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
423
427
  weight_loader(param, loaded_weight)