sglang 0.4.5.post3__py3-none-any.whl → 0.4.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. sglang/bench_one_batch.py +19 -3
  2. sglang/bench_serving.py +8 -9
  3. sglang/compile_deep_gemm.py +45 -4
  4. sglang/srt/code_completion_parser.py +1 -1
  5. sglang/srt/configs/deepseekvl2.py +1 -1
  6. sglang/srt/configs/model_config.py +9 -3
  7. sglang/srt/constrained/llguidance_backend.py +78 -61
  8. sglang/srt/conversation.py +34 -1
  9. sglang/srt/disaggregation/decode.py +59 -11
  10. sglang/srt/disaggregation/mini_lb.py +45 -8
  11. sglang/srt/disaggregation/mooncake/conn.py +198 -31
  12. sglang/srt/disaggregation/prefill.py +24 -9
  13. sglang/srt/entrypoints/http_server.py +8 -2
  14. sglang/srt/function_call_parser.py +77 -5
  15. sglang/srt/layers/attention/base_attn_backend.py +3 -0
  16. sglang/srt/layers/attention/flashattention_backend.py +28 -10
  17. sglang/srt/layers/attention/flashmla_backend.py +8 -11
  18. sglang/srt/layers/attention/vision.py +2 -0
  19. sglang/srt/layers/layernorm.py +38 -16
  20. sglang/srt/layers/logits_processor.py +2 -2
  21. sglang/srt/layers/moe/fused_moe_native.py +2 -4
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +41 -41
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -15
  25. sglang/srt/layers/pooler.py +6 -0
  26. sglang/srt/layers/quantization/awq.py +5 -1
  27. sglang/srt/layers/quantization/deep_gemm.py +17 -10
  28. sglang/srt/layers/quantization/int8_kernel.py +32 -1
  29. sglang/srt/layers/radix_attention.py +13 -3
  30. sglang/srt/layers/rotary_embedding.py +170 -126
  31. sglang/srt/managers/data_parallel_controller.py +10 -3
  32. sglang/srt/managers/io_struct.py +7 -0
  33. sglang/srt/managers/mm_utils.py +85 -28
  34. sglang/srt/managers/multimodal_processors/base_processor.py +14 -1
  35. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +9 -2
  36. sglang/srt/managers/multimodal_processors/gemma3.py +2 -5
  37. sglang/srt/managers/multimodal_processors/janus_pro.py +2 -2
  38. sglang/srt/managers/multimodal_processors/minicpm.py +4 -3
  39. sglang/srt/managers/multimodal_processors/qwen_vl.py +38 -13
  40. sglang/srt/managers/schedule_batch.py +29 -12
  41. sglang/srt/managers/scheduler.py +31 -20
  42. sglang/srt/managers/tokenizer_manager.py +5 -1
  43. sglang/srt/mem_cache/memory_pool.py +87 -0
  44. sglang/srt/model_executor/cuda_graph_runner.py +4 -3
  45. sglang/srt/model_executor/forward_batch_info.py +51 -95
  46. sglang/srt/model_executor/model_runner.py +11 -24
  47. sglang/srt/models/deepseek.py +12 -2
  48. sglang/srt/models/deepseek_nextn.py +101 -6
  49. sglang/srt/models/deepseek_v2.py +144 -70
  50. sglang/srt/models/deepseek_vl2.py +9 -4
  51. sglang/srt/models/gemma3_causal.py +1 -1
  52. sglang/srt/models/llama4.py +0 -1
  53. sglang/srt/models/minicpmo.py +5 -1
  54. sglang/srt/models/mllama4.py +2 -2
  55. sglang/srt/models/qwen2_5_vl.py +3 -6
  56. sglang/srt/models/qwen2_vl.py +3 -7
  57. sglang/srt/models/roberta.py +178 -0
  58. sglang/srt/openai_api/adapter.py +18 -8
  59. sglang/srt/server_args.py +15 -22
  60. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  61. sglang/srt/torch_memory_saver_adapter.py +10 -1
  62. sglang/srt/utils.py +2 -1
  63. sglang/test/runners.py +6 -13
  64. sglang/test/test_utils.py +36 -18
  65. sglang/version.py +1 -1
  66. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/METADATA +4 -5
  67. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/RECORD +70 -68
  68. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/WHEEL +1 -1
  69. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/licenses/LICENSE +0 -0
  70. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/top_level.txt +0 -0
