sglang 0.4.5.post3__py3-none-any.whl → 0.4.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +19 -3
- sglang/bench_serving.py +8 -9
- sglang/compile_deep_gemm.py +45 -4
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +1 -1
- sglang/srt/configs/model_config.py +9 -3
- sglang/srt/constrained/llguidance_backend.py +78 -61
- sglang/srt/conversation.py +34 -1
- sglang/srt/disaggregation/decode.py +59 -11
- sglang/srt/disaggregation/mini_lb.py +45 -8
- sglang/srt/disaggregation/mooncake/conn.py +198 -31
- sglang/srt/disaggregation/prefill.py +24 -9
- sglang/srt/entrypoints/http_server.py +8 -2
- sglang/srt/function_call_parser.py +77 -5
- sglang/srt/layers/attention/base_attn_backend.py +3 -0
- sglang/srt/layers/attention/flashattention_backend.py +28 -10
- sglang/srt/layers/attention/flashmla_backend.py +8 -11
- sglang/srt/layers/attention/vision.py +2 -0
- sglang/srt/layers/layernorm.py +38 -16
- sglang/srt/layers/logits_processor.py +2 -2
- sglang/srt/layers/moe/fused_moe_native.py +2 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +41 -41
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -15
- sglang/srt/layers/pooler.py +6 -0
- sglang/srt/layers/quantization/awq.py +5 -1
- sglang/srt/layers/quantization/deep_gemm.py +17 -10
- sglang/srt/layers/quantization/int8_kernel.py +32 -1
- sglang/srt/layers/radix_attention.py +13 -3
- sglang/srt/layers/rotary_embedding.py +170 -126
- sglang/srt/managers/data_parallel_controller.py +10 -3
- sglang/srt/managers/io_struct.py +7 -0
- sglang/srt/managers/mm_utils.py +85 -28
- sglang/srt/managers/multimodal_processors/base_processor.py +14 -1
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +9 -2
- sglang/srt/managers/multimodal_processors/gemma3.py +2 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +2 -2
- sglang/srt/managers/multimodal_processors/minicpm.py +4 -3
- sglang/srt/managers/multimodal_processors/qwen_vl.py +38 -13
- sglang/srt/managers/schedule_batch.py +29 -12
- sglang/srt/managers/scheduler.py +31 -20
- sglang/srt/managers/tokenizer_manager.py +5 -1
- sglang/srt/mem_cache/memory_pool.py +87 -0
- sglang/srt/model_executor/cuda_graph_runner.py +4 -3
- sglang/srt/model_executor/forward_batch_info.py +51 -95
- sglang/srt/model_executor/model_runner.py +11 -24
- sglang/srt/models/deepseek.py +12 -2
- sglang/srt/models/deepseek_nextn.py +101 -6
- sglang/srt/models/deepseek_v2.py +144 -70
- sglang/srt/models/deepseek_vl2.py +9 -4
- sglang/srt/models/gemma3_causal.py +1 -1
- sglang/srt/models/llama4.py +0 -1
- sglang/srt/models/minicpmo.py +5 -1
- sglang/srt/models/mllama4.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +3 -6
- sglang/srt/models/qwen2_vl.py +3 -7
- sglang/srt/models/roberta.py +178 -0
- sglang/srt/openai_api/adapter.py +18 -8
- sglang/srt/server_args.py +15 -22
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/torch_memory_saver_adapter.py +10 -1
- sglang/srt/utils.py +2 -1
- sglang/test/runners.py +6 -13
- sglang/test/test_utils.py +36 -18
- sglang/version.py +1 -1
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/METADATA +4 -5
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/RECORD +70 -68
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/WHEEL +1 -1
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/top_level.txt +0 -0
@@ -14,8 +14,6 @@ _is_cuda = is_cuda()
|
|
14
14
|
|
15
15
|
if _is_cuda:
|
16
16
|
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
|
17
|
-
else:
|
18
|
-
from vllm._custom_ops import rotary_embedding as vllm_rotary_embedding
|
19
17
|
|
20
18
|
|
21
19
|
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
@@ -84,6 +82,12 @@ class RotaryEmbedding(CustomOp):
|
|
84
82
|
# NOTE(ByronHsu): cache needs to be in FP32 for numerical stability
|
85
83
|
if not _is_cuda:
|
86
84
|
cache = cache.to(dtype)
|
85
|
+
|
86
|
+
if not _is_cuda or self.head_size not in [64, 128, 256, 512]:
|
87
|
+
from vllm._custom_ops import rotary_embedding
|
88
|
+
|
89
|
+
self.vllm_rotary_embedding = rotary_embedding
|
90
|
+
|
87
91
|
self.cos_sin_cache: torch.Tensor
|
88
92
|
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
89
93
|
|
@@ -160,7 +164,7 @@ class RotaryEmbedding(CustomOp):
|
|
160
164
|
)
|
161
165
|
else:
|
162
166
|
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
|
163
|
-
vllm_rotary_embedding(
|
167
|
+
self.vllm_rotary_embedding(
|
164
168
|
positions,
|
165
169
|
query,
|
166
170
|
key,
|
@@ -665,6 +669,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|
665
669
|
offsets: Optional[torch.Tensor] = None,
|
666
670
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
667
671
|
"""PyTorch-native implementation equivalent to forward()."""
|
672
|
+
dtype = query.dtype
|
668
673
|
query_rot = query[..., : self.rotary_dim]
|
669
674
|
key_rot = key[..., : self.rotary_dim]
|
670
675
|
if self.rotary_dim < self.head_size:
|
@@ -695,7 +700,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|
695
700
|
else:
|
696
701
|
query = query_rot
|
697
702
|
key = key_rot
|
698
|
-
return query, key
|
703
|
+
return query.to(dtype), key.to(dtype)
|
699
704
|
|
700
705
|
|
701
706
|
class Llama3RotaryEmbedding(RotaryEmbedding):
|
@@ -876,142 +881,181 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|
876
881
|
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
877
882
|
return query, key
|
878
883
|
|
884
|
+
# Copied from https://github.com/huggingface/transformers/blob/c8e0e603de9b3d49161a15fe6e8ea84badfb5d02/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1439
|
879
885
|
@staticmethod
|
880
|
-
def
|
881
|
-
|
882
|
-
image_grid_thw: Union[List[List[int]], torch.Tensor],
|
883
|
-
video_grid_thw: Union[List[List[int]], torch.Tensor],
|
886
|
+
def get_rope_index(
|
887
|
+
spatial_merge_size: int,
|
884
888
|
image_token_id: int,
|
885
889
|
video_token_id: int,
|
886
890
|
vision_start_token_id: int,
|
887
|
-
|
888
|
-
spatial_merge_size: int,
|
889
|
-
context_len: int = 0,
|
890
|
-
seq_len: Optional[int] = None,
|
891
|
-
second_per_grid_ts: Optional[torch.Tensor] = None,
|
891
|
+
model_type: str,
|
892
892
|
tokens_per_second: Optional[int] = None,
|
893
|
-
|
894
|
-
|
895
|
-
|
896
|
-
|
897
|
-
|
898
|
-
|
899
|
-
|
900
|
-
|
901
|
-
|
902
|
-
|
903
|
-
|
904
|
-
|
905
|
-
|
906
|
-
|
907
|
-
|
908
|
-
|
909
|
-
|
910
|
-
input_tokens_tensor == vision_start_token_id
|
911
|
-
).squeeze(1)
|
912
|
-
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
|
913
|
-
image_nums = (vision_tokens == image_token_id).sum()
|
914
|
-
video_nums = (vision_tokens == video_token_id).sum()
|
915
|
-
llm_pos_ids_list: list = []
|
916
|
-
|
917
|
-
st = 0
|
918
|
-
remain_images, remain_videos = image_nums, video_nums
|
919
|
-
|
920
|
-
image_index, video_index = 0, 0
|
921
|
-
for _ in range(image_nums + video_nums):
|
922
|
-
if image_token_id in input_tokens and remain_images > 0:
|
923
|
-
ed_image = input_tokens.index(image_token_id, st)
|
924
|
-
else:
|
925
|
-
ed_image = len(input_tokens) + 1
|
926
|
-
if video_token_id in input_tokens and remain_videos > 0:
|
927
|
-
ed_video = input_tokens.index(video_token_id, st)
|
928
|
-
else:
|
929
|
-
ed_video = len(input_tokens) + 1
|
930
|
-
if ed_image < ed_video:
|
931
|
-
t, h, w = (
|
932
|
-
image_grid_thw[image_index][0],
|
933
|
-
image_grid_thw[image_index][1],
|
934
|
-
image_grid_thw[image_index][2],
|
935
|
-
)
|
936
|
-
image_index += 1
|
937
|
-
remain_images -= 1
|
938
|
-
second_per_grid_t = 0
|
939
|
-
ed = ed_image
|
940
|
-
else:
|
941
|
-
t, h, w = (
|
942
|
-
video_grid_thw[video_index][0],
|
943
|
-
video_grid_thw[video_index][1],
|
944
|
-
video_grid_thw[video_index][2],
|
945
|
-
)
|
946
|
-
if second_per_grid_ts is not None:
|
947
|
-
second_per_grid_t = second_per_grid_ts[video_index]
|
948
|
-
else:
|
949
|
-
second_per_grid_t = 1.0
|
950
|
-
video_index += 1
|
951
|
-
remain_videos -= 1
|
952
|
-
ed = ed_video
|
953
|
-
llm_grid_t, llm_grid_h, llm_grid_w = (
|
954
|
-
t,
|
955
|
-
h // spatial_merge_size,
|
956
|
-
w // spatial_merge_size,
|
957
|
-
)
|
958
|
-
text_len = ed - st
|
959
|
-
|
960
|
-
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
961
|
-
llm_pos_ids_list.append(
|
962
|
-
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
963
|
-
)
|
964
|
-
|
965
|
-
t_index = (
|
966
|
-
torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w)
|
967
|
-
* second_per_grid_t
|
968
|
-
* tokens_per_second
|
969
|
-
).flatten()
|
970
|
-
|
971
|
-
h_index = (
|
972
|
-
torch.arange(llm_grid_h)
|
973
|
-
.view(1, -1, 1)
|
974
|
-
.expand(llm_grid_t, -1, llm_grid_w)
|
975
|
-
.flatten()
|
976
|
-
)
|
977
|
-
w_index = (
|
978
|
-
torch.arange(llm_grid_w)
|
979
|
-
.view(1, 1, -1)
|
980
|
-
.expand(llm_grid_t, llm_grid_h, -1)
|
981
|
-
.flatten()
|
982
|
-
)
|
983
|
-
llm_pos_ids_list.append(
|
984
|
-
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
|
893
|
+
input_ids: Optional[torch.LongTensor] = None,
|
894
|
+
image_grid_thw: Optional[torch.LongTensor] = None,
|
895
|
+
video_grid_thw: Optional[torch.LongTensor] = None,
|
896
|
+
second_per_grid_ts: Optional[torch.Tensor] = None,
|
897
|
+
**kwargs,
|
898
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
899
|
+
mrope_position_deltas = []
|
900
|
+
if input_ids is not None and (
|
901
|
+
image_grid_thw is not None or video_grid_thw is not None
|
902
|
+
):
|
903
|
+
total_input_ids = input_ids
|
904
|
+
position_ids = torch.ones(
|
905
|
+
3,
|
906
|
+
input_ids.shape[0],
|
907
|
+
input_ids.shape[1],
|
908
|
+
dtype=input_ids.dtype,
|
909
|
+
device=input_ids.device,
|
985
910
|
)
|
986
|
-
|
987
|
-
|
988
|
-
|
989
|
-
|
990
|
-
|
991
|
-
|
992
|
-
|
911
|
+
image_index, video_index = 0, 0
|
912
|
+
for i, input_ids in enumerate(total_input_ids):
|
913
|
+
image_nums, video_nums = 0, 0
|
914
|
+
vision_start_indices = torch.argwhere(
|
915
|
+
input_ids == vision_start_token_id
|
916
|
+
).squeeze(1)
|
917
|
+
vision_tokens = input_ids[vision_start_indices + 1]
|
918
|
+
image_nums = (vision_tokens == image_token_id).sum()
|
919
|
+
video_nums = (vision_tokens == video_token_id).sum()
|
920
|
+
input_tokens = input_ids.tolist()
|
921
|
+
llm_pos_ids_list: list = []
|
922
|
+
st = 0
|
923
|
+
remain_images, remain_videos = image_nums, video_nums
|
924
|
+
for _ in range(image_nums + video_nums):
|
925
|
+
if image_token_id in input_tokens and remain_images > 0:
|
926
|
+
ed_image = input_tokens.index(image_token_id, st)
|
927
|
+
else:
|
928
|
+
ed_image = len(input_tokens) + 1
|
929
|
+
if video_token_id in input_tokens and remain_videos > 0:
|
930
|
+
ed_video = input_tokens.index(video_token_id, st)
|
931
|
+
else:
|
932
|
+
ed_video = len(input_tokens) + 1
|
933
|
+
if ed_image < ed_video:
|
934
|
+
t, h, w = (
|
935
|
+
image_grid_thw[image_index][0],
|
936
|
+
image_grid_thw[image_index][1],
|
937
|
+
image_grid_thw[image_index][2],
|
938
|
+
)
|
939
|
+
second_per_grid_t = 0
|
940
|
+
image_index += 1
|
941
|
+
remain_images -= 1
|
942
|
+
ed = ed_image
|
943
|
+
else:
|
944
|
+
t, h, w = (
|
945
|
+
video_grid_thw[video_index][0],
|
946
|
+
video_grid_thw[video_index][1],
|
947
|
+
video_grid_thw[video_index][2],
|
948
|
+
)
|
949
|
+
if second_per_grid_ts is not None:
|
950
|
+
second_per_grid_t = second_per_grid_ts[video_index]
|
951
|
+
else:
|
952
|
+
second_per_grid_t = 1.0
|
953
|
+
video_index += 1
|
954
|
+
remain_videos -= 1
|
955
|
+
ed = ed_video
|
956
|
+
llm_grid_t, llm_grid_h, llm_grid_w = (
|
957
|
+
t.item(),
|
958
|
+
h.item() // spatial_merge_size,
|
959
|
+
w.item() // spatial_merge_size,
|
960
|
+
)
|
961
|
+
text_len = ed - st
|
962
|
+
|
963
|
+
st_idx = (
|
964
|
+
llm_pos_ids_list[-1].max() + 1
|
965
|
+
if len(llm_pos_ids_list) > 0
|
966
|
+
else 0
|
967
|
+
)
|
968
|
+
llm_pos_ids_list.append(
|
969
|
+
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
970
|
+
)
|
971
|
+
|
972
|
+
if model_type == "qwen2_5_vl":
|
973
|
+
range_tensor = torch.arange(llm_grid_t).view(-1, 1)
|
974
|
+
expanded_range = range_tensor.expand(
|
975
|
+
-1, llm_grid_h * llm_grid_w
|
976
|
+
)
|
977
|
+
|
978
|
+
time_tensor = (
|
979
|
+
expanded_range * second_per_grid_t * tokens_per_second
|
980
|
+
)
|
981
|
+
|
982
|
+
time_tensor_long = time_tensor.long()
|
983
|
+
t_index = time_tensor_long.flatten()
|
984
|
+
elif model_type == "qwen2_vl":
|
985
|
+
t_index = (
|
986
|
+
torch.arange(llm_grid_t)
|
987
|
+
.view(-1, 1)
|
988
|
+
.expand(-1, llm_grid_h * llm_grid_w)
|
989
|
+
.flatten()
|
990
|
+
)
|
991
|
+
else:
|
992
|
+
raise RuntimeError("Unimplemented")
|
993
|
+
h_index = (
|
994
|
+
torch.arange(llm_grid_h)
|
995
|
+
.view(1, -1, 1)
|
996
|
+
.expand(llm_grid_t, -1, llm_grid_w)
|
997
|
+
.flatten()
|
998
|
+
)
|
999
|
+
w_index = (
|
1000
|
+
torch.arange(llm_grid_w)
|
1001
|
+
.view(1, 1, -1)
|
1002
|
+
.expand(llm_grid_t, llm_grid_h, -1)
|
1003
|
+
.flatten()
|
1004
|
+
)
|
1005
|
+
llm_pos_ids_list.append(
|
1006
|
+
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
|
1007
|
+
)
|
1008
|
+
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
1009
|
+
|
1010
|
+
if st < len(input_tokens):
|
1011
|
+
st_idx = (
|
1012
|
+
llm_pos_ids_list[-1].max() + 1
|
1013
|
+
if len(llm_pos_ids_list) > 0
|
1014
|
+
else 0
|
1015
|
+
)
|
1016
|
+
text_len = len(input_tokens) - st
|
1017
|
+
llm_pos_ids_list.append(
|
1018
|
+
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
1019
|
+
)
|
1020
|
+
|
1021
|
+
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
1022
|
+
position_ids[..., i, :] = llm_positions.to(position_ids.device)
|
1023
|
+
mrope_position_deltas.append(
|
1024
|
+
llm_positions.max() + 1 - len(total_input_ids[i])
|
1025
|
+
)
|
1026
|
+
mrope_position_deltas = torch.tensor(
|
1027
|
+
mrope_position_deltas, device=input_ids.device
|
1028
|
+
).unsqueeze(1)
|
1029
|
+
return position_ids, mrope_position_deltas
|
1030
|
+
else:
|
1031
|
+
s = input_ids.shape[1]
|
1032
|
+
position_ids = torch.arange(s)
|
1033
|
+
position_ids = (
|
1034
|
+
position_ids.unsqueeze(0).expand(3, -1, -1).to(input_ids.device)
|
993
1035
|
)
|
994
|
-
|
995
|
-
|
996
|
-
|
997
|
-
|
998
|
-
|
999
|
-
return llm_positions.tolist(), mrope_position_delta
|
1036
|
+
max_position_ids = position_ids.max(0, keepdim=False)[0].max(
|
1037
|
+
-1, keepdim=True
|
1038
|
+
)[0]
|
1039
|
+
mrope_position_deltas = max_position_ids + 1 - s
|
1040
|
+
return position_ids, mrope_position_deltas
|
1000
1041
|
|
1001
1042
|
@staticmethod
|
1002
1043
|
def get_next_input_positions(
|
1003
1044
|
mrope_position_delta: int,
|
1004
1045
|
context_len: int,
|
1005
1046
|
seq_len: int,
|
1006
|
-
) ->
|
1007
|
-
return
|
1008
|
-
|
1009
|
-
|
1010
|
-
|
1047
|
+
) -> torch.Tensor:
|
1048
|
+
return torch.tensor(
|
1049
|
+
[
|
1050
|
+
list(
|
1051
|
+
range(
|
1052
|
+
context_len + mrope_position_delta,
|
1053
|
+
seq_len + mrope_position_delta,
|
1054
|
+
)
|
1011
1055
|
)
|
1012
|
-
|
1013
|
-
|
1014
|
-
|
1056
|
+
for _ in range(3)
|
1057
|
+
]
|
1058
|
+
)
|
1015
1059
|
|
1016
1060
|
|
1017
1061
|
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
|
@@ -23,11 +23,13 @@ import psutil
|
|
23
23
|
import setproctitle
|
24
24
|
import zmq
|
25
25
|
|
26
|
+
from sglang.srt.disaggregation.utils import DisaggregationMode
|
26
27
|
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
27
28
|
from sglang.srt.managers.io_struct import (
|
28
29
|
TokenizedEmbeddingReqInput,
|
29
30
|
TokenizedGenerateReqInput,
|
30
31
|
)
|
32
|
+
from sglang.srt.managers.schedule_batch import Req
|
31
33
|
from sglang.srt.managers.scheduler import run_scheduler_process
|
32
34
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
33
35
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
@@ -226,9 +228,14 @@ class DataParallelController:
|
|
226
228
|
self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
|
227
229
|
self.max_req_input_len = scheduler_info[0]["max_req_input_len"]
|
228
230
|
|
229
|
-
def round_robin_scheduler(self, req):
|
230
|
-
self.
|
231
|
-
|
231
|
+
def round_robin_scheduler(self, req: Req):
|
232
|
+
if self.server_args.disaggregation_mode == "null":
|
233
|
+
self.workers[self.round_robin_counter].send_pyobj(req)
|
234
|
+
self.round_robin_counter = (self.round_robin_counter + 1) % len(
|
235
|
+
self.workers
|
236
|
+
)
|
237
|
+
else:
|
238
|
+
self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req)
|
232
239
|
|
233
240
|
def shortest_queue_scheduler(self, input_requests):
|
234
241
|
raise NotImplementedError()
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -97,6 +97,7 @@ class GenerateReqInput:
|
|
97
97
|
|
98
98
|
# For disaggregated inference
|
99
99
|
bootstrap_host: Optional[Union[List[str], str]] = None
|
100
|
+
bootstrap_port: Optional[Union[List[int], int]] = None
|
100
101
|
bootstrap_room: Optional[Union[List[int], int]] = None
|
101
102
|
|
102
103
|
def normalize_batch_and_arguments(self):
|
@@ -400,6 +401,9 @@ class GenerateReqInput:
|
|
400
401
|
bootstrap_host=(
|
401
402
|
self.bootstrap_host[i] if self.bootstrap_host is not None else None
|
402
403
|
),
|
404
|
+
bootstrap_port=(
|
405
|
+
self.bootstrap_port[i] if self.bootstrap_port is not None else None
|
406
|
+
),
|
403
407
|
bootstrap_room=(
|
404
408
|
self.bootstrap_room[i] if self.bootstrap_room is not None else None
|
405
409
|
),
|
@@ -447,6 +451,7 @@ class TokenizedGenerateReqInput:
|
|
447
451
|
|
448
452
|
# For disaggregated inference
|
449
453
|
bootstrap_host: Optional[str] = None
|
454
|
+
bootstrap_port: Optional[int] = None
|
450
455
|
bootstrap_room: Optional[int] = None
|
451
456
|
|
452
457
|
|
@@ -463,6 +468,8 @@ class EmbeddingReqInput:
|
|
463
468
|
image_data: Optional[
|
464
469
|
Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]]
|
465
470
|
] = None
|
471
|
+
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
|
472
|
+
audio_data: Optional[Union[List[str], str]] = None
|
466
473
|
# The token ids for text; one can either specify text or input_ids.
|
467
474
|
input_ids: Optional[Union[List[List[int]], List[int]]] = None
|
468
475
|
# The request id.
|
sglang/srt/managers/mm_utils.py
CHANGED
@@ -10,12 +10,13 @@ import torch
|
|
10
10
|
from torch import nn
|
11
11
|
|
12
12
|
from sglang.srt.managers.schedule_batch import (
|
13
|
+
Modality,
|
13
14
|
MultimodalDataItem,
|
14
15
|
MultimodalInputs,
|
15
16
|
global_server_args_dict,
|
16
17
|
)
|
17
18
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
18
|
-
from sglang.srt.utils import print_warning_once
|
19
|
+
from sglang.srt.utils import flatten_nested_list, print_warning_once
|
19
20
|
|
20
21
|
logger = logging.getLogger(__name__)
|
21
22
|
|
@@ -97,31 +98,80 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
|
|
97
98
|
return padded_ids
|
98
99
|
|
99
100
|
|
100
|
-
class
|
101
|
+
class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPattern):
|
101
102
|
"""In this pattern, data tokens should be represented as repetitions of a single token
|
102
103
|
e.g. <image><image>....<image>, or <audio><audio>...<audio>
|
103
104
|
"""
|
104
105
|
|
105
|
-
def __init__(self,
|
106
|
-
self.
|
106
|
+
def __init__(self, token_ids: List[int]) -> None:
|
107
|
+
self.token_ids = token_ids
|
107
108
|
|
108
|
-
def pad_input_tokens(
|
109
|
+
def pad_input_tokens(
|
110
|
+
self, input_ids: List[int], mm_inputs: MultimodalInputs
|
111
|
+
) -> List[int]:
|
109
112
|
"""
|
110
|
-
|
113
|
+
Finds contiguous regions of tokens matching `self.token_ids` in `input_ids`
|
114
|
+
and replaces each region with the corresponding `pad_value` from `mm_inputs.mm_items`.
|
111
115
|
"""
|
112
116
|
pad_values = [item.pad_value for item in mm_inputs.mm_items]
|
113
|
-
|
117
|
+
if not pad_values:
|
118
|
+
# No multimodal items, return original input_ids
|
119
|
+
return input_ids
|
120
|
+
if not input_ids:
|
121
|
+
return []
|
114
122
|
|
115
123
|
input_ids_tensor = torch.tensor(input_ids)
|
116
|
-
|
124
|
+
device = input_ids_tensor.device
|
125
|
+
token_ids_tensor = torch.tensor(self.token_ids, device=device)
|
126
|
+
mask = torch.isin(input_ids_tensor, token_ids_tensor)
|
117
127
|
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
128
|
+
if not mask.any():
|
129
|
+
# No tokens match token_ids, return original input_ids
|
130
|
+
return input_ids
|
131
|
+
|
132
|
+
# Find contiguous regions
|
133
|
+
padded_mask = torch.cat(
|
134
|
+
(
|
135
|
+
torch.tensor([False], device=device),
|
136
|
+
mask,
|
137
|
+
torch.tensor([False], device=device),
|
138
|
+
)
|
139
|
+
)
|
140
|
+
# Find indices where the mask value changes
|
141
|
+
diff_indices = torch.where(padded_mask[1:] != padded_mask[:-1])[0]
|
142
|
+
|
143
|
+
# Start indices are where False changes to True
|
144
|
+
starts = diff_indices[::2]
|
145
|
+
# End indices are where True changes to False (exclusive index)
|
146
|
+
ends = diff_indices[1::2]
|
147
|
+
|
148
|
+
# Check if the number of regions matches the number of pad values
|
149
|
+
if len(starts) != len(pad_values):
|
150
|
+
# Maybe log a warning here?
|
151
|
+
num_regions = len(starts)
|
152
|
+
num_pad_values = len(pad_values)
|
153
|
+
if num_regions > 0 and num_pad_values > 0:
|
154
|
+
pad_values = (pad_values * (num_regions // num_pad_values + 1))[
|
155
|
+
:num_regions
|
156
|
+
]
|
157
|
+
else: # If no regions or no pad_values, this loop won't run anyway.
|
158
|
+
pad_values = [] # Ensure pad_values is empty if starts is empty
|
159
|
+
|
160
|
+
# Create a copy to modify
|
161
|
+
output_ids_tensor = input_ids_tensor.clone()
|
162
|
+
|
163
|
+
# Replace tokens in each region with the corresponding pad value
|
164
|
+
# Ensure we don't iterate if pad_values became empty due to mismatch and num_regions=0
|
165
|
+
for i in range(min(len(starts), len(pad_values))):
|
166
|
+
start_idx = starts[i]
|
167
|
+
end_idx = ends[i]
|
168
|
+
pad_value = pad_values[i]
|
169
|
+
if pad_value is not None: # Ensure pad_value is not None before assignment
|
170
|
+
output_ids_tensor[start_idx:end_idx] = pad_value
|
171
|
+
else:
|
172
|
+
logger.warning(f"Skipping region {i} due to None pad_value.")
|
122
173
|
|
123
|
-
|
124
|
-
return input_ids_tensor.tolist()
|
174
|
+
return output_ids_tensor.tolist()
|
125
175
|
|
126
176
|
|
127
177
|
def get_embedding_and_mask(
|
@@ -150,7 +200,6 @@ def get_embedding_and_mask(
|
|
150
200
|
).unsqueeze(-1)
|
151
201
|
|
152
202
|
num_mm_tokens_in_input_ids = special_multimodal_mask.sum().item()
|
153
|
-
|
154
203
|
if num_mm_tokens_in_input_ids != num_mm_tokens_in_embedding:
|
155
204
|
logger.warning(
|
156
205
|
f"Number of tokens in multimodal embedding does not match those in the input text."
|
@@ -190,13 +239,13 @@ def embed_mm_inputs(
|
|
190
239
|
audio_data_embedding_func: Callable[
|
191
240
|
[List[MultimodalDataItem]], torch.Tensor
|
192
241
|
] = None,
|
193
|
-
|
242
|
+
placeholder_tokens: dict[Modality, List[int]] = None,
|
194
243
|
) -> Optional[torch.Tensor]:
|
195
244
|
"""
|
196
245
|
Calculate the multimodal embeddings if necessary, then scatter the result with the help of a boolean mask denoting the embed locations
|
197
246
|
|
198
247
|
Args:
|
199
|
-
|
248
|
+
placeholder_tokens: denoting the token of multimodal data in input_ids.
|
200
249
|
If none, the pad_values of multimodal items are used
|
201
250
|
|
202
251
|
Returns:
|
@@ -208,9 +257,17 @@ def embed_mm_inputs(
|
|
208
257
|
|
209
258
|
# 1. Calculate the multimodal data which exists in input_ids, with the help of pad_values
|
210
259
|
# we assume that multimodal data are represented with its pad_values in input_ids
|
211
|
-
|
212
|
-
|
213
|
-
|
260
|
+
# See `pad_input_ids` for more detail
|
261
|
+
|
262
|
+
# if placeholder_tokens is specified
|
263
|
+
if placeholder_tokens is not None:
|
264
|
+
placeholder_token_ids = flatten_nested_list(
|
265
|
+
[placeholder_token for placeholder_token in placeholder_tokens.values()]
|
266
|
+
)
|
267
|
+
else:
|
268
|
+
placeholder_token_ids = [item.pad_value for item in mm_inputs.mm_items]
|
269
|
+
|
270
|
+
assert isinstance(placeholder_token_ids[0], int)
|
214
271
|
|
215
272
|
placeholder_tensor = torch.tensor(placeholder_token_ids, device=input_ids.device)
|
216
273
|
|
@@ -233,7 +290,7 @@ def embed_mm_inputs(
|
|
233
290
|
using_all_items = False
|
234
291
|
if len(appearing_items) == 0:
|
235
292
|
# This happens mostly when arg placeholder_token_ids is passed
|
236
|
-
logger.
|
293
|
+
logger.warning(
|
237
294
|
"No multimodal data item's pad value exist in placeholder ids. Using all items"
|
238
295
|
)
|
239
296
|
using_all_items = True
|
@@ -253,7 +310,8 @@ def embed_mm_inputs(
|
|
253
310
|
data_embedding_func=image_data_embedding_func,
|
254
311
|
embedding_items=items,
|
255
312
|
placeholder_tensor=(
|
256
|
-
|
313
|
+
# use the specified modality token to identify the location to embed
|
314
|
+
placeholder_tokens[Modality.IMAGE]
|
257
315
|
if using_all_items
|
258
316
|
else torch.tensor(
|
259
317
|
[item.pad_value for item in items],
|
@@ -275,7 +333,7 @@ def embed_mm_inputs(
|
|
275
333
|
data_embedding_func=audio_data_embedding_func,
|
276
334
|
embedding_items=items,
|
277
335
|
placeholder_tensor=(
|
278
|
-
|
336
|
+
placeholder_tokens[Modality.AUDIO]
|
279
337
|
if using_all_items
|
280
338
|
else torch.tensor(
|
281
339
|
[item.pad_value for item in items],
|
@@ -296,7 +354,7 @@ def embed_mm_inputs(
|
|
296
354
|
input_ids.clamp_(min=0, max=vocab_size - 1)
|
297
355
|
inputs_embeds = input_embedding(input_ids)
|
298
356
|
|
299
|
-
# 4.
|
357
|
+
# 4. Scatter embeddings into input embedding
|
300
358
|
for embedding, mask in zip(embeddings, masks):
|
301
359
|
mask = mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
302
360
|
inputs_embeds = inputs_embeds.masked_scatter(
|
@@ -316,7 +374,7 @@ def general_mm_embed_routine(
|
|
316
374
|
audio_data_embedding_func: Callable[
|
317
375
|
[List[MultimodalDataItem]], torch.Tensor
|
318
376
|
] = None,
|
319
|
-
|
377
|
+
placeholder_tokens: dict[Modality, List[int]] = None,
|
320
378
|
**kwargs,
|
321
379
|
) -> torch.Tensor:
|
322
380
|
"""
|
@@ -328,7 +386,6 @@ def general_mm_embed_routine(
|
|
328
386
|
audio_data_embedding_func : the function returning the image embedding
|
329
387
|
|
330
388
|
Returns:
|
331
|
-
inputs_embedding
|
332
389
|
forwarded hidden states
|
333
390
|
|
334
391
|
"""
|
@@ -346,9 +403,9 @@ def general_mm_embed_routine(
|
|
346
403
|
input_embedding=embed_tokens,
|
347
404
|
image_data_embedding_func=image_data_embedding_func,
|
348
405
|
audio_data_embedding_func=audio_data_embedding_func,
|
349
|
-
|
406
|
+
placeholder_tokens=placeholder_tokens,
|
350
407
|
)
|
351
|
-
# once used, mm_inputs is useless
|
408
|
+
# once used, mm_inputs is useless, considering chunked-prefill is disabled for multimodal models
|
352
409
|
# just being defensive here
|
353
410
|
forward_batch.mm_inputs = None
|
354
411
|
else:
|
@@ -8,6 +8,7 @@ from typing import List, Optional
|
|
8
8
|
|
9
9
|
import numpy as np
|
10
10
|
import PIL
|
11
|
+
from PIL import Image
|
11
12
|
from transformers import BaseImageProcessorFast
|
12
13
|
|
13
14
|
from sglang.srt.managers.schedule_batch import Modality
|
@@ -92,7 +93,12 @@ class BaseMultimodalProcessor(ABC):
|
|
92
93
|
|
93
94
|
@abstractmethod
|
94
95
|
async def process_mm_data_async(
|
95
|
-
self,
|
96
|
+
self,
|
97
|
+
image_data,
|
98
|
+
input_text,
|
99
|
+
request_obj,
|
100
|
+
max_req_input_len,
|
101
|
+
**kwargs,
|
96
102
|
):
|
97
103
|
pass
|
98
104
|
|
@@ -104,6 +110,8 @@ class BaseMultimodalProcessor(ABC):
|
|
104
110
|
from decord import VideoReader, cpu
|
105
111
|
|
106
112
|
# Before processing inputs
|
113
|
+
if not image_data or len(image_data) == 0:
|
114
|
+
return []
|
107
115
|
estimated_frames_list = []
|
108
116
|
for image in image_data:
|
109
117
|
if isinstance(image, str) and image.startswith("video:"):
|
@@ -215,6 +223,9 @@ class BaseMultimodalProcessor(ABC):
|
|
215
223
|
discard_alpha_channel: if True, discards the alpha channel in the returned images
|
216
224
|
|
217
225
|
"""
|
226
|
+
|
227
|
+
if image_data is None:
|
228
|
+
image_data = []
|
218
229
|
if isinstance(multimodal_tokens.image_token, int):
|
219
230
|
multimodal_tokens.image_token = (
|
220
231
|
self._processor.tokenizer.convert_ids_to_tokens(
|
@@ -229,6 +240,8 @@ class BaseMultimodalProcessor(ABC):
|
|
229
240
|
prompt = self._processor.tokenizer.decode(prompt)
|
230
241
|
else:
|
231
242
|
prompt = prompt
|
243
|
+
|
244
|
+
assert isinstance(prompt, str)
|
232
245
|
if return_text:
|
233
246
|
import re
|
234
247
|
|