sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +3 -1
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +667 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +63 -11
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/parallel_state.py +10 -3
  34. sglang/srt/entrypoints/engine.py +55 -5
  35. sglang/srt/entrypoints/http_server.py +71 -12
  36. sglang/srt/function_call_parser.py +133 -54
  37. sglang/srt/hf_transformers_utils.py +28 -3
  38. sglang/srt/layers/activation.py +4 -2
  39. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +295 -0
  41. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  42. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  43. sglang/srt/layers/attention/triton_backend.py +171 -38
  44. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  45. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  46. sglang/srt/layers/attention/utils.py +53 -0
  47. sglang/srt/layers/attention/vision.py +9 -28
  48. sglang/srt/layers/dp_attention.py +32 -21
  49. sglang/srt/layers/layernorm.py +24 -2
  50. sglang/srt/layers/linear.py +17 -5
  51. sglang/srt/layers/logits_processor.py +25 -7
  52. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  53. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  54. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  55. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  61. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  62. sglang/srt/layers/moe/topk.py +31 -18
  63. sglang/srt/layers/parameter.py +1 -1
  64. sglang/srt/layers/quantization/__init__.py +184 -126
  65. sglang/srt/layers/quantization/base_config.py +5 -0
  66. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  67. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  68. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  69. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  70. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  71. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  72. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  73. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  74. sglang/srt/layers/quantization/fp8.py +76 -34
  75. sglang/srt/layers/quantization/fp8_kernel.py +24 -8
  76. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  77. sglang/srt/layers/quantization/gptq.py +36 -9
  78. sglang/srt/layers/quantization/kv_cache.py +98 -0
  79. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  80. sglang/srt/layers/quantization/utils.py +153 -0
  81. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  82. sglang/srt/layers/rotary_embedding.py +66 -87
  83. sglang/srt/layers/sampler.py +1 -1
  84. sglang/srt/lora/layers.py +68 -0
  85. sglang/srt/lora/lora.py +2 -22
  86. sglang/srt/lora/lora_manager.py +47 -23
  87. sglang/srt/lora/mem_pool.py +110 -51
  88. sglang/srt/lora/utils.py +12 -1
  89. sglang/srt/managers/cache_controller.py +2 -5
  90. sglang/srt/managers/data_parallel_controller.py +30 -8
  91. sglang/srt/managers/expert_distribution.py +81 -0
  92. sglang/srt/managers/io_struct.py +39 -3
  93. sglang/srt/managers/mm_utils.py +373 -0
  94. sglang/srt/managers/multimodal_processor.py +68 -0
  95. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  96. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  97. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  98. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  99. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  100. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  101. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  102. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  103. sglang/srt/managers/schedule_batch.py +133 -30
  104. sglang/srt/managers/scheduler.py +273 -20
  105. sglang/srt/managers/session_controller.py +1 -1
  106. sglang/srt/managers/tokenizer_manager.py +59 -23
  107. sglang/srt/managers/tp_worker.py +1 -1
  108. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  109. sglang/srt/managers/utils.py +6 -1
  110. sglang/srt/mem_cache/hiradix_cache.py +18 -7
  111. sglang/srt/mem_cache/memory_pool.py +255 -98
  112. sglang/srt/mem_cache/paged_allocator.py +2 -2
  113. sglang/srt/mem_cache/radix_cache.py +4 -4
  114. sglang/srt/model_executor/cuda_graph_runner.py +27 -13
  115. sglang/srt/model_executor/forward_batch_info.py +68 -11
  116. sglang/srt/model_executor/model_runner.py +70 -6
  117. sglang/srt/model_loader/loader.py +160 -2
  118. sglang/srt/model_loader/weight_utils.py +45 -0
  119. sglang/srt/models/deepseek_janus_pro.py +29 -86
  120. sglang/srt/models/deepseek_nextn.py +22 -10
  121. sglang/srt/models/deepseek_v2.py +208 -77
  122. sglang/srt/models/deepseek_vl2.py +358 -0
  123. sglang/srt/models/gemma3_causal.py +684 -0
  124. sglang/srt/models/gemma3_mm.py +462 -0
  125. sglang/srt/models/llama.py +47 -7
  126. sglang/srt/models/llama_eagle.py +1 -0
  127. sglang/srt/models/llama_eagle3.py +196 -0
  128. sglang/srt/models/llava.py +3 -3
  129. sglang/srt/models/llavavid.py +3 -3
  130. sglang/srt/models/minicpmo.py +1995 -0
  131. sglang/srt/models/minicpmv.py +62 -137
  132. sglang/srt/models/mllama.py +4 -4
  133. sglang/srt/models/phi3_small.py +1 -1
  134. sglang/srt/models/qwen2.py +3 -0
  135. sglang/srt/models/qwen2_5_vl.py +68 -146
  136. sglang/srt/models/qwen2_classification.py +75 -0
  137. sglang/srt/models/qwen2_moe.py +9 -1
  138. sglang/srt/models/qwen2_vl.py +25 -63
  139. sglang/srt/openai_api/adapter.py +124 -28
  140. sglang/srt/openai_api/protocol.py +23 -2
  141. sglang/srt/sampling/sampling_batch_info.py +1 -1
  142. sglang/srt/sampling/sampling_params.py +6 -6
  143. sglang/srt/server_args.py +99 -9
  144. sglang/srt/speculative/build_eagle_tree.py +7 -347
  145. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  146. sglang/srt/speculative/eagle_utils.py +208 -252
  147. sglang/srt/speculative/eagle_worker.py +139 -53
  148. sglang/srt/speculative/spec_info.py +6 -1
  149. sglang/srt/torch_memory_saver_adapter.py +22 -0
  150. sglang/srt/utils.py +182 -21
  151. sglang/test/__init__.py +0 -0
  152. sglang/test/attention/__init__.py +0 -0
  153. sglang/test/attention/test_flashattn_backend.py +312 -0
  154. sglang/test/runners.py +2 -0
  155. sglang/test/test_activation.py +2 -1
  156. sglang/test/test_block_fp8.py +5 -4
  157. sglang/test/test_block_fp8_ep.py +2 -1
  158. sglang/test/test_dynamic_grad_mode.py +58 -0
  159. sglang/test/test_layernorm.py +3 -2
  160. sglang/test/test_utils.py +55 -4
  161. sglang/utils.py +31 -0
  162. sglang/version.py +1 -1
  163. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
  164. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +167 -123
  165. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
  166. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  167. sglang/srt/managers/image_processor.py +0 -55
  168. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  169. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  170. sglang/srt/managers/multi_modality_padding.py +0 -134
  171. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
  172. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,275 @@