@@ -14,8 +14,6 @@ _is_cuda = is_cuda()
14
14
 
15
15
  if _is_cuda:
16
16
  from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
17
- else:
18
- from vllm._custom_ops import rotary_embedding as vllm_rotary_embedding
19
17
 
20
18
 
21
19
  def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
@@ -84,6 +82,12 @@ class RotaryEmbedding(CustomOp):
84
82
  # NOTE(ByronHsu): cache needs to be in FP32 for numerical stability
85
83
  if not _is_cuda:
86
84
  cache = cache.to(dtype)
85
+
86
+ if not _is_cuda or self.head_size not in [64, 128, 256, 512]:
87
+ from vllm._custom_ops import rotary_embedding
88
+
89
+ self.vllm_rotary_embedding = rotary_embedding
90
+
87
91
  self.cos_sin_cache: torch.Tensor
88
92
  self.register_buffer("cos_sin_cache", cache, persistent=False)
89
93
 
@@ -160,7 +164,7 @@ class RotaryEmbedding(CustomOp):
160
164
  )
161
165
  else:
162
166
  self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
163
- vllm_rotary_embedding(
167
+ self.vllm_rotary_embedding(
164
168
  positions,
165
169
  query,
166
170
  key,
@@ -665,6 +669,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
665
669
  offsets: Optional[torch.Tensor] = None,
666
670
  ) -> Tuple[torch.Tensor, torch.Tensor]:
667
671
  """PyTorch-native implementation equivalent to forward()."""
672
+ dtype = query.dtype
668
673
  query_rot = query[..., : self.rotary_dim]
669
674
  key_rot = key[..., : self.rotary_dim]
670
675
  if self.rotary_dim < self.head_size:
@@ -695,7 +700,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
695
700
  else:
696
701
  query = query_rot
697
702
  key = key_rot
698
- return query, key
703
+ return query.to(dtype), key.to(dtype)
699
704
 
700
705
 
701
706
  class Llama3RotaryEmbedding(RotaryEmbedding):
@@ -876,142 +881,181 @@ class MRotaryEmbedding(RotaryEmbedding):
876
881
  key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
877
882
  return query, key
878
883
 
884
+ # Copied from https://github.com/huggingface/transformers/blob/c8e0e603de9b3d49161a15fe6e8ea84badfb5d02/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1439
879
885
  @staticmethod
880
- def get_input_positions(
881
- input_tokens: List[int],
882
- image_grid_thw: Union[List[List[int]], torch.Tensor],
883
- video_grid_thw: Union[List[List[int]], torch.Tensor],
886
+ def get_rope_index(
887
+ spatial_merge_size: int,
884
888
  image_token_id: int,
885
889
  video_token_id: int,
886
890
  vision_start_token_id: int,
887
- vision_end_token_id: int,
888
- spatial_merge_size: int,
889
- context_len: int = 0,
890
- seq_len: Optional[int] = None,
891
- second_per_grid_ts: Optional[torch.Tensor] = None,
891
+ model_type: str,
892
892
  tokens_per_second: Optional[int] = None,
893
- ) -> Tuple[List[List[int]], int]:
894
- """
895
- Get mrope input positions and delta value.
896
-
897
- :arg
898
- second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
899
- The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
900
-
901
- """
902
-
903
- if isinstance(image_grid_thw, torch.Tensor):
904
- image_grid_thw = image_grid_thw.tolist()
905
- if isinstance(video_grid_thw, torch.Tensor):
906
- video_grid_thw = video_grid_thw.tolist()
907
-
908
- input_tokens_tensor = torch.tensor(input_tokens)
909
- vision_start_indices = torch.argwhere(
910
- input_tokens_tensor == vision_start_token_id
911
- ).squeeze(1)
912
- vision_tokens = input_tokens_tensor[vision_start_indices + 1]
913
- image_nums = (vision_tokens == image_token_id).sum()
914
- video_nums = (vision_tokens == video_token_id).sum()
915
- llm_pos_ids_list: list = []
916
-
917
- st = 0
918
- remain_images, remain_videos = image_nums, video_nums
919
-
920
- image_index, video_index = 0, 0
921
- for _ in range(image_nums + video_nums):
922
- if image_token_id in input_tokens and remain_images > 0:
923
- ed_image = input_tokens.index(image_token_id, st)
924
- else:
925
- ed_image = len(input_tokens) + 1
926
- if video_token_id in input_tokens and remain_videos > 0:
927
- ed_video = input_tokens.index(video_token_id, st)
928
- else:
929
- ed_video = len(input_tokens) + 1
930
- if ed_image < ed_video:
931
- t, h, w = (
932
- image_grid_thw[image_index][0],
933
- image_grid_thw[image_index][1],
934
- image_grid_thw[image_index][2],
935
- )
936
- image_index += 1
937
- remain_images -= 1
938
- second_per_grid_t = 0
939
- ed = ed_image
940
- else:
941
- t, h, w = (
942
- video_grid_thw[video_index][0],
943
- video_grid_thw[video_index][1],
944
- video_grid_thw[video_index][2],
945
- )
946
- if second_per_grid_ts is not None:
947
- second_per_grid_t = second_per_grid_ts[video_index]
948
- else:
949
- second_per_grid_t = 1.0
950
- video_index += 1
951
- remain_videos -= 1
952
- ed = ed_video
953
- llm_grid_t, llm_grid_h, llm_grid_w = (
954
- t,
955
- h // spatial_merge_size,
956
- w // spatial_merge_size,
957
- )
958
- text_len = ed - st
959
-
960
- st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
961
- llm_pos_ids_list.append(
962
- torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
963
- )
964
-
965
- t_index = (
966
- torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w)
967
- * second_per_grid_t
968
- * tokens_per_second
969
- ).flatten()
970
-
971
- h_index = (
972
- torch.arange(llm_grid_h)
973
- .view(1, -1, 1)
974
- .expand(llm_grid_t, -1, llm_grid_w)
975
- .flatten()
976
- )
977
- w_index = (
978
- torch.arange(llm_grid_w)
979
- .view(1, 1, -1)
980
- .expand(llm_grid_t, llm_grid_h, -1)
981
- .flatten()
982
- )
983
- llm_pos_ids_list.append(
984
- torch.stack([t_index, h_index, w_index]) + text_len + st_idx
893
+ input_ids: Optional[torch.LongTensor] = None,
894
+ image_grid_thw: Optional[torch.LongTensor] = None,
895
+ video_grid_thw: Optional[torch.LongTensor] = None,
896
+ second_per_grid_ts: Optional[torch.Tensor] = None,
897
+ **kwargs,
898
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
899
+ mrope_position_deltas = []
900
+ if input_ids is not None and (
901
+ image_grid_thw is not None or video_grid_thw is not None
902
+ ):
903
+ total_input_ids = input_ids
904
+ position_ids = torch.ones(
905
+ 3,
906
+ input_ids.shape[0],
907
+ input_ids.shape[1],
908
+ dtype=input_ids.dtype,
909
+ device=input_ids.device,
985
910
  )
986
- st = ed + llm_grid_t * llm_grid_h * llm_grid_w
987
-
988
- if st < len(input_tokens):
989
- st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
990
- text_len = len(input_tokens) - st
991
- llm_pos_ids_list.append(
992
- torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
911
+ image_index, video_index = 0, 0
912
+ for i, input_ids in enumerate(total_input_ids):
913
+ image_nums, video_nums = 0, 0
914
+ vision_start_indices = torch.argwhere(
915
+ input_ids == vision_start_token_id
916
+ ).squeeze(1)
917
+ vision_tokens = input_ids[vision_start_indices + 1]
918
+ image_nums = (vision_tokens == image_token_id).sum()
919
+ video_nums = (vision_tokens == video_token_id).sum()
920
+ input_tokens = input_ids.tolist()
921
+ llm_pos_ids_list: list = []
922
+ st = 0
923
+ remain_images, remain_videos = image_nums, video_nums
924
+ for _ in range(image_nums + video_nums):
925
+ if image_token_id in input_tokens and remain_images > 0:
926
+ ed_image = input_tokens.index(image_token_id, st)
927
+ else:
928
+ ed_image = len(input_tokens) + 1
929
+ if video_token_id in input_tokens and remain_videos > 0:
930
+ ed_video = input_tokens.index(video_token_id, st)
931
+ else:
932
+ ed_video = len(input_tokens) + 1
933
+ if ed_image < ed_video:
934
+ t, h, w = (
935
+ image_grid_thw[image_index][0],
936
+ image_grid_thw[image_index][1],
937
+ image_grid_thw[image_index][2],
938
+ )
939
+ second_per_grid_t = 0
940
+ image_index += 1
941
+ remain_images -= 1
942
+ ed = ed_image
943
+ else:
944
+ t, h, w = (
945
+ video_grid_thw[video_index][0],
946
+ video_grid_thw[video_index][1],
947
+ video_grid_thw[video_index][2],
948
+ )
949
+ if second_per_grid_ts is not None:
950
+ second_per_grid_t = second_per_grid_ts[video_index]
951
+ else:
952
+ second_per_grid_t = 1.0
953
+ video_index += 1
954
+ remain_videos -= 1
955
+ ed = ed_video
956
+ llm_grid_t, llm_grid_h, llm_grid_w = (
957
+ t.item(),
958
+ h.item() // spatial_merge_size,
959
+ w.item() // spatial_merge_size,
960
+ )
961
+ text_len = ed - st
962
+
963
+ st_idx = (
964
+ llm_pos_ids_list[-1].max() + 1
965
+ if len(llm_pos_ids_list) > 0
966
+ else 0
967
+ )
968
+ llm_pos_ids_list.append(
969
+ torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
970
+ )
971
+
972
+ if model_type == "qwen2_5_vl":
973
+ range_tensor = torch.arange(llm_grid_t).view(-1, 1)
974
+ expanded_range = range_tensor.expand(
975
+ -1, llm_grid_h * llm_grid_w
976
+ )
977
+
978
+ time_tensor = (
979
+ expanded_range * second_per_grid_t * tokens_per_second
980
+ )
981
+
982
+ time_tensor_long = time_tensor.long()
983
+ t_index = time_tensor_long.flatten()
984
+ elif model_type == "qwen2_vl":
985
+ t_index = (
986
+ torch.arange(llm_grid_t)
987
+ .view(-1, 1)
988
+ .expand(-1, llm_grid_h * llm_grid_w)
989
+ .flatten()
990
+ )
991
+ else:
992
+ raise RuntimeError("Unimplemented")
993
+ h_index = (
994
+ torch.arange(llm_grid_h)
995
+ .view(1, -1, 1)
996
+ .expand(llm_grid_t, -1, llm_grid_w)
997
+ .flatten()
998
+ )
999
+ w_index = (
1000
+ torch.arange(llm_grid_w)
1001
+ .view(1, 1, -1)
1002
+ .expand(llm_grid_t, llm_grid_h, -1)
1003
+ .flatten()
1004
+ )
1005
+ llm_pos_ids_list.append(
1006
+ torch.stack([t_index, h_index, w_index]) + text_len + st_idx
1007
+ )
1008
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
1009
+
1010
+ if st < len(input_tokens):
1011
+ st_idx = (
1012
+ llm_pos_ids_list[-1].max() + 1
1013
+ if len(llm_pos_ids_list) > 0
1014
+ else 0
1015
+ )
1016
+ text_len = len(input_tokens) - st
1017
+ llm_pos_ids_list.append(
1018
+ torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
1019
+ )
1020
+
1021
+ llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
1022
+ position_ids[..., i, :] = llm_positions.to(position_ids.device)
1023
+ mrope_position_deltas.append(
1024
+ llm_positions.max() + 1 - len(total_input_ids[i])
1025
+ )
1026
+ mrope_position_deltas = torch.tensor(
1027
+ mrope_position_deltas, device=input_ids.device
1028
+ ).unsqueeze(1)
1029
+ return position_ids, mrope_position_deltas
1030
+ else:
1031
+ s = input_ids.shape[1]
1032
+ position_ids = torch.arange(s)
1033
+ position_ids = (
1034
+ position_ids.unsqueeze(0).expand(3, -1, -1).to(input_ids.device)
993
1035
  )
994
-
995
- llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
996
- mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
997
- llm_positions = llm_positions[:, context_len:seq_len]
998
-
999
- return llm_positions.tolist(), mrope_position_delta
1036
+ max_position_ids = position_ids.max(0, keepdim=False)[0].max(
1037
+ -1, keepdim=True
1038
+ )[0]
1039
+ mrope_position_deltas = max_position_ids + 1 - s
1040
+ return position_ids, mrope_position_deltas
1000
1041
 
1001
1042
  @staticmethod
1002
1043
  def get_next_input_positions(
1003
1044
  mrope_position_delta: int,
1004
1045
  context_len: int,
1005
1046
  seq_len: int,
1006
- ) -> List[List[int]]:
1007
- return [
1008
- list(
1009
- range(
1010
- context_len + mrope_position_delta, seq_len + mrope_position_delta
1047
+ ) -> torch.Tensor:
1048
+ return torch.tensor(
1049
+ [
1050
+ list(
1051
+ range(
1052
+ context_len + mrope_position_delta,
1053
+ seq_len + mrope_position_delta,
1054
+ )
1011
1055
  )
1012
- )
1013
- for _ in range(3)
1014
- ]
1056
+ for _ in range(3)
1057
+ ]
1058
+ )
1015
1059
 
1016
1060
 
1017
1061
  _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
@@ -23,11 +23,13 @@ import psutil
23
23
  import setproctitle
24
24
  import zmq
25
25
 
26
+ from sglang.srt.disaggregation.utils import DisaggregationMode
26
27
  from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
27
28
  from sglang.srt.managers.io_struct import (
28
29
  TokenizedEmbeddingReqInput,
29
30
  TokenizedGenerateReqInput,
30
31
  )
32
+ from sglang.srt.managers.schedule_batch import Req
31
33
  from sglang.srt.managers.scheduler import run_scheduler_process
32
34
  from sglang.srt.server_args import PortArgs, ServerArgs
33
35
  from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
@@ -226,9 +228,14 @@ class DataParallelController:
226
228
  self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
227
229
  self.max_req_input_len = scheduler_info[0]["max_req_input_len"]
228
230
 
229
- def round_robin_scheduler(self, req):
230
- self.workers[self.round_robin_counter].send_pyobj(req)
231
- self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers)
231
+ def round_robin_scheduler(self, req: Req):
232
+ if self.server_args.disaggregation_mode == "null":
233
+ self.workers[self.round_robin_counter].send_pyobj(req)
234
+ self.round_robin_counter = (self.round_robin_counter + 1) % len(
235
+ self.workers
236
+ )
237
+ else:
238
+ self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req)
232
239
 
233
240
  def shortest_queue_scheduler(self, input_requests):
234
241
  raise NotImplementedError()
@@ -97,6 +97,7 @@ class GenerateReqInput:
97
97
 
98
98
  # For disaggregated inference
99
99
  bootstrap_host: Optional[Union[List[str], str]] = None
100
+ bootstrap_port: Optional[Union[List[int], int]] = None
100
101
  bootstrap_room: Optional[Union[List[int], int]] = None
101
102
 
102
103
  def normalize_batch_and_arguments(self):
@@ -400,6 +401,9 @@ class GenerateReqInput:
400
401
  bootstrap_host=(
401
402
  self.bootstrap_host[i] if self.bootstrap_host is not None else None
402
403
  ),
404
+ bootstrap_port=(
405
+ self.bootstrap_port[i] if self.bootstrap_port is not None else None
406
+ ),
403
407
  bootstrap_room=(
404
408
  self.bootstrap_room[i] if self.bootstrap_room is not None else None
405
409
  ),
@@ -447,6 +451,7 @@ class TokenizedGenerateReqInput:
447
451
 
448
452
  # For disaggregated inference
449
453
  bootstrap_host: Optional[str] = None
454
+ bootstrap_port: Optional[int] = None
450
455
  bootstrap_room: Optional[int] = None
451
456
 
452
457
 
@@ -463,6 +468,8 @@ class EmbeddingReqInput:
463
468
  image_data: Optional[
464
469
  Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]]
