sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/_custom_ops.py +29 -1
  4. sglang/srt/configs/deepseekvl2.py +11 -2
  5. sglang/srt/configs/internvl.py +3 -0
  6. sglang/srt/configs/janus_pro.py +3 -0
  7. sglang/srt/configs/model_config.py +10 -8
  8. sglang/srt/configs/update_config.py +3 -1
  9. sglang/srt/conversation.py +2 -1
  10. sglang/srt/custom_op.py +5 -2
  11. sglang/srt/disaggregation/common/conn.py +34 -6
  12. sglang/srt/disaggregation/decode.py +9 -1
  13. sglang/srt/disaggregation/mini_lb.py +3 -2
  14. sglang/srt/disaggregation/mooncake/conn.py +93 -76
  15. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  16. sglang/srt/disaggregation/nixl/conn.py +17 -13
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  18. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  19. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  20. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  21. sglang/srt/distributed/parallel_state.py +103 -15
  22. sglang/srt/entrypoints/engine.py +31 -33
  23. sglang/srt/entrypoints/http_server.py +20 -32
  24. sglang/srt/entrypoints/openai/protocol.py +3 -3
  25. sglang/srt/entrypoints/openai/serving_chat.py +48 -6
  26. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  27. sglang/srt/function_call/base_format_detector.py +74 -12
  28. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  29. sglang/srt/function_call/ebnf_composer.py +95 -63
  30. sglang/srt/function_call/function_call_parser.py +4 -2
  31. sglang/srt/function_call/kimik2_detector.py +41 -16
  32. sglang/srt/function_call/llama32_detector.py +6 -3
  33. sglang/srt/function_call/mistral_detector.py +11 -3
  34. sglang/srt/function_call/pythonic_detector.py +16 -14
  35. sglang/srt/function_call/qwen25_detector.py +12 -3
  36. sglang/srt/function_call/qwen3_coder_detector.py +151 -0
  37. sglang/srt/hf_transformers_utils.py +0 -1
  38. sglang/srt/layers/activation.py +24 -3
  39. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  41. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  42. sglang/srt/layers/communicator.py +12 -12
  43. sglang/srt/layers/dp_attention.py +72 -24
  44. sglang/srt/layers/linear.py +13 -102
  45. sglang/srt/layers/logits_processor.py +34 -24
  46. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  47. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  49. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
  57. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  58. sglang/srt/layers/moe/topk.py +190 -23
  59. sglang/srt/layers/quantization/__init__.py +20 -134
  60. sglang/srt/layers/quantization/awq.py +578 -11
  61. sglang/srt/layers/quantization/awq_triton.py +339 -0
  62. sglang/srt/layers/quantization/base_config.py +85 -10
  63. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  64. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  65. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
  66. sglang/srt/layers/quantization/fp8.py +273 -62
  67. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  68. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  69. sglang/srt/layers/quantization/gptq.py +501 -143
  70. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  71. sglang/srt/layers/quantization/modelopt_quant.py +34 -112
  72. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  73. sglang/srt/layers/quantization/petit.py +252 -0
  74. sglang/srt/layers/quantization/petit_utils.py +104 -0
  75. sglang/srt/layers/quantization/qoq.py +7 -6
  76. sglang/srt/layers/quantization/scalar_type.py +352 -0
  77. sglang/srt/layers/quantization/unquant.py +422 -0
  78. sglang/srt/layers/quantization/utils.py +340 -9
  79. sglang/srt/layers/quantization/w4afp8.py +8 -4
  80. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  81. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  82. sglang/srt/layers/radix_attention.py +5 -3
  83. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  84. sglang/srt/lora/lora.py +0 -4
  85. sglang/srt/lora/lora_manager.py +162 -164
  86. sglang/srt/lora/lora_registry.py +124 -0
  87. sglang/srt/lora/mem_pool.py +83 -35
  88. sglang/srt/lora/utils.py +12 -5
  89. sglang/srt/managers/cache_controller.py +288 -0
  90. sglang/srt/managers/io_struct.py +60 -30
  91. sglang/srt/managers/mm_utils.py +7 -8
  92. sglang/srt/managers/schedule_batch.py +163 -113
  93. sglang/srt/managers/schedule_policy.py +68 -27
  94. sglang/srt/managers/scheduler.py +256 -86
  95. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  96. sglang/srt/managers/tokenizer_manager.py +38 -27
  97. sglang/srt/managers/tp_worker.py +16 -4
  98. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  99. sglang/srt/mem_cache/allocator.py +74 -23
  100. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  101. sglang/srt/mem_cache/chunk_cache.py +5 -2
  102. sglang/srt/mem_cache/hicache_storage.py +168 -0
  103. sglang/srt/mem_cache/hiradix_cache.py +194 -5
  104. sglang/srt/mem_cache/memory_pool.py +16 -1
  105. sglang/srt/mem_cache/memory_pool_host.py +44 -2
  106. sglang/srt/mem_cache/radix_cache.py +26 -0
  107. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  108. sglang/srt/metrics/collector.py +9 -0
  109. sglang/srt/model_executor/cuda_graph_runner.py +66 -31
  110. sglang/srt/model_executor/forward_batch_info.py +210 -25
  111. sglang/srt/model_executor/model_runner.py +147 -42
  112. sglang/srt/model_loader/loader.py +7 -1
  113. sglang/srt/model_loader/utils.py +4 -4
  114. sglang/srt/models/clip.py +1 -1
  115. sglang/srt/models/deepseek.py +9 -6
  116. sglang/srt/models/deepseek_janus_pro.py +1 -1
  117. sglang/srt/models/deepseek_v2.py +192 -173
  118. sglang/srt/models/deepseek_vl2.py +5 -5
  119. sglang/srt/models/gemma.py +48 -0
  120. sglang/srt/models/gemma2.py +52 -0
  121. sglang/srt/models/gemma3_causal.py +63 -0
  122. sglang/srt/models/gemma3_mm.py +1 -1
  123. sglang/srt/models/gemma3n_mm.py +2 -4
  124. sglang/srt/models/granitemoe.py +385 -0
  125. sglang/srt/models/grok.py +9 -3
  126. sglang/srt/models/hunyuan.py +63 -16
  127. sglang/srt/models/internvl.py +1 -1
  128. sglang/srt/models/kimi_vl.py +1 -1
  129. sglang/srt/models/llama.py +41 -0
  130. sglang/srt/models/llama4.py +11 -11
  131. sglang/srt/models/llava.py +2 -2
  132. sglang/srt/models/llavavid.py +1 -1
  133. sglang/srt/models/minicpm.py +0 -2
  134. sglang/srt/models/minicpmo.py +3 -7
  135. sglang/srt/models/minicpmv.py +1 -1
  136. sglang/srt/models/mistral.py +1 -1
  137. sglang/srt/models/mixtral.py +9 -2
  138. sglang/srt/models/mllama.py +3 -5
  139. sglang/srt/models/mllama4.py +13 -6
  140. sglang/srt/models/olmoe.py +8 -5
  141. sglang/srt/models/persimmon.py +330 -0
  142. sglang/srt/models/phi.py +321 -0
  143. sglang/srt/models/phi4mm.py +44 -4
  144. sglang/srt/models/phi4mm_audio.py +1260 -0
  145. sglang/srt/models/phi4mm_utils.py +1917 -0
  146. sglang/srt/models/phimoe.py +9 -3
  147. sglang/srt/models/qwen.py +37 -0
  148. sglang/srt/models/qwen2.py +41 -0
  149. sglang/srt/models/qwen2_5_vl.py +4 -4
  150. sglang/srt/models/qwen2_audio.py +1 -1
  151. sglang/srt/models/qwen2_moe.py +53 -9
  152. sglang/srt/models/qwen2_vl.py +4 -4
  153. sglang/srt/models/qwen3.py +65 -1
  154. sglang/srt/models/qwen3_moe.py +57 -24
  155. sglang/srt/models/vila.py +1 -1
  156. sglang/srt/multimodal/processors/base_processor.py +91 -97
  157. sglang/srt/multimodal/processors/clip.py +21 -19
  158. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  159. sglang/srt/multimodal/processors/gemma3.py +13 -17
  160. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  161. sglang/srt/multimodal/processors/internvl.py +9 -10
  162. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  163. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  164. sglang/srt/multimodal/processors/llava.py +4 -2
  165. sglang/srt/multimodal/processors/minicpm.py +35 -44
  166. sglang/srt/multimodal/processors/mlama.py +21 -18
  167. sglang/srt/multimodal/processors/mllama4.py +4 -5
  168. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  169. sglang/srt/multimodal/processors/pixtral.py +14 -35
  170. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  171. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  172. sglang/srt/multimodal/processors/vila.py +14 -14
  173. sglang/srt/reasoning_parser.py +46 -4
  174. sglang/srt/sampling/sampling_batch_info.py +6 -5
  175. sglang/srt/sampling/sampling_params.py +8 -1
  176. sglang/srt/server_args.py +454 -270
  177. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  178. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
  179. sglang/srt/speculative/eagle_utils.py +51 -23
  180. sglang/srt/speculative/eagle_worker.py +59 -44
  181. sglang/srt/two_batch_overlap.py +10 -5
  182. sglang/srt/utils.py +44 -69
  183. sglang/test/runners.py +14 -3
  184. sglang/test/test_activation.py +50 -1
  185. sglang/test/test_block_fp8.py +8 -3
  186. sglang/test/test_block_fp8_ep.py +1 -1
  187. sglang/test/test_custom_ops.py +12 -7
  188. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  189. sglang/test/test_fp4_moe.py +1 -3
  190. sglang/test/test_marlin_moe.py +286 -0
  191. sglang/test/test_marlin_utils.py +171 -0
  192. sglang/test/test_utils.py +35 -0
  193. sglang/version.py +1 -1
  194. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
  195. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
  196. sglang/srt/layers/quantization/quant_utils.py +0 -166
  197. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  198. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
  199. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
  200. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -5,8 +5,7 @@ import multiprocessing as mp