1
+ import concurrent
2
+ import concurrent.futures
3
+ import dataclasses
4
+ import multiprocessing as mp
5
+ import os
6
+ from abc import ABC, abstractmethod
7
+ from typing import Optional
8
+
9
+ import numpy as np
10
+ import PIL
11
+ import transformers
12
+ from decord import VideoReader, cpu
13
+ from openai import BadRequestError
14
+ from PIL import Image
15
+
16
+ from sglang.srt.utils import load_audio, load_image, logger
17
+
18
+ global global_processor
19
+
20
+
21
+ def get_global_processor():
22
+ global global_processor
23
+ return global_processor
24
+
25
+
26
+ @dataclasses.dataclass
27
+ class BaseMultiModalProcessorOutput:
28
+ # input_text, with each frame of video/image represented with a image_token
29
+ input_text: str
30
+
31
+ mm_data_hashes: Optional[list[int]]
32
+ # images
33
+ image_sizes: Optional[list[int]]
34
+ # frames loaded from image and video, in given order
35
+ images: Optional[list[PIL.Image]] = None
36
+
37
+ # audios
38
+ audios: Optional[list[np.ndarray]] = None
39
+
40
+ def normalize(self):
41
+ for field_name in ["data_hashes", "image_sizes", "images", "audios"]:
42
+ field = getattr(self, field_name, None)
43
+ if field is not None and isinstance(field, list) and len(field) == 0:
44
+ setattr(self, field_name, None)
45
+
46
+
47
+ @dataclasses.dataclass
48
+ class MultimodalSpecialTokens:
49
+ image_token: Optional[str] = None
50
+ video_token: Optional[str] = None
51
+ audio_token: Optional[str] = None
52
+
53
+ def collect(self) -> list[str]:
54
+ return [
55
+ token
56
+ for token in [self.image_token, self.video_token, self.audio_token]
57
+ if token
58
+ ]
59
+
60
+
61
+ class BaseMultimodalProcessor(ABC):
62
+ models = []
63
+
64
+ def __init__(self, hf_config, server_args, _processor):
65
+ self.hf_config = hf_config
66
+ self._processor = _processor
67
+ self.server_args = server_args
68
+ # FIXME: not accurate, model and image specific
69
+ self.NUM_TOKEN_PER_FRAME = 330
70
+
71
+ # Initialize global processor first
72
+ init_global_processor(self, server_args)
73
+
74
+ self.executor = concurrent.futures.ProcessPoolExecutor(
75
+ initializer=init_global_processor,
76
+ mp_context=mp.get_context("fork"),
77
+ initargs=(
78
+ self,
79
+ server_args,
80
+ ),
81
+ max_workers=int(os.environ.get("SGLANG_CPU_COUNT", os.cpu_count())),
82
+ )
83
+
84
+ def _build_processor(self, server_args):
85
+ """Init the global processor for multi modal models."""
86
+ from sglang.srt.hf_transformers_utils import get_processor
87
+
88
+ return get_processor(
89
+ server_args.tokenizer_path,
90
+ tokenizer_mode=server_args.tokenizer_mode,
91
+ trust_remote_code=server_args.trust_remote_code,
92
+ )
93
+
94
+ @abstractmethod
95
+ async def process_mm_data_async(
96
+ self, image_data, input_text, max_req_input_len, **kwargs
97
+ ):
98
+ pass
99
+
100
+ def get_estimated_frames_list(self, image_data):
101
+ """
102
+ estimate the total frame count from all visual input
103
+ """
104
+ # Before processing inputs
105
+ estimated_frames_list = []
106
+ for image in image_data:
107
+ if isinstance(image, str) and image.startswith("video:"):
108
+ path = image[len("video:") :]
109
+ # Estimate frames for the video
110
+ vr = VideoReader(path, ctx=cpu(0))
111
+ num_frames = len(vr)
112
+ else:
113
+ # For images, each contributes one frame
114
+ num_frames = 1
115
+ estimated_frames_list.append(num_frames)
116
+
117
+ return estimated_frames_list
118
+
119
+ @staticmethod
120
+ def encode_video(video_path, frame_count_limit=None):
121
+ if not os.path.exists(video_path):
122
+ logger.error(f"Video {video_path} does not exist")
123
+ return []
124
+
125
+ if frame_count_limit == 0:
126
+ return []
127
+
128
+ def uniform_sample(l, n):
129
+ gap = len(l) / n
130
+ idxs = [int(i * gap + gap / 2) for i in range(n)]
131
+ return [l[i] for i in idxs]
132
+
133
+ vr = VideoReader(video_path, ctx=cpu(0))
134
+ sample_fps = round(vr.get_avg_fps() / 1) # FPS
135
+ frame_indices = [i for i in range(0, len(vr), sample_fps)]
136
+ if frame_count_limit is not None and len(frame_indices) > frame_count_limit:
137
+ frame_indices = uniform_sample(frame_indices, frame_count_limit)
138
+
139
+ frames = vr.get_batch(frame_indices).asnumpy()
140
+ frames = [Image.fromarray(v.astype("uint8")) for v in frames]
141
+ return frames
142
+
143
+ def load_mm_data(
144
+ self,
145
+ input_ids: list[int],
146
+ multimodal_tokens: MultimodalSpecialTokens,
147
+ max_req_input_len: int,
148
+ image_data: Optional[list] = None,
149
+ audio_data: Optional[list] = None,
150
+ return_text: Optional[bool] = True,
151
+ discard_alpha_channel: bool = True,
152
+ ) -> BaseMultiModalProcessorOutput:
153
+ """
154
+ Each frame of video/image will be replaced by a single image token
155
+
156
+ Args:
157
+ multimodal_tokens (list[str]): list of special token which denoting a single multimodal data
158
+ e.g. image token or audio token
159
+ discard_alpha_channel: if True, discards the alpha channel in the returned images
160
+
161
+ """
162
+ if isinstance(multimodal_tokens.image_token, int):
163
+ multimodal_tokens.image_token = (
164
+ self._processor.tokenizer.convert_ids_to_tokens(
165
+ multimodal_tokens.image_token
166
+ )
167
+ )
168
+ else:
169
+ multimodal_tokens.image_token = multimodal_tokens.image_token
170
+
171
+ if isinstance(input_ids, list) and return_text:
172
+ assert len(input_ids) and isinstance(input_ids[0], int)
173
+ input_text = self._processor.tokenizer.decode(input_ids)
174
+ else:
175
+ input_text = input_ids
176
+ if return_text:
177
+ import re
178
+
179
+ pattern = (
180
+ "("
181
+ + "|".join(re.escape(sep) for sep in multimodal_tokens.collect())
182
+ + ")"
183
+ )
184
+ # split text into list of normal text and special tokens
185
+ text_parts = re.split(pattern, input_text)
186
+
187
+ # TODO(mick): load from server_args, env, or sampling_params
188
+ MAX_NUM_FRAMES = 30
189
+ estimated_frames_list = self.get_estimated_frames_list(image_data=image_data)
190
+ total_frame_count = sum(estimated_frames_list)
191
+ # a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs.
192
+ # e.g., 0.1 suggests that 1 frame out of 10 input frames should be used
193
+ scaling_factor = min(1.0, MAX_NUM_FRAMES / max(1, total_frame_count))
194
+
195
+ assert len(image_data) == len(estimated_frames_list)
196
+
197
+ image_index, audio_index = 0, 0
198
+ hashes, image_sizes, images, audios = [], [], [], []
199
+ new_text = ""
200
+ for index, text_part in enumerate(text_parts):
201
+ try:
202
+ if text_part == multimodal_tokens.image_token:
203
+ # load as image
204
+ if len(images) >= MAX_NUM_FRAMES:
205
+ frames_to_process = 0
206
+ else:
207
+ estimated_frames = estimated_frames_list[image_index]
208
+ frames_to_process = max(
209
+ 1, int(estimated_frames * scaling_factor)
210
+ )
211
+
212
+ if frames_to_process == 0:
213
+ frames = []
214
+ else:
215
+ image_file = image_data[image_index]
216
+ if isinstance(image_file, str) and image_file.startswith(
217
+ "video:"
218
+ ):
219
+ # video
220
+ path = image_file[len("video:") :]
221
+ frames = BaseMultimodalProcessor.encode_video(
222
+ path, frame_count_limit=frames_to_process
223
+ )
224
+ else:
225
+ # image
226
+ raw_image, _size = load_image(image_file)
227
+ if discard_alpha_channel:
228
+ raw_image = raw_image.convert("RGB")
229
+ frames = [raw_image]
230
+ if len(frames) == 0:
231
+ continue
232
+
233
+ image_sizes += frames[0].size * len(frames)
234
+ hashes += [hash(image_file)] * len(frames)
235
+ images += frames
236
+ image_index += 1
237
+ if frames_to_process != 0:
238
+ new_text += multimodal_tokens.image_token * len(frames)
239
+ assert frames_to_process == len(frames)
240
+ elif text_part == multimodal_tokens.audio_token:
241
+ # load as audio
242
+ audio_file = audio_data[audio_index]
243
+ audio = load_audio(audio_file)
244
+ hashes += [hash(audio_file)]
245
+ audios += [audio]
246
+ audio_index += 1
247
+ new_text += multimodal_tokens.audio_token
248
+ else:
249
+ # TODO(mick): handle video
250
+ # normal text
251
+ new_text += text_part
252
+
253
+ except Exception as e:
254
+ logger.error(f"An exception occurred while loading images: {e}")
255
+ raise BadRequestError(
256
+ f"An exception occurred while loading images: {e}"
257
+ )
258
+
259
+ out = BaseMultiModalProcessorOutput(
260
+ mm_data_hashes=hashes,
261
+ image_sizes=image_sizes,
262
+ images=images,
263
+ audios=audios,
264
+ input_text=new_text,
265
+ )
266
+ out.normalize()
267
+ return out
268
+
269
+
270
+ def init_global_processor(sglang_processor: BaseMultimodalProcessor, server_args):
271
+ """
272
+ Init the global processor for multimodal models."""
273
+ global global_processor
274
+ transformers.logging.set_verbosity_error()
275
+ global_processor = sglang_processor._build_processor(server_args=server_args)
@@ -0,0 +1,119 @@
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+ import asyncio
20
+
21
+ import torch
22
+
23
+ from sglang.srt.managers.multimodal_processors.base_processor import (
24
+ BaseMultimodalProcessor,
25
+ MultimodalSpecialTokens,
26
+ get_global_processor,
27
+ )
28
+ from sglang.srt.models.deepseek_vl2 import DeepseekVL2ForCausalLM
29
+
30
+
31
+ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
32
+ models = [DeepseekVL2ForCausalLM]
33
+
34
+ def __init__(self, hf_config, server_args, _processor):
35
+ super().__init__(hf_config, server_args, _processor)
36
+ self.IMAGE_TOKEN = "<image>"
37
+
38
+ @staticmethod
39
+ def _process_images_task(image, input_text, max_req_input_len):
40
+ processor = get_global_processor()
41
+ res = processor.__call__(
42
+ conversations=input_text, images=image, max_req_input_len=max_req_input_len
43
+ )
44
+
45
+ image_token_id = processor.image_token_id
46
+
47
+ res["im_token_id"] = image_token_id
48
+ return res
49
+
50
+ async def _process_images(self, image_data, input_text, max_req_input_len):
51
+ if self.executor is not None:
52
+ loop = asyncio.get_event_loop()
53
+ image_inputs = await loop.run_in_executor(
54
+ self.executor,
55
+ DeepseekVL2ImageProcessor._process_images_task,
56
+ image_data,
57
+ input_text,
58
+ max_req_input_len,
59
+ )
60
+ else:
61
+ image_inputs = self._process_images_task(
62
+ image_data, input_text, max_req_input_len
63
+ )
64
+
65
+ return image_inputs
66
+
67
+ async def _process_images(self, image_data, input_text, max_req_input_len):
68
+ if self.executor is not None:
69
+ loop = asyncio.get_event_loop()
70
+ image_inputs = await loop.run_in_executor(
71
+ self.executor,
72
+ DeepseekVL2ImageProcessor._process_images_task,
73
+ image_data,
74
+ input_text,
75
+ max_req_input_len,
76
+ )
77
+ else:
78
+ image_inputs = self._process_images_task(
79
+ image_data, input_text, max_req_input_len
80
+ )
81
+ return image_inputs
82
+
83
+ async def process_mm_data_async(
84
+ self, image_data, input_ids, request_obj, max_req_input_len, *args, **kwargs
85
+ ):
86
+ if not image_data:
87
+ return None
88
+
89
+ if not isinstance(image_data, list):
90
+ image_data = [image_data]
91
+
92
+ images, image_sizes = [], []
93
+
94
+ image_token = self.IMAGE_TOKEN
95
+ base_output = self.load_mm_data(
96
+ input_ids,
97
+ image_data=image_data,
98
+ multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
99
+ max_req_input_len=max_req_input_len,
100
+ )
101
+ res = await self._process_images(
102
+ base_output.images, base_output.input_text, max_req_input_len
103
+ )
104
+ images_seq_mask = res["images_seq_mask"]
105
+ images_spatial_crop = res["images_spatial_crop"]
106
+ batched_images_spatial_crop = []
107
+ batched_images_spatial_crop.append(images_spatial_crop)
108
+ batched_images_spatial_crop = torch.stack(batched_images_spatial_crop, dim=0)
109
+
110
+ return {
111
+ "input_ids": res["input_ids"].tolist(),
112
+ "pixel_values": res["images"],
113
+ "im_token_id": res["im_token_id"],
114
+ "data_hashes": base_output.mm_data_hashes,
115
+ "image_sizes": image_sizes,
116
+ "images_emb_mask": images_seq_mask,
117
+ "image_spatial_crop": batched_images_spatial_crop,
118
+ "modalities": request_obj.modalities or ["image"],
119
+ }
@@ -0,0 +1,83 @@
1
+ from typing import List, Union
2
+
3
+ from transformers.utils import logging
4
+
5
+ from sglang.srt.managers.multimodal_processor import (
6
+ BaseMultimodalProcessor as SGLangBaseProcessor,
7
+ )
8
+ from sglang.srt.managers.multimodal_processors.base_processor import (
9
+ MultimodalSpecialTokens,
10
+ get_global_processor,
11
+ )
12
+ from sglang.srt.models.gemma3_mm import Gemma3ForConditionalGeneration
13
+
14
+ # Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma3/image_processing_gemma3_fast.py
15
+ # will be removed in the future
16
+ logger = logging.get_logger(__name__)
17
+
18
+
19
+ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
20
+ models = [Gemma3ForConditionalGeneration]
21
+
22
+ def __init__(self, hf_config, server_args, _processor):
23
+ super().__init__(hf_config, server_args, _processor)
24
+ self.IMAGE_TOKEN = "<start_of_image>"
25
+ self.IM_START_TOKEN_ID = hf_config.boi_token_index
26
+ self.IM_END_TOKEN_ID = hf_config.eoi_token_index
27
+
28
+ async def _process_single_image(self, images, input_text) -> dict:
29
+ if isinstance(images, list) and len(images) == 0:
30
+ images = None
31
+ processor = get_global_processor()
32
+ result = processor.__call__(
33
+ text=[input_text],
34
+ images=images,
35
+ padding=True,
36
+ return_tensors="pt",
37
+ # if RGBA, this needs to be set
38
+ # images_kwargs={
39
+ # "input_data_format": ChannelDimension.FIRST
40
+ # }
41
+ )
42
+
43
+ pixel_values = getattr(result, "pixel_values", None)
44
+
45
+ return {
46
+ "input_ids": result.input_ids,
47
+ "pixel_values": pixel_values,
48
+ }
49
+
50
+ async def process_mm_data_async(
51
+ self,
52
+ image_data: List[Union[str, bytes]],
53
+ input_ids,
54
+ request_obj,
55
+ max_req_input_len,
56
+ *args,
57
+ **kwargs,
58
+ ):
59
+ if not image_data:
60
+ return None
61
+ if isinstance(image_data, str):
62
+ image_data = [image_data]
63
+
64
+ image_token = self.IMAGE_TOKEN
65
+ base_output = self.load_mm_data(
66
+ input_ids=input_ids,
67
+ image_data=image_data,
68
+ multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
69
+ max_req_input_len=max_req_input_len,
70
+ discard_alpha_channel=True,
71
+ )
72
+
73
+ ret = await self._process_single_image(
74
+ input_text=base_output.input_text, images=base_output.images
75
+ )
76
+
77
+ return {
78
+ "input_ids": ret["input_ids"].flatten().tolist(),
79
+ "pixel_values": ret["pixel_values"],
80
+ "data_hashes": base_output.mm_data_hashes,
81
+ "im_start_id": self.IM_START_TOKEN_ID,
82
+ "im_end_id": self.IM_END_TOKEN_ID,
83
+ }
@@ -1,16 +1,17 @@
1
1
  import asyncio
