sglang 0.4.4.post3__py3-none-any.whl → 0.4.4.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +49 -7
- sglang/srt/_custom_ops.py +59 -92
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/custom_op.py +5 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +27 -79
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/entrypoints/engine.py +0 -5
- sglang/srt/layers/attention/flashattention_backend.py +394 -76
- sglang/srt/layers/attention/flashinfer_backend.py +5 -7
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
- sglang/srt/layers/attention/flashmla_backend.py +1 -1
- sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
- sglang/srt/layers/moe/ep_moe/layer.py +79 -80
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +403 -47
- sglang/srt/layers/moe/topk.py +49 -3
- sglang/srt/layers/quantization/__init__.py +4 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
- sglang/srt/layers/quantization/fp8_utils.py +1 -4
- sglang/srt/layers/quantization/moe_wna16.py +501 -0
- sglang/srt/layers/quantization/utils.py +1 -1
- sglang/srt/layers/rotary_embedding.py +0 -12
- sglang/srt/managers/cache_controller.py +34 -11
- sglang/srt/managers/mm_utils.py +202 -156
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
- sglang/srt/managers/multimodal_processors/clip.py +7 -26
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
- sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
- sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
- sglang/srt/managers/multimodal_processors/llava.py +34 -14
- sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
- sglang/srt/managers/multimodal_processors/mlama.py +10 -23
- sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
- sglang/srt/managers/schedule_batch.py +185 -128
- sglang/srt/managers/scheduler.py +4 -4
- sglang/srt/managers/tokenizer_manager.py +1 -1
- sglang/srt/managers/utils.py +1 -6
- sglang/srt/mem_cache/hiradix_cache.py +62 -52
- sglang/srt/mem_cache/memory_pool.py +72 -6
- sglang/srt/mem_cache/paged_allocator.py +39 -0
- sglang/srt/metrics/collector.py +23 -53
- sglang/srt/model_executor/cuda_graph_runner.py +8 -6
- sglang/srt/model_executor/forward_batch_info.py +10 -10
- sglang/srt/model_executor/model_runner.py +59 -57
- sglang/srt/model_loader/loader.py +8 -0
- sglang/srt/models/clip.py +12 -7
- sglang/srt/models/deepseek_janus_pro.py +10 -15
- sglang/srt/models/deepseek_v2.py +212 -121
- sglang/srt/models/deepseek_vl2.py +105 -104
- sglang/srt/models/gemma3_mm.py +14 -80
- sglang/srt/models/llama.py +4 -1
- sglang/srt/models/llava.py +31 -19
- sglang/srt/models/llavavid.py +16 -7
- sglang/srt/models/minicpmo.py +63 -147
- sglang/srt/models/minicpmv.py +17 -27
- sglang/srt/models/mllama.py +29 -14
- sglang/srt/models/qwen2.py +9 -6
- sglang/srt/models/qwen2_5_vl.py +21 -31
- sglang/srt/models/qwen2_vl.py +20 -21
- sglang/srt/openai_api/adapter.py +18 -6
- sglang/srt/platforms/interface.py +371 -0
- sglang/srt/server_args.py +99 -14
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
- sglang/srt/speculative/eagle_utils.py +140 -28
- sglang/srt/speculative/eagle_worker.py +93 -24
- sglang/srt/utils.py +104 -51
- sglang/test/test_custom_ops.py +55 -0
- sglang/test/test_utils.py +13 -26
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/METADATA +4 -3
- {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/RECORD +81 -76
- {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/top_level.txt +0 -0
sglang/srt/models/minicpmo.py
CHANGED
@@ -40,16 +40,19 @@ from transformers.models.whisper.modeling_whisper import (
|
|
40
40
|
from sglang.srt.layers.quantization import QuantizationConfig
|
41
41
|
from sglang.srt.managers.mm_utils import (
|
42
42
|
MultiModalityDataPaddingPatternTokenPairs,
|
43
|
-
|
44
|
-
get_multimodal_data_bounds,
|
43
|
+
general_mm_embed_routine,
|
45
44
|
)
|
46
|
-
from sglang.srt.managers.schedule_batch import
|
47
|
-
|
45
|
+
from sglang.srt.managers.schedule_batch import (
|
46
|
+
MultimodalDataItem,
|
47
|
+
MultimodalInputs,
|
48
|
+
flatten_nested_list,
|
49
|
+
)
|
50
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
48
51
|
from sglang.srt.model_loader.utils import set_default_torch_dtype
|
49
52
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
50
53
|
from sglang.srt.models.minicpmv import (
|
51
54
|
Idefics2VisionTransformer,
|
52
|
-
|
55
|
+
MiniCPMBaseModel,
|
53
56
|
Resampler2_5,
|
54
57
|
)
|
55
58
|
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
@@ -1409,7 +1412,7 @@ class MultiModalProjector(nn.Module):
|
|
1409
1412
|
return hidden_states
|
1410
1413
|
|
1411
1414
|
|
1412
|
-
class MiniCPMO(
|
1415
|
+
class MiniCPMO(MiniCPMBaseModel):
|
1413
1416
|
def __init__(
|
1414
1417
|
self,
|
1415
1418
|
config: PretrainedConfig,
|
@@ -1537,7 +1540,7 @@ class MiniCPMO(MiniCPMVBaseModel):
|
|
1537
1540
|
|
1538
1541
|
return input_lengths_after_cnn, input_lengths_after_pooling
|
1539
1542
|
|
1540
|
-
def get_audio_embedding_streaming(self,
|
1543
|
+
def get_audio_embedding_streaming(self, items: List[MultimodalDataItem]):
|
1541
1544
|
r"""
|
1542
1545
|
Extract audio embeddings in a streaming manner using cached key-value pairs.
|
1543
1546
|
|
@@ -1545,26 +1548,15 @@ class MiniCPMO(MiniCPMVBaseModel):
|
|
1545
1548
|
for faster inference on subsequent audio frames. It only supports batch_size=1 and is intended
|
1546
1549
|
for streaming scenarios.
|
1547
1550
|
|
1548
|
-
Args:
|
1549
|
-
multimodal_input (dict):
|
1550
|
-
- **"audio_features"** (`torch.FloatTensor`): Input mel-spectrograms of shape `(batch_size, 80, frames)`.
|
1551
|
-
- **"audio_feature_lens"** (List[List[int]]): Lengths of each audio segment for each item in the batch.
|
1552
|
-
|
1553
1551
|
Returns:
|
1554
1552
|
List[List[torch.Tensor]]: audio embeddings
|
1555
1553
|
"""
|
1556
|
-
|
1557
|
-
|
1558
|
-
wavforms = (
|
1559
|
-
[]
|
1560
|
-
if multimodal_input.audio_features is None
|
1561
|
-
else multimodal_input.audio_features
|
1554
|
+
wavforms = flatten_nested_list(
|
1555
|
+
[item.audio_features for item in items if item.audio_features]
|
1562
1556
|
)
|
1563
1557
|
# list, [[x1, x2], [y1], [z1]]
|
1564
|
-
audio_feature_lens_raw = (
|
1565
|
-
[]
|
1566
|
-
if multimodal_input.audio_feature_lens is None
|
1567
|
-
else multimodal_input.audio_feature_lens
|
1558
|
+
audio_feature_lens_raw = flatten_nested_list(
|
1559
|
+
[item.audio_feature_lens for item in items if item.audio_feature_lens]
|
1568
1560
|
)
|
1569
1561
|
|
1570
1562
|
# exist audio
|
@@ -1650,7 +1642,7 @@ class MiniCPMO(MiniCPMVBaseModel):
|
|
1650
1642
|
ret[i, start:ending] = True
|
1651
1643
|
return ret
|
1652
1644
|
|
1653
|
-
def get_audio_embedding(self,
|
1645
|
+
def get_audio_embedding(self, items: List[MultimodalDataItem], chunk_length=-1):
|
1654
1646
|
r"""
|
1655
1647
|
Extract full audio embeddings with optional chunk-based attention.
|
1656
1648
|
|
@@ -1659,31 +1651,25 @@ class MiniCPMO(MiniCPMVBaseModel):
|
|
1659
1651
|
not use key-value caching and is suitable for non-streaming inference.
|
1660
1652
|
|
1661
1653
|
Args:
|
1662
|
-
multimodal_input (dict):
|
1663
|
-
- **"audio_features"** (`torch.FloatTensor`): Input mel-spectrograms of shape `(batch_size, 80, frames)`.
|
1664
|
-
- **"audio_feature_lens"** (List[List[int]]): Lengths of each audio segment for each item in the batch.
|
1665
1654
|
chunk_length (int, optional): Determines whether to use full attention (-1) or chunk-based
|
1666
1655
|
attention (>0) during embedding computation.
|
1667
1656
|
|
1668
1657
|
Returns:
|
1669
1658
|
List[List[torch.Tensor]]: audio embeddings
|
1670
1659
|
"""
|
1671
|
-
# print("audio embedding")
|
1672
1660
|
# (bs, 80, frames) or [], multi audios need filled in advance
|
1673
|
-
wavforms = (
|
1674
|
-
[]
|
1675
|
-
if multimodal_input.audio_features is None
|
1676
|
-
else multimodal_input.audio_features
|
1661
|
+
wavforms = flatten_nested_list(
|
1662
|
+
[item.audio_features for item in items if item.audio_features]
|
1677
1663
|
)
|
1678
1664
|
# list, [[x1, x2], [y1], [z1]]
|
1679
|
-
audio_feature_lens_raw = (
|
1680
|
-
[]
|
1681
|
-
if multimodal_input.audio_feature_lens is None
|
1682
|
-
else multimodal_input.audio_feature_lens
|
1665
|
+
audio_feature_lens_raw = flatten_nested_list(
|
1666
|
+
[item.audio_feature_lens for item in items if item.audio_feature_lens]
|
1683
1667
|
)
|
1684
1668
|
|
1685
1669
|
final_audio_embeds = []
|
1686
1670
|
|
1671
|
+
assert isinstance(wavforms, list)
|
1672
|
+
assert isinstance(wavforms[0], torch.Tensor)
|
1687
1673
|
# exist audio
|
1688
1674
|
for wavform in wavforms:
|
1689
1675
|
if len(wavform) > 0:
|
@@ -1757,86 +1743,46 @@ class MiniCPMO(MiniCPMVBaseModel):
|
|
1757
1743
|
final_audio_embeds.append(target_audio_embeds)
|
1758
1744
|
return final_audio_embeds
|
1759
1745
|
|
1746
|
+
def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
1747
|
+
embedding = self.get_omni_embedding(
|
1748
|
+
items=items,
|
1749
|
+
chunk_length=self.config.audio_chunk_length,
|
1750
|
+
stream_input=False,
|
1751
|
+
)
|
1752
|
+
return embedding
|
1753
|
+
|
1760
1754
|
def get_omni_embedding(
|
1761
1755
|
self,
|
1762
|
-
|
1763
|
-
multimodal_input: MultimodalInputs,
|
1764
|
-
input_embeds: torch.Tensor,
|
1765
|
-
forward_mode: ForwardMode,
|
1756
|
+
items: List[MultimodalDataItem],
|
1766
1757
|
chunk_length=-1,
|
1767
1758
|
stream_input=False,
|
1768
1759
|
):
|
1769
1760
|
"""
|
1770
1761
|
Args:
|
1771
|
-
multimodal_input:
|
1772
|
-
input_embeds:
|
1773
1762
|
chunk_length: whisper use full attention or chunk attention
|
1774
1763
|
stream_input: use streaming audio embedding
|
1775
1764
|
Returns:
|
1776
1765
|
final embeddings with audio feature
|
1777
1766
|
"""
|
1778
|
-
|
1779
|
-
if
|
1780
|
-
|
1781
|
-
|
1782
|
-
|
1783
|
-
|
1784
|
-
|
1785
|
-
|
1786
|
-
|
1787
|
-
|
1788
|
-
|
1789
|
-
|
1790
|
-
|
1791
|
-
|
1792
|
-
|
1793
|
-
|
1794
|
-
|
1795
|
-
|
1796
|
-
|
1797
|
-
else:
|
1798
|
-
audio_embeddings = self.get_audio_embedding(
|
1799
|
-
multimodal_input, chunk_length
|
1800
|
-
)
|
1801
|
-
# batch size
|
1802
|
-
assert len(audio_embeddings) == len(input_embeds)
|
1803
|
-
if len(audio_embeddings) > 0:
|
1804
|
-
if self.config.chunk_input:
|
1805
|
-
for i in range(bs):
|
1806
|
-
audio_embs = torch.cat(audio_embeddings[i], dim=0).to(
|
1807
|
-
device=input_embeds.device, dtype=input_embeds.dtype
|
1808
|
-
)
|
1809
|
-
audio_start_pos = 0
|
1810
|
-
for bound in audio_bounds[i]:
|
1811
|
-
audio_len = bound[1] - bound[0] + 1
|
1812
|
-
input_embeds[0, bound[0] : bound[1] + 1] = audio_embs[
|
1813
|
-
audio_start_pos : audio_start_pos + audio_len, :
|
1814
|
-
]
|
1815
|
-
audio_start_pos += audio_len
|
1816
|
-
else:
|
1817
|
-
for i in range(bs):
|
1818
|
-
audio_embs = audio_embeddings[i]
|
1819
|
-
bounds = audio_bounds[i]
|
1820
|
-
for embs, bound in zip(audio_embs, bounds):
|
1821
|
-
audio_indices = torch.arange(
|
1822
|
-
bound[0], bound[1], dtype=torch.long
|
1823
|
-
).to(input_embeds.device)
|
1824
|
-
|
1825
|
-
if embs.shape[0] != len(audio_indices):
|
1826
|
-
raise ValueError(
|
1827
|
-
f"Shape mismatch: Trying to assign embeddings of shape {embs.shape} "
|
1828
|
-
f"to input indices of length {len(audio_indices)}"
|
1829
|
-
)
|
1830
|
-
input_embeds[i, audio_indices] = embs.to(input_embeds.dtype)
|
1831
|
-
input_embeds = input_embeds.squeeze(0)
|
1832
|
-
return input_embeds
|
1833
|
-
|
1834
|
-
def get_image_features(
|
1835
|
-
self,
|
1836
|
-
image_inputs: MultimodalInputs,
|
1837
|
-
) -> torch.Tensor:
|
1838
|
-
pixel_values = image_inputs.pixel_values
|
1839
|
-
tgt_sizes = image_inputs.tgt_sizes
|
1767
|
+
|
1768
|
+
if stream_input:
|
1769
|
+
audio_embeddings = self.get_audio_embedding_streaming(items)
|
1770
|
+
else:
|
1771
|
+
audio_embeddings = self.get_audio_embedding(items, chunk_length)
|
1772
|
+
bs = len(audio_embeddings)
|
1773
|
+
# batch size
|
1774
|
+
audio_embs = torch.cat(flatten_nested_list(audio_embeddings), dim=0)
|
1775
|
+
|
1776
|
+
return audio_embs
|
1777
|
+
|
1778
|
+
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
1779
|
+
# list of tensors
|
1780
|
+
pixel_values = flatten_nested_list([item.pixel_values for item in items])
|
1781
|
+
tgt_sizes = torch.stack(
|
1782
|
+
flatten_nested_list([item.tgt_size for item in items]), dim=0
|
1783
|
+
)
|
1784
|
+
assert len(pixel_values) == tgt_sizes.shape[0]
|
1785
|
+
|
1840
1786
|
device = self.vpm.embeddings.position_embedding.weight.device
|
1841
1787
|
dtype = self.vpm.embeddings.position_embedding.weight.dtype
|
1842
1788
|
all_pixel_values_lst = [
|
@@ -1845,10 +1791,10 @@ class MiniCPMO(MiniCPMVBaseModel):
|
|
1845
1791
|
|
1846
1792
|
max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
|
1847
1793
|
assert isinstance(max_patches, int)
|
1848
|
-
|
1849
1794
|
all_pixel_values = torch.nn.utils.rnn.pad_sequence(
|
1850
1795
|
all_pixel_values_lst, batch_first=True, padding_value=0.0
|
1851
1796
|
)
|
1797
|
+
|
1852
1798
|
B, L, _ = all_pixel_values.shape
|
1853
1799
|
all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
|
1854
1800
|
patch_attn_mask = torch.zeros(
|
@@ -1875,53 +1821,23 @@ class MiniCPMO(MiniCPMVBaseModel):
|
|
1875
1821
|
forward_batch: ForwardBatch,
|
1876
1822
|
**kwargs: Any,
|
1877
1823
|
) -> torch.Tensor:
|
1878
|
-
inputs_embeds = None
|
1879
|
-
# TODO(mick): optimize the logic here: clamp, merge and embedding should happens at most once
|
1880
|
-
if (
|
1881
|
-
not forward_batch.forward_mode.is_decode()
|
1882
|
-
and forward_batch.contains_image_inputs()
|
1883
|
-
):
|
1884
|
-
mm_inputs = forward_batch.merge_mm_inputs()
|
1885
|
-
inputs_embeds = embed_mm_inputs(
|
1886
|
-
mm_input=mm_inputs,
|
1887
|
-
input_ids=input_ids,
|
1888
|
-
input_embedding=self.get_input_embeddings(),
|
1889
|
-
mm_data_embedding_func=self.get_image_features,
|
1890
|
-
placeholder_token_ids=[mm_inputs.im_token_id] + mm_inputs.pad_values,
|
1891
|
-
)
|
1892
1824
|
|
1893
|
-
|
1894
|
-
|
1825
|
+
mm_input = forward_batch.merge_mm_inputs()
|
1826
|
+
placeholder_token_ids = (
|
1827
|
+
([mm_input.im_token_id] + [item.pad_value for item in mm_input.mm_items])
|
1828
|
+
if forward_batch.contains_mm_inputs()
|
1829
|
+
else []
|
1895
1830
|
)
|
1896
|
-
|
1897
|
-
|
1898
|
-
if (
|
1899
|
-
not forward_batch.forward_mode.is_decode()
|
1900
|
-
and self.config.init_audio
|
1901
|
-
and forward_batch.contains_audio_inputs()
|
1902
|
-
):
|
1903
|
-
mm_input = forward_batch.merge_mm_inputs()
|
1904
|
-
inputs_embeds = self.get_omni_embedding(
|
1905
|
-
input_ids=input_ids,
|
1906
|
-
multimodal_input=mm_input,
|
1907
|
-
input_embeds=inputs_embeds,
|
1908
|
-
forward_mode=forward_batch.forward_mode,
|
1909
|
-
chunk_length=self.config.audio_chunk_length,
|
1910
|
-
stream_input=False,
|
1911
|
-
)
|
1912
|
-
|
1913
|
-
forward_batch.mm_inputs = None
|
1914
|
-
|
1915
|
-
hidden_states = self.llm.model(
|
1916
|
-
input_ids=None,
|
1917
|
-
positions=positions,
|
1831
|
+
hidden_states = general_mm_embed_routine(
|
1832
|
+
input_ids=input_ids,
|
1918
1833
|
forward_batch=forward_batch,
|
1919
|
-
|
1920
|
-
|
1921
|
-
|
1922
|
-
|
1923
|
-
|
1834
|
+
language_model=self.llm,
|
1835
|
+
image_data_embedding_func=self.get_image_feature,
|
1836
|
+
audio_data_embedding_func=self.get_audio_feature,
|
1837
|
+
placeholder_token_ids=placeholder_token_ids,
|
1838
|
+
positions=positions,
|
1924
1839
|
)
|
1840
|
+
return hidden_states
|
1925
1841
|
|
1926
1842
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
1927
1843
|
stacked_params_mapping = [
|
sglang/srt/models/minicpmv.py
CHANGED
@@ -54,12 +54,12 @@ from sglang.srt.managers.mm_utils import (
|
|
54
54
|
MultiModalityDataPaddingPatternTokenPairs,
|
55
55
|
general_mm_embed_routine,
|
56
56
|
)
|
57
|
-
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
57
|
+
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
|
58
58
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
59
59
|
from sglang.srt.model_loader.utils import set_default_torch_dtype
|
60
60
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
61
61
|
from sglang.srt.models.qwen2 import Qwen2Config, Qwen2ForCausalLM
|
62
|
-
from sglang.srt.utils import add_prefix
|
62
|
+
from sglang.srt.utils import add_prefix, flatten_nested_list
|
63
63
|
|
64
64
|
RawImageType = Union[Image.Image, torch.Tensor]
|
65
65
|
|
@@ -661,7 +661,7 @@ def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
|
|
661
661
|
return tuple(int(x) for x in version_str.split("."))
|
662
662
|
|
663
663
|
|
664
|
-
class
|
664
|
+
class MiniCPMBaseModel(nn.Module):
|
665
665
|
"""
|
666
666
|
The abstract class of MiniCPMV can only be inherited, but cannot be
|
667
667
|
instantiated.
|
@@ -853,7 +853,7 @@ class MiniCPMVBaseModel(nn.Module):
|
|
853
853
|
return vlm_embedding, vision_hidden_states
|
854
854
|
|
855
855
|
def get_input_embeddings(self) -> nn.Embedding:
|
856
|
-
return self.llm.
|
856
|
+
return self.llm.get_input_embeddings()
|
857
857
|
|
858
858
|
def forward(
|
859
859
|
self,
|
@@ -862,23 +862,14 @@ class MiniCPMVBaseModel(nn.Module):
|
|
862
862
|
forward_batch: ForwardBatch,
|
863
863
|
**kwargs: Any,
|
864
864
|
) -> torch.Tensor:
|
865
|
-
|
865
|
+
hidden_states = general_mm_embed_routine(
|
866
866
|
input_ids=input_ids,
|
867
867
|
forward_batch=forward_batch,
|
868
|
-
|
869
|
-
|
870
|
-
)
|
871
|
-
|
872
|
-
hidden_states = self.llm.model(
|
873
|
-
input_ids=None,
|
868
|
+
image_data_embedding_func=self.get_image_feature,
|
869
|
+
language_model=self.llm,
|
874
870
|
positions=positions,
|
875
|
-
forward_batch=forward_batch,
|
876
|
-
input_embeds=inputs_embeds,
|
877
|
-
)
|
878
|
-
|
879
|
-
return self.logits_processor(
|
880
|
-
input_ids, hidden_states, self.llm.lm_head, forward_batch
|
881
871
|
)
|
872
|
+
return hidden_states
|
882
873
|
|
883
874
|
def init_llm(
|
884
875
|
self,
|
@@ -913,11 +904,11 @@ class MiniCPMVBaseModel(nn.Module):
|
|
913
904
|
) -> torch.Tensor:
|
914
905
|
raise NotImplementedError
|
915
906
|
|
916
|
-
def
|
907
|
+
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
917
908
|
raise NotImplementedError
|
918
909
|
|
919
910
|
|
920
|
-
class MiniCPMV2_6(
|
911
|
+
class MiniCPMV2_6(MiniCPMBaseModel):
|
921
912
|
packed_modules_mapping = {
|
922
913
|
"qkv_proj": [
|
923
914
|
"q_proj",
|
@@ -1023,14 +1014,13 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
|
|
1023
1014
|
)
|
1024
1015
|
return vision_embedding
|
1025
1016
|
|
1026
|
-
def
|
1027
|
-
self,
|
1028
|
-
image_inputs: MultimodalInputs,
|
1029
|
-
) -> torch.Tensor:
|
1017
|
+
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
1030
1018
|
# list of tensors
|
1031
|
-
pixel_values =
|
1032
|
-
|
1033
|
-
|
1019
|
+
pixel_values = flatten_nested_list([item.pixel_values for item in items])
|
1020
|
+
tgt_sizes = torch.stack(
|
1021
|
+
flatten_nested_list([item.tgt_size for item in items]), dim=0
|
1022
|
+
)
|
1023
|
+
assert len(pixel_values) == tgt_sizes.shape[0]
|
1034
1024
|
|
1035
1025
|
device = self.vpm.embeddings.position_embedding.weight.device
|
1036
1026
|
dtype = self.vpm.embeddings.position_embedding.weight.dtype
|
@@ -1040,10 +1030,10 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
|
|
1040
1030
|
|
1041
1031
|
max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
|
1042
1032
|
assert isinstance(max_patches, int)
|
1043
|
-
|
1044
1033
|
all_pixel_values = torch.nn.utils.rnn.pad_sequence(
|
1045
1034
|
all_pixel_values_lst, batch_first=True, padding_value=0.0
|
1046
1035
|
)
|
1036
|
+
|
1047
1037
|
B, L, _ = all_pixel_values.shape
|
1048
1038
|
all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
|
1049
1039
|
patch_attn_mask = torch.zeros(
|
sglang/srt/models/mllama.py
CHANGED
@@ -796,14 +796,16 @@ class MllamaForConditionalGeneration(nn.Module):
|
|
796
796
|
self.logits_processor = LogitsProcessor(config.text_config)
|
797
797
|
self.capture_mode = False
|
798
798
|
|
799
|
-
def pad_input_ids(self, input_ids: List[int],
|
800
|
-
pixel_values =
|
801
|
-
|
799
|
+
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
800
|
+
pixel_values = torch.cat(
|
801
|
+
[item.pixel_values for item in mm_inputs.mm_items], dim=0
|
802
|
+
)
|
803
|
+
pad_values = [item.pad_value for item in mm_inputs.mm_items]
|
802
804
|
|
803
805
|
num_concurrent_media, num_tiles = pixel_values.shape[1:3]
|
804
806
|
num_patches = self.vision_model.num_patches
|
805
807
|
image_len = num_concurrent_media * num_tiles * num_patches
|
806
|
-
|
808
|
+
mm_inputs.num_image_tokens = image_len
|
807
809
|
|
808
810
|
pad_ids = pad_values * ((image_len + len(pad_values)) // len(pad_values))
|
809
811
|
|
@@ -815,10 +817,16 @@ class MllamaForConditionalGeneration(nn.Module):
|
|
815
817
|
|
816
818
|
# pixel_values: shape (bs, num_image, num_tiles, 3, image_res, image_res)
|
817
819
|
max_num_images = max_num_tiles = bs = 0
|
818
|
-
for i,
|
819
|
-
|
820
|
-
|
821
|
-
|
820
|
+
for i, mm_input in enumerate(forward_batch.mm_inputs):
|
821
|
+
|
822
|
+
if not forward_batch.encoder_cached[i] and mm_input is not None:
|
823
|
+
pixel_values = torch.cat(
|
824
|
+
[item.pixel_values for item in mm_input.mm_items], dim=0
|
825
|
+
)
|
826
|
+
# max_num_images = max(max_num_images, sum(1 if item.is_image() else 0 for item in mm_input.items))
|
827
|
+
max_num_images = max(max_num_images, pixel_values.shape[1])
|
828
|
+
|
829
|
+
max_num_tiles = max(max_num_tiles, pixel_values.shape[2])
|
822
830
|
bs += 1
|
823
831
|
|
824
832
|
if max_num_images * max_num_tiles * bs == 0:
|
@@ -842,17 +850,24 @@ class MllamaForConditionalGeneration(nn.Module):
|
|
842
850
|
)
|
843
851
|
i = 0
|
844
852
|
encoder_lens_need = []
|
845
|
-
|
846
|
-
|
853
|
+
|
854
|
+
for k, mm_input in enumerate(forward_batch.mm_inputs):
|
855
|
+
if forward_batch.encoder_cached[k] or mm_input is None:
|
847
856
|
continue
|
848
857
|
|
849
858
|
encoder_lens_need.append(forward_batch.encoder_lens[k])
|
850
|
-
|
851
|
-
|
859
|
+
pixel_values = torch.cat(
|
860
|
+
[item.pixel_values for item in mm_input.mm_items], dim=0
|
861
|
+
)
|
862
|
+
for j in range(pixel_values.shape[1]):
|
863
|
+
img = pixel_values[0, j]
|
852
864
|
num_tiles = img.shape[0]
|
853
865
|
batched_images[i, j, :num_tiles] = img
|
854
|
-
batched_ar_ids[i, j] =
|
855
|
-
|
866
|
+
batched_ar_ids[i, j] = mm_input.mm_items[0].aspect_ratio_id[0, j]
|
867
|
+
|
868
|
+
batched_ar_mask[i, j, :num_tiles] = mm_input.mm_items[
|
869
|
+
0
|
870
|
+
].aspect_ratio_mask[0, j]
|
856
871
|
i += 1
|
857
872
|
|
858
873
|
return batched_images, batched_ar_ids, batched_ar_mask, encoder_lens_need
|
sglang/srt/models/qwen2.py
CHANGED
@@ -261,11 +261,14 @@ class Qwen2Model(nn.Module):
|
|
261
261
|
)
|
262
262
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
263
263
|
|
264
|
-
def
|
264
|
+
def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
|
265
265
|
if hasattr(self.config, "scale_emb"):
|
266
|
-
return self.
|
266
|
+
return self.get_input_embeddings()(input_ids) * self.config.scale_emb
|
267
267
|
else:
|
268
|
-
return self.
|
268
|
+
return self.get_input_embeddings()(input_ids)
|
269
|
+
|
270
|
+
def get_input_embeddings(self) -> nn.Embedding:
|
271
|
+
return self.embed_tokens
|
269
272
|
|
270
273
|
def forward(
|
271
274
|
self,
|
@@ -358,10 +361,10 @@ class Qwen2ForCausalLM(nn.Module):
|
|
358
361
|
self.logits_processor = LogitsProcessor(config)
|
359
362
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
360
363
|
|
361
|
-
def
|
362
|
-
return self.model.
|
364
|
+
def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
|
365
|
+
return self.model.get_input_embedding(input_ids)
|
363
366
|
|
364
|
-
def
|
367
|
+
def get_input_embeddings(self) -> nn.Embedding:
|
365
368
|
return self.model.embed_tokens
|
366
369
|
|
367
370
|
@torch.no_grad()
|
sglang/srt/models/qwen2_5_vl.py
CHANGED
@@ -30,22 +30,13 @@ import torch
|
|
30
30
|
import torch.nn as nn
|
31
31
|
import torch.nn.functional as F
|
32
32
|
from einops import rearrange
|
33
|
-
from transformers import
|
33
|
+
from transformers import Qwen2VLConfig
|
34
34
|
from transformers.activations import ACT2FN
|
35
35
|
from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm
|
36
|
-
from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
|
37
36
|
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
38
|
-
Qwen2_5_VLConfig,
|
39
37
|
Qwen2_5_VLVisionConfig,
|
40
38
|
)
|
41
|
-
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
42
|
-
Qwen2_5_VLForConditionalGeneration,
|
43
|
-
)
|
44
39
|
|
45
|
-
from sglang.srt.distributed import (
|
46
|
-
get_tensor_model_parallel_rank,
|
47
|
-
get_tensor_model_parallel_world_size,
|
48
|
-
)
|
49
40
|
from sglang.srt.hf_transformers_utils import get_processor
|
50
41
|
from sglang.srt.layers.attention.vision import VisionAttention
|
51
42
|
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
|
@@ -57,7 +48,7 @@ from sglang.srt.managers.mm_utils import (
|
|
57
48
|
MultiModalityDataPaddingPatternTokenPairs,
|
58
49
|
general_mm_embed_routine,
|
59
50
|
)
|
60
|
-
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
51
|
+
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
|
61
52
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
62
53
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
63
54
|
from sglang.srt.models.qwen2 import Qwen2Model
|
@@ -513,19 +504,24 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|
513
504
|
self.logits_processor = LogitsProcessor(config)
|
514
505
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
515
506
|
|
516
|
-
def pad_input_ids(self, input_ids: List[int],
|
507
|
+
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
517
508
|
# Get all special token IDs
|
518
|
-
im_start_id: int =
|
519
|
-
im_end_id: int =
|
509
|
+
im_start_id: int = mm_inputs.im_start_id
|
510
|
+
im_end_id: int = mm_inputs.im_end_id
|
520
511
|
|
521
512
|
media_token_pairs = [(im_start_id, im_end_id)]
|
522
513
|
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
|
514
|
+
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
523
515
|
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
516
|
+
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
517
|
+
# in qwen-vl, last dim is the same
|
518
|
+
pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
|
519
|
+
self.visual.dtype
|
520
|
+
)
|
521
|
+
image_grid_thws = torch.concat([item.image_grid_thws for item in items], dim=0)
|
522
|
+
assert pixel_values.dim() == 2, pixel_values.dim()
|
523
|
+
assert image_grid_thws.dim() == 2, image_grid_thws.dim()
|
524
|
+
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thws)
|
529
525
|
return image_embeds
|
530
526
|
|
531
527
|
def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
|
@@ -570,18 +566,12 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|
570
566
|
f"(3, seq_len) positions, but got {positions.size()}"
|
571
567
|
)
|
572
568
|
|
573
|
-
|
569
|
+
hidden_states = general_mm_embed_routine(
|
574
570
|
input_ids=input_ids,
|
575
571
|
forward_batch=forward_batch,
|
576
|
-
|
577
|
-
|
578
|
-
)
|
579
|
-
|
580
|
-
hidden_states = self.model(
|
581
|
-
input_ids=None,
|
572
|
+
language_model=self.model,
|
573
|
+
image_data_embedding_func=self.get_image_feature,
|
582
574
|
positions=positions,
|
583
|
-
forward_batch=forward_batch,
|
584
|
-
input_embeds=inputs_embeds,
|
585
575
|
)
|
586
576
|
|
587
577
|
if not get_embedding:
|
@@ -594,9 +584,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|
594
584
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
595
585
|
stacked_params_mapping = [
|
596
586
|
# (param_name, shard_name, shard_id)
|
597
|
-
("qkv_proj", "q_proj", "q"),
|
598
|
-
("qkv_proj", "k_proj", "k"),
|
599
|
-
("qkv_proj", "v_proj", "v"),
|
587
|
+
(".qkv_proj", ".q_proj", "q"),
|
588
|
+
(".qkv_proj", ".k_proj", "k"),
|
589
|
+
(".qkv_proj", ".v_proj", "v"),
|
600
590
|
("gate_up_proj", "up_proj", 1),
|
601
591
|
("gate_up_proj", "gate_proj", 0),
|
602
592
|
]
|