5
5
  import os
6
6
  import re
7
7
  from abc import ABC, abstractmethod
8
- from functools import lru_cache
9
- from typing import Any, Dict, List, Optional, Tuple, Union
8
+ from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
10
9
 
11
10
  import numpy as np
12
11
  import torch
@@ -22,7 +21,7 @@ class BaseMultiModalProcessorOutput:
22
21
  # input_text, with each frame of video/image represented with a image_token
23
22
  input_text: str
24
23
 
25
- # frames loaded from image and video, in given order
24
+ # frames loaded from image, in given order
26
25
  images: Optional[list[Union[Image.Image, dict]]] = None
27
26
 
28
27
  # videos
@@ -45,14 +44,26 @@ class BaseMultiModalProcessorOutput:
45
44
 
46
45
  @dataclasses.dataclass
47
46
  class MultimodalSpecialTokens:
48
- image_token: Optional[Union[int, str, List[str]]] = None
49
- video_token: Optional[Union[int, str, List[str]]] = None
50
- audio_token: Optional[Union[int, str, List[str]]] = None
47
+ image_token: Optional[Union[str, List[str]]] = None
48
+ video_token: Optional[Union[str, List[str]]] = None
49
+ audio_token: Optional[Union[str, List[str]]] = None
50
+
51
+ image_token_id: Optional[int] = None
52
+ video_token_id: Optional[int] = None
53
+ audio_token_id: Optional[int] = None
51
54
 