465
470
  ] = None
471
+ # The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
472
+ audio_data: Optional[Union[List[str], str]] = None
466
473
  # The token ids for text; one can either specify text or input_ids.
467
474
  input_ids: Optional[Union[List[List[int]], List[int]]] = None
468
475
  # The request id.
@@ -10,12 +10,13 @@ import torch
10
10
  from torch import nn
11
11
 
12
12
  from sglang.srt.managers.schedule_batch import (
13
+ Modality,
13
14
  MultimodalDataItem,
14
15
  MultimodalInputs,
15
16
  global_server_args_dict,
16
17
  )
17
18
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
18
- from sglang.srt.utils import print_warning_once
19
+ from sglang.srt.utils import flatten_nested_list, print_warning_once
19
20
 
20
21
  logger = logging.getLogger(__name__)
21
22
 
@@ -97,31 +98,80 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
97
98
  return padded_ids
98
99
 
99
100
 
100
- class MultiModalityDataPaddingPatternImageTokens(MultiModalityDataPaddingPattern):
101
+ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPattern):
101
102
  """In this pattern, data tokens should be represented as repetitions of a single token
102
103
  e.g. <image><image>....<image>, or <audio><audio>...<audio>
103
104
  """
104
105
 
105
- def __init__(self, image_token_id: torch.Tensor) -> None:
106
- self.image_token_id = image_token_id
106
+ def __init__(self, token_ids: List[int]) -> None:
107
+ self.token_ids = token_ids
107
108
 
