sglang 0.4.4.post3__py3-none-any.whl → 0.4.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/bench_serving.py +49 -7
  2. sglang/lang/chat_template.py +24 -0
  3. sglang/srt/_custom_ops.py +59 -92
  4. sglang/srt/configs/model_config.py +5 -0
  5. sglang/srt/constrained/base_grammar_backend.py +5 -1
  6. sglang/srt/conversation.py +29 -4
  7. sglang/srt/custom_op.py +5 -0
  8. sglang/srt/distributed/device_communicators/custom_all_reduce.py +27 -79
  9. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  10. sglang/srt/entrypoints/engine.py +0 -5
  11. sglang/srt/layers/attention/flashattention_backend.py +678 -83
  12. sglang/srt/layers/attention/flashinfer_backend.py +5 -7
  13. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
  14. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  15. sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
  16. sglang/srt/layers/moe/ep_moe/layer.py +79 -80
  17. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
  18. sglang/srt/layers/moe/fused_moe_native.py +5 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +416 -50
  30. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  31. sglang/srt/layers/moe/topk.py +49 -3
  32. sglang/srt/layers/quantization/__init__.py +5 -1
  33. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  34. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
  35. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
  36. sglang/srt/layers/quantization/fp8.py +3 -1
  37. sglang/srt/layers/quantization/fp8_utils.py +1 -4
  38. sglang/srt/layers/quantization/moe_wna16.py +503 -0
  39. sglang/srt/layers/quantization/utils.py +1 -1
  40. sglang/srt/layers/quantization/w8a8_int8.py +2 -0
  41. sglang/srt/layers/radix_attention.py +2 -0
  42. sglang/srt/layers/rotary_embedding.py +63 -12
  43. sglang/srt/managers/cache_controller.py +34 -11
  44. sglang/srt/managers/mm_utils.py +202 -156
  45. sglang/srt/managers/multimodal_processor.py +0 -2
  46. sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
  47. sglang/srt/managers/multimodal_processors/clip.py +7 -26
  48. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
  49. sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
  50. sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
  51. sglang/srt/managers/multimodal_processors/llava.py +34 -14
  52. sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
  53. sglang/srt/managers/multimodal_processors/mlama.py +10 -23
  54. sglang/srt/managers/multimodal_processors/mllama4.py +161 -0
  55. sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
  56. sglang/srt/managers/schedule_batch.py +185 -128
  57. sglang/srt/managers/scheduler.py +4 -4
  58. sglang/srt/managers/tokenizer_manager.py +1 -1
  59. sglang/srt/managers/utils.py +1 -6
  60. sglang/srt/mem_cache/hiradix_cache.py +62 -52
  61. sglang/srt/mem_cache/memory_pool.py +72 -6
  62. sglang/srt/mem_cache/paged_allocator.py +39 -0
  63. sglang/srt/metrics/collector.py +23 -53
  64. sglang/srt/model_executor/cuda_graph_runner.py +8 -6
  65. sglang/srt/model_executor/forward_batch_info.py +10 -10
  66. sglang/srt/model_executor/model_runner.py +60 -57
  67. sglang/srt/model_loader/loader.py +8 -0
  68. sglang/srt/models/clip.py +12 -7
  69. sglang/srt/models/deepseek_janus_pro.py +10 -15
  70. sglang/srt/models/deepseek_v2.py +212 -121
  71. sglang/srt/models/deepseek_vl2.py +105 -104
  72. sglang/srt/models/gemma3_mm.py +14 -80
  73. sglang/srt/models/llama.py +16 -5
  74. sglang/srt/models/llama4.py +420 -0
  75. sglang/srt/models/llava.py +31 -19
  76. sglang/srt/models/llavavid.py +16 -7
  77. sglang/srt/models/minicpmo.py +63 -147
  78. sglang/srt/models/minicpmv.py +17 -27
  79. sglang/srt/models/mllama.py +29 -14
  80. sglang/srt/models/mllama4.py +154 -0
  81. sglang/srt/models/qwen2.py +9 -6
  82. sglang/srt/models/qwen2_5_vl.py +21 -31
  83. sglang/srt/models/qwen2_vl.py +20 -21
  84. sglang/srt/openai_api/adapter.py +18 -6
  85. sglang/srt/platforms/interface.py +371 -0
  86. sglang/srt/server_args.py +99 -14
  87. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
  88. sglang/srt/speculative/eagle_utils.py +140 -28
  89. sglang/srt/speculative/eagle_worker.py +93 -24
  90. sglang/srt/utils.py +104 -51
  91. sglang/test/test_custom_ops.py +55 -0
  92. sglang/test/test_utils.py +13 -26
  93. sglang/utils.py +2 -2
  94. sglang/version.py +1 -1
  95. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/METADATA +4 -3
  96. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/RECORD +99 -84
  97. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/WHEEL +0 -0
  98. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/licenses/LICENSE +0 -0
  99. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/top_level.txt +0 -0