52
55
  image_token_regex: Optional[re.Pattern] = None
53
56
  video_token_regex: Optional[re.Pattern] = None
54
57
  audio_token_regex: Optional[re.Pattern] = None
55
58
 
59
+ combined_regex: Optional[re.Pattern] = None
60
+
61
+ def build(self, processor):
62
+ self.convert_to_strs(processor)
63
+ self.parse_regex()
64
+ self.get_combined_regex()
65
+ return self
66
+
56
67
  def convert_to_str(self, token: Union[str, int], processor) -> str:
57
68
  if token is None:
58
69
  return token
@@ -61,11 +72,14 @@ class MultimodalSpecialTokens:
61
72
  return processor.tokenizer.convert_ids_to_tokens([token])[0]
62
73
 
63
74
  def convert_to_strs(self, processor):
64
- self.image_token = self.convert_to_str(self.image_token, processor)
65
- self.video_token = self.convert_to_str(self.video_token, processor)
66
- self.audio_token = self.convert_to_str(self.audio_token, processor)
67
-
68
- def get_modality_of_token(self, token) -> Optional[Modality]:
75
+ if not self.image_token:
76
+ self.image_token = self.convert_to_str(self.image_token_id, processor)
77
+ if not self.video_token:
78
+ self.video_token = self.convert_to_str(self.video_token_id, processor)
79
+ if not self.audio_token:
80
+ self.audio_token = self.convert_to_str(self.audio_token_id, processor)
81
+
82
+ def get_modality_of_token(self, token: str) -> Optional[Modality]:
69
83
  """
70
84
  :return: the modality associated with the given token, if the token is a special_token or matches with the multimodal token regex
71
85
  """
@@ -87,6 +101,14 @@ class MultimodalSpecialTokens:
87
101
 
88
102
  return None
89
103
 
104
+ def get_token_id_by_modality(self, modality: Modality) -> Optional[int]:
105
+ return {
106
+ Modality.IMAGE: self.image_token_id,
107
+ Modality.MULTI_IMAGES: self.image_token_id,
108
+ Modality.VIDEO: self.video_token_id,
109
+ Modality.AUDIO: self.audio_token_id,
110
+ }.get(modality)
111
+
90
112
  def parse_regex(self):
91
113
  if self.image_token_regex is None and self.image_token is not None:
92
114
  self.image_token_regex = re.compile(re.escape(self.image_token))
@@ -95,7 +117,12 @@ class MultimodalSpecialTokens:
95
117
  if self.audio_token_regex is None and self.audio_token is not None:
96
118
  self.audio_token_regex = re.compile(re.escape(self.audio_token))
97
119
 