108
- def pad_input_tokens(self, input_ids: List[int], mm_inputs) -> List[int]:
109
+ def pad_input_tokens(
110
+ self, input_ids: List[int], mm_inputs: MultimodalInputs
111
+ ) -> List[int]:
109
112
  """
110
- This function will replace the data-tokens in between with pad_values accordingly
113
+ Finds contiguous regions of tokens matching `self.token_ids` in `input_ids`
114
+ and replaces each region with the corresponding `pad_value` from `mm_inputs.mm_items`.
111
115
  """
112
116
  pad_values = [item.pad_value for item in mm_inputs.mm_items]
113
- assert len(pad_values) != 0
117
+ if not pad_values:
118
+ # No multimodal items, return original input_ids
119
+ return input_ids
120
+ if not input_ids:
121
+ return []
114
122
 
115
123
  input_ids_tensor = torch.tensor(input_ids)
116
- mask = torch.isin(input_ids_tensor, self.image_token_id)
124
+ device = input_ids_tensor.device
125
+ token_ids_tensor = torch.tensor(self.token_ids, device=device)
126
+ mask = torch.isin(input_ids_tensor, token_ids_tensor)
117
127
 
118
- num_image_tokens = mask.sum().item()
119
- repeated_pad_values = torch.tensor(pad_values).repeat(
120
- num_image_tokens // len(pad_values) + 1
121
- )[:num_image_tokens]
128
+ if not mask.any():
129
+ # No tokens match token_ids, return original input_ids
130
+ return input_ids
131
+
132
+ # Find contiguous regions
133
+ padded_mask = torch.cat(
134
+ (
135
+ torch.tensor([False], device=device),
136
+ mask,
137
+ torch.tensor([False], device=device),
138
+ )
139
+ )
140
+ # Find indices where the mask value changes
141
+ diff_indices = torch.where(padded_mask[1:] != padded_mask[:-1])[0]
142
+
143
+ # Start indices are where False changes to True
144
+ starts = diff_indices[::2]
145
+ # End indices are where True changes to False (exclusive index)
146
+ ends = diff_indices[1::2]
147
+
148
+ # Check if the number of regions matches the number of pad values
149
+ if len(starts) != len(pad_values):
150
+ # Maybe log a warning here?
151
+ num_regions = len(starts)
152
+ num_pad_values = len(pad_values)
153
+ if num_regions > 0 and num_pad_values > 0:
154
+ pad_values = (pad_values * (num_regions // num_pad_values + 1))[
155
+ :num_regions
156
+ ]
157
+ else: # If no regions or no pad_values, this loop won't run anyway.
158
+ pad_values = [] # Ensure pad_values is empty if starts is empty
159
+
160
+ # Create a copy to modify
161
+ output_ids_tensor = input_ids_tensor.clone()
162
+
163
+ # Replace tokens in each region with the corresponding pad value
164
+ # Ensure we don't iterate if pad_values became empty due to mismatch and num_regions=0
165
+ for i in range(min(len(starts), len(pad_values))):
166
+ start_idx = starts[i]
167
+ end_idx = ends[i]
168
+ pad_value = pad_values[i]
169
+ if pad_value is not None: # Ensure pad_value is not None before assignment
170
+ output_ids_tensor[start_idx:end_idx] = pad_value
171
+ else:
172
+ logger.warning(f"Skipping region {i} due to None pad_value.")
122
173
 
123
- input_ids_tensor[mask] = repeated_pad_values
124
- return input_ids_tensor.tolist()
174
+ return output_ids_tensor.tolist()
125
175
 
126
176
 
127
177
  def get_embedding_and_mask(
@@ -150,7 +200,6 @@ def get_embedding_and_mask(
150
200
  ).unsqueeze(-1)
151
201
 
152
202
  num_mm_tokens_in_input_ids = special_multimodal_mask.sum().item()
153
-
154
203
  if num_mm_tokens_in_input_ids != num_mm_tokens_in_embedding:
155
204
  logger.warning(
156
205
  f"Number of tokens in multimodal embedding does not match those in the input text."
@@ -190,13 +239,13 @@ def embed_mm_inputs(
190
239
  audio_data_embedding_func: Callable[
191
240
  [List[MultimodalDataItem]], torch.Tensor
192
241
  ] = None,
193
- placeholder_token_ids: List[int] = None,
242
+ placeholder_tokens: dict[Modality, List[int]] = None,
194
243
  ) -> Optional[torch.Tensor]:
195
244
  """
196
245
  Calculate the multimodal embeddings if necessary, then scatter the result with the help of a boolean mask denoting the embed locations
197
246
 
198
247
  Args:
199
- placeholder_token_ids: denoting the token of multimodal data in input_ids.
248
+ placeholder_tokens: denoting the token of multimodal data in input_ids.
200
249
  If none, the pad_values of multimodal items are used
201
250
 
202
251
  Returns:
@@ -208,9 +257,17 @@ def embed_mm_inputs(
208
257
 
209
258
  # 1. Calculate the multimodal data which exists in input_ids, with the help of pad_values
210
259
  # we assume that multimodal data are represented with its pad_values in input_ids
211
- placeholder_token_ids = placeholder_token_ids or [
212
- item.pad_value for item in mm_inputs.mm_items
213
- ]
260
+ # See `pad_input_ids` for more detail
261
+
262
+ # if placeholder_tokens is specified
263
+ if placeholder_tokens is not None:
264
+ placeholder_token_ids = flatten_nested_list(
265
+ [placeholder_token for placeholder_token in placeholder_tokens.values()]
266
+ )
267
+ else:
268
+ placeholder_token_ids = [item.pad_value for item in mm_inputs.mm_items]
269
+
270
+ assert isinstance(placeholder_token_ids[0], int)
214
271
 
215
272
  placeholder_tensor = torch.tensor(placeholder_token_ids, device=input_ids.device)
216
273
 
@@ -233,7 +290,7 @@ def embed_mm_inputs(
233
290
  using_all_items = False
234
291
  if len(appearing_items) == 0:
235
292
  # This happens mostly when arg placeholder_token_ids is passed
236
- logger.warning_once(
293
+ logger.warning(
237
294
  "No multimodal data item's pad value exist in placeholder ids. Using all items"
238
295
  )
239
296
  using_all_items = True
@@ -253,7 +310,8 @@ def embed_mm_inputs(
253
310
  data_embedding_func=image_data_embedding_func,
254
311
  embedding_items=items,
255
312
  placeholder_tensor=(
256
- placeholder_tensor
313
+ # use the specified modality token to identify the location to embed
314
+ placeholder_tokens[Modality.IMAGE]
257
315
  if using_all_items
258
316
  else torch.tensor(
259
317
  [item.pad_value for item in items],
@@ -275,7 +333,7 @@ def embed_mm_inputs(
275
333
  data_embedding_func=audio_data_embedding_func,
276
334
  embedding_items=items,
277
335
  placeholder_tensor=(
278
- placeholder_tensor
336
+ placeholder_tokens[Modality.AUDIO]
279
337
  if using_all_items
280
338
  else torch.tensor(
281
339
  [item.pad_value for item in items],
@@ -296,7 +354,7 @@ def embed_mm_inputs(
296
354
  input_ids.clamp_(min=0, max=vocab_size - 1)
297
355
  inputs_embeds = input_embedding(input_ids)
298
356
 
299
- # 4. scatter embeddings into input embedding
357
+ # 4. Scatter embeddings into input embedding
300
358
  for embedding, mask in zip(embeddings, masks):
301
359
  mask = mask.expand_as(inputs_embeds).to(inputs_embeds.device)
302
360
  inputs_embeds = inputs_embeds.masked_scatter(
@@ -316,7 +374,7 @@ def general_mm_embed_routine(
316
374
  audio_data_embedding_func: Callable[
317
375
  [List[MultimodalDataItem]], torch.Tensor
318
376
  ] = None,
319
- placeholder_token_ids: List[int] = None,
377
+ placeholder_tokens: dict[Modality, List[int]] = None,
320
378
  **kwargs,
321
379
  ) -> torch.Tensor:
322
380
  """
@@ -328,7 +386,6 @@ def general_mm_embed_routine(
328
386
  audio_data_embedding_func : the function returning the image embedding
329
387
 
330
388
  Returns:
331
- inputs_embedding
332
389
  forwarded hidden states
333
390
 
334
391
  """
@@ -346,9 +403,9 @@ def general_mm_embed_routine(
346
403
  input_embedding=embed_tokens,
347
404
  image_data_embedding_func=image_data_embedding_func,
348
405
  audio_data_embedding_func=audio_data_embedding_func,
349
- placeholder_token_ids=placeholder_token_ids,
406
+ placeholder_tokens=placeholder_tokens,
350
407
  )
351
- # once used, mm_inputs is useless
408
+ # once used, mm_inputs is useless, considering chunked-prefill is disabled for multimodal models
352
409
  # just being defensive here
353
410
  forward_batch.mm_inputs = None
354
411
  else:
@@ -8,6 +8,7 @@ from typing import List, Optional
8
8
 
9
9
  import numpy as np
10
10
  import PIL
11
+ from PIL import Image
11
12
  from transformers import BaseImageProcessorFast
12
13
 
13
14
  from sglang.srt.managers.schedule_batch import Modality
@@ -92,7 +93,12 @@ class BaseMultimodalProcessor(ABC):
92
93
 
93
94
  @abstractmethod
94
95
  async def process_mm_data_async(
95
- self, image_data, input_text, max_req_input_len, **kwargs
96
+ self,
97
+ image_data,
98
+ input_text,
99
+ request_obj,
100
+ max_req_input_len,
101
+ **kwargs,
96
102
  ):
97
103
  pass
98
104
 
@@ -104,6 +110,8 @@ class BaseMultimodalProcessor(ABC):
104
110
  from decord import VideoReader, cpu
105
111
 
106
112
  # Before processing inputs
113
+ if not image_data or len(image_data) == 0:
114
+ return []
107
115
  estimated_frames_list = []
108
116
  for image in image_data:
109
117
  if isinstance(image, str) and image.startswith("video:"):
@@ -215,6 +223,9 @@ class BaseMultimodalProcessor(ABC):
215
223
  discard_alpha_channel: if True, discards the alpha channel in the returned images
216
224
 
217
225
  """
226
+
227
+ if image_data is None:
228
+ image_data = []
218
229
  if isinstance(multimodal_tokens.image_token, int):
219
230
  multimodal_tokens.image_token = (
220
231
  self._processor.tokenizer.convert_ids_to_tokens(
@@ -229,6 +240,8 @@ class BaseMultimodalProcessor(ABC):
229
240
  prompt = self._processor.tokenizer.decode(prompt)
230
241
  else:
231
242
  prompt = prompt
243
+
244
+ assert isinstance(prompt, str)
232
245
  if return_text:
233
246
  import re
234
247