@@ -40,16 +40,19 @@ from transformers.models.whisper.modeling_whisper import (
40
40
  from sglang.srt.layers.quantization import QuantizationConfig
41
41
  from sglang.srt.managers.mm_utils import (
42
42
  MultiModalityDataPaddingPatternTokenPairs,
43
- embed_mm_inputs,
44
- get_multimodal_data_bounds,
43
+ general_mm_embed_routine,
45
44
  )
46
- from sglang.srt.managers.schedule_batch import MultimodalInputs
47
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
45
+ from sglang.srt.managers.schedule_batch import (
46
+ MultimodalDataItem,
47
+ MultimodalInputs,
48
+ flatten_nested_list,
49
+ )
50
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
48
51
  from sglang.srt.model_loader.utils import set_default_torch_dtype
49
52
  from sglang.srt.model_loader.weight_utils import default_weight_loader
50
53
  from sglang.srt.models.minicpmv import (
51
54
  Idefics2VisionTransformer,
52
- MiniCPMVBaseModel,
55
+ MiniCPMBaseModel,
53
56
  Resampler2_5,
54
57
  )
55
58
  from sglang.srt.models.qwen2 import Qwen2ForCausalLM
@@ -1409,7 +1412,7 @@ class MultiModalProjector(nn.Module):
1409
1412
  return hidden_states
1410
1413
 
1411
1414
 
1412
- class MiniCPMO(MiniCPMVBaseModel):
1415
+ class MiniCPMO(MiniCPMBaseModel):
1413
1416
  def __init__(
1414
1417
  self,
1415
1418
  config: PretrainedConfig,
@@ -1537,7 +1540,7 @@ class MiniCPMO(MiniCPMVBaseModel):
1537
1540
 
1538
1541
  return input_lengths_after_cnn, input_lengths_after_pooling
1539
1542
 
1540
- def get_audio_embedding_streaming(self, multimodal_input: MultimodalInputs):
1543
+ def get_audio_embedding_streaming(self, items: List[MultimodalDataItem]):
1541
1544
  r"""
1542
1545
  Extract audio embeddings in a streaming manner using cached key-value pairs.
1543
1546
 
@@ -1545,26 +1548,15 @@ class MiniCPMO(MiniCPMVBaseModel):
1545
1548
  for faster inference on subsequent audio frames. It only supports batch_size=1 and is intended
1546
1549
  for streaming scenarios.
1547
1550
 
1548
- Args:
1549
- multimodal_input (dict):
1550
- - **"audio_features"** (`torch.FloatTensor`): Input mel-spectrograms of shape `(batch_size, 80, frames)`.
1551
- - **"audio_feature_lens"** (List[List[int]]): Lengths of each audio segment for each item in the batch.
1552
-
1553
1551
  Returns:
1554
1552
  List[List[torch.Tensor]]: audio embeddings
1555
1553
  """
1556
- # print("audio embedding")
1557
-
1558
- wavforms = (
1559
- []
1560
- if multimodal_input.audio_features is None
1561
- else multimodal_input.audio_features
1554
+ wavforms = flatten_nested_list(
1555
+ [item.audio_features for item in items if item.audio_features]
1562
1556
  )
1563
1557
  # list, [[x1, x2], [y1], [z1]]
1564
- audio_feature_lens_raw = (
1565
- []
1566
- if multimodal_input.audio_feature_lens is None
1567
- else multimodal_input.audio_feature_lens
1558
+ audio_feature_lens_raw = flatten_nested_list(
1559
+ [item.audio_feature_lens for item in items if item.audio_feature_lens]
1568
1560
  )
1569
1561
 
1570
1562
  # exist audio
@@ -1650,7 +1642,7 @@ class MiniCPMO(MiniCPMVBaseModel):
1650
1642
  ret[i, start:ending] = True
1651
1643
  return ret
1652
1644
 
1653
- def get_audio_embedding(self, multimodal_input: MultimodalInputs, chunk_length=-1):
1645
+ def get_audio_embedding(self, items: List[MultimodalDataItem], chunk_length=-1):
1654
1646
  r"""
1655
1647
  Extract full audio embeddings with optional chunk-based attention.
1656
1648
 
@@ -1659,31 +1651,25 @@ class MiniCPMO(MiniCPMVBaseModel):
1659
1651
  not use key-value caching and is suitable for non-streaming inference.
1660
1652
 
1661
1653
  Args:
1662
- multimodal_input (dict):
1663
- - **"audio_features"** (`torch.FloatTensor`): Input mel-spectrograms of shape `(batch_size, 80, frames)`.
1664
- - **"audio_feature_lens"** (List[List[int]]): Lengths of each audio segment for each item in the batch.
1665
1654
  chunk_length (int, optional): Determines whether to use full attention (-1) or chunk-based
1666
1655
  attention (>0) during embedding computation.
1667
1656
 
1668
1657
  Returns:
1669
1658
  List[List[torch.Tensor]]: audio embeddings
1670
1659
  """
1671
- # print("audio embedding")
1672
1660
  # (bs, 80, frames) or [], multi audios need filled in advance
1673
- wavforms = (
1674
- []
1675
- if multimodal_input.audio_features is None
1676
- else multimodal_input.audio_features
1661
+ wavforms = flatten_nested_list(
1662
+ [item.audio_features for item in items if item.audio_features]
1677
1663
  )
1678
1664
  # list, [[x1, x2], [y1], [z1]]
1679
- audio_feature_lens_raw = (
1680
- []
1681
- if multimodal_input.audio_feature_lens is None
1682
- else multimodal_input.audio_feature_lens
1665
+ audio_feature_lens_raw = flatten_nested_list(
1666
+ [item.audio_feature_lens for item in items if item.audio_feature_lens]
1683
1667
  )
1684
1668
 
1685
1669
  final_audio_embeds = []
1686
1670
 
1671
+ assert isinstance(wavforms, list)
1672
+ assert isinstance(wavforms[0], torch.Tensor)
1687
1673
  # exist audio
1688
1674
  for wavform in wavforms:
1689
1675
  if len(wavform) > 0:
@@ -1757,86 +1743,46 @@ class MiniCPMO(MiniCPMVBaseModel):
1757
1743
  final_audio_embeds.append(target_audio_embeds)
1758
1744
  return final_audio_embeds
1759
1745
 
1746
+ def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
1747
+ embedding = self.get_omni_embedding(
1748
+ items=items,
1749
+ chunk_length=self.config.audio_chunk_length,
1750
+ stream_input=False,
1751
+ )
1752
+ return embedding
1753
+
1760
1754
  def get_omni_embedding(
1761
1755
  self,
1762
- input_ids,
1763
- multimodal_input: MultimodalInputs,
1764
- input_embeds: torch.Tensor,
1765
- forward_mode: ForwardMode,
1756
+ items: List[MultimodalDataItem],
1766
1757
  chunk_length=-1,
1767
1758
  stream_input=False,
1768
1759
  ):
1769
1760
  """
1770
1761
  Args:
1771
- multimodal_input:
1772
- input_embeds:
1773
1762
  chunk_length: whisper use full attention or chunk attention
1774
1763
  stream_input: use streaming audio embedding
1775
1764
  Returns:
1776
1765
  final embeddings with audio feature
1777
1766
  """
1778
- input_embeds = input_embeds.unsqueeze(0)
1779
- if not forward_mode.is_decode() and multimodal_input.contains_audio_inputs():
1780
- audio_bounds = get_multimodal_data_bounds(
1781
- input_ids=input_ids,
1782
- pad_values=multimodal_input.pad_values,
1783
- token_pairs=[
1784
- (multimodal_input.audio_start_id, multimodal_input.audio_end_id)
1785
- ],
1786
- )
1787
- if audio_bounds.numel() == 0:
1788
- input_embeds = input_embeds.squeeze(0)
1789
- # TODO
1790
- logger.warn("Unimplemented logic. Please try disabling chunked prefill")
1791
- return input_embeds
1792
- audio_bounds = audio_bounds.unsqueeze(0)
1793
- bs = len(input_embeds)
1794
-
1795
- if stream_input:
1796
- audio_embeddings = self.get_audio_embedding_streaming(multimodal_input)
1797
- else:
1798
- audio_embeddings = self.get_audio_embedding(
1799
- multimodal_input, chunk_length
1800
- )
1801
- # batch size
1802
- assert len(audio_embeddings) == len(input_embeds)
1803
- if len(audio_embeddings) > 0:
1804
- if self.config.chunk_input:
1805
- for i in range(bs):
1806
- audio_embs = torch.cat(audio_embeddings[i], dim=0).to(
1807
- device=input_embeds.device, dtype=input_embeds.dtype
1808
- )
1809
- audio_start_pos = 0
1810
- for bound in audio_bounds[i]:
1811
- audio_len = bound[1] - bound[0] + 1
1812
- input_embeds[0, bound[0] : bound[1] + 1] = audio_embs[
1813
- audio_start_pos : audio_start_pos + audio_len, :
1814
- ]
1815
- audio_start_pos += audio_len
1816
- else:
1817
- for i in range(bs):
1818
- audio_embs = audio_embeddings[i]
1819
- bounds = audio_bounds[i]
1820
- for embs, bound in zip(audio_embs, bounds):
1821
- audio_indices = torch.arange(
1822
- bound[0], bound[1], dtype=torch.long
1823
- ).to(input_embeds.device)
1824
-
1825
- if embs.shape[0] != len(audio_indices):
1826
- raise ValueError(
1827
- f"Shape mismatch: Trying to assign embeddings of shape {embs.shape} "
1828
- f"to input indices of length {len(audio_indices)}"
1829
- )
1830
- input_embeds[i, audio_indices] = embs.to(input_embeds.dtype)
1831
- input_embeds = input_embeds.squeeze(0)
1832
- return input_embeds
1833
-
1834
- def get_image_features(
1835
- self,
1836
- image_inputs: MultimodalInputs,
1837
- ) -> torch.Tensor:
1838
- pixel_values = image_inputs.pixel_values
1839
- tgt_sizes = image_inputs.tgt_sizes
1767
+
1768
+ if stream_input:
1769
+ audio_embeddings = self.get_audio_embedding_streaming(items)
1770
+ else:
1771
+ audio_embeddings = self.get_audio_embedding(items, chunk_length)
1772
+ bs = len(audio_embeddings)
1773
+ # batch size
1774
+ audio_embs = torch.cat(flatten_nested_list(audio_embeddings), dim=0)
1775
+
1776
+ return audio_embs
1777
+
1778
+ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
1779
+ # list of tensors
1780
+ pixel_values = flatten_nested_list([item.pixel_values for item in items])
1781
+ tgt_sizes = torch.stack(
1782
+ flatten_nested_list([item.tgt_size for item in items]), dim=0
1783
+ )
1784
+ assert len(pixel_values) == tgt_sizes.shape[0]
1785
+
1840
1786
  device = self.vpm.embeddings.position_embedding.weight.device
1841
1787
  dtype = self.vpm.embeddings.position_embedding.weight.dtype
1842
1788
  all_pixel_values_lst = [
@@ -1845,10 +1791,10 @@ class MiniCPMO(MiniCPMVBaseModel):
1845
1791
 
1846
1792
  max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
1847
1793
  assert isinstance(max_patches, int)
1848
-
1849
1794
  all_pixel_values = torch.nn.utils.rnn.pad_sequence(
1850
1795
  all_pixel_values_lst, batch_first=True, padding_value=0.0
1851
1796
  )
1797
+
1852
1798
  B, L, _ = all_pixel_values.shape
1853
1799
  all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
1854
1800
  patch_attn_mask = torch.zeros(
@@ -1875,53 +1821,23 @@ class MiniCPMO(MiniCPMVBaseModel):
1875
1821
  forward_batch: ForwardBatch,
1876
1822
  **kwargs: Any,
1877
1823
  ) -> torch.Tensor:
1878
- inputs_embeds = None
1879
- # TODO(mick): optimize the logic here: clamp, merge and embedding should happens at most once
1880
- if (
1881
- not forward_batch.forward_mode.is_decode()
1882
- and forward_batch.contains_image_inputs()
1883
- ):
1884
- mm_inputs = forward_batch.merge_mm_inputs()
1885
- inputs_embeds = embed_mm_inputs(
1886
- mm_input=mm_inputs,
1887
- input_ids=input_ids,
1888
- input_embedding=self.get_input_embeddings(),
1889
- mm_data_embedding_func=self.get_image_features,
1890
- placeholder_token_ids=[mm_inputs.im_token_id] + mm_inputs.pad_values,
1891
- )
1892
1824
 
1893
- input_ids = input_ids.clamp(
1894
- min=0, max=self.get_input_embeddings().num_embeddings - 1
1825
+ mm_input = forward_batch.merge_mm_inputs()
1826
+ placeholder_token_ids = (
1827
+ ([mm_input.im_token_id] + [item.pad_value for item in mm_input.mm_items])
1828
+ if forward_batch.contains_mm_inputs()
1829
+ else []
1895
1830
  )
1896
- if inputs_embeds is None:
1897
- inputs_embeds = self.llm.get_input_embeddings(input_ids)
1898
- if (
1899
- not forward_batch.forward_mode.is_decode()
1900
- and self.config.init_audio
1901
- and forward_batch.contains_audio_inputs()
1902
- ):
1903
- mm_input = forward_batch.merge_mm_inputs()
1904
- inputs_embeds = self.get_omni_embedding(
1905
- input_ids=input_ids,
1906
- multimodal_input=mm_input,
1907
- input_embeds=inputs_embeds,
1908
- forward_mode=forward_batch.forward_mode,
1909
- chunk_length=self.config.audio_chunk_length,
1910
- stream_input=False,
1911
- )
1912
-
1913
- forward_batch.mm_inputs = None
1914
-
1915
- hidden_states = self.llm.model(
1916
- input_ids=None,
1917
- positions=positions,
1831
+ hidden_states = general_mm_embed_routine(
1832
+ input_ids=input_ids,
1918
1833
  forward_batch=forward_batch,
1919
- input_embeds=inputs_embeds,
1920
- )
1921
-
1922
- return self.logits_processor(
1923
- input_ids, hidden_states, self.llm.lm_head, forward_batch
1834
+ language_model=self.llm,
1835
+ image_data_embedding_func=self.get_image_feature,
1836
+ audio_data_embedding_func=self.get_audio_feature,
1837
+ placeholder_token_ids=placeholder_token_ids,
1838
+ positions=positions,
1924
1839
  )
1840
+ return hidden_states
1925
1841
 
1926
1842
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
1927
1843
  stacked_params_mapping = [
@@ -54,12 +54,12 @@ from sglang.srt.managers.mm_utils import (
54
54
  MultiModalityDataPaddingPatternTokenPairs,
55
55
  general_mm_embed_routine,
56
56
  )
57
- from sglang.srt.managers.schedule_batch import MultimodalInputs
57
+ from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
58
58
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
59
59
  from sglang.srt.model_loader.utils import set_default_torch_dtype
60
60
  from sglang.srt.model_loader.weight_utils import default_weight_loader
61
61
  from sglang.srt.models.qwen2 import Qwen2Config, Qwen2ForCausalLM
62
- from sglang.srt.utils import add_prefix
62
+ from sglang.srt.utils import add_prefix, flatten_nested_list
63
63
 
64
64
  RawImageType = Union[Image.Image, torch.Tensor]
65
65
 
@@ -661,7 +661,7 @@ def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
661
661
  return tuple(int(x) for x in version_str.split("."))
662
662
 
663
663
 
664
- class MiniCPMVBaseModel(nn.Module):
664
+ class MiniCPMBaseModel(nn.Module):
665
665
  """
666
666
  The abstract class of MiniCPMV can only be inherited, but cannot be
667
667
  instantiated.
@@ -853,7 +853,7 @@ class MiniCPMVBaseModel(nn.Module):
853
853
  return vlm_embedding, vision_hidden_states
854
854
 
855
855
  def get_input_embeddings(self) -> nn.Embedding:
856
- return self.llm.get_input_embedding()
856
+ return self.llm.get_input_embeddings()
857
857
 
858
858
  def forward(
859
859
  self,
@@ -862,23 +862,14 @@ class MiniCPMVBaseModel(nn.Module):
862
862
  forward_batch: ForwardBatch,
863
863
  **kwargs: Any,
864
864
  ) -> torch.Tensor:
865
- inputs_embeds = general_mm_embed_routine(
865
+ hidden_states = general_mm_embed_routine(
866
866
  input_ids=input_ids,
867
867
  forward_batch=forward_batch,
868
- embed_tokens=self.get_input_embeddings(),
869
- mm_data_embedding_func=self.get_image_features,
870
- )
871
-
872
- hidden_states = self.llm.model(
873
- input_ids=None,
868
+ image_data_embedding_func=self.get_image_feature,
869
+ language_model=self.llm,
874
870
  positions=positions,
875
- forward_batch=forward_batch,
876
- input_embeds=inputs_embeds,
877
- )
878
-
879
- return self.logits_processor(
880
- input_ids, hidden_states, self.llm.lm_head, forward_batch
881
871
  )
872
+ return hidden_states
882
873
 
883
874
  def init_llm(
884
875
  self,
@@ -913,11 +904,11 @@ class MiniCPMVBaseModel(nn.Module):
913
904
  ) -> torch.Tensor:
914
905
  raise NotImplementedError
915
906
 
916
- def get_image_features(self, image_inputs: MultimodalInputs) -> torch.Tensor:
907
+ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
917
908
  raise NotImplementedError
918
909
 
919
910
 
920
- class MiniCPMV2_6(MiniCPMVBaseModel):
911
+ class MiniCPMV2_6(MiniCPMBaseModel):
921
912
  packed_modules_mapping = {
922
913
  "qkv_proj": [
923
914
  "q_proj",
@@ -1023,14 +1014,13 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
1023
1014
  )
1024
1015
  return vision_embedding
1025
1016
 
1026
- def get_image_features(
1027
- self,
1028
- image_inputs: MultimodalInputs,
1029
- ) -> torch.Tensor:
1017
+ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
1030
1018
  # list of tensors
1031
- pixel_values = image_inputs.pixel_values
1032
-
1033
- tgt_sizes = image_inputs.tgt_sizes
1019
+ pixel_values = flatten_nested_list([item.pixel_values for item in items])
1020
+ tgt_sizes = torch.stack(
1021
+ flatten_nested_list([item.tgt_size for item in items]), dim=0
1022
+ )
1023
+ assert len(pixel_values) == tgt_sizes.shape[0]
1034
1024
 
1035
1025
  device = self.vpm.embeddings.position_embedding.weight.device
1036
1026
  dtype = self.vpm.embeddings.position_embedding.weight.dtype
@@ -1040,10 +1030,10 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
1040
1030
 
1041
1031
  max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
1042
1032
  assert isinstance(max_patches, int)
1043
-
1044
1033
  all_pixel_values = torch.nn.utils.rnn.pad_sequence(
1045
1034
  all_pixel_values_lst, batch_first=True, padding_value=0.0
1046
1035
  )
1036
+
1047
1037
  B, L, _ = all_pixel_values.shape
1048
1038
  all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
1049
1039
  patch_attn_mask = torch.zeros(
@@ -796,14 +796,16 @@ class MllamaForConditionalGeneration(nn.Module):
796
796
  self.logits_processor = LogitsProcessor(config.text_config)
797
797
  self.capture_mode = False
798
798
 
799
- def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
800
- pixel_values = image_inputs.pixel_values
801
- pad_values = image_inputs.pad_values
799
+ def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
800
+ pixel_values = torch.cat(
801
+ [item.pixel_values for item in mm_inputs.mm_items], dim=0
802
+ )
803
+ pad_values = [item.pad_value for item in mm_inputs.mm_items]
802
804
 
803
805
  num_concurrent_media, num_tiles = pixel_values.shape[1:3]
804
806
  num_patches = self.vision_model.num_patches
805
807
  image_len = num_concurrent_media * num_tiles * num_patches
806
- image_inputs.num_image_tokens = image_len
808
+ mm_inputs.num_image_tokens = image_len
807
809
 
808
810
  pad_ids = pad_values * ((image_len + len(pad_values)) // len(pad_values))
809
811
 
@@ -815,10 +817,16 @@ class MllamaForConditionalGeneration(nn.Module):
815
817
 
816
818
  # pixel_values: shape (bs, num_image, num_tiles, 3, image_res, image_res)
817
819
  max_num_images = max_num_tiles = bs = 0
818
- for i, im in enumerate(forward_batch.mm_inputs):
819
- if not forward_batch.encoder_cached[i] and im is not None:
820
- max_num_images = max(max_num_images, im.pixel_values.shape[1])
821
- max_num_tiles = max(max_num_tiles, im.pixel_values.shape[2])
820
+ for i, mm_input in enumerate(forward_batch.mm_inputs):
821
+
822
+ if not forward_batch.encoder_cached[i] and mm_input is not None:
823
+ pixel_values = torch.cat(
824
+ [item.pixel_values for item in mm_input.mm_items], dim=0
825
+ )
826
+ # max_num_images = max(max_num_images, sum(1 if item.is_image() else 0 for item in mm_input.items))
827
+ max_num_images = max(max_num_images, pixel_values.shape[1])
828
+
829
+ max_num_tiles = max(max_num_tiles, pixel_values.shape[2])
822
830
  bs += 1
823
831
 
824
832
  if max_num_images * max_num_tiles * bs == 0:
@@ -842,17 +850,24 @@ class MllamaForConditionalGeneration(nn.Module):
842
850
  )
843
851
  i = 0
844
852
  encoder_lens_need = []
845
- for k, im in enumerate(forward_batch.mm_inputs):
846
- if forward_batch.encoder_cached[k] or im is None:
853
+
854
+ for k, mm_input in enumerate(forward_batch.mm_inputs):
855
+ if forward_batch.encoder_cached[k] or mm_input is None:
847
856
  continue
848
857
 
849
858
  encoder_lens_need.append(forward_batch.encoder_lens[k])
850
- for j in range(im.pixel_values.shape[1]):
851
- img = im.pixel_values[0, j]
859
+ pixel_values = torch.cat(
860
+ [item.pixel_values for item in mm_input.mm_items], dim=0
861
+ )
862
+ for j in range(pixel_values.shape[1]):
863
+ img = pixel_values[0, j]
852
864
  num_tiles = img.shape[0]
853
865
  batched_images[i, j, :num_tiles] = img
854
- batched_ar_ids[i, j] = im.aspect_ratio_ids[0, j]
855
- batched_ar_mask[i, j, :num_tiles] = im.aspect_ratio_mask[0, j]
866
+ batched_ar_ids[i, j] = mm_input.mm_items[0].aspect_ratio_id[0, j]
867
+
868
+ batched_ar_mask[i, j, :num_tiles] = mm_input.mm_items[
869
+ 0
870
+ ].aspect_ratio_mask[0, j]
856
871
  i += 1
857
872
 
858
873
  return batched_images, batched_ar_ids, batched_ar_mask, encoder_lens_need
@@ -0,0 +1,154 @@
1
+ # TODO: add Aapted from vllm/mllama4.py
2
+ from collections.abc import Iterable
3
+ from typing import Optional, Set, Tuple
4
+
5
+ import torch
6
+ from torch import nn
7
+ from transformers import Llama4Config
8
+
9
+ from sglang.srt.layers.logits_processor import LogitsProcessor
10
+ from sglang.srt.layers.quantization import QuantizationConfig
11
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
12
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
13
+ from sglang.srt.utils import add_prefix
14
+
15
+
16
+ class Llama4ForConditionalGeneration(nn.Module):
17
+ packed_modules_mapping = {
18
+ "qkv_proj": ["q_proj", "k_proj", "v_proj"],
19
+ }
20
+
21
+ def __init__(
22
+ self,
23
+ config: Llama4Config,
24
+ quant_config: Optional[QuantizationConfig] = None,
25
+ prefix: str = "",
26
+ ):
27
+ super().__init__()
28
+ self.config = config
29
+ self.quant_config = quant_config
30
+
31
+ # Initialize the language model
32
+ from sglang.srt.models.llama4 import Llama4ForCausalLM
33
+
34
+ self.language_model = Llama4ForCausalLM(
35
+ config.text_config,
36
+ quant_config=quant_config,
37
+ prefix=add_prefix("language_model", prefix),
38
+ )
39
+
40
+ self.logits_processor = LogitsProcessor(config.text_config)
41
+
42
+ def forward(
43
+ self,
44
+ input_ids: torch.Tensor,
45
+ positions: torch.Tensor,
46
+ forward_batch: ForwardBatch,
47
+ **kwargs: object,
48
+ ) -> torch.Tensor:
49
+
50
+ return self.language_model(input_ids, positions, forward_batch)
51
+
52
+ def permute_qk_weight_for_rotary(
53
+ self,
54
+ name: str,
55
+ loaded_weight: torch.Tensor,
56
+ ) -> Tuple[str, torch.Tensor]:
57
+
58
+ def permute(w: torch.Tensor, n_heads: int):
59
+ attn_in = self.language_model.config.head_dim * n_heads
60
+ attn_out = self.language_model.config.hidden_size
61
+
62
+ return (
63
+ w.view(n_heads, attn_in // n_heads // 2, 2, attn_out)
64
+ .transpose(1, 2)
65
+ .reshape(attn_in, attn_out)
66
+ )
67
+
68
+ modules = name.split(".")
69
+
70
+ # rotary embeds should be sliced
71
+ if ("wk" in modules or "k_proj" in modules) and modules[-1] == "weight":
72
+ loaded_weight = permute(
73
+ loaded_weight, self.language_model.config.num_key_value_heads
74
+ )
75
+ elif ("wq" in modules or "q_proj" in modules) and modules[-1] == "weight":
76
+ loaded_weight = permute(
77
+ loaded_weight, self.language_model.config.num_attention_heads
78
+ )
79
+
80
+ return name, loaded_weight
81
+
82
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
83
+
84
+ stacked_params_mapping = [
85
+ # (param_name, shard_name, shard_id)
86
+ (".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
87
+ (".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
88
+ (".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
89
+ (".shared_expert.gate_up_proj", ".shared_expert.gate_proj", 0),
90
+ (".shared_expert.gate_up_proj", ".shared_expert.up_proj", 1),
91
+ (".feed_forward.gate_up_proj", ".feed_forward.gate_proj", 0),
92
+ (".feed_forward.gate_up_proj", ".feed_forward.up_proj", 1),
93
+ ]
94
+
95
+ params_dict = dict(self.named_parameters())
96
+
97
+ num_experts = self.config.text_config.num_local_experts
98
+
99
+ for name, loaded_weight in weights:
100
+
101
+ if name.startswith("vision_model") or name.startswith(
102
+ "multi_modal_projector"
103
+ ):
104
+ continue
105
+
106
+ name, loaded_weight = self.permute_qk_weight_for_rotary(name, loaded_weight)
107
+
108
+ for param_name, weight_name, shard_id in stacked_params_mapping:
109
+ if weight_name not in name:
110
+ continue
111
+ name = name.replace(weight_name, param_name)
112
+ param = params_dict[name]
113
+ weight_loader = param.weight_loader
114
+ weight_loader(param, loaded_weight, shard_id)
115
+ break
116
+ else:
117
+ if ".experts" in name:
118
+ if ".gate_up_proj" in name:
119
+ name_list = [
120
+ name.replace(".experts.gate_up_proj", ".experts.w13_weight")
121
+ ] * 2
122
+ loaded_weight_list = loaded_weight.chunk(2, dim=-1)
123
+ shard_id_list = ["w1", "w3"]
124
+ else:
125
+ name_list = [
126
+ name.replace(".experts.down_proj", ".experts.w2_weight")
127
+ ]
128
+ shard_id_list = ["w2"]
129
+ loaded_weight_list = [loaded_weight]
130
+ for name, loaded_weight, shard_id in zip(
131
+ name_list, loaded_weight_list, shard_id_list
132
+ ):
133
+ param = params_dict[name]
134
+ weight_loader = param.weight_loader
135
+ for expert_id in range(num_experts):
136
+ weight_loader(
137
+ param,
138
+ loaded_weight[expert_id].T,
139
+ name,
140
+ shard_id=shard_id,
141
+ expert_id=expert_id,
142
+ )
143
+ else:
144
+ # Skip loading extra bias for GPTQ models.
145
+ if name.endswith(".bias") and name not in params_dict:
146
+ continue
147
+ param = params_dict[name]
148
+ weight_loader = getattr(
149
+ param, "weight_loader", default_weight_loader
150
+ )
151
+ weight_loader(param, loaded_weight)
152
+
153
+
154
+ EntryClass = Llama4ForConditionalGeneration
@@ -261,11 +261,14 @@ class Qwen2Model(nn.Module):
261
261
  )
262
262
  self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
263
263
 
264
- def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
264
+ def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
265
265
  if hasattr(self.config, "scale_emb"):
266
- return self.embed_tokens(input_ids) * self.config.scale_emb
266
+ return self.get_input_embeddings()(input_ids) * self.config.scale_emb
267
267
  else:
268
- return self.embed_tokens(input_ids)
268
+ return self.get_input_embeddings()(input_ids)
269
+
270
+ def get_input_embeddings(self) -> nn.Embedding:
271
+ return self.embed_tokens
269
272
 
270
273
  def forward(
271
274
  self,
@@ -358,10 +361,10 @@ class Qwen2ForCausalLM(nn.Module):
358
361
  self.logits_processor = LogitsProcessor(config)
359
362
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
360
363
 
361
- def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
362
- return self.model.get_input_embeddings(input_ids)
364
+ def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
365
+ return self.model.get_input_embedding(input_ids)
363
366
 
364
- def get_input_embedding(self) -> nn.Embedding:
367
+ def get_input_embeddings(self) -> nn.Embedding:
365
368
  return self.model.embed_tokens
366
369
 
367
370
  @torch.no_grad()