98
- def combine_regex(self) -> re.Pattern:
120
+ def get_combined_regex(self) -> re.Pattern:
121
+ """
122
+ Builds and returns a regex, used to split input str into tokens (with mm special tokens)
123
+ """
124
+ if self.combined_regex:
125
+ return self.combined_regex
99
126
  tokens = [
100
127
  self.image_token_regex,
101
128
  self.video_token_regex,
@@ -108,7 +135,8 @@ class MultimodalSpecialTokens:
108
135
  patterns.append(t.pattern)
109
136
  flags |= t.flags
110
137
  combined = "(" + "|".join(f"(?:{p})" for p in patterns) + ")"
111
- return re.compile(combined, flags)
138
+ self.combined_regex = re.compile(combined, flags)
139
+ return self.combined_regex
112
140
 
113
141
 
114
142
  class BaseMultimodalProcessor(ABC):
@@ -135,27 +163,33 @@ class BaseMultimodalProcessor(ABC):
135
163
  self.ATTR_NAME_TO_MODALITY = {
136
164
  # Image-related attributes
137
165
  "pixel_values": Modality.IMAGE,
138
- "pixel_values_videos": Modality.VIDEO,
139
166
  "image_sizes": Modality.IMAGE,
140
167
  "image_grid_thw": Modality.IMAGE,
168
+ "image_attention_mask": Modality.IMAGE,
141
169
  "image_emb_mask": Modality.IMAGE,
142
- "image_spatial_crop": Modality.IMAGE,
170
+ "images_spatial_crop": Modality.IMAGE,
143
171
  "tgt_size": Modality.IMAGE,
144
172
  "image_grid_hws": Modality.IMAGE,
145
- "aspect_ratio_id": Modality.IMAGE,
173
+ "aspect_ratio_ids": Modality.IMAGE,
146
174
  "aspect_ratio_mask": Modality.IMAGE,
147
- "second_per_grid_ts": Modality.IMAGE,
148
175
  # Audio-related attributes
149
176
  "audio_features": Modality.AUDIO,
150
177
  "audio_feature_lens": Modality.AUDIO,
151
178
  "input_features": Modality.AUDIO,
152
179
  "input_features_mask": Modality.AUDIO,
180
+ "audio_attention_mask": Modality.AUDIO,
153
181
  # Video-related attributes
182
+ "pixel_values_videos": Modality.VIDEO,
183
+ "second_per_grid_ts": Modality.VIDEO,
154
184
  "video_grid_thw": Modality.VIDEO,
155
185
  # Generic attributes that could apply to multiple modalities
156
- # "precomputed_features" - handled specially as it can be any modality
186
+ # "precomputed_embeddings" - handled specially as it can be any modality
157
187
  }
158
188
 
189
+ # name of the feature filed
190
+ # TODO: pass from processors
191
+ self.FEATURE_NAMES = ["pixel_values", "pixel_values_videos", "audio_features"]
192
+
159
193
  def process_mm_data(
160
194
  self, input_text, images=None, videos=None, audios=None, **kwargs
161
195
  ):
@@ -196,7 +230,6 @@ class BaseMultimodalProcessor(ABC):
196
230
  audio_data,
197
231
  input_text,
198
232
  request_obj,
199
- max_req_input_len,
200
233
  **kwargs,
201
234
  ) -> Optional[Dict[str, Any]]:
202
235
  pass
@@ -227,7 +260,11 @@ class BaseMultimodalProcessor(ABC):
227
260
 
228
261
  @staticmethod
229
262
  def _load_single_item(
230
- data, modality: Modality, frame_count_limit=None, discard_alpha_channel=True
263
+ data,
264
+ modality: Modality,
265
+ frame_count_limit=None,
266
+ audio_sample_rate: Optional[int] = None,
267
+ discard_alpha_channel=True,
231
268
  ):
232
269
  """
233
270
  Load a single multimodal data.
@@ -244,7 +281,7 @@ class BaseMultimodalProcessor(ABC):
244
281
  elif modality == Modality.VIDEO:
245
282
  return load_video(data, frame_count_limit)
246
283
  elif modality == Modality.AUDIO:
247
- return load_audio(data)
284
+ return load_audio(data, audio_sample_rate)
248
285
 
249
286
  except Exception as e:
250
287
  raise RuntimeError(f"Error while loading data {data}: {e}")
@@ -253,11 +290,12 @@ class BaseMultimodalProcessor(ABC):
253
290
  self,
254
291
  text_parts: List[str],
255
292
  multimodal_tokens: MultimodalSpecialTokens,
256
- data_iterators: dict,
293
+ data_iterators: dict[Modality, Iterator[Any]],
257
294
  discard_alpha_channel: bool = True,
258
295
  image_estimated_frames_iter: Optional[iter] = None,
259
296
  image_scaling_factor: float = 1.0,
260
297
  max_image_frames: int = 30,
298
+ audio_sample_rate: Optional[int] = None,
261
299
  ) -> Tuple[List, List]:
262
300
  """
263
301
  load multimodal data parallelly using iterators.
@@ -300,6 +338,7 @@ class BaseMultimodalProcessor(ABC):
300
338
  data,
301
339
  modality,
302
340
  frame_count_limit,
