sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +26 -4
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +676 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +49 -8
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  34. sglang/srt/distributed/parallel_state.py +42 -8
  35. sglang/srt/entrypoints/engine.py +55 -5
  36. sglang/srt/entrypoints/http_server.py +78 -13
  37. sglang/srt/entrypoints/verl_engine.py +2 -0
  38. sglang/srt/function_call_parser.py +133 -55
  39. sglang/srt/hf_transformers_utils.py +28 -3
  40. sglang/srt/layers/activation.py +4 -2
  41. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  42. sglang/srt/layers/attention/flashattention_backend.py +434 -0
  43. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  44. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  45. sglang/srt/layers/attention/triton_backend.py +171 -38
  46. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  47. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  48. sglang/srt/layers/attention/utils.py +53 -0
  49. sglang/srt/layers/attention/vision.py +9 -28
  50. sglang/srt/layers/dp_attention.py +41 -19
  51. sglang/srt/layers/layernorm.py +24 -2
  52. sglang/srt/layers/linear.py +17 -5
  53. sglang/srt/layers/logits_processor.py +25 -7
  54. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  55. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  56. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  57. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  63. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  64. sglang/srt/layers/moe/topk.py +60 -20
  65. sglang/srt/layers/parameter.py +1 -1
  66. sglang/srt/layers/quantization/__init__.py +80 -53
  67. sglang/srt/layers/quantization/awq.py +200 -0
  68. sglang/srt/layers/quantization/base_config.py +5 -0
  69. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  70. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  71. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  72. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  73. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  74. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  75. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  76. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  77. sglang/srt/layers/quantization/fp8.py +76 -34
  78. sglang/srt/layers/quantization/fp8_kernel.py +25 -8
  79. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  80. sglang/srt/layers/quantization/gptq.py +36 -19
  81. sglang/srt/layers/quantization/kv_cache.py +98 -0
  82. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  83. sglang/srt/layers/quantization/utils.py +153 -0
  84. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  85. sglang/srt/layers/rotary_embedding.py +78 -87
  86. sglang/srt/layers/sampler.py +1 -1
  87. sglang/srt/lora/backend/base_backend.py +4 -4
  88. sglang/srt/lora/backend/flashinfer_backend.py +12 -9
  89. sglang/srt/lora/backend/triton_backend.py +5 -8
  90. sglang/srt/lora/layers.py +87 -33
  91. sglang/srt/lora/lora.py +2 -22
  92. sglang/srt/lora/lora_manager.py +67 -30
  93. sglang/srt/lora/mem_pool.py +117 -52
  94. sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
  95. sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
  96. sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
  97. sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
  98. sglang/srt/lora/utils.py +18 -1
  99. sglang/srt/managers/cache_controller.py +2 -5
  100. sglang/srt/managers/data_parallel_controller.py +30 -8
  101. sglang/srt/managers/expert_distribution.py +81 -0
  102. sglang/srt/managers/io_struct.py +43 -5
  103. sglang/srt/managers/mm_utils.py +373 -0
  104. sglang/srt/managers/multimodal_processor.py +68 -0
  105. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  106. sglang/srt/managers/multimodal_processors/clip.py +63 -0
  107. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  108. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  109. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  110. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  111. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  112. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  113. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  114. sglang/srt/managers/schedule_batch.py +134 -30
  115. sglang/srt/managers/scheduler.py +290 -31
  116. sglang/srt/managers/session_controller.py +1 -1
  117. sglang/srt/managers/tokenizer_manager.py +59 -24
  118. sglang/srt/managers/tp_worker.py +4 -1
  119. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  120. sglang/srt/managers/utils.py +6 -1
  121. sglang/srt/mem_cache/hiradix_cache.py +18 -7
  122. sglang/srt/mem_cache/memory_pool.py +255 -98
  123. sglang/srt/mem_cache/paged_allocator.py +2 -2
  124. sglang/srt/mem_cache/radix_cache.py +4 -4
  125. sglang/srt/model_executor/cuda_graph_runner.py +36 -21
  126. sglang/srt/model_executor/forward_batch_info.py +68 -11
  127. sglang/srt/model_executor/model_runner.py +75 -8
  128. sglang/srt/model_loader/loader.py +171 -3
  129. sglang/srt/model_loader/weight_utils.py +51 -3
  130. sglang/srt/models/clip.py +563 -0
  131. sglang/srt/models/deepseek_janus_pro.py +31 -88
  132. sglang/srt/models/deepseek_nextn.py +22 -10
  133. sglang/srt/models/deepseek_v2.py +329 -73
  134. sglang/srt/models/deepseek_vl2.py +358 -0
  135. sglang/srt/models/gemma3_causal.py +694 -0
  136. sglang/srt/models/gemma3_mm.py +468 -0
  137. sglang/srt/models/llama.py +47 -7
  138. sglang/srt/models/llama_eagle.py +1 -0
  139. sglang/srt/models/llama_eagle3.py +196 -0
  140. sglang/srt/models/llava.py +3 -3
  141. sglang/srt/models/llavavid.py +3 -3
  142. sglang/srt/models/minicpmo.py +1995 -0
  143. sglang/srt/models/minicpmv.py +62 -137
  144. sglang/srt/models/mllama.py +4 -4
  145. sglang/srt/models/phi3_small.py +1 -1
  146. sglang/srt/models/qwen2.py +3 -0
  147. sglang/srt/models/qwen2_5_vl.py +68 -146
  148. sglang/srt/models/qwen2_classification.py +75 -0
  149. sglang/srt/models/qwen2_moe.py +9 -1
  150. sglang/srt/models/qwen2_vl.py +25 -63
  151. sglang/srt/openai_api/adapter.py +201 -104
  152. sglang/srt/openai_api/protocol.py +33 -7
  153. sglang/srt/patch_torch.py +71 -0
  154. sglang/srt/sampling/sampling_batch_info.py +1 -1
  155. sglang/srt/sampling/sampling_params.py +6 -6
  156. sglang/srt/server_args.py +114 -14
  157. sglang/srt/speculative/build_eagle_tree.py +7 -347
  158. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  159. sglang/srt/speculative/eagle_utils.py +208 -252
  160. sglang/srt/speculative/eagle_worker.py +140 -54
  161. sglang/srt/speculative/spec_info.py +6 -1
  162. sglang/srt/torch_memory_saver_adapter.py +22 -0
  163. sglang/srt/utils.py +215 -21
  164. sglang/test/__init__.py +0 -0
  165. sglang/test/attention/__init__.py +0 -0
  166. sglang/test/attention/test_flashattn_backend.py +312 -0
  167. sglang/test/runners.py +29 -2
  168. sglang/test/test_activation.py +2 -1
  169. sglang/test/test_block_fp8.py +5 -4
  170. sglang/test/test_block_fp8_ep.py +2 -1
  171. sglang/test/test_dynamic_grad_mode.py +58 -0
  172. sglang/test/test_layernorm.py +3 -2
  173. sglang/test/test_utils.py +56 -5
  174. sglang/utils.py +31 -0
  175. sglang/version.py +1 -1
  176. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +16 -8
  177. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +180 -132
  178. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +1 -1
  179. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  180. sglang/srt/managers/image_processor.py +0 -55
  181. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  182. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  183. sglang/srt/managers/multi_modality_padding.py +0 -134
  184. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info/licenses}/LICENSE +0 -0
  185. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