2
2
  from typing import List, Union
3
3
 
4
- from sglang.srt.managers.image_processors.base_image_processor import (
5
- BaseImageProcessor as SGLangBaseImageProcessor,
6
- )
7
- from sglang.srt.managers.image_processors.base_image_processor import (
4
+ from sglang.srt.managers.multimodal_processors.base_processor import (
5
+ BaseMultimodalProcessor,
6
+ MultimodalSpecialTokens,
8
7
  get_global_processor,
9
8
  )
10
9
  from sglang.srt.models.deepseek_janus_pro import MultiModalityCausalLM
11
10
 
12
11
 
13
- class JanusProProcessor(SGLangBaseImageProcessor):
12
+ class JanusProImageProcessor(BaseMultimodalProcessor):
13
+ models = [MultiModalityCausalLM]
14
+
14
15
  def __init__(self, hf_config, server_args, _processor):
15
16
  super().__init__(hf_config, server_args, _processor)
16
17
 
@@ -34,7 +35,7 @@ class JanusProProcessor(SGLangBaseImageProcessor):
34
35
  loop = asyncio.get_event_loop()
35
36
  image_inputs = await loop.run_in_executor(
36
37
  self.executor,
37
- JanusProProcessor._process_images_task,
38
+ JanusProImageProcessor._process_images_task,
38
39
  images,
39
40
  input_text,
40
41
  )
@@ -45,7 +46,7 @@ class JanusProProcessor(SGLangBaseImageProcessor):
45
46
 
46
47
  return image_inputs
47
48
 
48
- async def process_images_async(
49
+ async def process_mm_data_async(
49
50
  self,
50
51
  image_data: List[Union[str, bytes]],
51
52
  input_ids,
@@ -59,21 +60,25 @@ class JanusProProcessor(SGLangBaseImageProcessor):
59
60
  if not isinstance(image_data, list):
60
61
  image_data = [image_data]
61
62
 
62
- base_out = self.load_images(
63
- input_ids, image_data, "<image_placeholder>", max_req_input_len
63
+ base_out = self.load_mm_data(
64
+ input_ids=input_ids,
65
+ image_data=image_data,
66
+ multimodal_tokens=MultimodalSpecialTokens(
67
+ image_token="<image_placeholder>"
68
+ ),
69
+ max_req_input_len=max_req_input_len,
64
70
  )
65
- images = base_out.all_frames
71
+ images = base_out.images
66
72
  res = await self._process_images(images=images, input_text=base_out.input_text)
67
-
73
+ # print(res)
74
+ # print(base_out)
75
+ # print("", res["images_emb_mask"].shape)
68
76
  return {
69
77
  "input_ids": res["input_ids"].flatten().tolist(),
70
78
  "pixel_values": res["pixel_values"],
71
79
  "images_emb_mask": res["images_emb_mask"],
72
- "image_hashes": base_out.image_hashes,
80
+ "data_hashes": base_out.mm_data_hashes,
73
81
  "im_start_id": res["im_start_id"],
74
82
  "im_end_id": res["im_end_id"],
75
83
  "im_token_id": res["im_token_id"],
76
84
  }
77
-
78
-
79
- ImageProcessorMapping = {MultiModalityCausalLM: JanusProProcessor}
@@ -3,8 +3,8 @@ from typing import List, Optional, Union
3
3
 
4
4
  import numpy as np
5
5
 
6
- from sglang.srt.managers.image_processor import BaseImageProcessor
7
- from sglang.srt.managers.image_processors.base_image_processor import (
6
+ from sglang.srt.managers.multimodal_processors.base_processor import (
7
+ BaseMultimodalProcessor,
8
8
  get_global_processor,
9
9
  )
10
10
  from sglang.srt.mm_utils import expand2square, process_anyres_image
@@ -14,7 +14,9 @@ from sglang.srt.utils import load_image, logger
14
14
  from sglang.utils import get_exception_traceback
15
15
 
16
16
 
17
- class LlavaImageProcessor(BaseImageProcessor):
17
+ class LlavaImageProcessor(BaseMultimodalProcessor):
18
+ models = [LlavaVidForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM]
19
+
18
20
  def __init__(self, hf_config, server_args, _processor):
19
21
  super().__init__(hf_config, server_args, _processor)
20
22
 
@@ -84,7 +86,7 @@ class LlavaImageProcessor(BaseImageProcessor):
84
86
  image_data, aspect_ratio, grid_pinpoints
85
87
  )
86
88
 
87
- async def process_images_async(
89
+ async def process_mm_data_async(
88
90
  self,
89
91
  image_data: List[Union[str, bytes]],
90
92
  input_text,
@@ -111,7 +113,7 @@ class LlavaImageProcessor(BaseImageProcessor):
111
113
  if "multi-images" in modalities or "video" in modalities:
112
114
  # Multiple images
113
115
  aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
114
- pixel_values, image_hashes, image_sizes = [], [], []
116
+ pixel_values, data_hashes, image_sizes = [], [], []
115
117
  res = []
116
118
  for img_data in image_data:
117
119
  res.append(
@@ -122,7 +124,7 @@ class LlavaImageProcessor(BaseImageProcessor):
122
124
  res = await asyncio.gather(*res)
123
125
  for pixel_v, image_h, image_s in res:
124
126
  pixel_values.append(pixel_v)
125
- image_hashes.append(image_h)
127
+ data_hashes.append(image_h)
126
128
  image_sizes.append(image_s)
127
129
 
128
130
  if isinstance(pixel_values[0], np.ndarray):
@@ -132,21 +134,14 @@ class LlavaImageProcessor(BaseImageProcessor):
132
134
  pixel_values, image_hash, image_size = await self._process_single_image(
133
135
  image_data[0], aspect_ratio, grid_pinpoints
134
136
  )
135
- image_hashes = [image_hash]
137
+ data_hashes = [image_hash]
136
138
  image_sizes = [image_size]
137
139
  else:
138
140
  raise ValueError(f"Invalid image data: {image_data}")
139
141
 
140
142
  return {
141
143
  "pixel_values": pixel_values,
142
- "image_hashes": image_hashes,
144
+ "data_hashes": data_hashes,
143
145
  "image_sizes": image_sizes,
144
146
  "modalities": request_obj.modalities or ["image"],
145
147
  }
146
-
147
-
148
- ImageProcessorMapping = {
149
- LlavaVidForCausalLM: LlavaImageProcessor,
150
- LlavaQwenForCausalLM: LlavaImageProcessor,
151
- LlavaMistralForCausalLM: LlavaImageProcessor,
152
- }