341
+ audio_sample_rate,
303
342
  discard_alpha_channel,
304
343
  )
305
344
  )
@@ -322,12 +361,12 @@ class BaseMultimodalProcessor(ABC):
322
361
  self,
323
362
  prompt: str,
324
363
  multimodal_tokens: MultimodalSpecialTokens,
325
- max_req_input_len: int,
326
364
  image_data: Optional[list] = None,
327
365
  video_data: Optional[list] = None,
328
366
  audio_data: Optional[list] = None,
329
367
  return_text: Optional[bool] = True,
330
368
  discard_alpha_channel: bool = True,
369
+ audio_sample_rate: Optional[int] = None,
331
370
  ) -> BaseMultiModalProcessorOutput:
332
371
  """
333
372
  Each frame of video/image will be replaced by a single image token
@@ -338,9 +377,8 @@ class BaseMultimodalProcessor(ABC):
338
377
  discard_alpha_channel: if True, discards the alpha channel in the returned images
339
378
 
340
379
  """
341
- multimodal_tokens.convert_to_strs(self._processor)
342
- multimodal_tokens.parse_regex()
343
- multimodal_tokens_pattern = multimodal_tokens.combine_regex()
380
+ multimodal_tokens_pattern = multimodal_tokens.get_combined_regex()
381
+
344
382
  if isinstance(prompt, list) and return_text:
345
383
  assert len(prompt) and isinstance(prompt[0], int)
346
384
  prompt = self._processor.tokenizer.decode(prompt)
@@ -367,6 +405,7 @@ class BaseMultimodalProcessor(ABC):
367
405
  multimodal_tokens=multimodal_tokens,
368
406
  data_iterators=data_iterators,
369
407
  discard_alpha_channel=discard_alpha_channel,
408
+ audio_sample_rate=audio_sample_rate,
370
409
  )
371
410
  task_info_iter = iter(task_info)
372
411
  futures_iter = iter(futures)
@@ -442,7 +481,6 @@ class BaseMultimodalProcessor(ABC):
442
481
  return result = [(2,4),(6,7)]
443
482
  """
444
483
  mask = input_ids == mm_token_id
445
-
446
484
  start_positions = (mask & ~torch.roll(mask, 1)).nonzero(as_tuple=True)[0]
447
485
  end_positions = (mask & ~torch.roll(mask, -1)).nonzero(as_tuple=True)[0]
448
486
 
@@ -457,50 +495,11 @@ class BaseMultimodalProcessor(ABC):
457
495
 
458
496
  return list(zip(indices_start.tolist(), indices_end.tolist()))
459
497
 
460
- @staticmethod
461
- def _extract_processor_features(
462
- items: List[dict], attr_name: str
463
- ) -> Optional[torch.Tensor]:
464
- """
465
- Helper function to concat extracted attributes from processor output.
466
- """
467
- values = [value for item in items if (value := item.get(attr_name)) is not None]
468
- return torch.cat(values) if values else None
469
-
470
- # When we assume that all the items have the same attributes
471
- def _extract_processor_features_from_all_attributes(
472
- self, items: List[dict]
473
- ) -> dict:
474
- values = {}
475
- # Verify all items have the same keys
476
- first_keys = set(items[0].keys())
477
- for item in items[1:]:
478
- if set(item.keys()) != first_keys:
479
- raise ValueError(
480
- f"All items must have the same attributes. "
481
- f"First item has {first_keys}, but found {set(item.keys())}"
482
- )
483
-
484
- # Process each attribute
485
- for k, v in items[0].items():
486
- if isinstance(v, list):
487
- values[k] = self._extract_processor_features(items, k)
488
- else:
489
- # Verify all items have the same value for non-list attributes
490
- for item in items[1:]:
491
- if item[k] != v:
492
- raise ValueError(
493
- f"All items must have the same value for attribute {k}. "
494
- f"First item has {v}, but found {item[k]}"
495
- )
496
- values[k] = v
497
- return values
498
-
499
498
  def collect_mm_items_from_processor_output(
500
499
  self, data_dict: dict
501
500
  ) -> List[MultimodalDataItem]:
502
501
  """Create mm_items directly from processor output."""
503
- items = {} # modality -> MultimodalDataItem
502
+ items: dict[Modality, MultimodalDataItem] = {}
504
503
 
505
504
  for attr_name, value in data_dict.items():
506
505
  if attr_name == "input_ids":
@@ -509,23 +508,24 @@ class BaseMultimodalProcessor(ABC):
509
508
  # Get modality for this attribute
510
509
  modality = self.ATTR_NAME_TO_MODALITY.get(attr_name)
511
510
 
512
- if not modality and attr_name == "precomputed_features":
511
+ if attr_name == "precomputed_embeddings":
513
512
  modality_str = data_dict.get("modality")
514
- try:
515
- modality = (
516
- Modality.from_str(modality_str)
517
- if modality_str
518
- else Modality.IMAGE
519
- )
520
- except ValueError:
521
- modality = Modality.IMAGE
513
+ modality = Modality.IMAGE
514
+ if modality_str:
515
+ try:
516
+ modality = Modality.from_str(modality_str)
517
+ except ValueError:
518
+ pass
519
+
522
520
  if modality:
523
521
  # Create item if needed
524
522
  if modality not in items:
525
523
  items[modality] = MultimodalDataItem(modality=modality)
526
524
 
527
- # Set attribute
528
- setattr(items[modality], attr_name, value)
525
+ if attr_name in self.FEATURE_NAMES:
526
+ attr_name = "feature"
527
+
528
+ items[modality].set(attr_name, value)
529
529
 
530
530
  return list(items.values())
531
531
 
@@ -548,7 +548,10 @@ class BaseMultimodalProcessor(ABC):
548
548
  return collected_items, input_ids, ret
549
549
 
550
550
  def process_and_combine_mm_data(
551
- self, base_output: BaseMultiModalProcessorOutput
551
+ self,
552
+ base_output: BaseMultiModalProcessorOutput,
553
+ mm_tokens: MultimodalSpecialTokens,
554
+ **kwargs,
552
555
  ) -> Tuple[List[MultimodalDataItem], torch.Tensor, dict]:
553
556
  """