@@ -1,55 +0,0 @@
1
- # TODO: also move pad_input_ids into this module
2
- import importlib
3
- import logging
4
- import pkgutil
5
- from functools import lru_cache
6
-
7
- from transformers import IMAGE_PROCESSOR_MAPPING
8
-
9
- from sglang.srt.managers.image_processors.base_image_processor import (
10
- BaseImageProcessor,
11
- DummyImageProcessor,
12
- )
13
- from sglang.srt.server_args import ServerArgs
14
-
15
- logger = logging.getLogger(__name__)
16
-
17
-
18
- IMAGE_PROCESSOR_MAPPING = {}
19
-
20
-
21
- def get_image_processor(
22
- hf_config, server_args: ServerArgs, processor
23
- ) -> BaseImageProcessor:
24
- for model_cls, processor_cls in IMAGE_PROCESSOR_MAPPING.items():
25
- if model_cls.__name__ in hf_config.architectures:
26
- return processor_cls(hf_config, server_args, processor)
27
- raise ValueError(
28
- f"No image processor found for architecture: {hf_config.architectures}"
29
- )
30
-
31
-
32
- def get_dummy_image_processor():
33
- return DummyImageProcessor()
34
-
35
-
36
- @lru_cache()
37
- def import_image_processors():
38
- package_name = "sglang.srt.managers.image_processors"
39
- package = importlib.import_module(package_name)
40
- for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
41
- if not ispkg:
42
- try:
43
- module = importlib.import_module(name)
44
- except Exception as e:
45
- logger.warning(f"Ignore import error when loading {name}: " f"{e}")
46
- continue
47
- if hasattr(module, "ImageProcessorMapping"):
48
- entry = module.ImageProcessorMapping
49
- if isinstance(entry, dict):
50
- for processor_name, cls in entry.items():
51
- IMAGE_PROCESSOR_MAPPING[processor_name] = cls
52
-
53
-
54
- # also register processors
55
- import_image_processors()
@@ -1,219 +0,0 @@
1
- import concurrent
2
- import concurrent.futures
3
- import dataclasses
4
- import multiprocessing as mp
5
- import os
6
- from abc import ABC, abstractmethod
7
- from typing import Optional
8
-
9
- import PIL
10
- import transformers
11
- from decord import VideoReader, cpu
12
- from PIL import Image
13
-
14
- from sglang.srt.server_args import ServerArgs
15
- from sglang.srt.utils import load_image
16
- from sglang.utils import logger
17
-
18
- global global_processor
19
-
20
-
21
- def get_global_processor():
22
- global global_processor
23
- return global_processor
24
-
25
-
26
- def init_global_processor(sglang_image_processor, server_args: ServerArgs):
27
- """Init the global processor for multi-modal models."""
28
- global global_processor
29
- transformers.logging.set_verbosity_error()
30
- global_processor = sglang_image_processor._build_processor(server_args=server_args)
31
-
32
-
33
- @dataclasses.dataclass
34
- class BaseImageProcessorOutput:
35
- image_hashes: list[int]
36
- image_sizes: list[tuple[int, int]]
37
- all_frames: [PIL.Image]
38
- # input_text, with each frame of video/image represented as an image_token
39
- input_text: str
40
-
41
-
42
- class BaseImageProcessor(ABC):
43
- def __init__(self, hf_config, server_args, _processor):
44
- self.hf_config = hf_config
45
- self._processor = _processor
46
- self.server_args = server_args
47
- # FIXME: not accurate, model and image specific
48
- self.NUM_TOKEN_PER_FRAME = 330
49
-
50
- self.executor = concurrent.futures.ProcessPoolExecutor(
51
- initializer=init_global_processor,
52
- mp_context=mp.get_context("fork"),
53
- initargs=(
54
- self,
55
- server_args,
56
- ),
57
- max_workers=int(os.environ.get("SGLANG_CPU_COUNT", os.cpu_count())),
58
- )
59
-
60
- def _build_processor(self, server_args):
61
- """Init the global processor for multi modal models."""
62
- from sglang.srt.hf_transformers_utils import get_processor
63
-
64
- return get_processor(
65
- server_args.tokenizer_path,
66
- tokenizer_mode=server_args.tokenizer_mode,
67
- trust_remote_code=server_args.trust_remote_code,
68
- )
69
-
70
- @abstractmethod
71
- async def process_images_async(
72
- self, image_data, input_text, max_req_input_len, **kwargs
73
- ):
74
- pass
75
-
76
- def get_estimated_frames_list(self, image_data):
77
- """
78
- estimate the total frame count from all visual input
79
- """
80
- # Before processing inputs
81
- estimated_frames_list = []
82
- for image in image_data:
83
- if isinstance(image, str) and image.startswith("video:"):
84
- path = image[len("video:") :]
85
- # Estimate frames for the video
86
- vr = VideoReader(path, ctx=cpu(0))
87
- num_frames = len(vr)
88
- else:
89
- # For images, each contributes one frame
90
- num_frames = 1
91
- estimated_frames_list.append(num_frames)
92
-
93
- return estimated_frames_list
94
-
95
- @staticmethod
96
- def encode_video(video_path, frame_count_limit=None):
97
- if not os.path.exists(video_path):
98
- logger.error(f"Video {video_path} does not exist")
99
- return []
100
-
101
- if frame_count_limit == 0:
102
- return []
103
-
104
- def uniform_sample(l, n):
105
- gap = len(l) / n
106
- idxs = [int(i * gap + gap / 2) for i in range(n)]
107
- return [l[i] for i in idxs]
108
-
109
- vr = VideoReader(video_path, ctx=cpu(0))
110
- sample_fps = round(vr.get_avg_fps() / 1) # FPS
111
- frame_indices = [i for i in range(0, len(vr), sample_fps)]
112
- if frame_count_limit is not None and len(frame_indices) > frame_count_limit:
113
- frame_indices = uniform_sample(frame_indices, frame_count_limit)
114
-
115
- frames = vr.get_batch(frame_indices).asnumpy()
116
- frames = [Image.fromarray(v.astype("uint8")) for v in frames]
117
- return frames
118
-
119
- def load_images(
120
- self,
121
- input_ids: list,
122
- image_data,
123
- image_token: str,
124
- max_req_input_len: int,
125
- return_text: Optional[bool] = True,
126
- discard_alpha_channel: bool = True,
127
- ) -> BaseImageProcessorOutput:
128
- """
129
- Each frame of video/image will be replaced by a single image token
130
-
131
- Args:
132
-
133
- discard_alpha_channel: if True, discards the alpha channel in the returned images
134
-
135
- """
136
- image_hashes, image_sizes = [], []
137
- all_frames = []
138
- new_text_parts = []
139
-
140
- if isinstance(input_ids, list) and return_text:
141
- assert len(input_ids) and isinstance(input_ids[0], int)
142
- input_text = self._processor.tokenizer.decode(input_ids)
143
- else:
144
- input_text = input_ids
145
-
146
- if return_text:
147
- text_parts = input_text.split(image_token)
148
-
149
- # TODO(mick): load from server_args, env, or sampling_params
150
- MAX_NUM_FRAMES = 30
151
- estimated_frames_list = self.get_estimated_frames_list(image_data=image_data)
152
- total_frame_count = sum(estimated_frames_list)
153
- # a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs.
154
- # e.g., 0.1 suggests that 1 frame out of 10 input frames should be used
155
- scaling_factor = min(1.0, MAX_NUM_FRAMES / total_frame_count)
156
-
157
- assert len(image_data) == len(estimated_frames_list)
158
-
159
- # Process each input with allocated frames
160
- for image_index, (image, estimated_frames) in enumerate(
161
- zip(image_data, estimated_frames_list)
162
- ):
163
- if len(all_frames) >= MAX_NUM_FRAMES:
164
- max_frames_to_process = 0
165
- else:
166
- max_frames_to_process = max(1, int(estimated_frames * scaling_factor))
167
-
168
- if max_frames_to_process == 0:
169
- frames = []
170
- else:
171
- try:
172
- if isinstance(image, str) and image.startswith("video:"):
173
- path = image[len("video:") :]
174
- frames = BaseImageProcessor.encode_video(
175
- path, frame_count_limit=max_frames_to_process
176
- )
177
- else:
178
- raw_image, _size = load_image(image)
179
- if discard_alpha_channel:
180
- raw_image = raw_image.convert("RGB")
181
- frames = [raw_image]
182
- assert len(frames) != 0
183
- except FileNotFoundError as e:
184
- print(e)
185
- return None
186
-
187
- image_sizes += [frames[0].size] * len(frames)
188
- image_hashes += [hash(image)] * len(frames)
189
- all_frames += frames
190
-
191
- if return_text:
192
- new_text_parts.append(text_parts[image_index])
193
- if max_frames_to_process != 0:
194
- new_text_parts.append(image_token * len(frames))
195
- assert max_frames_to_process >= len(frames)
196
- if return_text:
197
- new_text_parts.append(text_parts[-1])
198
-
199
- input_text = "".join(new_text_parts)
200
- return BaseImageProcessorOutput(
201
- image_hashes, image_sizes, all_frames, input_text
202
- )
203
-
204
-
205
- class DummyImageProcessor(BaseImageProcessor):
206
- def __init__(self):
207
- pass
208
-
209
- async def process_images_async(self, *args, **kwargs):
210
- return None
211
-
212
-
213
- def init_global_processor(
214
- sglang_image_processor: BaseImageProcessor, server_args: ServerArgs
215
- ):
216
- """Init the global processor for multi-modal models."""
217
- global global_processor
218
- transformers.logging.set_verbosity_error()
219
- global_processor = sglang_image_processor._build_processor(server_args=server_args)
@@ -1,86 +0,0 @@
1
- import asyncio
2
- from typing import List, Union
3
-
4
- from sglang.srt.managers.image_processor import BaseImageProcessor
5
- from sglang.srt.managers.image_processors.base_image_processor import (
6
- get_global_processor,
7
- )
8
- from sglang.srt.models.minicpmv import MiniCPMV
9
-
10
-
11
- class MiniCPMVImageProcessor(BaseImageProcessor):
12
- def __init__(self, hf_config, server_args, _processor):
13
- super().__init__(hf_config, server_args, _processor)
14
- self.IMAGE_TOKEN = "(<image>./</image>)"
15
-
16
- @staticmethod
17
- def _process_images_task(images, input_text):
18
- processor = get_global_processor()
19
- result = processor.__call__(text=input_text, images=images, return_tensors="pt")
20
- return {
21
- "input_ids": result.input_ids,
22
- "pixel_values": result.pixel_values,
23
- "tgt_sizes": result.tgt_sizes,
24
- }
25
-
26
- async def _process_images(self, images, input_text):
27
- if self.executor is not None:
28
- loop = asyncio.get_event_loop()
29
- image_inputs = await loop.run_in_executor(
30
- self.executor,
31
- MiniCPMVImageProcessor._process_images_task,
32
- images,
33
- input_text,
34
- )
35
- else:
36
- image_inputs = self._processor(
37
- images=images, text=input_text, return_tensors="pt"
38
- )
39
-
40
- return image_inputs
41
-
42
- async def process_images_async(
43
- self,
44
- image_data: List[Union[str, bytes]],
45
- input_ids,
46
- request_obj,
47
- max_req_input_len,
48
- ):
49
- if not image_data:
50
- return None
51
- if not isinstance(image_data, list):
52
- image_data = [image_data]
53
-
54
- base_output = self.load_images(
55
- input_ids, image_data, self.IMAGE_TOKEN, max_req_input_len
56
- )
57
- if base_output is None:
58
- return None
59
-
60
- if len(base_output.all_frames) == 0:
61
- return None
62
- res = await self._process_images(
63
- images=base_output.all_frames, input_text=base_output.input_text
64
- )
65
-
66
- # Collect special token ids
67
- tokenizer = self._processor.tokenizer
68
- im_start_id = tokenizer.im_start_id
69
- im_end_id = tokenizer.im_end_id
70
- if tokenizer.slice_start_id:
71
- slice_start_id = tokenizer.slice_start_id
72
- slice_end_id = tokenizer.slice_end_id
73
- return {
74
- "input_ids": res["input_ids"].flatten().tolist(),
75
- "pixel_values": res["pixel_values"],
76
- "tgt_sizes": res["tgt_sizes"],
77
- "image_hashes": base_output.image_hashes,
78
- "modalities": request_obj.modalities or ["image"],
79
- "im_start_id": im_start_id,
80
- "im_end_id": im_end_id,
81
- "slice_start_id": slice_start_id,
82
- "slice_end_id": slice_end_id,
83
- }
84
-
85
-
86
- ImageProcessorMapping = {MiniCPMV: MiniCPMVImageProcessor}
@@ -1,134 +0,0 @@
1
- from abc import abstractmethod
2
- from typing import Callable, List, Optional, Tuple
3
-
4
- from sglang.srt.managers.schedule_batch import ImageInputs
5
- from sglang.utils import logger
6
-
7
-
8
- class MultiModalityDataPaddingPattern:
9
- """
10
- Data tokens (like image tokens) often need special handling during padding
11
- to maintain model compatibility. This class provides the interface for
12
- implementing different padding strategies for data tokens
13
- """
14
-
15
- @abstractmethod
16
- def pad_input_tokens(
17
- self, input_ids: List[int], image_inputs: ImageInputs
18
- ) -> List[int]:
19
- """
20
- Pad the input ids sequence containing data tokens, and replace them with pad_values
21
- """
22
- pass
23
-
24
-
25
- class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern):
26
- """In this pattern, data tokens should be enclosed by special token pairs (e.g. <image>...</image>, data_token_pairs)
27
-
28
- This strategy should be applied when data content is marked by start/end token pairs in the input sequence.
29
- """
30
-
31
- def __init__(self, data_token_pairs: Optional[List[Tuple[int, int]]]) -> None:
32
- self.data_token_id_pairs = data_token_pairs
33
-
34
- def pad_input_tokens(
35
- self, input_ids: List[int], image_inputs: ImageInputs
36
- ) -> List[int]:
37
- """
38
- This function will replace the data-tokens inbetween with pad_values accordingly
39
- """
40
- pad_values = image_inputs.pad_values
41
- data_token_pairs = self.data_token_id_pairs
42
- image_inputs.image_offsets = []
43
- if data_token_pairs is None:
44
- data_token_pairs = [image_inputs.im_start_id, image_inputs.im_end_id]
45
- if data_token_pairs is None:
46
- logger.warning(
47
- "No data_token_pairs provided, RadixAttention might be influenced."
48
- )
49
- return input_ids
50
- start_token_ids = [s for s, _e in data_token_pairs]
51
- end_tokens_ids = [e for _s, e in data_token_pairs]
52
- # First start token marks new data
53
- data_start_token = start_token_ids[0]
54
-
55
- padded_ids = []
56
- last_idx = 0
57
- data_idx = -1
58
-
59
- start_indices = [i for i, x in enumerate(input_ids) if x in start_token_ids]
60
- end_indices = [i for i, x in enumerate(input_ids) if x in end_tokens_ids]
61
-
62
- if len(start_indices) != len(end_indices):
63
- return input_ids
64
-
65
- for start_idx, end_idx in zip(start_indices, end_indices):
66
- padded_ids.extend(input_ids[last_idx : start_idx + 1])
67
-
68
- if input_ids[start_idx] == data_start_token:
69
- data_idx += 1
70
- image_inputs.image_offsets += [start_idx]
71
-
72
- num_tokens = end_idx - start_idx - 1
73
- pad_value = pad_values[data_idx]
74
- padded_ids.extend([pad_value] * num_tokens)
75
-
76
- last_idx = end_idx
77
-
78
- padded_ids.extend(input_ids[last_idx:])
79
-
80
- assert len(input_ids) == len(padded_ids)
81
- return padded_ids
82
-
83
-
84
- class MultModalityDataPaddingPatternSingleToken(MultiModalityDataPaddingPattern):
85
- """In this pattern, data is represented with a special token_id ( image_inputs.im_token_id ),
86
- which needs first to be expanded to multiple tokens, then replaced with their padding values
87
-
88
- This strategy should be used when a single data token represents content that should
89
- be expanded to multiple tokens during processing.
90
- """
91
-
92
- def __init__(
93
- self, num_data_token_calc_func: Callable[[Tuple[int, int, int]], int]
94
- ) -> None:
95
- self.num_data_token_calc_func = num_data_token_calc_func
96
-
97
- def pad_input_tokens(
98
- self, input_ids: List[int], image_inputs: ImageInputs
99
- ) -> List[int]:
100
- """
101
- This function will follow the procedure of:
102
- 1. the data token will be expanded, of which the final number will be calculated by `num_data_token_calc_func`
103
- 2. the padded data tokens will be replaced with their pad_values
104
- """
105
- image_grid_thws = image_inputs.image_grid_thws
106
- pad_values = image_inputs.pad_values
107
-
108
- image_indices = [
109
- idx
110
- for idx, token in enumerate(input_ids)
111
- if token == image_inputs.im_token_id
112
- ]
113
-
114
- image_inputs.image_offsets = []
115
-
116
- input_ids_with_image = []
117
- for image_cnt, _ in enumerate(image_grid_thws):
118
- print(f"image_cnt {image_cnt}")
119
- num_image_tokens = self.num_data_token_calc_func(image_grid_thws[image_cnt])
120
- if image_cnt == 0:
121
- non_image_tokens = input_ids[: image_indices[image_cnt]]
122
- else:
123
- non_image_tokens = input_ids[
124
- image_indices[image_cnt - 1] + 1 : image_indices[image_cnt]
125
- ]
126
- input_ids_with_image.extend(non_image_tokens)
127
- image_inputs.image_offsets.append(len(input_ids_with_image))
128
- pad_ids = pad_values * (
129
- (num_image_tokens + len(pad_values)) // len(pad_values)
130
- )
131
- input_ids_with_image.extend(pad_ids[:num_image_tokens])
132
- input_ids_with_image.extend(input_ids[image_indices[-1] + 1 :])
133
-
134
- return input_ids_with_image