554
557
  Process multimodal data and return the combined multimodal items and input_ids.
@@ -581,7 +584,7 @@ class BaseMultimodalProcessor(ABC):
581
584
  else:
582
585
  raise ValueError(f"Unknown multimodal item type: {type(item)}")
583
586
  # Process items and get input_ids
584
- all_collected_items = []
587
+ all_collected_items: list[MultimodalDataItem] = []
585
588
  input_ids = None
586
589
 
587
590
  # Handle dict items (already processed)
@@ -597,6 +600,7 @@ class BaseMultimodalProcessor(ABC):
597
600
  images=raw_images,
598
601
  audios=raw_audios,
599
602
  videos=raw_videos,
603
+ **kwargs,
600
604
  )
601
605
  all_collected_items.extend(collected_items)
602
606
  else:
@@ -612,22 +616,12 @@ class BaseMultimodalProcessor(ABC):
612
616
 
613
617
  # Add offsets to all items
614
618
  for mm_item in all_collected_items:
615
- if mm_item.modality in [Modality.IMAGE, Modality.MULTI_IMAGES]:
616
- mm_item.offsets = self.get_mm_items_offset(
617
- input_ids=input_ids,
618
- mm_token_id=self.IM_TOKEN_ID,
619
- )
620
- elif mm_item.modality == Modality.AUDIO:
621
- mm_item.offsets = self.get_mm_items_offset(
622
- input_ids=input_ids,
623
- mm_token_id=self.AUDIO_TOKEN_ID,
624
- )
625
- elif mm_item.modality == Modality.VIDEO:
626
- mm_item.offsets = self.get_mm_items_offset(
627
- input_ids=input_ids,
628
- mm_token_id=self.VIDEO_TOKEN_ID,
629
- )
630
- else:
631
- raise ValueError(f"Unknown modality: {mm_item.modality}")
619
+ mm_token_id = mm_tokens.get_token_id_by_modality(mm_item.modality)
620
+ if mm_token_id is None:
621
+ raise ValueError(f"No token id found for modality: {mm_item.modality}")
622
+ mm_item.offsets = self.get_mm_items_offset(
623
+ input_ids=input_ids,
624
+ mm_token_id=mm_token_id,
625
+ )
632
626
 
633
627
  return all_collected_items, input_ids, ret
@@ -1,9 +1,10 @@
1
1
  from typing import List, Union
2
2
 
3
- from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
4
3
  from sglang.srt.models.clip import CLIPModel
5
- from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor
6
- from sglang.srt.utils import load_image
4
+ from sglang.srt.multimodal.processors.base_processor import (
5
+ BaseMultimodalProcessor,
6
+ MultimodalSpecialTokens,
7
+ )
7
8
 
8
9
 
9
10
  class ClipImageProcessor(BaseMultimodalProcessor):
@@ -11,23 +12,24 @@ class ClipImageProcessor(BaseMultimodalProcessor):
11
12
 
12
13
  def __init__(self, hf_config, server_args, _processor):
13
14
  super().__init__(hf_config, server_args, _processor)
15
+ self.mm_tokens = MultimodalSpecialTokens(image_token="<image>").build(
16
+ _processor
17
+ )
14
18
 
15
19
  async def process_mm_data_async(
16
20
  self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
17
21
  ):
18
- if isinstance(input_text, list):
19
- assert len(input_text) and isinstance(input_text[0], int)
20
- input_text = self._processor.tokenizer.decode(input_text)
21
-
22
- images = [load_image(image)[0] for image in image_data]
23
-
24
- image_inputs = self.process_mm_data(input_text=input_text, images=images)
25
- image_inputs["data_hashes"] = [hash(str(image_data))]
26
- image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
27
- image_inputs["mm_items"] = [
28
- MultimodalDataItem(
29
- pixel_values=image_inputs["pixel_values"], modality=Modality.IMAGE
30
- )
31
- ]
32
-
33
- return image_inputs
22
+ base_output = self.load_mm_data(
23
+ prompt=input_text,
24
+ multimodal_tokens=self.mm_tokens,
25
+ image_data=image_data,
26
+ )
27
+
28
+ mm_items, input_ids, _ = self.process_and_combine_mm_data(
29
+ base_output, self.mm_tokens
30
+ )
31
+
32
+ return {
33
+ "input_ids": input_ids.tolist(),
34
+ "mm_items": mm_items,
35
+ }
@@ -33,7 +33,9 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
33
33
 
34
34
  def __init__(self, hf_config, server_args, _processor):
35
35
  super().__init__(hf_config, server_args, _processor)
36
- self.IMAGE_TOKEN = "<image>"
36
+ self.mm_tokens = MultimodalSpecialTokens(
37
+ image_token="<image>", image_token_id=self._processor.image_token_id
38
+ ).build(_processor)
37
39
 
38
40
  async def process_mm_data_async(
39
41
  self,
@@ -47,37 +49,17 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
47
49
  base_output = self.load_mm_data(
48
50
  input_text,
49
51
  image_data=image_data,
50
- multimodal_tokens=MultimodalSpecialTokens(image_token=self.IMAGE_TOKEN),
51
- max_req_input_len=max_req_input_len,
52
+ multimodal_tokens=self.mm_tokens,
52
53
  )
53
- res = self.process_mm_data(
54
- input_text=base_output.input_text,
55
- images=base_output.images,
54
+ mm_items, input_ids, _ = self.process_and_combine_mm_data(
55
+ base_output,
56
+ self.mm_tokens,
56
57
  max_req_input_len=max_req_input_len,
57
58
  conversations=base_output.input_text,
58
59
  )
59
- images_seq_mask = res["images_seq_mask"]
60
- images_spatial_crop = res["images_spatial_crop"]
61
- batched_images_spatial_crop = []
62
- batched_images_spatial_crop.append(images_spatial_crop)
63
- batched_images_spatial_crop = torch.stack(batched_images_spatial_crop, dim=0)
64
-
65
- items = []
66
- input_ids = res["input_ids"]
67
- image_offsets = self.get_mm_items_offset(
68
- input_ids=input_ids, mm_token_id=self._processor.image_token_id
69
- )
70
- item = MultimodalDataItem(
71
- pixel_values=res["images"],
72
- offsets=image_offsets,
73
- modality=Modality.IMAGE,
74
- image_emb_mask=images_seq_mask,
75
- image_spatial_crop=batched_images_spatial_crop,
76
- )
77
- items += [item]
78
60
 
79
61
  return {
80
- "mm_items": items,
62
+ "mm_items": mm_items,
81
63
  "input_ids": input_ids.tolist(),
82
64
  "im_token_id": self._processor.image_token_id,
83
65
  }
@@ -4,7 +4,6 @@ from typing import Dict, List, Union
4
4
  from sglang.srt.managers.multimodal_processor import (
5
5
  BaseMultimodalProcessor as SGLangBaseProcessor,
6
6
  )
7
- from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
8
7
  from sglang.srt.models.gemma3_mm import Gemma3ForConditionalGeneration
9
8
  from sglang.srt.multimodal.processors.base_processor import MultimodalSpecialTokens
10
9
 
@@ -17,39 +16,36 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
17
16
 
18
17
  def __init__(self, hf_config, server_args, _processor):
19
18
  super().__init__(hf_config, server_args, _processor)
20
- # The single, pre-expanded image token.
21
- self.IMAGE_TOKEN = "<start_of_image>"
22
- # The regex that matches expanded image tokens.
23
- self.IMAGE_TOKEN_REGEX = re.compile(
24
- r"<start_of_image>(?:(?:<image_soft_token>)*<end_of_image>)?"
25
- )
26
19
  self.IM_START_TOKEN_ID = hf_config.boi_token_index
27
20
  self.IM_END_TOKEN_ID = hf_config.eoi_token_index
28
- self.IM_TOKEN_ID = hf_config.image_token_index
21
+ self.mm_tokens = MultimodalSpecialTokens(
22
+ # The single, pre-expanded image token.
23
+ image_token="<start_of_image>",
24
+ image_token_id=hf_config.image_token_index,
25
+ # The regex that matches expanded image tokens.
26
+ image_token_regex=re.compile(
27
+ r"<start_of_image>(?:(?:<image_soft_token>)*<end_of_image>)?"
28
+ ),
29
+ ).build(_processor)
29
30
 
30
31
  async def process_mm_data_async(
31
32
  self,
32
33
  image_data: List[Union[str, bytes, Dict]],
33
34
  input_text,
34
35
  request_obj,
35
- max_req_input_len,
36
36
  *args,
37
37
  **kwargs,
38
38
  ):
39
- print(f"{image_data=}")
40
39
  base_output = self.load_mm_data(
41
40
  prompt=input_text,
42
41
  image_data=image_data,
43
- multimodal_tokens=MultimodalSpecialTokens(
44
- image_token=self.IMAGE_TOKEN, image_token_regex=self.IMAGE_TOKEN_REGEX
45
- ),
46
- max_req_input_len=max_req_input_len,
42
+ multimodal_tokens=self.mm_tokens,
47
43
  discard_alpha_channel=True,
48
44
  )
49
45
 
50
- mm_items, input_ids, _ = self.process_and_combine_mm_data(base_output)
51
- print(f"{base_output=}")
52
- print(f"{mm_items=}")
46
+ mm_items, input_ids, _ = self.process_and_combine_mm_data(
47
+ base_output, self.mm_tokens
48
+ )
53
49
  return {
54
50
  "input_ids": input_ids.tolist(),
55
51
  "mm_items": mm_items,
@@ -30,23 +30,23 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor):
30
30
  def __init__(self, hf_config, server_args, _processor):
31
31
  super().__init__(hf_config, server_args, _processor)
32
32
 
33
- self.IMAGE_TOKEN = "<image_soft_token>"
34
- self.IMAGE_TOKEN_REGEX = re.compile(
35
- r"<start_of_image>(?:(?:<image_soft_token>)*<end_of_image>)?"
36
- )
37
-
38
- self.AUDIO_TOKEN = "<audio_soft_token>"
39
- self.AUDIO_TOKEN_REGEX = re.compile(
40
- r"<start_of_audio>(?:(?:<audio_soft_token>)*<end_of_audio>)?"
41
- )
42
-
43
- self.IM_TOKEN_ID = hf_config.image_token_id
44
33
  self.IM_START_TOKEN_ID = hf_config.boi_token_id
45
34
  self.IM_END_TOKEN_ID = hf_config.eoi_token_id
46
35
 
47
- self.AUDIO_TOKEN_ID = hf_config.audio_token_id
48
36
  self.AUDIO_START_TOKEN_ID = hf_config.boa_token_id
49
37
  self.AUDIO_END_TOKEN_ID = hf_config.eoa_token_id
38
+ self.mm_tokens = MultimodalSpecialTokens(
39
+ image_token="<image_soft_token>",
40
+ image_token_id=hf_config.image_token_id,
41
+ image_token_regex=re.compile(
42
+ r"<start_of_image>(?:(?:<image_soft_token>)*<end_of_image>)?"
43
+ ),
44
+ audio_token="<audio_soft_token>",
45
+ audio_token_id=hf_config.audio_token_id,
46
+ audio_token_regex=re.compile(
47
+ r"<start_of_audio>(?:(?:<audio_soft_token>)*<end_of_audio>)?"
48
+ ),
49
+ ).build(_processor)
50
50
 
51
51
  async def process_mm_data_async(
52
52
  self,
@@ -54,7 +54,6 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor):
54
54
  audio_data: Optional[List[Union[str, bytes, Dict]]] = None,
55
55
  input_text: str = "",
56
56
  request_obj=None,
57
- max_req_input_len: int = 0,
58
57
  *args,
59
58
  **kwargs,
60
59
  ):
@@ -63,20 +62,17 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor):
63
62
  prompt=input_text,
64
63
  image_data=image_data,
65
64
  audio_data=audio_data,
66
- max_req_input_len=max_req_input_len,
67
- multimodal_tokens=MultimodalSpecialTokens(
68
- image_token=self.IMAGE_TOKEN,
69
- image_token_regex=self.IMAGE_TOKEN_REGEX,
70
- audio_token=self.AUDIO_TOKEN,
71
- audio_token_regex=self.AUDIO_TOKEN_REGEX,
72
- ),
65
+ multimodal_tokens=self.mm_tokens,
73
66
  )
74
67
 
75
- mm_items, input_ids, _ = self.process_and_combine_mm_data(base_output)
68
+ mm_items, input_ids, _ = self.process_and_combine_mm_data(
69
+ base_output, self.mm_tokens
70
+ )
76
71
 
77
72
  return {
78
73
  "input_ids": input_ids.tolist(),
79
74
  "mm_items": mm_items,
80
- "im_token_id": self.IM_TOKEN_ID,
81
- "audio_token_id": self.AUDIO_TOKEN_ID,
75
+ # TODO(mick): could we return MultimodalSpecialTokens directly?
76
+ "im_token_id": self.mm_tokens.image_token_id,
77
+ "audio_token_id": self.mm_tokens.audio_token_id,
82
78
  }